Module ilpyt.utils.net_utils

Expand source code
import logging
from typing import Callable

import gym
import torch
from gym import spaces

import ilpyt.nets.net1d as net1d
import ilpyt.nets.net2d as net2d


import logging
from typing import Callable, Tuple, Union

import gym
import torch
from gym import spaces

import ilpyt.nets.base_net as BaseNetwork
import ilpyt.nets.net1d as net1d
import ilpyt.nets.net2d as net2d


def choose_net(
    env: gym.Env,
    input_shape: tuple = None,
    output_shape: int = None,
    activation: str = 'relu',
    with_action_shape: bool = False,
) -> torch.nn.Module:
    """
    From the available networks, choose the network class best 
    suited for the Gym environment according to the environment action space 
    and observation space.

    Parameters
    ----------
    env: gym.Env
        gym environment
    input_shape: tuple, default=None
        input dimensions for network. If not specified, set to the `env` 
        observation space
    output_shape: int, default=None
        output dimensions for network. If not specified, set to the 
        `env.num_actions`
    activation: str, default='relu'
        activation layer to use in the network, choose from [relu or tanh]
    with_action_shape: bool, default=False
        whether or not to include action in the network input

    Returns
    -------
    torch.nn.Module:
        available network that best suits given env
    """

    # Get action space type
    action_space = None
    if isinstance(env.action_space, spaces.Discrete):
        action_space = 'discrete'
    elif isinstance(env.action_space, spaces.Box):
        action_space = 'continuous'
    else:
        logging.error('Action space not supported.')

    # Get observation space type
    env_space = None
    if len(env.observation_space.shape) == 1:
        env_space = '1d'
    elif len(env.observation_space.shape) == 3:
        env_space = '3d'
    else:
        logging.error('Observation space not supported.')

    # Select network
    if action_space == 'discrete' and env_space == '1d':
        net = net1d.DiscreteNetwork1D  # type: Callable
    elif action_space == 'continuous' and env_space == '1d':
        net = net1d.ContinuousNetwork1D
    elif action_space == 'discrete' and env_space == '3d':
        net = net2d.DiscreteNetwork2D
    elif action_space == 'continuous' and env_space == '3d':
        net = net2d.ContinuousNetwork2D
    else:
        logging.error('Invalid combination of action and observation spaces.')

    if input_shape is None:
        input_shape = env.observation_shape
    if output_shape is None:
        output_shape = env.num_actions
    if with_action_shape:
        with_action_shape = env.num_actions

    return net(
        input_shape=input_shape,
        output_shape=output_shape,
        activation=activation,
        with_action_shape=int(with_action_shape),
    )

Functions

def choose_net(env: gym.core.Env, input_shape: tuple = None, output_shape: int = None, activation: str = 'relu', with_action_shape: bool = False) ‑> torch.nn.modules.module.Module

From the available networks, choose the network class best suited for the Gym environment according to the environment action space and observation space.

Parameters

env : gym.Env
gym environment
input_shape : tuple, default=None
input dimensions for network. If not specified, set to the env observation space
output_shape : int, default=None
output dimensions for network. If not specified, set to the env.num_actions
activation : str, default='relu'
activation layer to use in the network, choose from [relu or tanh]
with_action_shape : bool, default=False
whether or not to include action in the network input

Returns

torch.nn.Module:
available network that best suits given env
Expand source code
def choose_net(
    env: gym.Env,
    input_shape: tuple = None,
    output_shape: int = None,
    activation: str = 'relu',
    with_action_shape: bool = False,
) -> torch.nn.Module:
    """
    From the available networks, choose the network class best 
    suited for the Gym environment according to the environment action space 
    and observation space.

    Parameters
    ----------
    env: gym.Env
        gym environment
    input_shape: tuple, default=None
        input dimensions for network. If not specified, set to the `env` 
        observation space
    output_shape: int, default=None
        output dimensions for network. If not specified, set to the 
        `env.num_actions`
    activation: str, default='relu'
        activation layer to use in the network, choose from [relu or tanh]
    with_action_shape: bool, default=False
        whether or not to include action in the network input

    Returns
    -------
    torch.nn.Module:
        available network that best suits given env
    """

    # Get action space type
    action_space = None
    if isinstance(env.action_space, spaces.Discrete):
        action_space = 'discrete'
    elif isinstance(env.action_space, spaces.Box):
        action_space = 'continuous'
    else:
        logging.error('Action space not supported.')

    # Get observation space type
    env_space = None
    if len(env.observation_space.shape) == 1:
        env_space = '1d'
    elif len(env.observation_space.shape) == 3:
        env_space = '3d'
    else:
        logging.error('Observation space not supported.')

    # Select network
    if action_space == 'discrete' and env_space == '1d':
        net = net1d.DiscreteNetwork1D  # type: Callable
    elif action_space == 'continuous' and env_space == '1d':
        net = net1d.ContinuousNetwork1D
    elif action_space == 'discrete' and env_space == '3d':
        net = net2d.DiscreteNetwork2D
    elif action_space == 'continuous' and env_space == '3d':
        net = net2d.ContinuousNetwork2D
    else:
        logging.error('Invalid combination of action and observation spaces.')

    if input_shape is None:
        input_shape = env.observation_shape
    if output_shape is None:
        output_shape = env.num_actions
    if with_action_shape:
        with_action_shape = env.num_actions

    return net(
        input_shape=input_shape,
        output_shape=output_shape,
        activation=activation,
        with_action_shape=int(with_action_shape),
    )