Skip to content

Model Parts

dreem.models.VisualEncoder

Bases: Module

Class wrapping around a visual feature extractor backbone.

Currently CNN only.

Source code in dreem/models/visual_encoder.py
class VisualEncoder(torch.nn.Module):
    """Class wrapping around a visual feature extractor backbone.

    Currently CNN only.
    """

    def __init__(
        self,
        model_name: str = "resnet18",
        d_model: int = 512,
        in_chans: int = 3,
        backend: int = "timm",
        **kwargs: Optional[Any],
    ):
        """Initialize Visual Encoder.

        Args:
            model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
            d_model (int): Output embedding dimension.
            in_chans: the number of input channels of the image.
            backend: Which model backend to use. One of {"timm", "torchvision"}
            kwargs: see `timm.create_model` and `torchvision.models.resnetX` for kwargs.
        """
        super().__init__()

        self.model_name = model_name.lower()
        self.d_model = d_model
        self.backend = backend
        if in_chans == 1:
            self.in_chans = 3
        else:
            self.in_chans = in_chans

        self.feature_extractor = self.select_feature_extractor(
            model_name=self.model_name,
            in_chans=self.in_chans,
            backend=self.backend,
            **kwargs,
        )

        self.out_layer = torch.nn.Linear(
            self.encoder_dim(self.feature_extractor), self.d_model
        )

    def select_feature_extractor(
        self, model_name: str, in_chans: int, backend: str, **kwargs: Optional[Any]
    ) -> torch.nn.Module:
        """Select the appropriate feature extractor based on config.

        Args:
            model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
            in_chans: the number of input channels of the image.
            backend: Which model backend to use. One of {"timm", "torchvision"}
            kwargs: see `timm.create_model` and `torchvision.models.resnetX` for kwargs.

        Returns:
            a CNN encoder based on the config and backend selected.
        """
        if "timm" in backend.lower():
            feature_extractor = timm.create_model(
                model_name=self.model_name,
                in_chans=self.in_chans,
                num_classes=0,
                **kwargs,
            )
        elif "torch" in backend.lower():
            if model_name.lower() == "resnet18":
                feature_extractor = torchvision.models.resnet18(**kwargs)

            elif model_name.lower() == "resnet50":
                feature_extractor = torchvision.models.resnet50(**kwargs)

            else:
                raise ValueError(
                    f"Only `[resnet18, resnet50]` are available when backend is {backend}. Found {model_name}"
                )
            feature_extractor = torch.nn.Sequential(
                *list(feature_extractor.children())[:-1]
            )
            input_layer = feature_extractor[0]
            if in_chans != 3:
                feature_extractor[0] = torch.nn.Conv2d(
                    in_channels=in_chans,
                    out_channels=input_layer.out_channels,
                    kernel_size=input_layer.kernel_size,
                    stride=input_layer.stride,
                    padding=input_layer.padding,
                    dilation=input_layer.dilation,
                    groups=input_layer.groups,
                    bias=input_layer.bias,
                    padding_mode=input_layer.padding_mode,
                )

        else:
            raise ValueError(
                f"Only ['timm', 'torch'] backends are available! Found {backend}."
            )
        return feature_extractor

    def encoder_dim(self, model: torch.nn.Module) -> int:
        """Compute dummy forward pass of encoder model and get embedding dimension.

        Args:
            model: a vision encoder model.

        Returns:
            The embedding dimension size.
        """
        _ = model.eval()
        dummy_output = model(torch.randn(1, self.in_chans, 224, 224)).squeeze()
        _ = model.train()  # to be safe
        return dummy_output.shape[-1]

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        """Forward pass of feature extractor to get feature vector.

        Args:
            img: Input image tensor of shape (B, C, H, W).

        Returns:
            feats: Normalized output tensor of shape (B, d_model).
        """
        # If grayscale, tile the image to 3 channels.
        if img.shape[1] == 1:
            img = img.repeat([1, 3, 1, 1])  # (B, nc=3, H, W)
        # Extract image features
        feats = self.feature_extractor(
            img
        )  # (B, out_dim, 1, 1) if using resnet18 backbone.

        # Reshape feature vectors
        feats = feats.reshape([img.shape[0], -1])  # (B, out_dim)
        # Map feature vectors to output dimension using linear layer.
        feats = self.out_layer(feats)  # (B, d_model)
        # Normalize output feature vectors.
        feats = F.normalize(feats)  # (B, d_model)
        return feats

__init__(model_name='resnet18', d_model=512, in_chans=3, backend='timm', **kwargs)

Initialize Visual Encoder.

Parameters:

Name Type Description Default
model_name str

Name of the CNN architecture to use (e.g. "resnet18", "resnet50").

'resnet18'
d_model int

Output embedding dimension.

512
in_chans int

the number of input channels of the image.

3
backend int

Which model backend to use. One of {"timm", "torchvision"}

'timm'
kwargs Optional[Any]

see timm.create_model and torchvision.models.resnetX for kwargs.

{}
Source code in dreem/models/visual_encoder.py
def __init__(
    self,
    model_name: str = "resnet18",
    d_model: int = 512,
    in_chans: int = 3,
    backend: int = "timm",
    **kwargs: Optional[Any],
):
    """Initialize Visual Encoder.

    Args:
        model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
        d_model (int): Output embedding dimension.
        in_chans: the number of input channels of the image.
        backend: Which model backend to use. One of {"timm", "torchvision"}
        kwargs: see `timm.create_model` and `torchvision.models.resnetX` for kwargs.
    """
    super().__init__()

    self.model_name = model_name.lower()
    self.d_model = d_model
    self.backend = backend
    if in_chans == 1:
        self.in_chans = 3
    else:
        self.in_chans = in_chans

    self.feature_extractor = self.select_feature_extractor(
        model_name=self.model_name,
        in_chans=self.in_chans,
        backend=self.backend,
        **kwargs,
    )

    self.out_layer = torch.nn.Linear(
        self.encoder_dim(self.feature_extractor), self.d_model
    )

encoder_dim(model)

Compute dummy forward pass of encoder model and get embedding dimension.

Parameters:

Name Type Description Default
model Module

a vision encoder model.

required

Returns:

Type Description
int

The embedding dimension size.

Source code in dreem/models/visual_encoder.py
def encoder_dim(self, model: torch.nn.Module) -> int:
    """Compute dummy forward pass of encoder model and get embedding dimension.

    Args:
        model: a vision encoder model.

    Returns:
        The embedding dimension size.
    """
    _ = model.eval()
    dummy_output = model(torch.randn(1, self.in_chans, 224, 224)).squeeze()
    _ = model.train()  # to be safe
    return dummy_output.shape[-1]

forward(img)

Forward pass of feature extractor to get feature vector.

Parameters:

Name Type Description Default
img Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Name Type Description
feats Tensor

Normalized output tensor of shape (B, d_model).

Source code in dreem/models/visual_encoder.py
def forward(self, img: torch.Tensor) -> torch.Tensor:
    """Forward pass of feature extractor to get feature vector.

    Args:
        img: Input image tensor of shape (B, C, H, W).

    Returns:
        feats: Normalized output tensor of shape (B, d_model).
    """
    # If grayscale, tile the image to 3 channels.
    if img.shape[1] == 1:
        img = img.repeat([1, 3, 1, 1])  # (B, nc=3, H, W)
    # Extract image features
    feats = self.feature_extractor(
        img
    )  # (B, out_dim, 1, 1) if using resnet18 backbone.

    # Reshape feature vectors
    feats = feats.reshape([img.shape[0], -1])  # (B, out_dim)
    # Map feature vectors to output dimension using linear layer.
    feats = self.out_layer(feats)  # (B, d_model)
    # Normalize output feature vectors.
    feats = F.normalize(feats)  # (B, d_model)
    return feats

select_feature_extractor(model_name, in_chans, backend, **kwargs)

Select the appropriate feature extractor based on config.

Parameters:

Name Type Description Default
model_name str

Name of the CNN architecture to use (e.g. "resnet18", "resnet50").

required
in_chans int

the number of input channels of the image.

required
backend str

Which model backend to use. One of {"timm", "torchvision"}

required
kwargs Optional[Any]

see timm.create_model and torchvision.models.resnetX for kwargs.

{}

Returns:

Type Description
Module

a CNN encoder based on the config and backend selected.

Source code in dreem/models/visual_encoder.py
def select_feature_extractor(
    self, model_name: str, in_chans: int, backend: str, **kwargs: Optional[Any]
) -> torch.nn.Module:
    """Select the appropriate feature extractor based on config.

    Args:
        model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
        in_chans: the number of input channels of the image.
        backend: Which model backend to use. One of {"timm", "torchvision"}
        kwargs: see `timm.create_model` and `torchvision.models.resnetX` for kwargs.

    Returns:
        a CNN encoder based on the config and backend selected.
    """
    if "timm" in backend.lower():
        feature_extractor = timm.create_model(
            model_name=self.model_name,
            in_chans=self.in_chans,
            num_classes=0,
            **kwargs,
        )
    elif "torch" in backend.lower():
        if model_name.lower() == "resnet18":
            feature_extractor = torchvision.models.resnet18(**kwargs)

        elif model_name.lower() == "resnet50":
            feature_extractor = torchvision.models.resnet50(**kwargs)

        else:
            raise ValueError(
                f"Only `[resnet18, resnet50]` are available when backend is {backend}. Found {model_name}"
            )
        feature_extractor = torch.nn.Sequential(
            *list(feature_extractor.children())[:-1]
        )
        input_layer = feature_extractor[0]
        if in_chans != 3:
            feature_extractor[0] = torch.nn.Conv2d(
                in_channels=in_chans,
                out_channels=input_layer.out_channels,
                kernel_size=input_layer.kernel_size,
                stride=input_layer.stride,
                padding=input_layer.padding,
                dilation=input_layer.dilation,
                groups=input_layer.groups,
                bias=input_layer.bias,
                padding_mode=input_layer.padding_mode,
            )

    else:
        raise ValueError(
            f"Only ['timm', 'torch'] backends are available! Found {backend}."
        )
    return feature_extractor

dreem.models.Transformer

Bases: Module

Transformer class.

Source code in dreem/models/transformer.py
class Transformer(torch.nn.Module):
    """Transformer class."""

    def __init__(
        self,
        d_model: int = 1024,
        nhead: int = 8,
        num_encoder_layers: int = 6,
        num_decoder_layers: int = 6,
        dropout: float = 0.1,
        activation: str = "relu",
        return_intermediate_dec: bool = False,
        norm: bool = False,
        num_layers_attn_head: int = 2,
        dropout_attn_head: float = 0.1,
        embedding_meta: dict = None,
        return_embedding: bool = False,
        decoder_self_attn: bool = False,
    ) -> None:
        """Initialize Transformer.

        Args:
            d_model: The number of features in the encoder/decoder inputs.
            nhead: The number of heads in the transfomer encoder/decoder.
            num_encoder_layers: The number of encoder-layers in the encoder.
            num_decoder_layers: The number of decoder-layers in the decoder.
            dropout: Dropout value applied to the output of transformer layers.
            activation: Activation function to use.
            return_intermediate_dec: Return intermediate layers from decoder.
            norm: If True, normalize output of encoder and decoder.
            num_layers_attn_head: The number of layers in the attention head.
            dropout_attn_head: Dropout value for the attention_head.
            embedding_meta: Metadata for positional embeddings. See below.
            return_embedding: Whether to return the positional embeddings
            decoder_self_attn: If True, use decoder self attention.

                More details on `embedding_meta`:
                    By default this will be an empty dict and indicate
                    that no positional embeddings should be used. To use the positional embeddings
                    pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie:
                    {"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'},
                    "temp": {'mode': 'learned', 'emb_num': 16}}. (see `dreem.models.embeddings.Embedding.EMB_TYPES`
                    and `dreem.models.embeddings.Embedding.EMB_MODES` for embedding parameters).
        """
        super().__init__()

        self.d_model = dim_feedforward = feature_dim_attn_head = d_model

        self.embedding_meta = embedding_meta
        self.return_embedding = return_embedding

        self.pos_emb = Embedding(emb_type="off", mode="off", features=self.d_model)
        self.temp_emb = Embedding(emb_type="off", mode="off", features=self.d_model)

        if self.embedding_meta:
            if "pos" in self.embedding_meta:
                pos_emb_cfg = self.embedding_meta["pos"]
                if pos_emb_cfg:
                    self.pos_emb = Embedding(
                        emb_type="pos", features=self.d_model, **pos_emb_cfg
                    )
            if "temp" in self.embedding_meta:
                temp_emb_cfg = self.embedding_meta["temp"]
                if temp_emb_cfg:
                    self.temp_emb = Embedding(
                        emb_type="temp", features=self.d_model, **temp_emb_cfg
                    )

        # Transformer Encoder
        encoder_layer = TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, dropout, activation, norm
        )

        encoder_norm = nn.LayerNorm(d_model) if (norm) else None

        self.encoder = TransformerEncoder(
            encoder_layer, num_encoder_layers, encoder_norm
        )

        # Transformer Decoder
        decoder_layer = TransformerDecoderLayer(
            d_model,
            nhead,
            dim_feedforward,
            dropout,
            activation,
            norm,
            decoder_self_attn,
        )

        decoder_norm = nn.LayerNorm(d_model) if (norm) else None

        self.decoder = TransformerDecoder(
            decoder_layer, num_decoder_layers, return_intermediate_dec, decoder_norm
        )

        # Transformer attention head
        self.attn_head = ATTWeightHead(
            feature_dim=feature_dim_attn_head,
            num_layers=num_layers_attn_head,
            dropout=dropout_attn_head,
        )

        self._reset_parameters()

    def _reset_parameters(self):
        """Initialize model weights from xavier distribution."""
        for p in self.parameters():
            if not torch.nn.parameter.is_lazy(p) and p.dim() > 1:
                try:
                    nn.init.xavier_uniform_(p)
                except ValueError as e:
                    print(f"Failed Trying to initialize {p}")
                    raise (e)

    def forward(
        self,
        ref_instances: list["dreem.io.Instance"],
        query_instances: list["dreem.io.Instance"] = None,
    ) -> list[AssociationMatrix]:
        """Execute a forward pass through the transformer and attention head.

        Args:
            ref_instances: A list of instance objects (See `dreem.io.Instance` for more info.)
            query_instances: An set of instances to be used as decoder queries.

        Returns:
            asso_output: A list of torch.Tensors of shape (L, n_query, total_instances) where:
                L: number of decoder blocks
                n_query: number of instances in current query/frame
                total_instances: number of instances in window
        """
        ref_features = torch.cat(
            [instance.features for instance in ref_instances], dim=0
        ).unsqueeze(0)

        # window_length = len(frames)
        # instances_per_frame = [frame.num_detected for frame in frames]
        total_instances = len(ref_instances)
        embed_dim = ref_features.shape[-1]
        # print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}')
        ref_boxes = get_boxes(ref_instances)  # total_instances, 4
        ref_boxes = torch.nan_to_num(ref_boxes, -1.0)
        ref_times, query_times = get_times(ref_instances, query_instances)

        window_length = len(ref_times.unique())

        ref_temp_emb = self.temp_emb(ref_times / window_length)

        ref_pos_emb = self.pos_emb(ref_boxes)

        if self.return_embedding:
            for i, instance in enumerate(ref_instances):
                instance.add_embedding("pos", ref_pos_emb[i])
                instance.add_embedding("temp", ref_temp_emb[i])

        ref_emb = (ref_pos_emb + ref_temp_emb) / 2.0

        ref_emb = ref_emb.view(1, total_instances, embed_dim)

        ref_emb = ref_emb.permute(1, 0, 2)  # (total_instances, batch_size, embed_dim)

        batch_size, total_instances, embed_dim = ref_features.shape

        ref_features = ref_features.permute(
            1, 0, 2
        )  # (total_instances, batch_size, embed_dim)

        encoder_queries = ref_features

        encoder_features = self.encoder(
            encoder_queries, pos_emb=ref_emb
        )  # (total_instances, batch_size, embed_dim)

        n_query = total_instances

        query_features = ref_features
        query_pos_emb = ref_pos_emb
        query_temp_emb = ref_temp_emb
        query_emb = ref_emb

        if query_instances is not None:
            n_query = len(query_instances)

            query_features = torch.cat(
                [instance.features for instance in query_instances], dim=0
            ).unsqueeze(0)

            query_features = query_features.permute(
                1, 0, 2
            )  # (n_query, batch_size, embed_dim)

            query_boxes = get_boxes(query_instances)
            query_boxes = torch.nan_to_num(query_boxes, -1.0)
            query_temp_emb = self.temp_emb(query_times / window_length)

            query_pos_emb = self.pos_emb(query_boxes)

            query_emb = (query_pos_emb + query_temp_emb) / 2.0
            query_emb = query_emb.view(1, n_query, embed_dim)
            query_emb = query_emb.permute(1, 0, 2)  # (n_query, batch_size, embed_dim)

        else:
            query_instances = ref_instances

        if self.return_embedding:
            for i, instance in enumerate(query_instances):
                instance.add_embedding("pos", query_pos_emb[i])
                instance.add_embedding("temp", query_temp_emb[i])

        decoder_features = self.decoder(
            query_features,
            encoder_features,
            ref_pos_emb=ref_emb,
            query_pos_emb=query_emb,
        )  # (L, n_query, batch_size, embed_dim)

        decoder_features = decoder_features.transpose(
            1, 2
        )  # # (L, batch_size, n_query, embed_dim)
        encoder_features = encoder_features.permute(1, 0, 2).view(
            batch_size, total_instances, embed_dim
        )  # (batch_size, total_instances, embed_dim)

        asso_output = []
        for frame_features in decoder_features:
            asso_matrix = self.attn_head(frame_features, encoder_features).view(
                n_query, total_instances
            )
            asso_matrix = AssociationMatrix(asso_matrix, ref_instances, query_instances)

            asso_output.append(asso_matrix)

        # (L=1, n_query, total_instances)
        return asso_output

__init__(d_model=1024, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dropout=0.1, activation='relu', return_intermediate_dec=False, norm=False, num_layers_attn_head=2, dropout_attn_head=0.1, embedding_meta=None, return_embedding=False, decoder_self_attn=False)

Initialize Transformer.

Parameters:

Name Type Description Default
d_model int

The number of features in the encoder/decoder inputs.

1024
nhead int

The number of heads in the transfomer encoder/decoder.

8
num_encoder_layers int

The number of encoder-layers in the encoder.

6
num_decoder_layers int

The number of decoder-layers in the decoder.

6
dropout float

Dropout value applied to the output of transformer layers.

0.1
activation str

Activation function to use.

'relu'
return_intermediate_dec bool

Return intermediate layers from decoder.

False
norm bool

If True, normalize output of encoder and decoder.

False
num_layers_attn_head int

The number of layers in the attention head.

2
dropout_attn_head float

Dropout value for the attention_head.

0.1
embedding_meta dict

Metadata for positional embeddings. See below.

None
return_embedding bool

Whether to return the positional embeddings

False
decoder_self_attn bool

If True, use decoder self attention.

More details on embedding_meta: By default this will be an empty dict and indicate that no positional embeddings should be used. To use the positional embeddings pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie: {"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'}, "temp": {'mode': 'learned', 'emb_num': 16}}. (see dreem.models.embeddings.Embedding.EMB_TYPES and dreem.models.embeddings.Embedding.EMB_MODES for embedding parameters).

False
Source code in dreem/models/transformer.py
def __init__(
    self,
    d_model: int = 1024,
    nhead: int = 8,
    num_encoder_layers: int = 6,
    num_decoder_layers: int = 6,
    dropout: float = 0.1,
    activation: str = "relu",
    return_intermediate_dec: bool = False,
    norm: bool = False,
    num_layers_attn_head: int = 2,
    dropout_attn_head: float = 0.1,
    embedding_meta: dict = None,
    return_embedding: bool = False,
    decoder_self_attn: bool = False,
) -> None:
    """Initialize Transformer.

    Args:
        d_model: The number of features in the encoder/decoder inputs.
        nhead: The number of heads in the transfomer encoder/decoder.
        num_encoder_layers: The number of encoder-layers in the encoder.
        num_decoder_layers: The number of decoder-layers in the decoder.
        dropout: Dropout value applied to the output of transformer layers.
        activation: Activation function to use.
        return_intermediate_dec: Return intermediate layers from decoder.
        norm: If True, normalize output of encoder and decoder.
        num_layers_attn_head: The number of layers in the attention head.
        dropout_attn_head: Dropout value for the attention_head.
        embedding_meta: Metadata for positional embeddings. See below.
        return_embedding: Whether to return the positional embeddings
        decoder_self_attn: If True, use decoder self attention.

            More details on `embedding_meta`:
                By default this will be an empty dict and indicate
                that no positional embeddings should be used. To use the positional embeddings
                pass in a dictionary containing a "pos" and "temp" key with subdictionaries for correct parameters ie:
                {"pos": {'mode': 'learned', 'emb_num': 16, 'over_boxes: 'True'},
                "temp": {'mode': 'learned', 'emb_num': 16}}. (see `dreem.models.embeddings.Embedding.EMB_TYPES`
                and `dreem.models.embeddings.Embedding.EMB_MODES` for embedding parameters).
    """
    super().__init__()

    self.d_model = dim_feedforward = feature_dim_attn_head = d_model

    self.embedding_meta = embedding_meta
    self.return_embedding = return_embedding

    self.pos_emb = Embedding(emb_type="off", mode="off", features=self.d_model)
    self.temp_emb = Embedding(emb_type="off", mode="off", features=self.d_model)

    if self.embedding_meta:
        if "pos" in self.embedding_meta:
            pos_emb_cfg = self.embedding_meta["pos"]
            if pos_emb_cfg:
                self.pos_emb = Embedding(
                    emb_type="pos", features=self.d_model, **pos_emb_cfg
                )
        if "temp" in self.embedding_meta:
            temp_emb_cfg = self.embedding_meta["temp"]
            if temp_emb_cfg:
                self.temp_emb = Embedding(
                    emb_type="temp", features=self.d_model, **temp_emb_cfg
                )

    # Transformer Encoder
    encoder_layer = TransformerEncoderLayer(
        d_model, nhead, dim_feedforward, dropout, activation, norm
    )

    encoder_norm = nn.LayerNorm(d_model) if (norm) else None

    self.encoder = TransformerEncoder(
        encoder_layer, num_encoder_layers, encoder_norm
    )

    # Transformer Decoder
    decoder_layer = TransformerDecoderLayer(
        d_model,
        nhead,
        dim_feedforward,
        dropout,
        activation,
        norm,
        decoder_self_attn,
    )

    decoder_norm = nn.LayerNorm(d_model) if (norm) else None

    self.decoder = TransformerDecoder(
        decoder_layer, num_decoder_layers, return_intermediate_dec, decoder_norm
    )

    # Transformer attention head
    self.attn_head = ATTWeightHead(
        feature_dim=feature_dim_attn_head,
        num_layers=num_layers_attn_head,
        dropout=dropout_attn_head,
    )

    self._reset_parameters()

forward(ref_instances, query_instances=None)

Execute a forward pass through the transformer and attention head.

Parameters:

Name Type Description Default
ref_instances list[Instance]

A list of instance objects (See dreem.io.Instance for more info.)

required
query_instances list[Instance]

An set of instances to be used as decoder queries.

None

Returns:

Name Type Description
asso_output list[AssociationMatrix]

A list of torch.Tensors of shape (L, n_query, total_instances) where: L: number of decoder blocks n_query: number of instances in current query/frame total_instances: number of instances in window

Source code in dreem/models/transformer.py
def forward(
    self,
    ref_instances: list["dreem.io.Instance"],
    query_instances: list["dreem.io.Instance"] = None,
) -> list[AssociationMatrix]:
    """Execute a forward pass through the transformer and attention head.

    Args:
        ref_instances: A list of instance objects (See `dreem.io.Instance` for more info.)
        query_instances: An set of instances to be used as decoder queries.

    Returns:
        asso_output: A list of torch.Tensors of shape (L, n_query, total_instances) where:
            L: number of decoder blocks
            n_query: number of instances in current query/frame
            total_instances: number of instances in window
    """
    ref_features = torch.cat(
        [instance.features for instance in ref_instances], dim=0
    ).unsqueeze(0)

    # window_length = len(frames)
    # instances_per_frame = [frame.num_detected for frame in frames]
    total_instances = len(ref_instances)
    embed_dim = ref_features.shape[-1]
    # print(f'T: {window_length}; N: {total_instances}; N_t: {instances_per_frame} n_reid: {reid_features.shape}')
    ref_boxes = get_boxes(ref_instances)  # total_instances, 4
    ref_boxes = torch.nan_to_num(ref_boxes, -1.0)
    ref_times, query_times = get_times(ref_instances, query_instances)

    window_length = len(ref_times.unique())

    ref_temp_emb = self.temp_emb(ref_times / window_length)

    ref_pos_emb = self.pos_emb(ref_boxes)

    if self.return_embedding:
        for i, instance in enumerate(ref_instances):
            instance.add_embedding("pos", ref_pos_emb[i])
            instance.add_embedding("temp", ref_temp_emb[i])

    ref_emb = (ref_pos_emb + ref_temp_emb) / 2.0

    ref_emb = ref_emb.view(1, total_instances, embed_dim)

    ref_emb = ref_emb.permute(1, 0, 2)  # (total_instances, batch_size, embed_dim)

    batch_size, total_instances, embed_dim = ref_features.shape

    ref_features = ref_features.permute(
        1, 0, 2
    )  # (total_instances, batch_size, embed_dim)

    encoder_queries = ref_features

    encoder_features = self.encoder(
        encoder_queries, pos_emb=ref_emb
    )  # (total_instances, batch_size, embed_dim)

    n_query = total_instances

    query_features = ref_features
    query_pos_emb = ref_pos_emb
    query_temp_emb = ref_temp_emb
    query_emb = ref_emb

    if query_instances is not None:
        n_query = len(query_instances)

        query_features = torch.cat(
            [instance.features for instance in query_instances], dim=0
        ).unsqueeze(0)

        query_features = query_features.permute(
            1, 0, 2
        )  # (n_query, batch_size, embed_dim)

        query_boxes = get_boxes(query_instances)
        query_boxes = torch.nan_to_num(query_boxes, -1.0)
        query_temp_emb = self.temp_emb(query_times / window_length)

        query_pos_emb = self.pos_emb(query_boxes)

        query_emb = (query_pos_emb + query_temp_emb) / 2.0
        query_emb = query_emb.view(1, n_query, embed_dim)
        query_emb = query_emb.permute(1, 0, 2)  # (n_query, batch_size, embed_dim)

    else:
        query_instances = ref_instances

    if self.return_embedding:
        for i, instance in enumerate(query_instances):
            instance.add_embedding("pos", query_pos_emb[i])
            instance.add_embedding("temp", query_temp_emb[i])

    decoder_features = self.decoder(
        query_features,
        encoder_features,
        ref_pos_emb=ref_emb,
        query_pos_emb=query_emb,
    )  # (L, n_query, batch_size, embed_dim)

    decoder_features = decoder_features.transpose(
        1, 2
    )  # # (L, batch_size, n_query, embed_dim)
    encoder_features = encoder_features.permute(1, 0, 2).view(
        batch_size, total_instances, embed_dim
    )  # (batch_size, total_instances, embed_dim)

    asso_output = []
    for frame_features in decoder_features:
        asso_matrix = self.attn_head(frame_features, encoder_features).view(
            n_query, total_instances
        )
        asso_matrix = AssociationMatrix(asso_matrix, ref_instances, query_instances)

        asso_output.append(asso_matrix)

    # (L=1, n_query, total_instances)
    return asso_output

dreem.models.transformer.TransformerEncoder

Bases: Module

A transformer encoder block composed of encoder layers.

Source code in dreem/models/transformer.py
class TransformerEncoder(nn.Module):
    """A transformer encoder block composed of encoder layers."""

    def __init__(
        self,
        encoder_layer: TransformerEncoderLayer,
        num_layers: int,
        norm: nn.Module = None,
    ) -> None:
        """Initialize transformer encoder.

        Args:
            encoder_layer: An instance of the TransformerEncoderLayer.
            num_layers: The number of encoder layers to be stacked.
            norm: The normalization layer to be applied.
        """
        super().__init__()

        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm if norm is not None else nn.Identity()

    def forward(
        self, queries: torch.Tensor, pos_emb: torch.Tensor = None
    ) -> torch.Tensor:
        """Execute a forward pass of encoder layer.

        Args:
            queries: The input tensor of shape (n_query, batch_size, embed_dim).
            pos_emb: The positional embedding tensor of shape (n_query, embed_dim).

        Returns:
            The output tensor of shape (n_query, batch_size, embed_dim).
        """
        for layer in self.layers:
            queries = layer(queries, pos_emb=pos_emb)

        encoder_features = self.norm(queries)
        return encoder_features

__init__(encoder_layer, num_layers, norm=None)

Initialize transformer encoder.

Parameters:

Name Type Description Default
encoder_layer TransformerEncoderLayer

An instance of the TransformerEncoderLayer.

required
num_layers int

The number of encoder layers to be stacked.

required
norm Module

The normalization layer to be applied.

None
Source code in dreem/models/transformer.py
def __init__(
    self,
    encoder_layer: TransformerEncoderLayer,
    num_layers: int,
    norm: nn.Module = None,
) -> None:
    """Initialize transformer encoder.

    Args:
        encoder_layer: An instance of the TransformerEncoderLayer.
        num_layers: The number of encoder layers to be stacked.
        norm: The normalization layer to be applied.
    """
    super().__init__()

    self.layers = _get_clones(encoder_layer, num_layers)
    self.num_layers = num_layers
    self.norm = norm if norm is not None else nn.Identity()

forward(queries, pos_emb=None)

Execute a forward pass of encoder layer.

Parameters:

Name Type Description Default
queries Tensor

The input tensor of shape (n_query, batch_size, embed_dim).

required
pos_emb Tensor

The positional embedding tensor of shape (n_query, embed_dim).

None

Returns:

Type Description
Tensor

The output tensor of shape (n_query, batch_size, embed_dim).

Source code in dreem/models/transformer.py
def forward(
    self, queries: torch.Tensor, pos_emb: torch.Tensor = None
) -> torch.Tensor:
    """Execute a forward pass of encoder layer.

    Args:
        queries: The input tensor of shape (n_query, batch_size, embed_dim).
        pos_emb: The positional embedding tensor of shape (n_query, embed_dim).

    Returns:
        The output tensor of shape (n_query, batch_size, embed_dim).
    """
    for layer in self.layers:
        queries = layer(queries, pos_emb=pos_emb)

    encoder_features = self.norm(queries)
    return encoder_features

dreem.models.transformer.TransformerEncoderLayer

Bases: Module

A single transformer encoder layer.

Source code in dreem/models/transformer.py
class TransformerEncoderLayer(nn.Module):
    """A single transformer encoder layer."""

    def __init__(
        self,
        d_model: int = 1024,
        nhead: int = 6,
        dim_feedforward: int = 1024,
        dropout: float = 0.1,
        activation: str = "relu",
        norm: bool = False,
    ) -> None:
        """Initialize a transformer encoder layer.

        Args:
            d_model: The number of features in the encoder inputs.
            nhead: The number of heads for the encoder.
            dim_feedforward: Dimension of the feedforward layers of encoder.
            dropout: Dropout value applied to the output of encoder.
            activation: Activation function to use.
            norm: If True, normalize output of encoder.
        """
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model) if norm else nn.Identity()
        self.norm2 = nn.LayerNorm(d_model) if norm else nn.Identity()

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def forward(
        self, queries: torch.Tensor, pos_emb: torch.Tensor = None
    ) -> torch.Tensor:
        """Execute a forward pass of the encoder layer.

        Args:
            queries: Input sequence for encoder (n_query, batch_size, embed_dim).
            pos_emb: Position embedding, if provided is added to src

        Returns:
            The output tensor of shape (n_query, batch_size, embed_dim).
        """
        if pos_emb is None:
            pos_emb = torch.zeros_like(queries)

        queries = queries + pos_emb

        # q = k = src

        attn_features = self.self_attn(
            query=queries,
            key=queries,
            value=queries,
        )[0]

        queries = queries + self.dropout1(attn_features)
        queries = self.norm1(queries)
        projection = self.linear2(self.dropout(self.activation(self.linear1(queries))))
        queries = queries + self.dropout2(projection)
        encoder_features = self.norm2(queries)

        return encoder_features
__init__(d_model=1024, nhead=6, dim_feedforward=1024, dropout=0.1, activation='relu', norm=False)

Initialize a transformer encoder layer.

Parameters:

Name Type Description Default
d_model int

The number of features in the encoder inputs.

1024
nhead int

The number of heads for the encoder.

6
dim_feedforward int

Dimension of the feedforward layers of encoder.

1024
dropout float

Dropout value applied to the output of encoder.

0.1
activation str

Activation function to use.

'relu'
norm bool

If True, normalize output of encoder.

False
Source code in dreem/models/transformer.py
def __init__(
    self,
    d_model: int = 1024,
    nhead: int = 6,
    dim_feedforward: int = 1024,
    dropout: float = 0.1,
    activation: str = "relu",
    norm: bool = False,
) -> None:
    """Initialize a transformer encoder layer.

    Args:
        d_model: The number of features in the encoder inputs.
        nhead: The number of heads for the encoder.
        dim_feedforward: Dimension of the feedforward layers of encoder.
        dropout: Dropout value applied to the output of encoder.
        activation: Activation function to use.
        norm: If True, normalize output of encoder.
    """
    super().__init__()
    self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
    self.linear1 = nn.Linear(d_model, dim_feedforward)
    self.dropout = nn.Dropout(dropout)
    self.linear2 = nn.Linear(dim_feedforward, d_model)

    self.norm1 = nn.LayerNorm(d_model) if norm else nn.Identity()
    self.norm2 = nn.LayerNorm(d_model) if norm else nn.Identity()

    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)

    self.activation = _get_activation_fn(activation)
forward(queries, pos_emb=None)

Execute a forward pass of the encoder layer.

Parameters:

Name Type Description Default
queries Tensor

Input sequence for encoder (n_query, batch_size, embed_dim).

required
pos_emb Tensor

Position embedding, if provided is added to src

None

Returns:

Type Description
Tensor

The output tensor of shape (n_query, batch_size, embed_dim).

Source code in dreem/models/transformer.py
def forward(
    self, queries: torch.Tensor, pos_emb: torch.Tensor = None
) -> torch.Tensor:
    """Execute a forward pass of the encoder layer.

    Args:
        queries: Input sequence for encoder (n_query, batch_size, embed_dim).
        pos_emb: Position embedding, if provided is added to src

    Returns:
        The output tensor of shape (n_query, batch_size, embed_dim).
    """
    if pos_emb is None:
        pos_emb = torch.zeros_like(queries)

    queries = queries + pos_emb

    # q = k = src

    attn_features = self.self_attn(
        query=queries,
        key=queries,
        value=queries,
    )[0]

    queries = queries + self.dropout1(attn_features)
    queries = self.norm1(queries)
    projection = self.linear2(self.dropout(self.activation(self.linear1(queries))))
    queries = queries + self.dropout2(projection)
    encoder_features = self.norm2(queries)

    return encoder_features
dreem.models.Embedding

Bases: Module

Class that wraps around different embedding types.

Used for both learned and fixed embeddings.

Source code in dreem/models/embedding.py
class Embedding(torch.nn.Module):
    """Class that wraps around different embedding types.

    Used for both learned and fixed embeddings.
    """

    EMB_TYPES = {
        "temp": {},
        "pos": {"over_boxes"},
        "off": {},
        None: {},
    }  # dict of valid args:keyword params
    EMB_MODES = {
        "fixed": {"temperature", "scale", "normalize"},
        "learned": {"emb_num"},
        "off": {},
    }  # dict of valid args:keyword params

    def __init__(
        self,
        emb_type: str,
        mode: str,
        features: int,
        n_points: Optional[int] = 1,
        emb_num: Optional[int] = 16,
        over_boxes: Optional[bool] = True,
        temperature: Optional[int] = 10000,
        normalize: Optional[bool] = False,
        scale: Optional[float] = None,
        mlp_cfg: dict = None,
    ):
        """Initialize embeddings.

        Args:
            emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", "off"}`
            mode: The mode or function used to map positions to vector embeddings.
                  Must be one of `{"fixed", "learned", "off"}`
            features: The embedding dimensions. Must match the dimension of the
                      input vectors for the transformer model.
            n_points: the number of points that will be embedded.
            emb_num: the number of embeddings in the `self.lookup` table (Only used in learned embeddings).
            over_boxes: Whether to compute the position embedding for each bbox coordinate (y1x1y2x2) or the centroid + bbox size (yxwh).
            temperature: the temperature constant to be used when computing the sinusoidal position embedding
            normalize: whether or not to normalize the positions (Only used in fixed embeddings).
            scale: factor by which to scale the positions after normalizing (Only used in fixed embeddings).
            mlp_cfg: A dictionary of mlp hyperparameters for projecting embedding to correct space.
                    Example: {"hidden_dims": 256, "num_layers":3, "dropout": 0.3}
        """
        self._check_init_args(emb_type, mode)

        super().__init__()

        self.emb_type = emb_type
        self.mode = mode
        self.features = features
        self.emb_num = emb_num
        self.over_boxes = over_boxes
        self.temperature = temperature
        self.normalize = normalize
        self.scale = scale
        self.n_points = n_points

        if self.normalize and self.scale is None:
            self.scale = 2 * math.pi

        if self.emb_type == "pos" and mlp_cfg is not None and mlp_cfg["num_layers"] > 0:
            if self.mode == "fixed":
                self.mlp = MLP(
                    input_dim=n_points * self.features,
                    output_dim=self.features,
                    **mlp_cfg,
                )
            else:
                in_dim = (self.features // (4 * n_points)) * (4 * n_points)
                self.mlp = MLP(
                    input_dim=in_dim,
                    output_dim=self.features,
                    **mlp_cfg,
                )
        else:
            self.mlp = torch.nn.Identity()

        self._emb_func = lambda tensor: torch.zeros(
            (tensor.shape[0], self.features), dtype=tensor.dtype, device=tensor.device
        )  # turn off embedding by returning zeros

        self.lookup = None

        if self.mode == "learned":
            if self.emb_type == "pos":
                self.lookup = torch.nn.Embedding(
                    self.emb_num * 4 * self.n_points, self.features // (4 * n_points)
                )
                self._emb_func = self._learned_pos_embedding
            elif self.emb_type == "temp":
                self.lookup = torch.nn.Embedding(self.emb_num, self.features)
                self._emb_func = self._learned_temp_embedding

        elif self.mode == "fixed":
            if self.emb_type == "pos":
                self._emb_func = self._sine_box_embedding
            elif self.emb_type == "temp":
                self._emb_func = self._sine_temp_embedding

    def _check_init_args(self, emb_type: str, mode: str):
        """Check whether the correct arguments were passed to initialization.

        Args:
            emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", ""}`
            mode: The mode or function used to map positions to vector embeddings.
                Must be one of `{"fixed", "learned"}`

        Raises:
            ValueError:
              * if the incorrect `emb_type` or `mode` string are passed
            NotImplementedError: if `emb_type` is `temp` and `mode` is `fixed`.
        """
        if emb_type.lower() not in self.EMB_TYPES:
            raise ValueError(
                f"Embedding `emb_type` must be one of {self.EMB_TYPES} not {emb_type}"
            )

        if mode.lower() not in self.EMB_MODES:
            raise ValueError(
                f"Embedding `mode` must be one of {self.EMB_MODES} not {mode}"
            )

    def forward(self, seq_positions: torch.Tensor) -> torch.Tensor:
        """Get the sequence positional embeddings.

        Args:
            seq_positions:
                * An (`N`, 1) tensor where seq_positions[i] represents the temporal position of instance_i in the sequence.
                * An (`N`, n_anchors x 4) tensor where seq_positions[i, j, :] represents the [y1, x1, y2, x2] spatial locations of jth point of instance_i in the sequence.

        Returns:
            An `N` x `self.features` tensor representing the corresponding spatial or temporal embedding.
        """
        emb = self._emb_func(seq_positions)

        if emb.shape[-1] != self.features:
            raise RuntimeError(
                (
                    f"Output embedding dimension is {emb.shape[-1]} but requested {self.features} dimensions! \n"
                    f"hint: Try turning the MLP on by passing `mlp_cfg` to the constructor to project to the correct embedding dimensions."
                )
            )
        return emb

    def _torch_int_div(
        self, tensor1: torch.Tensor, tensor2: torch.Tensor
    ) -> torch.Tensor:
        """Perform integer division of two tensors.

        Args:
            tensor1: dividend tensor.
            tensor2: divisor tensor.

        Returns:
            torch.Tensor, resulting tensor.
        """
        return torch.div(tensor1, tensor2, rounding_mode="floor")

    def _sine_box_embedding(self, boxes: torch.Tensor) -> torch.Tensor:
        """Compute sine positional embeddings for boxes using given parameters.

         Args:
             boxes: the input boxes of shape N, n_anchors, 4 or B, N, n_anchors, 4
                    where the last dimension is the bbox coords in [y1, x1, y2, x2].
                    (Note currently `B=batch_size=1`).

        Returns:
             torch.Tensor, the sine positional embeddings
             (embedding[:, 4i] = sin(x)
              embedding[:, 4i+1] = cos(x)
              embedding[:, 4i+2] = sin(y)
              embedding[:, 4i+3] = cos(y)
              )
        """
        if self.scale is not None and self.normalize is False:
            raise ValueError("normalize should be True if scale is passed")

        if len(boxes.size()) == 3:
            boxes = boxes.unsqueeze(0)

        if self.normalize:
            boxes = boxes / (boxes[:, :, -1:] + 1e-6) * self.scale

        dim_t = torch.arange(self.features // 4, dtype=torch.float32)

        dim_t = self.temperature ** (
            2 * self._torch_int_div(dim_t, 2) / (self.features // 4)
        )

        # (b, n_t, n_anchors, 4, D//4)
        pos_emb = boxes[:, :, :, :, None] / dim_t.to(boxes.device)

        pos_emb = torch.stack(
            (pos_emb[:, :, :, :, 0::2].sin(), pos_emb[:, :, :, :, 1::2].cos()), dim=4
        )
        pos_emb = pos_emb.flatten(2).squeeze(0)  # (N_t, n_anchors * D)

        pos_emb = self.mlp(pos_emb)

        pos_emb = pos_emb.view(boxes.shape[1], self.features)

        return pos_emb

    def _sine_temp_embedding(self, times: torch.Tensor) -> torch.Tensor:
        """Compute fixed sine temporal embeddings.

        Args:
            times: the input times of shape (N,) or (N,1) where N = (sum(instances_per_frame))
            which is the frame index of the instance relative
            to the batch size
            (e.g. `torch.tensor([0, 0, ..., 0, 1, 1, ..., 1, 2, 2, ..., 2,..., B, B, ...B])`).

        Returns:
            an n_instances x D embedding representing the temporal embedding.
        """
        T = times.int().max().item() + 1
        d = self.features
        n = self.temperature

        positions = torch.arange(0, T).unsqueeze(1)
        temp_lookup = torch.zeros(T, d, device=times.device)

        denominators = torch.pow(
            n, 2 * torch.arange(0, d // 2) / d
        )  # 10000^(2i/d_model), i is the index of embedding
        temp_lookup[:, 0::2] = torch.sin(
            positions / denominators
        )  # sin(pos/10000^(2i/d_model))
        temp_lookup[:, 1::2] = torch.cos(
            positions / denominators
        )  # cos(pos/10000^(2i/d_model))

        temp_emb = temp_lookup[times.int()]
        return temp_emb  # .view(len(times), self.features)

    def _learned_pos_embedding(self, boxes: torch.Tensor) -> torch.Tensor:
        """Compute learned positional embeddings for boxes using given parameters.

        Args:
            boxes: the input boxes of shape N x 4 or B x N x 4
                   where the last dimension is the bbox coords in [y1, x1, y2, x2].
                   (Note currently `B=batch_size=1`).

        Returns:
            torch.Tensor, the learned positional embeddings.
        """
        pos_lookup = self.lookup

        N, n_anchors, _ = boxes.shape
        boxes = boxes.view(N, n_anchors, 4)

        if self.over_boxes:
            xywh = boxes
        else:
            xywh = torch.cat(
                [
                    (boxes[:, :, 2:] + boxes[:, :, :2]) / 2,
                    (boxes[:, :, 2:] - boxes[:, :, :2]),
                ],
                dim=1,
            )

        left_ind, right_ind, left_weight, right_weight = self._compute_weights(xywh)
        f = pos_lookup.weight.shape[1]  # self.features // 4

        try:
            pos_emb_table = pos_lookup.weight.view(
                self.emb_num, n_anchors, 4, f
            )  # T x 4 x (D * 4)
        except RuntimeError as e:
            print(f"Hint: `n_points` ({self.n_points}) may be set incorrectly!")
            raise (e)

        left_emb = pos_emb_table.gather(
            0,
            left_ind[:, :, :, None].to(pos_emb_table.device).expand(N, n_anchors, 4, f),
        )  # N x 4 x d
        right_emb = pos_emb_table.gather(
            0,
            right_ind[:, :, :, None]
            .to(pos_emb_table.device)
            .expand(N, n_anchors, 4, f),
        )  # N x 4 x d
        pos_emb = left_weight[:, :, :, None] * right_emb.to(
            left_weight.device
        ) + right_weight[:, :, :, None] * left_emb.to(right_weight.device)

        pos_emb = pos_emb.flatten(1)
        pos_emb = self.mlp(pos_emb)

        return pos_emb.view(N, self.features)

    def _learned_temp_embedding(self, times: torch.Tensor) -> torch.Tensor:
        """Compute learned temporal embeddings for times using given parameters.

        Args:
            times: the input times of shape (N,) or (N,1) where N = (sum(instances_per_frame))
            which is the frame index of the instance relative
            to the batch size
            (e.g. `torch.tensor([0, 0, ..., 0, 1, 1, ..., 1, 2, 2, ..., 2,..., B, B, ...B])`).

        Returns:
            torch.Tensor, the learned temporal embeddings.
        """
        temp_lookup = self.lookup
        N = times.shape[0]

        left_ind, right_ind, left_weight, right_weight = self._compute_weights(times)

        left_emb = temp_lookup.weight[
            left_ind.to(temp_lookup.weight.device)
        ]  # T x D --> N x D
        right_emb = temp_lookup.weight[right_ind.to(temp_lookup.weight.device)]

        temp_emb = left_weight[:, None] * right_emb.to(
            left_weight.device
        ) + right_weight[:, None] * left_emb.to(right_weight.device)

        return temp_emb.view(N, self.features)

    def _compute_weights(self, data: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        """Compute left and right learned embedding weights.

        Args:
            data: the input data (e.g boxes or times).

        Returns:
            A torch.Tensor for each of the left/right indices and weights, respectively
        """
        data = data * self.emb_num

        left_ind = data.clamp(min=0, max=self.emb_num - 1).long()  # N x 4
        right_ind = (left_ind + 1).clamp(min=0, max=self.emb_num - 1).long()  # N x 4

        left_weight = data - left_ind.float()  # N x 4

        right_weight = 1.0 - left_weight

        return left_ind, right_ind, left_weight, right_weight
__init__(emb_type, mode, features, n_points=1, emb_num=16, over_boxes=True, temperature=10000, normalize=False, scale=None, mlp_cfg=None)

Initialize embeddings.

Parameters:

Name Type Description Default
emb_type str

The type of embedding to compute. Must be one of {"temp", "pos", "off"}

required
mode str

The mode or function used to map positions to vector embeddings. Must be one of {"fixed", "learned", "off"}

required
features int

The embedding dimensions. Must match the dimension of the input vectors for the transformer model.

required
n_points Optional[int]

the number of points that will be embedded.

1
emb_num Optional[int]

the number of embeddings in the self.lookup table (Only used in learned embeddings).

16
over_boxes Optional[bool]

Whether to compute the position embedding for each bbox coordinate (y1x1y2x2) or the centroid + bbox size (yxwh).

True
temperature Optional[int]

the temperature constant to be used when computing the sinusoidal position embedding

10000
normalize Optional[bool]

whether or not to normalize the positions (Only used in fixed embeddings).

False
scale Optional[float]

factor by which to scale the positions after normalizing (Only used in fixed embeddings).

None
mlp_cfg dict

A dictionary of mlp hyperparameters for projecting embedding to correct space. Example: {"hidden_dims": 256, "num_layers":3, "dropout": 0.3}

None
Source code in dreem/models/embedding.py
def __init__(
    self,
    emb_type: str,
    mode: str,
    features: int,
    n_points: Optional[int] = 1,
    emb_num: Optional[int] = 16,
    over_boxes: Optional[bool] = True,
    temperature: Optional[int] = 10000,
    normalize: Optional[bool] = False,
    scale: Optional[float] = None,
    mlp_cfg: dict = None,
):
    """Initialize embeddings.

    Args:
        emb_type: The type of embedding to compute. Must be one of `{"temp", "pos", "off"}`
        mode: The mode or function used to map positions to vector embeddings.
              Must be one of `{"fixed", "learned", "off"}`
        features: The embedding dimensions. Must match the dimension of the
                  input vectors for the transformer model.
        n_points: the number of points that will be embedded.
        emb_num: the number of embeddings in the `self.lookup` table (Only used in learned embeddings).
        over_boxes: Whether to compute the position embedding for each bbox coordinate (y1x1y2x2) or the centroid + bbox size (yxwh).
        temperature: the temperature constant to be used when computing the sinusoidal position embedding
        normalize: whether or not to normalize the positions (Only used in fixed embeddings).
        scale: factor by which to scale the positions after normalizing (Only used in fixed embeddings).
        mlp_cfg: A dictionary of mlp hyperparameters for projecting embedding to correct space.
                Example: {"hidden_dims": 256, "num_layers":3, "dropout": 0.3}
    """
    self._check_init_args(emb_type, mode)

    super().__init__()

    self.emb_type = emb_type
    self.mode = mode
    self.features = features
    self.emb_num = emb_num
    self.over_boxes = over_boxes
    self.temperature = temperature
    self.normalize = normalize
    self.scale = scale
    self.n_points = n_points

    if self.normalize and self.scale is None:
        self.scale = 2 * math.pi

    if self.emb_type == "pos" and mlp_cfg is not None and mlp_cfg["num_layers"] > 0:
        if self.mode == "fixed":
            self.mlp = MLP(
                input_dim=n_points * self.features,
                output_dim=self.features,
                **mlp_cfg,
            )
        else:
            in_dim = (self.features // (4 * n_points)) * (4 * n_points)
            self.mlp = MLP(
                input_dim=in_dim,
                output_dim=self.features,
                **mlp_cfg,
            )
    else:
        self.mlp = torch.nn.Identity()

    self._emb_func = lambda tensor: torch.zeros(
        (tensor.shape[0], self.features), dtype=tensor.dtype, device=tensor.device
    )  # turn off embedding by returning zeros

    self.lookup = None

    if self.mode == "learned":
        if self.emb_type == "pos":
            self.lookup = torch.nn.Embedding(
                self.emb_num * 4 * self.n_points, self.features // (4 * n_points)
            )
            self._emb_func = self._learned_pos_embedding
        elif self.emb_type == "temp":
            self.lookup = torch.nn.Embedding(self.emb_num, self.features)
            self._emb_func = self._learned_temp_embedding

    elif self.mode == "fixed":
        if self.emb_type == "pos":
            self._emb_func = self._sine_box_embedding
        elif self.emb_type == "temp":
            self._emb_func = self._sine_temp_embedding
forward(seq_positions)

Get the sequence positional embeddings.

Parameters:

Name Type Description Default
seq_positions Tensor
  • An (N, 1) tensor where seq_positions[i] represents the temporal position of instance_i in the sequence.
  • An (N, n_anchors x 4) tensor where seq_positions[i, j, :] represents the [y1, x1, y2, x2] spatial locations of jth point of instance_i in the sequence.
required

Returns:

Type Description
Tensor

An N x self.features tensor representing the corresponding spatial or temporal embedding.

Source code in dreem/models/embedding.py
def forward(self, seq_positions: torch.Tensor) -> torch.Tensor:
    """Get the sequence positional embeddings.

    Args:
        seq_positions:
            * An (`N`, 1) tensor where seq_positions[i] represents the temporal position of instance_i in the sequence.
            * An (`N`, n_anchors x 4) tensor where seq_positions[i, j, :] represents the [y1, x1, y2, x2] spatial locations of jth point of instance_i in the sequence.

    Returns:
        An `N` x `self.features` tensor representing the corresponding spatial or temporal embedding.
    """
    emb = self._emb_func(seq_positions)

    if emb.shape[-1] != self.features:
        raise RuntimeError(
            (
                f"Output embedding dimension is {emb.shape[-1]} but requested {self.features} dimensions! \n"
                f"hint: Try turning the MLP on by passing `mlp_cfg` to the constructor to project to the correct embedding dimensions."
            )
        )
    return emb

dreem.models.transformer.TransformerDecoder

Bases: Module

Transformer Decoder Block composed of Transformer Decoder Layers.

Source code in dreem/models/transformer.py
class TransformerDecoder(nn.Module):
    """Transformer Decoder Block composed of Transformer Decoder Layers."""

    def __init__(
        self,
        decoder_layer: TransformerDecoderLayer,
        num_layers: int,
        return_intermediate: bool = False,
        norm: nn.Module = None,
    ) -> None:
        """Initialize transformer decoder block.

        Args:
            decoder_layer: An instance of TransformerDecoderLayer.
            num_layers: The number of decoder layers to be stacked.
            return_intermediate: Return intermediate layers from decoder.
            norm: The normalization layer to be applied.
        """
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.return_intermediate = return_intermediate
        self.norm = norm if norm is not None else nn.Identity()

    def forward(
        self,
        decoder_queries: torch.Tensor,
        encoder_features: torch.Tensor,
        ref_pos_emb: torch.Tensor = None,
        query_pos_emb: torch.Tensor = None,
    ) -> torch.Tensor:
        """Execute a forward pass of the decoder block.

        Args:
            decoder_queries: Query sequence for decoder to generate (n_query, batch_size, embed_dim).
            encoder_features: Output from encoder, that decoder uses to attend to relevant
                parts of input sequence (total_instances, batch_size, embed_dim)
            ref_pos_emb: The input positional embedding tensor of shape (total_instances, batch_size, embed_dim).
            query_pos_emb: The query positional embedding of shape (n_query, batch_size, embed_dim)

        Returns:
            The output tensor of shape (L, n_query, batch_size, embed_dim).
        """
        decoder_features = decoder_queries

        intermediate = []

        for layer in self.layers:
            decoder_features = layer(
                decoder_features,
                encoder_features,
                ref_pos_emb=ref_pos_emb,
                query_pos_emb=query_pos_emb,
            )
            if self.return_intermediate:
                intermediate.append(self.norm(decoder_features))

        decoder_features = self.norm(decoder_features)
        if self.return_intermediate:
            intermediate.pop()
            intermediate.append(decoder_features)

            return torch.stack(intermediate)

        return decoder_features.unsqueeze(0)

__init__(decoder_layer, num_layers, return_intermediate=False, norm=None)

Initialize transformer decoder block.

Parameters:

Name Type Description Default
decoder_layer TransformerDecoderLayer

An instance of TransformerDecoderLayer.

required
num_layers int

The number of decoder layers to be stacked.

required
return_intermediate bool

Return intermediate layers from decoder.

False
norm Module

The normalization layer to be applied.

None
Source code in dreem/models/transformer.py
def __init__(
    self,
    decoder_layer: TransformerDecoderLayer,
    num_layers: int,
    return_intermediate: bool = False,
    norm: nn.Module = None,
) -> None:
    """Initialize transformer decoder block.

    Args:
        decoder_layer: An instance of TransformerDecoderLayer.
        num_layers: The number of decoder layers to be stacked.
        return_intermediate: Return intermediate layers from decoder.
        norm: The normalization layer to be applied.
    """
    super().__init__()
    self.layers = _get_clones(decoder_layer, num_layers)
    self.num_layers = num_layers
    self.return_intermediate = return_intermediate
    self.norm = norm if norm is not None else nn.Identity()

forward(decoder_queries, encoder_features, ref_pos_emb=None, query_pos_emb=None)

Execute a forward pass of the decoder block.

Parameters:

Name Type Description Default
decoder_queries Tensor

Query sequence for decoder to generate (n_query, batch_size, embed_dim).

required
encoder_features Tensor

Output from encoder, that decoder uses to attend to relevant parts of input sequence (total_instances, batch_size, embed_dim)

required
ref_pos_emb Tensor

The input positional embedding tensor of shape (total_instances, batch_size, embed_dim).

None
query_pos_emb Tensor

The query positional embedding of shape (n_query, batch_size, embed_dim)

None

Returns:

Type Description
Tensor

The output tensor of shape (L, n_query, batch_size, embed_dim).

Source code in dreem/models/transformer.py
def forward(
    self,
    decoder_queries: torch.Tensor,
    encoder_features: torch.Tensor,
    ref_pos_emb: torch.Tensor = None,
    query_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
    """Execute a forward pass of the decoder block.

    Args:
        decoder_queries: Query sequence for decoder to generate (n_query, batch_size, embed_dim).
        encoder_features: Output from encoder, that decoder uses to attend to relevant
            parts of input sequence (total_instances, batch_size, embed_dim)
        ref_pos_emb: The input positional embedding tensor of shape (total_instances, batch_size, embed_dim).
        query_pos_emb: The query positional embedding of shape (n_query, batch_size, embed_dim)

    Returns:
        The output tensor of shape (L, n_query, batch_size, embed_dim).
    """
    decoder_features = decoder_queries

    intermediate = []

    for layer in self.layers:
        decoder_features = layer(
            decoder_features,
            encoder_features,
            ref_pos_emb=ref_pos_emb,
            query_pos_emb=query_pos_emb,
        )
        if self.return_intermediate:
            intermediate.append(self.norm(decoder_features))

    decoder_features = self.norm(decoder_features)
    if self.return_intermediate:
        intermediate.pop()
        intermediate.append(decoder_features)

        return torch.stack(intermediate)

    return decoder_features.unsqueeze(0)

dreem.models.transformer.TransformerDecoderLayer

Bases: Module

A single transformer decoder layer.

Source code in dreem/models/transformer.py
class TransformerDecoderLayer(nn.Module):
    """A single transformer decoder layer."""

    def __init__(
        self,
        d_model: int = 1024,
        nhead: int = 6,
        dim_feedforward: int = 1024,
        dropout: float = 0.1,
        activation: str = "relu",
        norm: bool = False,
        decoder_self_attn: bool = False,
    ) -> None:
        """Initialize transformer decoder layer.

        Args:
            d_model: The number of features in the decoder inputs.
            nhead: The number of heads for the decoder.
            dim_feedforward: Dimension of the feedforward layers of decoder.
            dropout: Dropout value applied to the output of decoder.
            activation: Activation function to use.
            norm: If True, normalize output of decoder.
            decoder_self_attn: If True, use decoder self attention
        """
        super().__init__()

        self.decoder_self_attn = decoder_self_attn

        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        if self.decoder_self_attn:
            self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        self.norm1 = nn.LayerNorm(d_model) if norm else nn.Identity()
        self.norm2 = nn.LayerNorm(d_model) if norm else nn.Identity()
        self.norm3 = nn.LayerNorm(d_model) if norm else nn.Identity()

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)

    def forward(
        self,
        decoder_queries: torch.Tensor,
        encoder_features: torch.Tensor,
        ref_pos_emb: torch.Tensor = None,
        query_pos_emb: torch.Tensor = None,
    ) -> torch.Tensor:
        """Execute forward pass of decoder layer.

        Args:
            decoder_queries: Target sequence for decoder to generate (n_query, batch_size, embed_dim).
            encoder_features: Output from encoder, that decoder uses to attend to relevant
                parts of input sequence (total_instances, batch_size, embed_dim)
            ref_pos_emb: The input positional embedding tensor of shape (n_query, embed_dim).
            query_pos_emb: The target positional embedding of shape (n_query, embed_dim)

        Returns:
            The output tensor of shape (n_query, batch_size, embed_dim).
        """
        if query_pos_emb is None:
            query_pos_emb = torch.zeros_like(decoder_queries)
        if ref_pos_emb is None:
            ref_pos_emb = torch.zeros_like(encoder_features)

        decoder_queries = decoder_queries + query_pos_emb
        encoder_features = encoder_features + ref_pos_emb

        if self.decoder_self_attn:
            self_attn_features = self.self_attn(
                query=decoder_queries, key=decoder_queries, value=decoder_queries
            )[0]
            decoder_queries = decoder_queries + self.dropout1(self_attn_features)
            decoder_queries = self.norm1(decoder_queries)

        x_attn_features = self.multihead_attn(
            query=decoder_queries,  # (n_query, batch_size, embed_dim)
            key=encoder_features,  # (total_instances, batch_size, embed_dim)
            value=encoder_features,  # (total_instances, batch_size, embed_dim)
        )[
            0
        ]  # (n_query, batch_size, embed_dim)

        decoder_queries = decoder_queries + self.dropout2(
            x_attn_features
        )  # (n_query, batch_size, embed_dim)
        decoder_queries = self.norm2(
            decoder_queries
        )  # (n_query, batch_size, embed_dim)
        projection = self.linear2(
            self.dropout(self.activation(self.linear1(decoder_queries)))
        )  # (n_query, batch_size, embed_dim)
        decoder_queries = decoder_queries + self.dropout3(
            projection
        )  # (n_query, batch_size, embed_dim)
        decoder_features = self.norm3(decoder_queries)

        return decoder_features  # (n_query, batch_size, embed_dim)
__init__(d_model=1024, nhead=6, dim_feedforward=1024, dropout=0.1, activation='relu', norm=False, decoder_self_attn=False)

Initialize transformer decoder layer.

Parameters:

Name Type Description Default
d_model int

The number of features in the decoder inputs.

1024
nhead int

The number of heads for the decoder.

6
dim_feedforward int

Dimension of the feedforward layers of decoder.

1024
dropout float

Dropout value applied to the output of decoder.

0.1
activation str

Activation function to use.

'relu'
norm bool

If True, normalize output of decoder.

False
decoder_self_attn bool

If True, use decoder self attention

False
Source code in dreem/models/transformer.py
def __init__(
    self,
    d_model: int = 1024,
    nhead: int = 6,
    dim_feedforward: int = 1024,
    dropout: float = 0.1,
    activation: str = "relu",
    norm: bool = False,
    decoder_self_attn: bool = False,
) -> None:
    """Initialize transformer decoder layer.

    Args:
        d_model: The number of features in the decoder inputs.
        nhead: The number of heads for the decoder.
        dim_feedforward: Dimension of the feedforward layers of decoder.
        dropout: Dropout value applied to the output of decoder.
        activation: Activation function to use.
        norm: If True, normalize output of decoder.
        decoder_self_attn: If True, use decoder self attention
    """
    super().__init__()

    self.decoder_self_attn = decoder_self_attn

    self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
    self.linear1 = nn.Linear(d_model, dim_feedforward)
    self.dropout = nn.Dropout(dropout)
    self.linear2 = nn.Linear(dim_feedforward, d_model)

    if self.decoder_self_attn:
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

    self.norm1 = nn.LayerNorm(d_model) if norm else nn.Identity()
    self.norm2 = nn.LayerNorm(d_model) if norm else nn.Identity()
    self.norm3 = nn.LayerNorm(d_model) if norm else nn.Identity()

    self.dropout1 = nn.Dropout(dropout)
    self.dropout2 = nn.Dropout(dropout)
    self.dropout3 = nn.Dropout(dropout)

    self.activation = _get_activation_fn(activation)
forward(decoder_queries, encoder_features, ref_pos_emb=None, query_pos_emb=None)

Execute forward pass of decoder layer.

Parameters:

Name Type Description Default
decoder_queries Tensor

Target sequence for decoder to generate (n_query, batch_size, embed_dim).

required
encoder_features Tensor

Output from encoder, that decoder uses to attend to relevant parts of input sequence (total_instances, batch_size, embed_dim)

required
ref_pos_emb Tensor

The input positional embedding tensor of shape (n_query, embed_dim).

None
query_pos_emb Tensor

The target positional embedding of shape (n_query, embed_dim)

None

Returns:

Type Description
Tensor

The output tensor of shape (n_query, batch_size, embed_dim).

Source code in dreem/models/transformer.py
def forward(
    self,
    decoder_queries: torch.Tensor,
    encoder_features: torch.Tensor,
    ref_pos_emb: torch.Tensor = None,
    query_pos_emb: torch.Tensor = None,
) -> torch.Tensor:
    """Execute forward pass of decoder layer.

    Args:
        decoder_queries: Target sequence for decoder to generate (n_query, batch_size, embed_dim).
        encoder_features: Output from encoder, that decoder uses to attend to relevant
            parts of input sequence (total_instances, batch_size, embed_dim)
        ref_pos_emb: The input positional embedding tensor of shape (n_query, embed_dim).
        query_pos_emb: The target positional embedding of shape (n_query, embed_dim)

    Returns:
        The output tensor of shape (n_query, batch_size, embed_dim).
    """
    if query_pos_emb is None:
        query_pos_emb = torch.zeros_like(decoder_queries)
    if ref_pos_emb is None:
        ref_pos_emb = torch.zeros_like(encoder_features)

    decoder_queries = decoder_queries + query_pos_emb
    encoder_features = encoder_features + ref_pos_emb

    if self.decoder_self_attn:
        self_attn_features = self.self_attn(
            query=decoder_queries, key=decoder_queries, value=decoder_queries
        )[0]
        decoder_queries = decoder_queries + self.dropout1(self_attn_features)
        decoder_queries = self.norm1(decoder_queries)

    x_attn_features = self.multihead_attn(
        query=decoder_queries,  # (n_query, batch_size, embed_dim)
        key=encoder_features,  # (total_instances, batch_size, embed_dim)
        value=encoder_features,  # (total_instances, batch_size, embed_dim)
    )[
        0
    ]  # (n_query, batch_size, embed_dim)

    decoder_queries = decoder_queries + self.dropout2(
        x_attn_features
    )  # (n_query, batch_size, embed_dim)
    decoder_queries = self.norm2(
        decoder_queries
    )  # (n_query, batch_size, embed_dim)
    projection = self.linear2(
        self.dropout(self.activation(self.linear1(decoder_queries)))
    )  # (n_query, batch_size, embed_dim)
    decoder_queries = decoder_queries + self.dropout3(
        projection
    )  # (n_query, batch_size, embed_dim)
    decoder_features = self.norm3(decoder_queries)

    return decoder_features  # (n_query, batch_size, embed_dim)

dreem.models.attention_head.ATTWeightHead

Bases: Module

Single attention head.

Source code in dreem/models/attention_head.py
class ATTWeightHead(torch.nn.Module):
    """Single attention head."""

    def __init__(
        self,
        feature_dim: int,
        num_layers: int,
        dropout: float,
    ):
        """Initialize an instance of ATTWeightHead.

        Args:
            feature_dim: The dimensionality of input features.
            num_layers: The number of hidden layers in the MLP.
            dropout: Dropout probability.
        """
        super().__init__()

        self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
        self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
    ) -> torch.Tensor:
        """Compute the attention weights of a query tensor using the key tensor.

        Args:
            query: Input tensor of shape (batch_size, num_frame_instances, feature_dim).
            key: Input tensor of shape (batch_size, num_window_instances, feature_dim).

        Returns:
            Output tensor of shape (batch_size, num_frame_instances, num_window_instances).
        """
        k = self.k_proj(key)
        q = self.q_proj(query)
        attn_weights = torch.bmm(q, k.transpose(1, 2))

        return attn_weights  # (B, N_t, N)

__init__(feature_dim, num_layers, dropout)

Initialize an instance of ATTWeightHead.

Parameters:

Name Type Description Default
feature_dim int

The dimensionality of input features.

required
num_layers int

The number of hidden layers in the MLP.

required
dropout float

Dropout probability.

required
Source code in dreem/models/attention_head.py
def __init__(
    self,
    feature_dim: int,
    num_layers: int,
    dropout: float,
):
    """Initialize an instance of ATTWeightHead.

    Args:
        feature_dim: The dimensionality of input features.
        num_layers: The number of hidden layers in the MLP.
        dropout: Dropout probability.
    """
    super().__init__()

    self.q_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)
    self.k_proj = MLP(feature_dim, feature_dim, feature_dim, num_layers, dropout)

forward(query, key)

Compute the attention weights of a query tensor using the key tensor.

Parameters:

Name Type Description Default
query Tensor

Input tensor of shape (batch_size, num_frame_instances, feature_dim).

required
key Tensor

Input tensor of shape (batch_size, num_window_instances, feature_dim).

required

Returns:

Type Description
Tensor

Output tensor of shape (batch_size, num_frame_instances, num_window_instances).

Source code in dreem/models/attention_head.py
def forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
) -> torch.Tensor:
    """Compute the attention weights of a query tensor using the key tensor.

    Args:
        query: Input tensor of shape (batch_size, num_frame_instances, feature_dim).
        key: Input tensor of shape (batch_size, num_window_instances, feature_dim).

    Returns:
        Output tensor of shape (batch_size, num_frame_instances, num_window_instances).
    """
    k = self.k_proj(key)
    q = self.q_proj(query)
    attn_weights = torch.bmm(q, k.transpose(1, 2))

    return attn_weights  # (B, N_t, N)

dreem.models.mlp.MLP

Bases: Module

Multi-Layer Perceptron (MLP) module.

Source code in dreem/models/mlp.py
class MLP(torch.nn.Module):
    """Multi-Layer Perceptron (MLP) module."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_layers: int,
        dropout: float = 0.0,
    ):
        """Initialize MLP.

        Args:
            input_dim: Dimensionality of the input features.
            hidden_dim: Number of units in the hidden layers.
            output_dim: Dimensionality of the output features.
            num_layers: Number of hidden layers.
            dropout: Dropout probability.
        """
        super().__init__()

        self.num_layers = num_layers
        self.dropout = dropout

        if self.num_layers > 0:
            h = [hidden_dim] * (num_layers - 1)
            self.layers = torch.nn.ModuleList(
                [
                    torch.nn.Linear(n, k)
                    for n, k in zip([input_dim] + h, h + [output_dim])
                ]
            )
            if self.dropout > 0.0:
                self.dropouts = torch.nn.ModuleList(
                    [torch.nn.Dropout(dropout) for _ in range(self.num_layers - 1)]
                )
        else:
            self.layers = []

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the MLP.

        Args:
            x: Input tensor of shape (batch_size, num_instances, input_dim).

        Returns:
            Output tensor of shape (batch_size, num_instances, output_dim).
        """
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
            if i < self.num_layers - 1 and self.dropout > 0.0:
                x = self.dropouts[i](x)

        return x

__init__(input_dim, hidden_dim, output_dim, num_layers, dropout=0.0)

Initialize MLP.

Parameters:

Name Type Description Default
input_dim int

Dimensionality of the input features.

required
hidden_dim int

Number of units in the hidden layers.

required
output_dim int

Dimensionality of the output features.

required
num_layers int

Number of hidden layers.

required
dropout float

Dropout probability.

0.0
Source code in dreem/models/mlp.py
def __init__(
    self,
    input_dim: int,
    hidden_dim: int,
    output_dim: int,
    num_layers: int,
    dropout: float = 0.0,
):
    """Initialize MLP.

    Args:
        input_dim: Dimensionality of the input features.
        hidden_dim: Number of units in the hidden layers.
        output_dim: Dimensionality of the output features.
        num_layers: Number of hidden layers.
        dropout: Dropout probability.
    """
    super().__init__()

    self.num_layers = num_layers
    self.dropout = dropout

    if self.num_layers > 0:
        h = [hidden_dim] * (num_layers - 1)
        self.layers = torch.nn.ModuleList(
            [
                torch.nn.Linear(n, k)
                for n, k in zip([input_dim] + h, h + [output_dim])
            ]
        )
        if self.dropout > 0.0:
            self.dropouts = torch.nn.ModuleList(
                [torch.nn.Dropout(dropout) for _ in range(self.num_layers - 1)]
            )
    else:
        self.layers = []

forward(x)

Forward pass of the MLP.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (batch_size, num_instances, input_dim).

required

Returns:

Type Description
Tensor

Output tensor of shape (batch_size, num_instances, output_dim).

Source code in dreem/models/mlp.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass of the MLP.

    Args:
        x: Input tensor of shape (batch_size, num_instances, input_dim).

    Returns:
        Output tensor of shape (batch_size, num_instances, output_dim).
    """
    for i, layer in enumerate(self.layers):
        x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        if i < self.num_layers - 1 and self.dropout > 0.0:
            x = self.dropouts[i](x)

    return x