Skip to content

cell_tracking_dataset

dreem.datasets.cell_tracking_dataset

Module containing cell tracking challenge dataset.

Classes:

Name Description
CellTrackingDataset

Dataset for loading cell tracking challenge data.

CellTrackingDataset

Bases: BaseDataset

Dataset for loading cell tracking challenge data.

Methods:

Name Description
__init__

Initialize CellTrackingDataset.

get_indices

Retrieve label and frame indices given batch index.

get_instances

Get an element of the dataset.

Source code in dreem/datasets/cell_tracking_dataset.py
class CellTrackingDataset(BaseDataset):
    """Dataset for loading cell tracking challenge data."""

    def __init__(
        self,
        gt_list: list[list[str]],
        raw_img_list: list[list[str]],
        data_dirs: Optional[list[str]] = None,
        padding: int = 5,
        crop_size: int = 20,
        chunk: bool = False,
        clip_length: int = 10,
        mode: str = "train",
        augmentations: dict | None = None,
        n_chunks: int | float = 1.0,
        seed: int | None = None,
        max_batching_gap: int = 15,
        use_tight_bbox: bool = False,
        ctc_track_meta: list[str] | None = None,
        **kwargs,
    ):
        """Initialize CellTrackingDataset.

        Args:
            gt_list: filepaths of gt label images in a list of lists (each list corresponds to a dataset)
            raw_img_list: filepaths of original tif images in a list of lists (each list corresponds to a dataset)
            data_dirs: paths to data directories
            padding: amount of padding around object crops
            crop_size: the size of the object crops. Can be either:
                - An integer specifying a single crop size for all objects
                - A list of integers specifying different crop sizes for different data directories
            chunk: whether or not to chunk the dataset into batches
            clip_length: the number of frames in each chunk
            mode: `train` or `val`. Determines whether this dataset is used for
                training or validation. Currently doesn't affect dataset logic
            augmentations: An optional dict mapping augmentations to parameters. The keys
                should map directly to augmentation classes in albumentations. Example:
                    augs = {
                        'Rotate': {'limit': [-90, 90]},
                        'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0},
                        'RandomContrast': {'limit': 0.2}
                    }
            n_chunks: Number of chunks to subsample from.
                Can either a fraction of the dataset (ie (0,1.0]) or number of chunks
            seed: set a seed for reproducibility
            max_batching_gap: the max number of frames that can be unlabelled before starting a new batch
            use_tight_bbox: whether to use tight bounding box (around keypoints) instead of the default square bounding box
            ctc_track_meta: filepaths of man_track.txt files in a list of lists (each list corresponds to a dataset)
        """
        super().__init__(
            gt_list,
            raw_img_list,
            padding,
            crop_size,
            chunk,
            clip_length,
            mode,
            augmentations,
            n_chunks,
            seed,
            ctc_track_meta,
        )

        self.raw_img_list = raw_img_list
        self.gt_list = gt_list
        self.ctc_track_meta = ctc_track_meta
        self.data_dirs = data_dirs
        self.chunk = chunk
        self.clip_length = clip_length
        self.crop_size = crop_size
        self.padding = padding
        self.mode = mode.lower()
        self.n_chunks = n_chunks
        self.seed = seed
        self.max_batching_gap = max_batching_gap
        self.use_tight_bbox = use_tight_bbox
        self.skeleton = sio.Skeleton(nodes=["centroid"])
        if not isinstance(self.data_dirs, list):
            self.data_dirs = [self.data_dirs]

        if not isinstance(self.crop_size, list):
            # make a list so its handled consistently if multiple crops are used
            if len(self.data_dirs) > 0:  # for test mode, data_dirs is []
                self.crop_size = [self.crop_size] * len(self.data_dirs)
            else:
                self.crop_size = [self.crop_size]

        if len(self.data_dirs) > 0 and len(self.crop_size) != len(self.data_dirs):
            raise ValueError(
                f"If a list of crop sizes or data directories are given,"
                f"they must have the same length but got {len(self.crop_size)} "
                f"and {len(self.data_dirs)}"
            )

        # if self.seed is not None:
        #     np.random.seed(self.seed)

        if augmentations and self.mode == "train":
            self.augmentations = data_utils.build_augmentations(augmentations)
        else:
            self.augmentations = None

        #
        if self.ctc_track_meta is not None:
            self.list_df_track_meta = [
                pd.read_csv(
                    gtf,
                    delimiter=" ",
                    header=None,
                    names=["track_id", "start_frame", "end_frame", "parent_id"],
                )
                for gtf in self.ctc_track_meta
            ]
        else:
            self.list_df_track_meta = None
        # frame indices for each dataset; list of lists (each list corresponds to a dataset)
        self.frame_idx = [torch.arange(len(gt_dataset)) for gt_dataset in self.gt_list]

        # Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
        # used in call to get_instances()
        self.create_chunks_other()

    def get_indices(self, idx: int) -> tuple:
        """Retrieve label and frame indices given batch index.

        Args:
            idx: the index of the batch.

        Returns:
            the label and frame indices corresponding to a batch,
        """
        return self.label_idx[idx], self.chunked_frame_idx[idx]

    def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Frame]:
        """Get an element of the dataset.

        Args:
            label_idx: index of the labels
            frame_idx: index of the frames

        Returns:
            a list of Frame objects containing frame metadata and Instance Objects.
            See `dreem.io.data_structures` for more info.
        """
        image_paths = self.raw_img_list[label_idx]
        gt_paths = self.gt_list[label_idx]

        if self.list_df_track_meta is not None:
            df_track_meta = self.list_df_track_meta[label_idx]
        else:
            df_track_meta = None

        # get the correct crop size based on the video
        video_par_path = Path(image_paths[0]).parent.parent
        if len(self.data_dirs) > 0:
            crop_size = self.crop_size[0]
            for j, data_dir in enumerate(self.data_dirs):
                if Path(data_dir) == video_par_path:
                    crop_size = self.crop_size[j]
                    break
        else:
            crop_size = self.crop_size[0]

        frames = []
        max_crop_h, max_crop_w = 0, 0
        for i in frame_idx:
            instances, gt_track_ids, centroids, dict_centroids, bboxes = (
                [],
                [],
                [],
                {},
                [],
            )

            i = int(i)

            img = image_paths[i]
            gt_sec = gt_paths[i]

            img = np.array(Image.open(img))
            gt_sec = np.array(Image.open(gt_sec))

            if img.dtype == np.uint16:
                img = ((img - img.min()) * (1 / (img.max() - img.min()) * 255)).astype(
                    np.uint8
                )
            # if df_track_meta is None:
            unique_instances = np.unique(gt_sec)
            # else:
            # unique_instances = df_track_meta["track_id"].unique()

            for instance in unique_instances:
                # not all instances are in the frame, and they also label the
                # background instance as zero
                if instance in gt_sec and instance != 0:
                    mask = gt_sec == instance
                    center_of_mass = measurements.center_of_mass(mask)

                    # scipy returns yx
                    x, y = center_of_mass[::-1]

                    if self.use_tight_bbox:
                        bbox = data_utils.get_tight_bbox_masks(mask)
                    else:
                        bbox = data_utils.pad_bbox(
                            data_utils.get_bbox([int(x), int(y)], crop_size),
                            padding=self.padding,
                        )

                    gt_track_ids.append(int(instance))
                    centroids.append([x, y])
                    dict_centroids[int(instance)] = [x, y]
                    bboxes.append(bbox)

            # albumentations wants (spatial, channels), ensure correct dims
            if self.augmentations is not None:
                for transform in self.augmentations:
                    # for occlusion simulation, can remove if we don't want
                    if isinstance(transform, A.CoarseDropout):
                        transform.fill_value = random.randint(0, 255)

                augmented = self.augmentations(
                    image=img,
                    keypoints=np.vstack(centroids),
                )

                img, centroids = augmented["image"], augmented["keypoints"]

            img = torch.Tensor(img).unsqueeze(0)

            for j in range(len(gt_track_ids)):
                # just formatting for compatibility with Instance class
                instance_centroid = {
                    "centroid": np.array(dict_centroids[gt_track_ids[j]])
                }
                pose = {"centroid": dict_centroids[gt_track_ids[j]]}  # more formatting
                crop = data_utils.crop_bbox(img, bboxes[j])
                c, h, w = crop.shape
                if h > max_crop_h:
                    max_crop_h = h
                if w > max_crop_w:
                    max_crop_w = w

                instances.append(
                    Instance(
                        gt_track_id=gt_track_ids[j],
                        pred_track_id=-1,
                        centroid=instance_centroid,
                        skeleton=self.skeleton,
                        point_scores=np.array([1.0]),
                        instance_score=np.array([1.0]),
                        pose=pose,
                        bbox=bboxes[j],
                        crop=crop,
                    )
                )

            if self.mode == "train":
                np.random.shuffle(instances)

            frames.append(
                Frame(
                    video_id=label_idx,
                    frame_id=i,
                    vid_file=Path(image_paths[0]).parent.name,
                    img_shape=img.shape,
                    instances=instances,
                )
            )

        # pad bbox to max size
        if self.use_tight_bbox:
            # bound the max crop size to the user defined crop size
            max_crop_h = crop_size if max_crop_h == 0 else min(max_crop_h, crop_size)
            max_crop_w = crop_size if max_crop_w == 0 else min(max_crop_w, crop_size)
            # gather all the crops
            for frame in frames:
                for instance in frame.instances:
                    data_utils.pad_variable_size_crops(
                        instance, (max_crop_h, max_crop_w)
                    )

        return frames

__init__(gt_list, raw_img_list, data_dirs=None, padding=5, crop_size=20, chunk=False, clip_length=10, mode='train', augmentations=None, n_chunks=1.0, seed=None, max_batching_gap=15, use_tight_bbox=False, ctc_track_meta=None, **kwargs)

Initialize CellTrackingDataset.

Parameters:

Name Type Description Default
gt_list list[list[str]]

filepaths of gt label images in a list of lists (each list corresponds to a dataset)

required
raw_img_list list[list[str]]

filepaths of original tif images in a list of lists (each list corresponds to a dataset)

required
data_dirs Optional[list[str]]

paths to data directories

None
padding int

amount of padding around object crops

5
crop_size int

the size of the object crops. Can be either: - An integer specifying a single crop size for all objects - A list of integers specifying different crop sizes for different data directories

20
chunk bool

whether or not to chunk the dataset into batches

False
clip_length int

the number of frames in each chunk

10
mode str

train or val. Determines whether this dataset is used for training or validation. Currently doesn't affect dataset logic

'train'
augmentations dict | None

An optional dict mapping augmentations to parameters. The keys should map directly to augmentation classes in albumentations. Example: augs = { 'Rotate': {'limit': [-90, 90]}, 'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0}, 'RandomContrast': {'limit': 0.2} }

None
n_chunks int | float

Number of chunks to subsample from. Can either a fraction of the dataset (ie (0,1.0]) or number of chunks

1.0
seed int | None

set a seed for reproducibility

None
max_batching_gap int

the max number of frames that can be unlabelled before starting a new batch

15
use_tight_bbox bool

whether to use tight bounding box (around keypoints) instead of the default square bounding box

False
ctc_track_meta list[str] | None

filepaths of man_track.txt files in a list of lists (each list corresponds to a dataset)

None
Source code in dreem/datasets/cell_tracking_dataset.py
def __init__(
    self,
    gt_list: list[list[str]],
    raw_img_list: list[list[str]],
    data_dirs: Optional[list[str]] = None,
    padding: int = 5,
    crop_size: int = 20,
    chunk: bool = False,
    clip_length: int = 10,
    mode: str = "train",
    augmentations: dict | None = None,
    n_chunks: int | float = 1.0,
    seed: int | None = None,
    max_batching_gap: int = 15,
    use_tight_bbox: bool = False,
    ctc_track_meta: list[str] | None = None,
    **kwargs,
):
    """Initialize CellTrackingDataset.

    Args:
        gt_list: filepaths of gt label images in a list of lists (each list corresponds to a dataset)
        raw_img_list: filepaths of original tif images in a list of lists (each list corresponds to a dataset)
        data_dirs: paths to data directories
        padding: amount of padding around object crops
        crop_size: the size of the object crops. Can be either:
            - An integer specifying a single crop size for all objects
            - A list of integers specifying different crop sizes for different data directories
        chunk: whether or not to chunk the dataset into batches
        clip_length: the number of frames in each chunk
        mode: `train` or `val`. Determines whether this dataset is used for
            training or validation. Currently doesn't affect dataset logic
        augmentations: An optional dict mapping augmentations to parameters. The keys
            should map directly to augmentation classes in albumentations. Example:
                augs = {
                    'Rotate': {'limit': [-90, 90]},
                    'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0},
                    'RandomContrast': {'limit': 0.2}
                }
        n_chunks: Number of chunks to subsample from.
            Can either a fraction of the dataset (ie (0,1.0]) or number of chunks
        seed: set a seed for reproducibility
        max_batching_gap: the max number of frames that can be unlabelled before starting a new batch
        use_tight_bbox: whether to use tight bounding box (around keypoints) instead of the default square bounding box
        ctc_track_meta: filepaths of man_track.txt files in a list of lists (each list corresponds to a dataset)
    """
    super().__init__(
        gt_list,
        raw_img_list,
        padding,
        crop_size,
        chunk,
        clip_length,
        mode,
        augmentations,
        n_chunks,
        seed,
        ctc_track_meta,
    )

    self.raw_img_list = raw_img_list
    self.gt_list = gt_list
    self.ctc_track_meta = ctc_track_meta
    self.data_dirs = data_dirs
    self.chunk = chunk
    self.clip_length = clip_length
    self.crop_size = crop_size
    self.padding = padding
    self.mode = mode.lower()
    self.n_chunks = n_chunks
    self.seed = seed
    self.max_batching_gap = max_batching_gap
    self.use_tight_bbox = use_tight_bbox
    self.skeleton = sio.Skeleton(nodes=["centroid"])
    if not isinstance(self.data_dirs, list):
        self.data_dirs = [self.data_dirs]

    if not isinstance(self.crop_size, list):
        # make a list so its handled consistently if multiple crops are used
        if len(self.data_dirs) > 0:  # for test mode, data_dirs is []
            self.crop_size = [self.crop_size] * len(self.data_dirs)
        else:
            self.crop_size = [self.crop_size]

    if len(self.data_dirs) > 0 and len(self.crop_size) != len(self.data_dirs):
        raise ValueError(
            f"If a list of crop sizes or data directories are given,"
            f"they must have the same length but got {len(self.crop_size)} "
            f"and {len(self.data_dirs)}"
        )

    # if self.seed is not None:
    #     np.random.seed(self.seed)

    if augmentations and self.mode == "train":
        self.augmentations = data_utils.build_augmentations(augmentations)
    else:
        self.augmentations = None

    #
    if self.ctc_track_meta is not None:
        self.list_df_track_meta = [
            pd.read_csv(
                gtf,
                delimiter=" ",
                header=None,
                names=["track_id", "start_frame", "end_frame", "parent_id"],
            )
            for gtf in self.ctc_track_meta
        ]
    else:
        self.list_df_track_meta = None
    # frame indices for each dataset; list of lists (each list corresponds to a dataset)
    self.frame_idx = [torch.arange(len(gt_dataset)) for gt_dataset in self.gt_list]

    # Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
    # used in call to get_instances()
    self.create_chunks_other()

get_indices(idx)

Retrieve label and frame indices given batch index.

Parameters:

Name Type Description Default
idx int

the index of the batch.

required

Returns:

Type Description
tuple

the label and frame indices corresponding to a batch,

Source code in dreem/datasets/cell_tracking_dataset.py
def get_indices(self, idx: int) -> tuple:
    """Retrieve label and frame indices given batch index.

    Args:
        idx: the index of the batch.

    Returns:
        the label and frame indices corresponding to a batch,
    """
    return self.label_idx[idx], self.chunked_frame_idx[idx]

get_instances(label_idx, frame_idx)

Get an element of the dataset.

Parameters:

Name Type Description Default
label_idx list[int]

index of the labels

required
frame_idx list[int]

index of the frames

required

Returns:

Type Description
list[Frame]

a list of Frame objects containing frame metadata and Instance Objects. See dreem.io.data_structures for more info.

Source code in dreem/datasets/cell_tracking_dataset.py
def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Frame]:
    """Get an element of the dataset.

    Args:
        label_idx: index of the labels
        frame_idx: index of the frames

    Returns:
        a list of Frame objects containing frame metadata and Instance Objects.
        See `dreem.io.data_structures` for more info.
    """
    image_paths = self.raw_img_list[label_idx]
    gt_paths = self.gt_list[label_idx]

    if self.list_df_track_meta is not None:
        df_track_meta = self.list_df_track_meta[label_idx]
    else:
        df_track_meta = None

    # get the correct crop size based on the video
    video_par_path = Path(image_paths[0]).parent.parent
    if len(self.data_dirs) > 0:
        crop_size = self.crop_size[0]
        for j, data_dir in enumerate(self.data_dirs):
            if Path(data_dir) == video_par_path:
                crop_size = self.crop_size[j]
                break
    else:
        crop_size = self.crop_size[0]

    frames = []
    max_crop_h, max_crop_w = 0, 0
    for i in frame_idx:
        instances, gt_track_ids, centroids, dict_centroids, bboxes = (
            [],
            [],
            [],
            {},
            [],
        )

        i = int(i)

        img = image_paths[i]
        gt_sec = gt_paths[i]

        img = np.array(Image.open(img))
        gt_sec = np.array(Image.open(gt_sec))

        if img.dtype == np.uint16:
            img = ((img - img.min()) * (1 / (img.max() - img.min()) * 255)).astype(
                np.uint8
            )
        # if df_track_meta is None:
        unique_instances = np.unique(gt_sec)
        # else:
        # unique_instances = df_track_meta["track_id"].unique()

        for instance in unique_instances:
            # not all instances are in the frame, and they also label the
            # background instance as zero
            if instance in gt_sec and instance != 0:
                mask = gt_sec == instance
                center_of_mass = measurements.center_of_mass(mask)

                # scipy returns yx
                x, y = center_of_mass[::-1]

                if self.use_tight_bbox:
                    bbox = data_utils.get_tight_bbox_masks(mask)
                else:
                    bbox = data_utils.pad_bbox(
                        data_utils.get_bbox([int(x), int(y)], crop_size),
                        padding=self.padding,
                    )

                gt_track_ids.append(int(instance))
                centroids.append([x, y])
                dict_centroids[int(instance)] = [x, y]
                bboxes.append(bbox)

        # albumentations wants (spatial, channels), ensure correct dims
        if self.augmentations is not None:
            for transform in self.augmentations:
                # for occlusion simulation, can remove if we don't want
                if isinstance(transform, A.CoarseDropout):
                    transform.fill_value = random.randint(0, 255)

            augmented = self.augmentations(
                image=img,
                keypoints=np.vstack(centroids),
            )

            img, centroids = augmented["image"], augmented["keypoints"]

        img = torch.Tensor(img).unsqueeze(0)

        for j in range(len(gt_track_ids)):
            # just formatting for compatibility with Instance class
            instance_centroid = {
                "centroid": np.array(dict_centroids[gt_track_ids[j]])
            }
            pose = {"centroid": dict_centroids[gt_track_ids[j]]}  # more formatting
            crop = data_utils.crop_bbox(img, bboxes[j])
            c, h, w = crop.shape
            if h > max_crop_h:
                max_crop_h = h
            if w > max_crop_w:
                max_crop_w = w

            instances.append(
                Instance(
                    gt_track_id=gt_track_ids[j],
                    pred_track_id=-1,
                    centroid=instance_centroid,
                    skeleton=self.skeleton,
                    point_scores=np.array([1.0]),
                    instance_score=np.array([1.0]),
                    pose=pose,
                    bbox=bboxes[j],
                    crop=crop,
                )
            )

        if self.mode == "train":
            np.random.shuffle(instances)

        frames.append(
            Frame(
                video_id=label_idx,
                frame_id=i,
                vid_file=Path(image_paths[0]).parent.name,
                img_shape=img.shape,
                instances=instances,
            )
        )

    # pad bbox to max size
    if self.use_tight_bbox:
        # bound the max crop size to the user defined crop size
        max_crop_h = crop_size if max_crop_h == 0 else min(max_crop_h, crop_size)
        max_crop_w = crop_size if max_crop_w == 0 else min(max_crop_w, crop_size)
        # gather all the crops
        for frame in frames:
            for instance in frame.instances:
                data_utils.pad_variable_size_crops(
                    instance, (max_crop_h, max_crop_w)
                )

    return frames