Skip to content

data_utils

dreem.datasets.data_utils

Module containing helper functions for datasets.

LazyTiffStack

Class used for loading tiffs without loading into memory.

Source code in dreem/datasets/data_utils.py
class LazyTiffStack:
    """Class used for loading tiffs without loading into memory."""

    def __init__(self, filename: str):
        """Initialize class.

        Args:
            filename: name of tif file to be opened
        """
        # expects spatial, channels
        self.image = Image.open(filename)

    def __getitem__(self, section_idx: int) -> Image:
        """Get frame.

        Args:
            section_idx: index of frame or z-slice to get.

        Returns:
            a PIL image of that frame/z-slice.
        """
        self.image.seek(section_idx)
        return self.image

    def get_section(self, section_idx: int) -> np.array:
        """Get frame as ndarray.

        Args:
            section_idx: index of frame or z-slice to get.

        Returns:
            an np.array of that frame/z-slice.
        """
        section = self.__getitem__(section_idx)
        return np.array(section)

    def close(self):
        """Close tiff stack."""
        self.file.close()

__getitem__(section_idx)

Get frame.

Parameters:

Name Type Description Default
section_idx int

index of frame or z-slice to get.

required

Returns:

Type Description
Image

a PIL image of that frame/z-slice.

Source code in dreem/datasets/data_utils.py
def __getitem__(self, section_idx: int) -> Image:
    """Get frame.

    Args:
        section_idx: index of frame or z-slice to get.

    Returns:
        a PIL image of that frame/z-slice.
    """
    self.image.seek(section_idx)
    return self.image

__init__(filename)

Initialize class.

Parameters:

Name Type Description Default
filename str

name of tif file to be opened

required
Source code in dreem/datasets/data_utils.py
def __init__(self, filename: str):
    """Initialize class.

    Args:
        filename: name of tif file to be opened
    """
    # expects spatial, channels
    self.image = Image.open(filename)

close()

Close tiff stack.

Source code in dreem/datasets/data_utils.py
def close(self):
    """Close tiff stack."""
    self.file.close()

get_section(section_idx)

Get frame as ndarray.

Parameters:

Name Type Description Default
section_idx int

index of frame or z-slice to get.

required

Returns:

Type Description
array

an np.array of that frame/z-slice.

Source code in dreem/datasets/data_utils.py
def get_section(self, section_idx: int) -> np.array:
    """Get frame as ndarray.

    Args:
        section_idx: index of frame or z-slice to get.

    Returns:
        an np.array of that frame/z-slice.
    """
    section = self.__getitem__(section_idx)
    return np.array(section)

NodeDropout

Node dropout augmentation.

Drop up to n nodes with probability p.

Source code in dreem/datasets/data_utils.py
class NodeDropout:
    """Node dropout augmentation.

    Drop up to `n` nodes with probability `p`.
    """

    def __init__(self, p: float, n: int) -> None:
        """Initialize Node Dropout Augmentation.

        Args:
            p: the probability with which to drop the nodes
            n: the maximum number of nodes to drop
        """
        self.n = n
        self.p = p

    def __call__(self, nodes: list[str]) -> list[str]:
        """Wrap `drop_nodes` to enable class call.

        Args:
            nodes: A list of available node names to drop.

        Returns:
            dropped_nodes: A list of up to `self.n` nodes to drop.
        """
        return self.forward(nodes)

    def forward(self, nodes: list[str]) -> list[str]:
        """Drop up to `n` random nodes with probability p.

        Args:
            nodes: A list of available node names to drop.

        Returns:
            dropped_nodes: A list of up to `self.n` nodes to drop.
        """
        if self.n == 0 or self.p == 0:
            return []

        nodes_to_drop = np.random.permutation(nodes)
        node_dropout_p = np.random.uniform(size=len(nodes_to_drop))

        dropped_node_inds = np.where(node_dropout_p < self.p)
        node_dropout_p = node_dropout_p[dropped_node_inds]

        n_nodes_to_drop = min(self.n, len(node_dropout_p))

        dropped_node_inds = np.argpartition(node_dropout_p, -n_nodes_to_drop)[
            -n_nodes_to_drop:
        ]

        dropped_nodes = nodes_to_drop[dropped_node_inds]

        return dropped_nodes

__call__(nodes)

Wrap drop_nodes to enable class call.

Parameters:

Name Type Description Default
nodes list[str]

A list of available node names to drop.

required

Returns:

Name Type Description
dropped_nodes list[str]

A list of up to self.n nodes to drop.

Source code in dreem/datasets/data_utils.py
def __call__(self, nodes: list[str]) -> list[str]:
    """Wrap `drop_nodes` to enable class call.

    Args:
        nodes: A list of available node names to drop.

    Returns:
        dropped_nodes: A list of up to `self.n` nodes to drop.
    """
    return self.forward(nodes)

__init__(p, n)

Initialize Node Dropout Augmentation.

Parameters:

Name Type Description Default
p float

the probability with which to drop the nodes

required
n int

the maximum number of nodes to drop

required
Source code in dreem/datasets/data_utils.py
def __init__(self, p: float, n: int) -> None:
    """Initialize Node Dropout Augmentation.

    Args:
        p: the probability with which to drop the nodes
        n: the maximum number of nodes to drop
    """
    self.n = n
    self.p = p

forward(nodes)

Drop up to n random nodes with probability p.

Parameters:

Name Type Description Default
nodes list[str]

A list of available node names to drop.

required

Returns:

Name Type Description
dropped_nodes list[str]

A list of up to self.n nodes to drop.

Source code in dreem/datasets/data_utils.py
def forward(self, nodes: list[str]) -> list[str]:
    """Drop up to `n` random nodes with probability p.

    Args:
        nodes: A list of available node names to drop.

    Returns:
        dropped_nodes: A list of up to `self.n` nodes to drop.
    """
    if self.n == 0 or self.p == 0:
        return []

    nodes_to_drop = np.random.permutation(nodes)
    node_dropout_p = np.random.uniform(size=len(nodes_to_drop))

    dropped_node_inds = np.where(node_dropout_p < self.p)
    node_dropout_p = node_dropout_p[dropped_node_inds]

    n_nodes_to_drop = min(self.n, len(node_dropout_p))

    dropped_node_inds = np.argpartition(node_dropout_p, -n_nodes_to_drop)[
        -n_nodes_to_drop:
    ]

    dropped_nodes = nodes_to_drop[dropped_node_inds]

    return dropped_nodes

build_augmentations(augmentations)

Get augmentations for dataset.

Parameters:

Name Type Description Default
augmentations dict

a dict containing the name of the augmentations and their parameters

required

Returns:

Type Description
Compose

An Albumentations composition of different augmentations.

Source code in dreem/datasets/data_utils.py
def build_augmentations(augmentations: dict) -> A.Compose:
    """Get augmentations for dataset.

    Args:
        augmentations: a dict containing the name of the augmentations
                       and their parameters

    Returns:
        An Albumentations composition of different augmentations.
    """
    aug_list = []
    for aug_name, aug_args in augmentations.items():
        aug_class = getattr(A, aug_name)
        aug = aug_class(**aug_args)
        aug_list.append(aug)

    augs = A.Compose(
        aug_list,
        p=1.0,
        keypoint_params=A.KeypointParams(format="xy", remove_invisible=False),
    )

    return augs

centroid_bbox(points, anchors, crop_size)

Calculate bbox around instance centroid.

This is useful for ensuring that crops are centered around each instance in the case of incorrect pose estimates.

Parameters:

Name Type Description Default
points ArrayLike

2d array of centroid coordinates where each row corresponds to a different anchor point.

required
anchors list

indices of a given anchor point to use as the centroid

required
crop_size int

Integer specifying the crop height and width

required

Returns:

Type Description
Tensor

Bounding box in [y1, x1, y2, x2] format.

Source code in dreem/datasets/data_utils.py
def centroid_bbox(points: ArrayLike, anchors: list, crop_size: int) -> torch.Tensor:
    """Calculate bbox around instance centroid.

    This is useful for ensuring that crops are centered around each instance
    in the case of incorrect pose estimates.

    Args:
        points: 2d array of centroid coordinates where each row corresponds to a
            different anchor point.
        anchors: indices of a given anchor point to use as the centroid
        crop_size: Integer specifying the crop height and width

    Returns:
        Bounding box in [y1, x1, y2, x2] format.
    """
    print(anchors)
    for anchor in anchors:
        cx, cy = points[anchor][0], points[anchor][1]
        if not np.isnan(cx):
            break

    bbox = torch.Tensor(
        [
            -crop_size / 2 + cy,
            -crop_size / 2 + cx,
            crop_size / 2 + cy,
            crop_size / 2 + cx,
        ]
    )

    return bbox

crop_bbox(img, bbox)

Crop an image to a bounding box.

Parameters:

Name Type Description Default
img Tensor

Image as a tensor of shape (channels, height, width).

required
bbox ArrayLike

Bounding box in [y1, x1, y2, x2] format.

required

Returns:

Type Description
Tensor

Cropped pixels as tensor of shape (channels, height, width).

Source code in dreem/datasets/data_utils.py
def crop_bbox(img: torch.Tensor, bbox: ArrayLike) -> torch.Tensor:
    """Crop an image to a bounding box.

    Args:
        img: Image as a tensor of shape (channels, height, width).
        bbox: Bounding box in [y1, x1, y2, x2] format.

    Returns:
        Cropped pixels as tensor of shape (channels, height, width).
    """
    # Crop to the bounding box.
    y1, x1, y2, x2 = bbox
    crop = tvf.crop(
        img,
        top=int(y1.round()),
        left=int(x1.round()),
        height=int((y2 - y1).round()),
        width=int((x2 - x1).round()),
    )

    return crop

get_bbox(center, size)

Get a square bbox around a centroid coordinates.

Parameters:

Name Type Description Default
center ArrayLike

centroid coordinates in (x,y)

required
size Union[int, tuple[int]]

size of the bounding box

required

Returns:

Type Description
Tensor

A torch tensor in form y1, x1, y2, x2

Source code in dreem/datasets/data_utils.py
def get_bbox(center: ArrayLike, size: Union[int, tuple[int]]) -> torch.Tensor:
    """Get a square bbox around a centroid coordinates.

    Args:
        center: centroid coordinates in (x,y)
        size: size of the bounding box

    Returns:
        A torch tensor in form y1, x1, y2, x2
    """
    if isinstance(size, int):
        size = (size, size)
    cx, cy = center[0], center[1]

    y1 = max(0, -size[-1] // 2 + cy)
    x1 = max(0, -size[0] // 2 + cx)
    y2 = size[-1] // 2 + cy if y1 != 0 else size[1]
    x2 = size[0] // 2 + cx if x1 != 0 else size[0]
    bbox = torch.Tensor([y1, x1, y2, x2])

    return bbox

get_max_padding(height, width)

Calculate maximum padding dimensions for a given height and width.

Useful if padding is required for rotational augmentations, e.g when centroids lie on the borders of an image.

Parameters:

Name Type Description Default
height int

The original height.

required
width int

The original width.

required

Returns:

Type Description
tuple

A tuple containing the padded height and padded width.

Source code in dreem/datasets/data_utils.py
def get_max_padding(height: int, width: int) -> tuple:
    """Calculate maximum padding dimensions for a given height and width.

    Useful if padding is required for rotational augmentations, e.g when
    centroids lie on the borders of an image.

    Args:
        height: The original height.
        width: The original width.

    Returns:
        A tuple containing the padded height and padded width.
    """
    diagonal = math.ceil(math.sqrt(height**2 + width**2))

    padded_height = height + (diagonal - height)
    padded_width = width + (diagonal - width)

    return padded_height, padded_width

pad_bbox(bbox, padding=16)

Pad bounding box coordinates.

Parameters:

Name Type Description Default
bbox ArrayLike

Bounding box in [y1, x1, y2, x2] format.

required
padding int

Padding to add to each side in pixels.

16

Returns:

Type Description
Tensor

Padded bounding box in [y1, x1, y2, x2] format.

Source code in dreem/datasets/data_utils.py
def pad_bbox(bbox: ArrayLike, padding: int = 16) -> torch.Tensor:
    """Pad bounding box coordinates.

    Args:
        bbox: Bounding box in [y1, x1, y2, x2] format.
        padding: Padding to add to each side in pixels.

    Returns:
        Padded bounding box in [y1, x1, y2, x2] format.
    """
    y1, x1, y2, x2 = bbox
    y1, x1 = y1 - padding, x1 - padding
    y2, x2 = y2 + padding, x2 + padding
    return torch.Tensor([y1, x1, y2, x2])

parse_synthetic(xml_path, source='icy')

Parse .xml labels from synthetic data generated by ICY or ISBI tracking challenge.

Logic adapted from https://github.com/sylvainprigent/napari-tracks-reader/blob/main/napari_tracks_reader

Parameters:

Name Type Description Default
xml_path str

path to .xml file containing ICY or ISBI gt trajectory labels

required
source str

synthetic dataset type. Should be either icy or isbi

'icy'

Returns:

Type Description
DataFrame

pandas DataFrame containing frame idx, gt track id and centroid x,y coordinates in pixels

Source code in dreem/datasets/data_utils.py
def parse_synthetic(xml_path: str, source: str = "icy") -> pd.DataFrame:
    """Parse .xml labels from synthetic data generated by ICY or ISBI tracking challenge.

    Logic adapted from https://github.com/sylvainprigent/napari-tracks-reader/blob/main/napari_tracks_reader

    Args:
        xml_path: path to .xml file containing ICY or ISBI gt trajectory labels
        source: synthetic dataset type. Should be either icy or isbi

    Returns:
        pandas DataFrame containing frame idx, gt track id
        and centroid x,y coordinates in pixels
    """
    if source.lower() == "icy":
        root_tag = "trackgroup"
    elif source.lower() == "isbi":
        root_tag = "TrackContestISBI2012"
    else:
        raise ValueError(f"{source} source mode not supported")

    tree = et.parse(xml_path)

    root = tree.getroot()
    tracks = np.empty((0, 4))

    # get the trackgroup element
    idx_trackgroup = 0
    for i in range(len(root)):
        if root[i].tag == root_tag:
            idx_trackgroup = i
            break

    ids_map = {}
    track_id = -1
    for track_element in root[idx_trackgroup]:
        track_id += 1

        try:
            ids_map[track_element.attrib["id"]] = track_id
        except:
            pass

        for detection_element in track_element:
            row = [
                float(track_id),
                float(detection_element.attrib["t"]),
                float(detection_element.attrib["y"]),
                float(detection_element.attrib["x"]),
            ]
            tracks = np.concatenate((tracks, [row]), axis=0)

    tracks_df = pd.DataFrame(
        tracks, columns=["TRACK_ID", "FRAME", "POSITION_Y", "POSITION_X"]
    )

    tracks_df = tracks_df.apply(pd.to_numeric, errors="coerce", downcast="integer")

    return tracks_df

parse_trackmate(data_path)

Parse trackmate xml or csv labels file.

Logic adapted from https://github.com/hadim/pytrackmate.

Parameters:

Name Type Description Default
data_path str

string path to xml or csv file storing trackmate trajectory labels

required

Returns:

Type Description
DataFrame

pandas DataFrame containing frame number, track_ids, and centroid x,y coordinates in pixels

Source code in dreem/datasets/data_utils.py
def parse_trackmate(data_path: str) -> pd.DataFrame:
    """Parse trackmate xml or csv labels file.

    Logic adapted from https://github.com/hadim/pytrackmate.

    Args:
        data_path: string path to xml or csv file storing trackmate trajectory labels

    Returns:
        `pandas DataFrame` containing frame number, track_ids,
        and centroid x,y coordinates in pixels
    """
    if data_path.endswith(".xml"):
        root = et.fromstring(open(data_path).read())

        objects = []
        features = root.find("Model").find("FeatureDeclarations").find("SpotFeatures")
        features = [c.get("feature") for c in list(features)] + ["ID"]

        spots = root.find("Model").find("AllSpots")

        objects = []

        for frame in spots.findall("SpotsInFrame"):
            for spot in frame.findall("Spot"):
                single_object = []
                for label in features:
                    single_object.append(spot.get(label))
                objects.append(single_object)

        tracks_df = pd.DataFrame(objects, columns=features)
        tracks_df = tracks_df.astype(np.float)

        filtered_track_ids = [
            int(track.get("TRACK_ID"))
            for track in root.find("Model").find("FilteredTracks").findall("TrackID")
        ]

        label_id = 0
        tracks_df["label"] = np.nan

        tracks = root.find("Model").find("AllTracks")
        for track in tracks.findall("Track"):
            track_id = int(track.get("TRACK_ID"))
            if track_id in filtered_track_ids:
                spot_ids = [
                    (
                        edge.get("SPOT_SOURCE_ID"),
                        edge.get("SPOT_TARGET_ID"),
                        edge.get("EDGE_TIME"),
                    )
                    for edge in track.findall("Edge")
                ]
                spot_ids = np.array(spot_ids).astype("float")[:, :2]
                spot_ids = set(spot_ids.flatten())

                tracks_df.loc[tracks_df["ID"].isin(spot_ids), "TRACK_ID"] = label_id
                label_id += 1

    elif data_path.endswith(".csv"):
        tracks_df = pd.read_csv(data_path, encoding="ISO-8859-1")

    else:
        raise ValueError(f"Unsupported trackmate file extension: {data_path}")

    tracks_df = tracks_df.apply(pd.to_numeric, errors="coerce", downcast="integer")

    posx_key = "POSITION_X"
    posy_key = "POSITION_Y"
    frame_key = "FRAME"
    track_key = "TRACK_ID"

    mapper = {
        "X": posx_key,
        "Y": posy_key,
        "x": posx_key,
        "y": posy_key,
        "Slice n°": frame_key,
        "Track n°": track_key,
    }

    if "t" in tracks_df:
        mapper.update({"t": frame_key})

    tracks_df = tracks_df.rename(mapper=mapper, axis=1)

    if data_path.endswith(".csv"):
        # 0 index track and frame ids
        if min(tracks_df[frame_key]) == 1:
            tracks_df[frame_key] = tracks_df[frame_key] - 1

        if min(tracks_df[track_key] == 1):
            tracks_df[track_key] = tracks_df[track_key] - 1

    return tracks_df

pose_bbox(points, bbox_size)

Calculate bbox around instance pose.

Parameters:

Name Type Description Default
points ndarray

an np array of shape nodes x 2,

required
bbox_size Union[tuple[int], int]

size of bbox either an int indicating square bbox or in (x,y)

required

Returns:

Type Description
Tensor

Bounding box in [y1, x1, y2, x2] format.

Source code in dreem/datasets/data_utils.py
def pose_bbox(points: np.ndarray, bbox_size: Union[tuple[int], int]) -> torch.Tensor:
    """Calculate bbox around instance pose.

    Args:
        points: an np array of shape nodes x 2,
        bbox_size: size of bbox either an int indicating square bbox or in (x,y)

    Returns:
        Bounding box in [y1, x1, y2, x2] format.
    """
    if isinstance(bbox_size, int):
        bbox_size = (bbox_size, bbox_size)
    # print(points)

    c = np.nanmean(points, axis=0)
    bbox = torch.Tensor(
        [
            c[-1] - bbox_size[-1] / 2,
            c[0] - bbox_size[0] / 2,
            c[-1] + bbox_size[-1] / 2,
            c[0] + bbox_size[0] / 2,
        ]
    )
    return bbox

resize_and_pad(img, output_size)

Resize and pad an image to fit a square output size.

Parameters:

Name Type Description Default
img Tensor

Image as a tensor of shape (channels, height, width).

required
output_size int

Integer size of height and width of output.

required

Returns:

Type Description
Tensor

The image zero padded to be of shape (channels, output_size, output_size).

Source code in dreem/datasets/data_utils.py
def resize_and_pad(img: torch.Tensor, output_size: int) -> torch.Tensor:
    """Resize and pad an image to fit a square output size.

    Args:
        img: Image as a tensor of shape (channels, height, width).
        output_size: Integer size of height and width of output.

    Returns:
        The image zero padded to be of shape (channels, output_size, output_size).
    """
    # Figure out how to scale without breaking aspect ratio.
    img_height, img_width = img.shape[-2:]
    if img_width < img_height:  # taller
        crop_height = output_size
        scale = crop_height / img_height
        crop_width = int(img_width * scale)
    else:  # wider
        crop_width = output_size
        scale = crop_width / img_width
        crop_height = int(img_height * scale)

    # Scale without breaking aspect ratio.
    img = tvf.resize(img, size=[crop_height, crop_width])

    # Pad to square.
    img_height, img_width = img.shape[-2:]
    hp1 = int((output_size - img_width) / 2)
    vp1 = int((output_size - img_height) / 2)
    hp2 = output_size - (img_width + hp1)
    vp2 = output_size - (img_height + vp1)
    padding = (hp1, vp1, hp2, vp2)
    return tvf.pad(img, padding, 0, "constant")

sorted_anchors(labels)

Sort anchor names from most instances with that node to least.

Parameters:

Name Type Description Default
labels Labels

a sleap_io.labels object containing all the labels for that video

required

Returns:

Type Description
list[str]

A list of anchor names sorted by most nodes to least nodes

Source code in dreem/datasets/data_utils.py
def sorted_anchors(labels: sio.Labels) -> list[str]:
    """Sort anchor names from most instances with that node to least.

    Args:
        labels: a sleap_io.labels object containing all the labels for that video

    Returns:
        A list of anchor names sorted by most nodes to least nodes
    """
    all_anchors = labels.skeletons[0].node_names

    anchor_counts = {anchor: 0 for anchor in all_anchors}

    for i in range(len(labels)):
        lf = labels[i]
        for instance in lf:
            for anchor in all_anchors:
                x, y = instance[anchor].x, instance[anchor].y
                if np.isnan(x) or np.isnan(y):
                    anchor_counts[anchor] += 1

    sorted_anchors = sorted(anchor_counts.keys(), key=lambda k: anchor_counts[k])

    return sorted_anchors

view_training_batch(instances, num_frames=1, cmap=None)

Display a grid of images from a batch of training instances.

Parameters:

Name Type Description Default
instances List[Dict[str, List[ndarray]]]

A list of training instances, where each instance is a dictionary containing the object crops.

required
num_frames int

The number of frames to display per instance.

1

Returns:

Type Description
None

None

Source code in dreem/datasets/data_utils.py
def view_training_batch(
    instances: List[Dict[str, List[np.ndarray]]], num_frames: int = 1, cmap=None
) -> None:
    """Display a grid of images from a batch of training instances.

    Args:
        instances: A list of training instances, where each instance is a
            dictionary containing the object crops.
        num_frames: The number of frames to display per instance.

    Returns:
        None
    """
    num_crops = len(instances[0]["crops"])
    num_columns = num_crops
    num_rows = num_frames

    base_size = 2
    fig_size = (base_size * num_columns, base_size * num_rows)

    fig, axes = plt.subplots(num_rows, num_columns, figsize=fig_size)

    for i in range(num_frames):
        for j, data in enumerate(instances[i]["crops"]):
            try:
                ax = (
                    axes[j]
                    if num_frames == 1
                    else (axes[i] if num_crops == 1 else axes[i, j])
                )

                (ax.imshow(data.T) if cmap is None else ax.imshow(data.T, cmap=cmap))
                ax.axis("off")

            except Exception as e:
                print(e)
                pass

    plt.tight_layout()
    plt.show()