Skip to content

mlp

dreem.models.mlp

Multi-Layer Perceptron (MLP) module.

MLP

Bases: Module

Multi-Layer Perceptron (MLP) module.

Source code in dreem/models/mlp.py
class MLP(torch.nn.Module):
    """Multi-Layer Perceptron (MLP) module."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_layers: int,
        dropout: float = 0.0,
    ):
        """Initialize MLP.

        Args:
            input_dim: Dimensionality of the input features.
            hidden_dim: Number of units in the hidden layers.
            output_dim: Dimensionality of the output features.
            num_layers: Number of hidden layers.
            dropout: Dropout probability.
        """
        super().__init__()

        self.num_layers = num_layers
        self.dropout = dropout

        if self.num_layers > 0:
            h = [hidden_dim] * (num_layers - 1)
            self.layers = torch.nn.ModuleList(
                [
                    torch.nn.Linear(n, k)
                    for n, k in zip([input_dim] + h, h + [output_dim])
                ]
            )
            if self.dropout > 0.0:
                self.dropouts = torch.nn.ModuleList(
                    [torch.nn.Dropout(dropout) for _ in range(self.num_layers - 1)]
                )
        else:
            self.layers = []

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the MLP.

        Args:
            x: Input tensor of shape (batch_size, num_instances, input_dim).

        Returns:
            Output tensor of shape (batch_size, num_instances, output_dim).
        """
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
            if i < self.num_layers - 1 and self.dropout > 0.0:
                x = self.dropouts[i](x)

        return x

__init__(input_dim, hidden_dim, output_dim, num_layers, dropout=0.0)

Initialize MLP.

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the input features.

required
hidden_dim int

Number of units in the hidden layers.

required
output_dim int

Dimensionality of the output features.

required
num_layers int

Number of hidden layers.

required
dropout float

Dropout probability.

0.0
Source code in dreem/models/mlp.py
def __init__(
    self,
    input_dim: int,
    hidden_dim: int,
    output_dim: int,
    num_layers: int,
    dropout: float = 0.0,
):
    """Initialize MLP.

    Args:
        input_dim: Dimensionality of the input features.
        hidden_dim: Number of units in the hidden layers.
        output_dim: Dimensionality of the output features.
        num_layers: Number of hidden layers.
        dropout: Dropout probability.
    """
    super().__init__()

    self.num_layers = num_layers
    self.dropout = dropout

    if self.num_layers > 0:
        h = [hidden_dim] * (num_layers - 1)
        self.layers = torch.nn.ModuleList(
            [
                torch.nn.Linear(n, k)
                for n, k in zip([input_dim] + h, h + [output_dim])
            ]
        )
        if self.dropout > 0.0:
            self.dropouts = torch.nn.ModuleList(
                [torch.nn.Dropout(dropout) for _ in range(self.num_layers - 1)]
            )
    else:
        self.layers = []

forward(x)

Forward pass of the MLP.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, num_instances, input_dim).

required

Returns:

Type Description
Tensor

Output tensor of shape (batch_size, num_instances, output_dim).

Source code in dreem/models/mlp.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass of the MLP.

    Args:
        x: Input tensor of shape (batch_size, num_instances, input_dim).

    Returns:
        Output tensor of shape (batch_size, num_instances, output_dim).
    """
    for i, layer in enumerate(self.layers):
        x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        if i < self.num_layers - 1 and self.dropout > 0.0:
            x = self.dropouts[i](x)

    return x