Skip to content

data_utils

dreem.datasets.data_utils

Module containing helper functions for datasets.

Classes:

Name Description
LazyTiffStack

Class used for loading tiffs without loading into memory.

NodeDropout

Node dropout augmentation.

Functions:

Name Description
build_augmentations

Get augmentations for dataset.

centroid_bbox

Calculate bbox around instance centroid.

crop_bbox

Crop an image to a bounding box.

get_bbox

Get a square bbox around a centroid coordinates.

get_max_padding

Calculate maximum padding dimensions for a given height and width.

get_tight_bbox

Get a tight bbox around an instance.

get_tight_bbox_masks

Get a tight bbox around an instance.

load_slp

Read a SLEAP labels file.

pad_bbox

Pad bounding box coordinates.

pad_variable_size_crops

Pad or crop an instance's crop to the target size.

parse_synthetic

Parse .xml labels from synthetic data generated by ICY or ISBI tracking challenge.

parse_trackmate

Parse trackmate xml or csv labels file.

pose_bbox

Calculate bbox around instance pose.

resize_and_pad

Resize and pad an image to fit a square output size.

sorted_anchors

Sort anchor names from most instances with that node to least.

view_training_batch

Display a grid of images from a batch of training instances.

LazyTiffStack

Class used for loading tiffs without loading into memory.

Methods:

Name Description
__getitem__

Get frame.

__init__

Initialize class.

close

Close tiff stack.

get_section

Get frame as ndarray.

Source code in dreem/datasets/data_utils.py
class LazyTiffStack:
    """Class used for loading tiffs without loading into memory."""

    def __init__(self, filename: str):
        """Initialize class.

        Args:
            filename: name of tif file to be opened
        """
        # expects spatial, channels
        self.image = Image.open(filename)

    def __getitem__(self, section_idx: int) -> Image:
        """Get frame.

        Args:
            section_idx: index of frame or z-slice to get.

        Returns:
            a PIL image of that frame/z-slice.
        """
        self.image.seek(section_idx)
        return self.image

    def get_section(self, section_idx: int) -> np.array:
        """Get frame as ndarray.

        Args:
            section_idx: index of frame or z-slice to get.

        Returns:
            an np.array of that frame/z-slice.
        """
        section = self.__getitem__(section_idx)
        return np.array(section)

    def close(self):
        """Close tiff stack."""
        self.file.close()

__getitem__(section_idx)

Get frame.

Parameters:

Name Type Description Default
section_idx int

index of frame or z-slice to get.

required

Returns:

Type Description
Image

a PIL image of that frame/z-slice.

Source code in dreem/datasets/data_utils.py
def __getitem__(self, section_idx: int) -> Image:
    """Get frame.

    Args:
        section_idx: index of frame or z-slice to get.

    Returns:
        a PIL image of that frame/z-slice.
    """
    self.image.seek(section_idx)
    return self.image

__init__(filename)

Initialize class.

Parameters:

Name Type Description Default
filename str

name of tif file to be opened

required
Source code in dreem/datasets/data_utils.py
def __init__(self, filename: str):
    """Initialize class.

    Args:
        filename: name of tif file to be opened
    """
    # expects spatial, channels
    self.image = Image.open(filename)

close()

Close tiff stack.

Source code in dreem/datasets/data_utils.py
def close(self):
    """Close tiff stack."""
    self.file.close()

get_section(section_idx)

Get frame as ndarray.

Parameters:

Name Type Description Default
section_idx int

index of frame or z-slice to get.

required

Returns:

Type Description
array

an np.array of that frame/z-slice.

Source code in dreem/datasets/data_utils.py
def get_section(self, section_idx: int) -> np.array:
    """Get frame as ndarray.

    Args:
        section_idx: index of frame or z-slice to get.

    Returns:
        an np.array of that frame/z-slice.
    """
    section = self.__getitem__(section_idx)
    return np.array(section)

NodeDropout

Node dropout augmentation.

Drop up to n nodes with probability p.

Methods:

Name Description
__call__

Wrap drop_nodes to enable class call.

__init__

Initialize Node Dropout Augmentation.

forward

Drop up to n random nodes with probability p.

Source code in dreem/datasets/data_utils.py
class NodeDropout:
    """Node dropout augmentation.

    Drop up to `n` nodes with probability `p`.
    """

    def __init__(self, p: float, n: int) -> None:
        """Initialize Node Dropout Augmentation.

        Args:
            p: the probability with which to drop the nodes
            n: the maximum number of nodes to drop
        """
        self.n = n
        self.p = p

    def __call__(self, nodes: list[str]) -> list[str]:
        """Wrap `drop_nodes` to enable class call.

        Args:
            nodes: A list of available node names to drop.

        Returns:
            dropped_nodes: A list of up to `self.n` nodes to drop.
        """
        return self.forward(nodes)

    def forward(self, nodes: list[str]) -> list[str]:
        """Drop up to `n` random nodes with probability p.

        Args:
            nodes: A list of available node names to drop.

        Returns:
            dropped_nodes: A list of up to `self.n` nodes to drop.
        """
        if self.n == 0 or self.p == 0:
            return []

        nodes_to_drop = np.random.permutation(nodes)
        node_dropout_p = np.random.uniform(size=len(nodes_to_drop))

        dropped_node_inds = np.where(node_dropout_p < self.p)
        node_dropout_p = node_dropout_p[dropped_node_inds]

        n_nodes_to_drop = min(self.n, len(node_dropout_p))

        dropped_node_inds = np.argpartition(node_dropout_p, -n_nodes_to_drop)[
            -n_nodes_to_drop:
        ]

        dropped_nodes = nodes_to_drop[dropped_node_inds]

        return dropped_nodes

__call__(nodes)

Wrap drop_nodes to enable class call.

Parameters:

Name Type Description Default
nodes list[str]

A list of available node names to drop.

required

Returns:

Name Type Description
dropped_nodes list[str]

A list of up to self.n nodes to drop.

Source code in dreem/datasets/data_utils.py
def __call__(self, nodes: list[str]) -> list[str]:
    """Wrap `drop_nodes` to enable class call.

    Args:
        nodes: A list of available node names to drop.

    Returns:
        dropped_nodes: A list of up to `self.n` nodes to drop.
    """
    return self.forward(nodes)

__init__(p, n)

Initialize Node Dropout Augmentation.

Parameters:

Name Type Description Default
p float

the probability with which to drop the nodes

required
n int

the maximum number of nodes to drop

required
Source code in dreem/datasets/data_utils.py
def __init__(self, p: float, n: int) -> None:
    """Initialize Node Dropout Augmentation.

    Args:
        p: the probability with which to drop the nodes
        n: the maximum number of nodes to drop
    """
    self.n = n
    self.p = p

forward(nodes)

Drop up to n random nodes with probability p.

Parameters:

Name Type Description Default
nodes list[str]

A list of available node names to drop.

required

Returns:

Name Type Description
dropped_nodes list[str]

A list of up to self.n nodes to drop.

Source code in dreem/datasets/data_utils.py
def forward(self, nodes: list[str]) -> list[str]:
    """Drop up to `n` random nodes with probability p.

    Args:
        nodes: A list of available node names to drop.

    Returns:
        dropped_nodes: A list of up to `self.n` nodes to drop.
    """
    if self.n == 0 or self.p == 0:
        return []

    nodes_to_drop = np.random.permutation(nodes)
    node_dropout_p = np.random.uniform(size=len(nodes_to_drop))

    dropped_node_inds = np.where(node_dropout_p < self.p)
    node_dropout_p = node_dropout_p[dropped_node_inds]

    n_nodes_to_drop = min(self.n, len(node_dropout_p))

    dropped_node_inds = np.argpartition(node_dropout_p, -n_nodes_to_drop)[
        -n_nodes_to_drop:
    ]

    dropped_nodes = nodes_to_drop[dropped_node_inds]

    return dropped_nodes

build_augmentations(augmentations)

Get augmentations for dataset.

Parameters:

Name Type Description Default
augmentations dict

a dict containing the name of the augmentations and their parameters

required

Returns:

Type Description
Compose

An Albumentations composition of different augmentations.

Source code in dreem/datasets/data_utils.py
def build_augmentations(augmentations: dict) -> A.Compose:
    """Get augmentations for dataset.

    Args:
        augmentations: a dict containing the name of the augmentations
                       and their parameters

    Returns:
        An Albumentations composition of different augmentations.
    """
    aug_list = []
    for aug_name, aug_args in augmentations.items():
        aug_class = getattr(A, aug_name)
        aug = aug_class(**aug_args)
        aug_list.append(aug)

    augs = A.Compose(
        aug_list,
        p=1.0,
        keypoint_params=A.KeypointParams(format="xy", remove_invisible=False),
    )

    return augs

centroid_bbox(points, anchors, crop_size)

Calculate bbox around instance centroid.

This is useful for ensuring that crops are centered around each instance in the case of incorrect pose estimates.

Parameters:

Name Type Description Default
points ArrayLike

2d array of centroid coordinates where each row corresponds to a different anchor point.

required
anchors list

indices of a given anchor point to use as the centroid

required
crop_size int

Integer specifying the crop height and width

required

Returns:

Type Description
Tensor

Bounding box in [y1, x1, y2, x2] format.

Source code in dreem/datasets/data_utils.py
def centroid_bbox(points: ArrayLike, anchors: list, crop_size: int) -> torch.Tensor:
    """Calculate bbox around instance centroid.

    This is useful for ensuring that crops are centered around each instance
    in the case of incorrect pose estimates.

    Args:
        points: 2d array of centroid coordinates where each row corresponds to a
            different anchor point.
        anchors: indices of a given anchor point to use as the centroid
        crop_size: Integer specifying the crop height and width

    Returns:
        Bounding box in [y1, x1, y2, x2] format.
    """
    for anchor in anchors:
        cx, cy = points[anchor][0], points[anchor][1]
        if not np.isnan(cx):
            break

    bbox = torch.Tensor(
        [
            -crop_size / 2 + cy,
            -crop_size / 2 + cx,
            crop_size / 2 + cy,
            crop_size / 2 + cx,
        ]
    )

    return bbox

crop_bbox(img, bbox)

Crop an image to a bounding box.

Parameters:

Name Type Description Default
img Tensor

Image as a tensor of shape (channels, height, width).

required
bbox ArrayLike

Bounding box in [y1, x1, y2, x2] format.

required

Returns:

Type Description
Tensor

Cropped pixels as tensor of shape (channels, height, width).

Source code in dreem/datasets/data_utils.py
def crop_bbox(img: torch.Tensor, bbox: ArrayLike) -> torch.Tensor:
    """Crop an image to a bounding box.

    Args:
        img: Image as a tensor of shape (channels, height, width).
        bbox: Bounding box in [y1, x1, y2, x2] format.

    Returns:
        Cropped pixels as tensor of shape (channels, height, width).
    """
    # Crop to the bounding box.
    y1, x1, y2, x2 = bbox
    crop = tvf.crop(
        img,
        top=int(y1.round()),
        left=int(x1.round()),
        height=int((y2 - y1).round()),
        width=int((x2 - x1).round()),
    )

    return crop

get_bbox(center, size)

Get a square bbox around a centroid coordinates.

Parameters:

Name Type Description Default
center ArrayLike

centroid coordinates in (x,y)

required
size int | tuple[int]

size of the bounding box

required

Returns:

Type Description
Tensor

A torch tensor in form y1, x1, y2, x2

Source code in dreem/datasets/data_utils.py
def get_bbox(center: ArrayLike, size: int | tuple[int]) -> torch.Tensor:
    """Get a square bbox around a centroid coordinates.

    Args:
        center: centroid coordinates in (x,y)
        size: size of the bounding box

    Returns:
        A torch tensor in form y1, x1, y2, x2
    """
    if isinstance(size, int):
        size = (size, size)
    cx, cy = center[0], center[1]

    y1 = max(0, -size[-1] // 2 + cy)
    x1 = max(0, -size[0] // 2 + cx)
    y2 = size[-1] // 2 + cy if y1 != 0 else size[1]
    x2 = size[0] // 2 + cx if x1 != 0 else size[0]
    bbox = torch.Tensor([y1, x1, y2, x2])

    return bbox

get_max_padding(height, width)

Calculate maximum padding dimensions for a given height and width.

Useful if padding is required for rotational augmentations, e.g when centroids lie on the borders of an image.

Parameters:

Name Type Description Default
height int

The original height.

required
width int

The original width.

required

Returns:

Type Description
tuple

A tuple containing the padded height and padded width.

Source code in dreem/datasets/data_utils.py
def get_max_padding(height: int, width: int) -> tuple:
    """Calculate maximum padding dimensions for a given height and width.

    Useful if padding is required for rotational augmentations, e.g when
    centroids lie on the borders of an image.

    Args:
        height: The original height.
        width: The original width.

    Returns:
        A tuple containing the padded height and padded width.
    """
    diagonal = math.ceil(math.sqrt(height**2 + width**2))

    padded_height = height + (diagonal - height)
    padded_width = width + (diagonal - width)

    return padded_height, padded_width

get_tight_bbox(pose)

Get a tight bbox around an instance.

Parameters:

Name Type Description Default
poses

array of keypoints around which to create the tight bbox

required

Returns:

Type Description
Tensor

A torch tensor in form y1, x1, y2, x2 representing the tight bbox

Source code in dreem/datasets/data_utils.py
def get_tight_bbox(pose: ArrayLike) -> torch.Tensor:
    """Get a tight bbox around an instance.

    Args:
        poses: array of keypoints around which to create the tight bbox

    Returns:
        A torch tensor in form y1, x1, y2, x2 representing the tight bbox
    """
    x_coords = pose[:, 0]
    y_coords = pose[:, 1]
    x1 = np.min(x_coords)
    x2 = np.max(x_coords)
    y1 = np.min(y_coords)
    y2 = np.max(y_coords)
    bbox = torch.Tensor([y1, x1, y2, x2])

    return bbox

get_tight_bbox_masks(mask)

Get a tight bbox around an instance.

Parameters:

Name Type Description Default
mask ArrayLike

mask of the instance

required

Returns:

Type Description
Tensor

A torch tensor in form y1, x1, y2, x2 representing the tight bbox

Source code in dreem/datasets/data_utils.py
def get_tight_bbox_masks(mask: ArrayLike) -> torch.Tensor:
    """Get a tight bbox around an instance.

    Args:
        mask: mask of the instance

    Returns:
        A torch tensor in form y1, x1, y2, x2 representing the tight bbox
    """
    max_x = np.asarray(mask != 0).nonzero()[1].max()
    max_y = np.asarray(mask != 0).nonzero()[0].max()
    min_x = np.asarray(mask != 0).nonzero()[1].min()
    min_y = np.asarray(mask != 0).nonzero()[0].min()
    bbox = torch.Tensor([min_y, min_x, max_y, max_x])

    return bbox

load_slp(labels_path, open_videos=True)

Read a SLEAP labels file.

Parameters:

Name Type Description Default
labels_path str

A string path to the SLEAP labels file.

required
open_videos bool

If True (the default), attempt to open the video backend for I/O. If False, the backend will not be opened (useful for reading metadata when the video files are not available).

True

Returns:

Type Description
Labels

The processed Labels object.

Source code in dreem/datasets/data_utils.py
def load_slp(labels_path: str, open_videos: bool = True) -> Labels:
    """Read a SLEAP labels file.

    Args:
        labels_path: A string path to the SLEAP labels file.
        open_videos: If `True` (the default), attempt to open the video backend for
            I/O. If `False`, the backend will not be opened (useful for reading metadata
            when the video files are not available).

    Returns:
        The processed `Labels` object.
    """
    tracks = read_tracks(labels_path)
    videos = read_videos(labels_path, open_backend=open_videos)
    skeletons = read_skeletons(labels_path)
    points = read_points(labels_path)
    pred_points = read_pred_points(labels_path)
    format_id = read_hdf5_attrs(labels_path, "metadata", "format_id")
    instances = read_instances(
        labels_path, skeletons, tracks, points, pred_points, format_id
    )
    metadata = read_metadata(labels_path)
    provenance = metadata.get("provenance", dict())

    frames = read_hdf5_dataset(labels_path, "frames")
    labeled_frames = []
    annotated_segments = []
    curr_segment_start = frames[0][2]
    curr_frame = curr_segment_start
    # note that frames only contains frames with labelled instances, not all frames
    for i, video_id, frame_idx, instance_id_start, instance_id_end in frames:
        # if no instances, don't add this frame to the labeled frames
        if len(instances[instance_id_start:instance_id_end]) == 0:
            continue

        labeled_frames.append(
            LabeledFrame(
                video=videos[video_id],
                frame_idx=int(frame_idx),
                instances=instances[instance_id_start:instance_id_end],
            )
        )
        if frame_idx == curr_frame:
            pass
        elif frame_idx == curr_frame + 1:
            curr_frame = frame_idx
        elif frame_idx > curr_frame + 1:
            annotated_segments.append((curr_segment_start, curr_frame))
            curr_segment_start = frame_idx
            curr_frame = frame_idx

    # add last segment
    annotated_segments.append((curr_segment_start, curr_frame))

    labels = Labels(
        labeled_frames=labeled_frames,
        videos=videos,
        skeletons=skeletons,
        tracks=tracks,
        provenance=provenance,
    )
    labels.provenance["filename"] = labels_path

    return labels, annotated_segments

pad_bbox(bbox, padding=16)

Pad bounding box coordinates.

Parameters:

Name Type Description Default
bbox ArrayLike

Bounding box in [y1, x1, y2, x2] format.

required
padding int

Padding to add to each side in pixels.

16

Returns:

Type Description
Tensor

Padded bounding box in [y1, x1, y2, x2] format.

Source code in dreem/datasets/data_utils.py
def pad_bbox(bbox: ArrayLike, padding: int = 16) -> torch.Tensor:
    """Pad bounding box coordinates.

    Args:
        bbox: Bounding box in [y1, x1, y2, x2] format.
        padding: Padding to add to each side in pixels.

    Returns:
        Padded bounding box in [y1, x1, y2, x2] format.
    """
    y1, x1, y2, x2 = bbox
    y1, x1 = y1 - padding, x1 - padding
    y2, x2 = y2 + padding, x2 + padding
    return torch.Tensor([y1, x1, y2, x2])

pad_variable_size_crops(instance, target_size)

Pad or crop an instance's crop to the target size.

Parameters:

Name Type Description Default
instance

Instance object with a crop attribute

required
target_size

Tuple of (height, width) for the target size

required

Returns:

Type Description

The instance with modified crop

Source code in dreem/datasets/data_utils.py
def pad_variable_size_crops(instance, target_size):
    """Pad or crop an instance's crop to the target size.

    Args:
        instance: Instance object with a crop attribute
        target_size: Tuple of (height, width) for the target size

    Returns:
        The instance with modified crop
    """
    _, c, h, w = instance.crop.shape
    target_h, target_w = target_size

    # Crop the image further if target_size is smaller than current crop size
    if h > target_h or w > target_w:
        instance.crop = tvf.center_crop(
            instance.crop, (min(h, target_h), min(w, target_w))
        )

    _, c, h, w = instance.crop.shape

    if h < target_h or w < target_w:
        # If height or width is smaller than target size, pad the image to target_size
        pad_w = max(0, target_w - w)
        pad_h = max(0, target_h - h)

        pad_w_left = pad_w // 2
        pad_w_right = pad_w - pad_w_left

        pad_h_top = pad_h // 2
        pad_h_bottom = pad_h - pad_h_top

        # Apply padding
        instance.crop = tvf.pad(
            instance.crop,
            (pad_w_left, pad_h_top, pad_w_right, pad_h_bottom),
            0,
            "constant",
        )

    return instance

parse_synthetic(xml_path, source='icy')

Parse .xml labels from synthetic data generated by ICY or ISBI tracking challenge.

Logic adapted from https://github.com/sylvainprigent/napari-tracks-reader/blob/main/napari_tracks_reader

Parameters:

Name Type Description Default
xml_path str

path to .xml file containing ICY or ISBI gt trajectory labels

required
source str

synthetic dataset type. Should be either icy or isbi

'icy'

Returns:

Type Description
DataFrame

pandas DataFrame containing frame idx, gt track id and centroid x,y coordinates in pixels

Source code in dreem/datasets/data_utils.py
def parse_synthetic(xml_path: str, source: str = "icy") -> pd.DataFrame:
    """Parse .xml labels from synthetic data generated by ICY or ISBI tracking challenge.

    Logic adapted from https://github.com/sylvainprigent/napari-tracks-reader/blob/main/napari_tracks_reader

    Args:
        xml_path: path to .xml file containing ICY or ISBI gt trajectory labels
        source: synthetic dataset type. Should be either icy or isbi

    Returns:
        pandas DataFrame containing frame idx, gt track id
        and centroid x,y coordinates in pixels
    """
    if source.lower() == "icy":
        root_tag = "trackgroup"
    elif source.lower() == "isbi":
        root_tag = "TrackContestISBI2012"
    else:
        raise ValueError(f"{source} source mode not supported")

    tree = et.parse(xml_path)

    root = tree.getroot()
    tracks = np.empty((0, 4))

    # get the trackgroup element
    idx_trackgroup = 0
    for i in range(len(root)):
        if root[i].tag == root_tag:
            idx_trackgroup = i
            break

    ids_map = {}
    track_id = -1
    for track_element in root[idx_trackgroup]:
        track_id += 1

        try:
            ids_map[track_element.attrib["id"]] = track_id
        except:
            pass

        for detection_element in track_element:
            row = [
                float(track_id),
                float(detection_element.attrib["t"]),
                float(detection_element.attrib["y"]),
                float(detection_element.attrib["x"]),
            ]
            tracks = np.concatenate((tracks, [row]), axis=0)

    tracks_df = pd.DataFrame(
        tracks, columns=["TRACK_ID", "FRAME", "POSITION_Y", "POSITION_X"]
    )

    tracks_df = tracks_df.apply(pd.to_numeric, errors="coerce", downcast="integer")

    return tracks_df

parse_trackmate(data_path)

Parse trackmate xml or csv labels file.

Logic adapted from https://github.com/hadim/pytrackmate.

Parameters:

Name Type Description Default
data_path str

string path to xml or csv file storing trackmate trajectory labels

required

Returns:

Type Description
DataFrame

pandas DataFrame containing frame number, track_ids, and centroid x,y coordinates in pixels

Source code in dreem/datasets/data_utils.py
def parse_trackmate(data_path: str) -> pd.DataFrame:
    """Parse trackmate xml or csv labels file.

    Logic adapted from https://github.com/hadim/pytrackmate.

    Args:
        data_path: string path to xml or csv file storing trackmate trajectory labels

    Returns:
        `pandas DataFrame` containing frame number, track_ids,
        and centroid x,y coordinates in pixels
    """
    if data_path.endswith(".xml"):
        root = et.fromstring(open(data_path).read())

        objects = []
        features = root.find("Model").find("FeatureDeclarations").find("SpotFeatures")
        features = [c.get("feature") for c in list(features)] + ["ID"]

        spots = root.find("Model").find("AllSpots")

        objects = []

        for frame in spots.findall("SpotsInFrame"):
            for spot in frame.findall("Spot"):
                single_object = []
                for label in features:
                    single_object.append(spot.get(label))
                objects.append(single_object)

        tracks_df = pd.DataFrame(objects, columns=features)
        tracks_df = tracks_df.astype(np.float)

        filtered_track_ids = [
            int(track.get("TRACK_ID"))
            for track in root.find("Model").find("FilteredTracks").findall("TrackID")
        ]

        label_id = 0
        tracks_df["label"] = np.nan

        tracks = root.find("Model").find("AllTracks")
        for track in tracks.findall("Track"):
            track_id = int(track.get("TRACK_ID"))
            if track_id in filtered_track_ids:
                spot_ids = [
                    (
                        edge.get("SPOT_SOURCE_ID"),
                        edge.get("SPOT_TARGET_ID"),
                        edge.get("EDGE_TIME"),
                    )
                    for edge in track.findall("Edge")
                ]
                spot_ids = np.array(spot_ids).astype("float")[:, :2]
                spot_ids = set(spot_ids.flatten())

                tracks_df.loc[tracks_df["ID"].isin(spot_ids), "TRACK_ID"] = label_id
                label_id += 1

    elif data_path.endswith(".csv"):
        tracks_df = pd.read_csv(data_path, encoding="ISO-8859-1")

    else:
        raise ValueError(f"Unsupported trackmate file extension: {data_path}")

    tracks_df = tracks_df.apply(pd.to_numeric, errors="coerce", downcast="integer")

    posx_key = "POSITION_X"
    posy_key = "POSITION_Y"
    frame_key = "FRAME"
    track_key = "TRACK_ID"

    mapper = {
        "X": posx_key,
        "Y": posy_key,
        "x": posx_key,
        "y": posy_key,
        "Slice n°": frame_key,
        "Track n°": track_key,
    }

    if "t" in tracks_df:
        mapper.update({"t": frame_key})

    tracks_df = tracks_df.rename(mapper=mapper, axis=1)

    if data_path.endswith(".csv"):
        # 0 index track and frame ids
        if min(tracks_df[frame_key]) == 1:
            tracks_df[frame_key] = tracks_df[frame_key] - 1

        if min(tracks_df[track_key] == 1):
            tracks_df[track_key] = tracks_df[track_key] - 1

    return tracks_df

pose_bbox(points, bbox_size)

Calculate bbox around instance pose.

Parameters:

Name Type Description Default
points ndarray

an np array of shape nodes x 2,

required
bbox_size tuple[int] | int

size of bbox either an int indicating square bbox or in (x,y)

required

Returns:

Type Description
Tensor

Bounding box in [y1, x1, y2, x2] format.

Source code in dreem/datasets/data_utils.py
def pose_bbox(points: np.ndarray, bbox_size: tuple[int] | int) -> torch.Tensor:
    """Calculate bbox around instance pose.

    Args:
        points: an np array of shape nodes x 2,
        bbox_size: size of bbox either an int indicating square bbox or in (x,y)

    Returns:
        Bounding box in [y1, x1, y2, x2] format.
    """
    if isinstance(bbox_size, int):
        bbox_size = (bbox_size, bbox_size)

    c = np.nanmean(points, axis=0)
    bbox = torch.Tensor(
        [
            c[-1] - bbox_size[-1] / 2,
            c[0] - bbox_size[0] / 2,
            c[-1] + bbox_size[-1] / 2,
            c[0] + bbox_size[0] / 2,
        ]
    )
    return bbox

resize_and_pad(img, output_size)

Resize and pad an image to fit a square output size.

Parameters:

Name Type Description Default
img Tensor

Image as a tensor of shape (channels, height, width).

required
output_size int

Integer size of height and width of output.

required

Returns:

Type Description
Tensor

The image zero padded to be of shape (channels, output_size, output_size).

Source code in dreem/datasets/data_utils.py
def resize_and_pad(img: torch.Tensor, output_size: int) -> torch.Tensor:
    """Resize and pad an image to fit a square output size.

    Args:
        img: Image as a tensor of shape (channels, height, width).
        output_size: Integer size of height and width of output.

    Returns:
        The image zero padded to be of shape (channels, output_size, output_size).
    """
    # Figure out how to scale without breaking aspect ratio.
    img_height, img_width = img.shape[-2:]
    if img_width < img_height:  # taller
        crop_height = output_size
        scale = crop_height / img_height
        crop_width = int(img_width * scale)
    else:  # wider
        crop_width = output_size
        scale = crop_width / img_width
        crop_height = int(img_height * scale)

    # Scale without breaking aspect ratio.
    img = tvf.resize(img, size=[crop_height, crop_width])

    # Pad to square.
    img_height, img_width = img.shape[-2:]
    hp1 = int((output_size - img_width) / 2)
    vp1 = int((output_size - img_height) / 2)
    hp2 = output_size - (img_width + hp1)
    vp2 = output_size - (img_height + vp1)
    padding = (hp1, vp1, hp2, vp2)
    return tvf.pad(img, padding, 0, "constant")

sorted_anchors(labels)

Sort anchor names from most instances with that node to least.

Parameters:

Name Type Description Default
labels Labels

a sleap_io.labels object containing all the labels for that video

required

Returns:

Type Description
list[str]

A list of anchor names sorted by most nodes to least nodes

Source code in dreem/datasets/data_utils.py
def sorted_anchors(labels: sio.Labels) -> list[str]:
    """Sort anchor names from most instances with that node to least.

    Args:
        labels: a sleap_io.labels object containing all the labels for that video

    Returns:
        A list of anchor names sorted by most nodes to least nodes
    """
    all_anchors = labels.skeletons[0].node_names

    anchor_counts = {anchor: 0 for anchor in all_anchors}

    for i in range(len(labels)):
        lf = labels[i]
        for instance in lf:
            for anchor in all_anchors:
                x, y = instance[anchor].x, instance[anchor].y
                if np.isnan(x) or np.isnan(y):
                    anchor_counts[anchor] += 1

    sorted_anchors = sorted(anchor_counts.keys(), key=lambda k: anchor_counts[k])

    return sorted_anchors

view_training_batch(instances, num_frames=1, cmap=None)

Display a grid of images from a batch of training instances.

Parameters:

Name Type Description Default
instances list[dict[str, list[ndarray]]]

A list of training instances, where each instance is a dictionary containing the object crops.

required
num_frames int

The number of frames to display per instance.

1

Returns:

Type Description
None

None

Source code in dreem/datasets/data_utils.py
def view_training_batch(
    instances: list[dict[str, list[np.ndarray]]], num_frames: int = 1, cmap=None
) -> None:
    """Display a grid of images from a batch of training instances.

    Args:
        instances: A list of training instances, where each instance is a
            dictionary containing the object crops.
        num_frames: The number of frames to display per instance.

    Returns:
        None
    """
    num_crops = len(instances[0]["crops"])
    num_columns = num_crops
    num_rows = num_frames

    base_size = 2
    fig_size = (base_size * num_columns, base_size * num_rows)

    fig, axes = plt.subplots(num_rows, num_columns, figsize=fig_size)

    for i in range(num_frames):
        for j, data in enumerate(instances[i]["crops"]):
            try:
                ax = (
                    axes[j]
                    if num_frames == 1
                    else (axes[i] if num_crops == 1 else axes[i, j])
                )

                (ax.imshow(data.T) if cmap is None else ax.imshow(data.T, cmap=cmap))
                ax.axis("off")

            except Exception as e:
                print(e)
                pass

    plt.tight_layout()
    plt.show()