Module ilpyt.nets.base_net

BaseNetwork is the abstract class for a network. Networks parameterize important functions during learning - most often, the agent policy.

To create a custom network, simply extend BaseNetwork. The BaseNetwork API requires you to override the initialize, get_action, and forward methods.

  • initalize sets network class variables, such as the network layers
  • get_action draws from a torch distribution to perform an action
  • forward computes a forward pass of the network
Expand source code
"""
`BaseNetwork` is the abstract class for a network. Networks parameterize 
important functions during learning - most often, the agent policy.

To create a custom network, simply extend `BaseNetwork`. The `BaseNetwork` API 
requires you to override the `initialize`, `get_action`, and `forward` methods.

- `initalize` sets `network` class variables, such as the network layers
- `get_action` draws from a torch distribution to perform an action
- `forward` computes a forward pass of the network
"""

from abc import abstractmethod
from typing import Any, Tuple

import torch
from torch.distributions import Distribution


class BaseNetwork(torch.nn.Module):
    def __init__(self, **kwargs: Any) -> None:
        """
        Parameters
        ----------
        **kwargs:
            arbitrary keyword arguments. Will be passed to the `initialize` and 
            `setup_experiment` functions
        """
        super(BaseNetwork, self).__init__()
        self.initialize(**kwargs)

    @abstractmethod
    def initialize(self, input_shape: tuple, output_shape: int) -> None:
        """
        Perform network initialization. Build the network layers here. 
        Override this method when extending the `BaseNetwork` class.

        Parameters
        ----------
        input_shape: tuple
            shape of input to network
        output_shape: int
            shape of output of network
        """
        pass

    @abstractmethod
    def get_action(self, x: torch.Tensor) -> Tuple[Distribution, torch.Tensor]:
        """
        Some algorithms will require us to draw from a distribution to perform 
        an action. Override this method when extending the `BaseNetwork` class.

        Parameters
        ----------
        x: torch.Tensor
            input tensor to network

        Returns
        -------
        torch.distributions.Distribution:
            distribution to sample actions from
        torch.Tensor:
            action tensor, sampled from distribution
        """
        pass

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the network. Override this method when extending the 
        `BaseNetwork` class.

        Parameters
        ----------
        x: torch.Tensor
            input tensor to network

        Returns
        -------
        torch.Tensor:
            output of network
        """
        pass


def get_activation_layer(name: str) -> torch.nn.Module:
    """
    Get an activation layer with the given name.

    Parameters
    -----------
    name: str
        activation layer name, choose from 'relu' or 'tanh'

    Returns
    -------
    torch.nn.Module:
        activation layer

    Raises
    ------
    ValueError:
        if an unsupported activation layer is specified
    """
    if name == 'relu':
        return torch.nn.ReLU()
    elif name == 'tanh':
        return torch.nn.Tanh()
    else:
        raise ValueError('Unsupported activation layer chosen.')

Functions

def get_activation_layer(name: str) ‑> torch.nn.modules.module.Module

Get an activation layer with the given name.

Parameters

name : str
activation layer name, choose from 'relu' or 'tanh'

Returns

torch.nn.Module:
activation layer

Raises

Valueerror

if an unsupported activation layer is specified

Expand source code
def get_activation_layer(name: str) -> torch.nn.Module:
    """
    Get an activation layer with the given name.

    Parameters
    -----------
    name: str
        activation layer name, choose from 'relu' or 'tanh'

    Returns
    -------
    torch.nn.Module:
        activation layer

    Raises
    ------
    ValueError:
        if an unsupported activation layer is specified
    """
    if name == 'relu':
        return torch.nn.ReLU()
    elif name == 'tanh':
        return torch.nn.Tanh()
    else:
        raise ValueError('Unsupported activation layer chosen.')

Classes

class BaseNetwork (**kwargs: Any)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Parameters

**kwargs: arbitrary keyword arguments. Will be passed to the initialize and setup_experiment functions

Expand source code
class BaseNetwork(torch.nn.Module):
    def __init__(self, **kwargs: Any) -> None:
        """
        Parameters
        ----------
        **kwargs:
            arbitrary keyword arguments. Will be passed to the `initialize` and 
            `setup_experiment` functions
        """
        super(BaseNetwork, self).__init__()
        self.initialize(**kwargs)

    @abstractmethod
    def initialize(self, input_shape: tuple, output_shape: int) -> None:
        """
        Perform network initialization. Build the network layers here. 
        Override this method when extending the `BaseNetwork` class.

        Parameters
        ----------
        input_shape: tuple
            shape of input to network
        output_shape: int
            shape of output of network
        """
        pass

    @abstractmethod
    def get_action(self, x: torch.Tensor) -> Tuple[Distribution, torch.Tensor]:
        """
        Some algorithms will require us to draw from a distribution to perform 
        an action. Override this method when extending the `BaseNetwork` class.

        Parameters
        ----------
        x: torch.Tensor
            input tensor to network

        Returns
        -------
        torch.distributions.Distribution:
            distribution to sample actions from
        torch.Tensor:
            action tensor, sampled from distribution
        """
        pass

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the network. Override this method when extending the 
        `BaseNetwork` class.

        Parameters
        ----------
        x: torch.Tensor
            input tensor to network

        Returns
        -------
        torch.Tensor:
            output of network
        """
        pass

Ancestors

  • torch.nn.modules.module.Module

Subclasses

Class variables

var dump_patches : bool
var training : bool

Methods

def forward(self, x: torch.Tensor) ‑> torch.Tensor

Forward pass of the network. Override this method when extending the BaseNetwork class.

Parameters

x : torch.Tensor
input tensor to network

Returns

torch.Tensor:
output of network
Expand source code
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass of the network. Override this method when extending the 
    `BaseNetwork` class.

    Parameters
    ----------
    x: torch.Tensor
        input tensor to network

    Returns
    -------
    torch.Tensor:
        output of network
    """
    pass
def get_action(self, x: torch.Tensor) ‑> Tuple[torch.distributions.distribution.Distribution, torch.Tensor]

Some algorithms will require us to draw from a distribution to perform an action. Override this method when extending the BaseNetwork class.

Parameters

x : torch.Tensor
input tensor to network

Returns

torch.distributions.Distribution:
distribution to sample actions from
torch.Tensor:
action tensor, sampled from distribution
Expand source code
@abstractmethod
def get_action(self, x: torch.Tensor) -> Tuple[Distribution, torch.Tensor]:
    """
    Some algorithms will require us to draw from a distribution to perform 
    an action. Override this method when extending the `BaseNetwork` class.

    Parameters
    ----------
    x: torch.Tensor
        input tensor to network

    Returns
    -------
    torch.distributions.Distribution:
        distribution to sample actions from
    torch.Tensor:
        action tensor, sampled from distribution
    """
    pass
def initialize(self, input_shape: tuple, output_shape: int) ‑> NoneType

Perform network initialization. Build the network layers here. Override this method when extending the BaseNetwork class.

Parameters

input_shape : tuple
shape of input to network
output_shape : int
shape of output of network
Expand source code
@abstractmethod
def initialize(self, input_shape: tuple, output_shape: int) -> None:
    """
    Perform network initialization. Build the network layers here. 
    Override this method when extending the `BaseNetwork` class.

    Parameters
    ----------
    input_shape: tuple
        shape of input to network
    output_shape: int
        shape of output of network
    """
    pass