Skip to content

visual_encoder

dreem.models.visual_encoder

Module for different visual feature extractors.

VisualEncoder

Bases: Module

Class wrapping around a visual feature extractor backbone.

Currently CNN only.

Source code in dreem/models/visual_encoder.py
class VisualEncoder(torch.nn.Module):
    """Class wrapping around a visual feature extractor backbone.

    Currently CNN only.
    """

    def __init__(
        self,
        model_name: str = "resnet18",
        d_model: int = 512,
        in_chans: int = 3,
        backend: int = "timm",
        **kwargs: Optional[Any],
    ):
        """Initialize Visual Encoder.

        Args:
            model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
            d_model (int): Output embedding dimension.
            in_chans: the number of input channels of the image.
            backend: Which model backend to use. One of {"timm", "torchvision"}
            kwargs: see `timm.create_model` and `torchvision.models.resnetX` for kwargs.
        """
        super().__init__()

        self.model_name = model_name.lower()
        self.d_model = d_model
        self.backend = backend
        if in_chans == 1:
            self.in_chans = 3
        else:
            self.in_chans = in_chans

        self.feature_extractor = self.select_feature_extractor(
            model_name=self.model_name,
            in_chans=self.in_chans,
            backend=self.backend,
            **kwargs,
        )

        self.out_layer = torch.nn.Linear(
            self.encoder_dim(self.feature_extractor), self.d_model
        )

    def select_feature_extractor(
        self, model_name: str, in_chans: int, backend: str, **kwargs: Optional[Any]
    ) -> torch.nn.Module:
        """Select the appropriate feature extractor based on config.

        Args:
            model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
            in_chans: the number of input channels of the image.
            backend: Which model backend to use. One of {"timm", "torchvision"}
            kwargs: see `timm.create_model` and `torchvision.models.resnetX` for kwargs.

        Returns:
            a CNN encoder based on the config and backend selected.
        """
        if "timm" in backend.lower():
            feature_extractor = timm.create_model(
                model_name=self.model_name,
                in_chans=self.in_chans,
                num_classes=0,
                **kwargs,
            )
        elif "torch" in backend.lower():
            if model_name.lower() == "resnet18":
                feature_extractor = torchvision.models.resnet18(**kwargs)

            elif model_name.lower() == "resnet50":
                feature_extractor = torchvision.models.resnet50(**kwargs)

            else:
                raise ValueError(
                    f"Only `[resnet18, resnet50]` are available when backend is {backend}. Found {model_name}"
                )
            feature_extractor = torch.nn.Sequential(
                *list(feature_extractor.children())[:-1]
            )
            input_layer = feature_extractor[0]
            if in_chans != 3:
                feature_extractor[0] = torch.nn.Conv2d(
                    in_channels=in_chans,
                    out_channels=input_layer.out_channels,
                    kernel_size=input_layer.kernel_size,
                    stride=input_layer.stride,
                    padding=input_layer.padding,
                    dilation=input_layer.dilation,
                    groups=input_layer.groups,
                    bias=input_layer.bias,
                    padding_mode=input_layer.padding_mode,
                )

        else:
            raise ValueError(
                f"Only ['timm', 'torch'] backends are available! Found {backend}."
            )
        return feature_extractor

    def encoder_dim(self, model: torch.nn.Module) -> int:
        """Compute dummy forward pass of encoder model and get embedding dimension.

        Args:
            model: a vision encoder model.

        Returns:
            The embedding dimension size.
        """
        _ = model.eval()
        dummy_output = model(torch.randn(1, self.in_chans, 224, 224)).squeeze()
        _ = model.train()  # to be safe
        return dummy_output.shape[-1]

    def forward(self, img: torch.Tensor) -> torch.Tensor:
        """Forward pass of feature extractor to get feature vector.

        Args:
            img: Input image tensor of shape (B, C, H, W).

        Returns:
            feats: Normalized output tensor of shape (B, d_model).
        """
        # If grayscale, tile the image to 3 channels.
        if img.shape[1] == 1:
            img = img.repeat([1, 3, 1, 1])  # (B, nc=3, H, W)
        # Extract image features
        feats = self.feature_extractor(
            img
        )  # (B, out_dim, 1, 1) if using resnet18 backbone.

        # Reshape feature vectors
        feats = feats.reshape([img.shape[0], -1])  # (B, out_dim)
        # Map feature vectors to output dimension using linear layer.
        feats = self.out_layer(feats)  # (B, d_model)
        # Normalize output feature vectors.
        feats = F.normalize(feats)  # (B, d_model)
        return feats

__init__(model_name='resnet18', d_model=512, in_chans=3, backend='timm', **kwargs)

Initialize Visual Encoder.

Parameters:

Name Type Description Default
model_name str

Name of the CNN architecture to use (e.g. "resnet18", "resnet50").

'resnet18'
d_model int

Output embedding dimension.

512
in_chans int

the number of input channels of the image.

3
backend int

Which model backend to use. One of {"timm", "torchvision"}

'timm'
kwargs Optional[Any]

see timm.create_model and torchvision.models.resnetX for kwargs.

{}
Source code in dreem/models/visual_encoder.py
def __init__(
    self,
    model_name: str = "resnet18",
    d_model: int = 512,
    in_chans: int = 3,
    backend: int = "timm",
    **kwargs: Optional[Any],
):
    """Initialize Visual Encoder.

    Args:
        model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
        d_model (int): Output embedding dimension.
        in_chans: the number of input channels of the image.
        backend: Which model backend to use. One of {"timm", "torchvision"}
        kwargs: see `timm.create_model` and `torchvision.models.resnetX` for kwargs.
    """
    super().__init__()

    self.model_name = model_name.lower()
    self.d_model = d_model
    self.backend = backend
    if in_chans == 1:
        self.in_chans = 3
    else:
        self.in_chans = in_chans

    self.feature_extractor = self.select_feature_extractor(
        model_name=self.model_name,
        in_chans=self.in_chans,
        backend=self.backend,
        **kwargs,
    )

    self.out_layer = torch.nn.Linear(
        self.encoder_dim(self.feature_extractor), self.d_model
    )

encoder_dim(model)

Compute dummy forward pass of encoder model and get embedding dimension.

Parameters:

Name Type Description Default
model Module

a vision encoder model.

required

Returns:

Type Description
int

The embedding dimension size.

Source code in dreem/models/visual_encoder.py
def encoder_dim(self, model: torch.nn.Module) -> int:
    """Compute dummy forward pass of encoder model and get embedding dimension.

    Args:
        model: a vision encoder model.

    Returns:
        The embedding dimension size.
    """
    _ = model.eval()
    dummy_output = model(torch.randn(1, self.in_chans, 224, 224)).squeeze()
    _ = model.train()  # to be safe
    return dummy_output.shape[-1]

forward(img)

Forward pass of feature extractor to get feature vector.

Parameters:

Name Type Description Default
img Tensor

Input image tensor of shape (B, C, H, W).

required

Returns:

Name Type Description
feats Tensor

Normalized output tensor of shape (B, d_model).

Source code in dreem/models/visual_encoder.py
def forward(self, img: torch.Tensor) -> torch.Tensor:
    """Forward pass of feature extractor to get feature vector.

    Args:
        img: Input image tensor of shape (B, C, H, W).

    Returns:
        feats: Normalized output tensor of shape (B, d_model).
    """
    # If grayscale, tile the image to 3 channels.
    if img.shape[1] == 1:
        img = img.repeat([1, 3, 1, 1])  # (B, nc=3, H, W)
    # Extract image features
    feats = self.feature_extractor(
        img
    )  # (B, out_dim, 1, 1) if using resnet18 backbone.

    # Reshape feature vectors
    feats = feats.reshape([img.shape[0], -1])  # (B, out_dim)
    # Map feature vectors to output dimension using linear layer.
    feats = self.out_layer(feats)  # (B, d_model)
    # Normalize output feature vectors.
    feats = F.normalize(feats)  # (B, d_model)
    return feats

select_feature_extractor(model_name, in_chans, backend, **kwargs)

Select the appropriate feature extractor based on config.

Parameters:

Name Type Description Default
model_name str

Name of the CNN architecture to use (e.g. "resnet18", "resnet50").

required
in_chans int

the number of input channels of the image.

required
backend str

Which model backend to use. One of {"timm", "torchvision"}

required
kwargs Optional[Any]

see timm.create_model and torchvision.models.resnetX for kwargs.

{}

Returns:

Type Description
Module

a CNN encoder based on the config and backend selected.

Source code in dreem/models/visual_encoder.py
def select_feature_extractor(
    self, model_name: str, in_chans: int, backend: str, **kwargs: Optional[Any]
) -> torch.nn.Module:
    """Select the appropriate feature extractor based on config.

    Args:
        model_name (str): Name of the CNN architecture to use (e.g. "resnet18", "resnet50").
        in_chans: the number of input channels of the image.
        backend: Which model backend to use. One of {"timm", "torchvision"}
        kwargs: see `timm.create_model` and `torchvision.models.resnetX` for kwargs.

    Returns:
        a CNN encoder based on the config and backend selected.
    """
    if "timm" in backend.lower():
        feature_extractor = timm.create_model(
            model_name=self.model_name,
            in_chans=self.in_chans,
            num_classes=0,
            **kwargs,
        )
    elif "torch" in backend.lower():
        if model_name.lower() == "resnet18":
            feature_extractor = torchvision.models.resnet18(**kwargs)

        elif model_name.lower() == "resnet50":
            feature_extractor = torchvision.models.resnet50(**kwargs)

        else:
            raise ValueError(
                f"Only `[resnet18, resnet50]` are available when backend is {backend}. Found {model_name}"
            )
        feature_extractor = torch.nn.Sequential(
            *list(feature_extractor.children())[:-1]
        )
        input_layer = feature_extractor[0]
        if in_chans != 3:
            feature_extractor[0] = torch.nn.Conv2d(
                in_channels=in_chans,
                out_channels=input_layer.out_channels,
                kernel_size=input_layer.kernel_size,
                stride=input_layer.stride,
                padding=input_layer.padding,
                dilation=input_layer.dilation,
                groups=input_layer.groups,
                bias=input_layer.bias,
                padding_mode=input_layer.padding_mode,
            )

    else:
        raise ValueError(
            f"Only ['timm', 'torch'] backends are available! Found {backend}."
        )
    return feature_extractor