cell_tracking_dataset
dreem.datasets.cell_tracking_dataset
¶
Module containing cell tracking challenge dataset.
CellTrackingDataset
¶
Bases: BaseDataset
Dataset for loading cell tracking challenge data.
Source code in dreem/datasets/cell_tracking_dataset.py
class CellTrackingDataset(BaseDataset):
"""Dataset for loading cell tracking challenge data."""
def __init__(
self,
raw_images: list[list[str]],
gt_images: list[list[str]],
padding: int = 5,
crop_size: int = 20,
chunk: bool = False,
clip_length: int = 10,
mode: str = "train",
augmentations: Optional[dict] = None,
n_chunks: Union[int, float] = 1.0,
seed: int = None,
gt_list: list[str] = None,
):
"""Initialize CellTrackingDataset.
Args:
raw_images: paths to raw microscopy images
gt_images: paths to gt label images
padding: amount of padding around object crops
crop_size: the size of the object crops
chunk: whether or not to chunk the dataset into batches
clip_length: the number of frames in each chunk
mode: `train` or `val`. Determines whether this dataset is used for
training or validation. Currently doesn't affect dataset logic
augmentations: An optional dict mapping augmentations to parameters. The keys
should map directly to augmentation classes in albumentations. Example:
augs = {
'Rotate': {'limit': [-90, 90]},
'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0},
'RandomContrast': {'limit': 0.2}
}
n_chunks: Number of chunks to subsample from.
Can either a fraction of the dataset (ie (0,1.0]) or number of chunks
seed: set a seed for reproducibility
gt_list: An optional path to .txt file containing gt ids stored in cell
tracking challenge format: "track_id", "start_frame",
"end_frame", "parent_id"
"""
super().__init__(
gt_images,
raw_images,
padding,
crop_size,
chunk,
clip_length,
mode,
augmentations,
n_chunks,
seed,
gt_list,
)
self.videos = raw_images
self.labels = gt_images
self.chunk = chunk
self.clip_length = clip_length
self.crop_size = crop_size
self.padding = padding
self.mode = mode.lower()
self.n_chunks = n_chunks
self.seed = seed
# if self.seed is not None:
# np.random.seed(self.seed)
if augmentations and self.mode == "train":
self.augmentations = data_utils.build_augmentations(augmentations)
else:
self.augmentations = None
if gt_list is not None:
self.gt_list = [
pd.read_csv(
gtf,
delimiter=" ",
header=None,
names=["track_id", "start_frame", "end_frame", "parent_id"],
)
for gtf in gt_list
]
else:
self.gt_list = None
self.frame_idx = [torch.arange(len(image)) for image in self.labels]
# Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
# used in call to get_instances()
self.create_chunks()
def get_indices(self, idx: int) -> tuple:
"""Retrieve label and frame indices given batch index.
Args:
idx: the index of the batch.
Returns:
the label and frame indices corresponding to a batch,
"""
return self.label_idx[idx], self.chunked_frame_idx[idx]
def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Frame]:
"""Get an element of the dataset.
Args:
label_idx: index of the labels
frame_idx: index of the frames
Returns:
a list of Frame objects containing frame metadata and Instance Objects.
See `dreem.io.data_structures` for more info.
"""
image = self.videos[label_idx]
gt = self.labels[label_idx]
if self.gt_list is not None:
gt_list = self.gt_list[label_idx]
else:
gt_list = None
frames = []
for i in frame_idx:
instances, gt_track_ids, centroids, bboxes = [], [], [], []
i = int(i)
img = image[i]
gt_sec = gt[i]
img = np.array(Image.open(img))
gt_sec = np.array(Image.open(gt_sec))
if img.dtype == np.uint16:
img = ((img - img.min()) * (1 / (img.max() - img.min()) * 255)).astype(
np.uint8
)
if gt_list is None:
unique_instances = np.unique(gt_sec)
else:
unique_instances = gt_list["track_id"].unique()
for instance in unique_instances:
# not all instances are in the frame, and they also label the
# background instance as zero
if instance in gt_sec and instance != 0:
mask = gt_sec == instance
center_of_mass = measurements.center_of_mass(mask)
# scipy returns yx
x, y = center_of_mass[::-1]
bbox = data_utils.pad_bbox(
data_utils.get_bbox([int(x), int(y)], self.crop_size),
padding=self.padding,
)
gt_track_ids.append(int(instance))
centroids.append([x, y])
bboxes.append(bbox)
# albumentations wants (spatial, channels), ensure correct dims
if self.augmentations is not None:
for transform in self.augmentations:
# for occlusion simulation, can remove if we don't want
if isinstance(transform, A.CoarseDropout):
transform.fill_value = random.randint(0, 255)
augmented = self.augmentations(
image=img,
keypoints=np.vstack(centroids),
)
img, centroids = augmented["image"], augmented["keypoints"]
img = torch.Tensor(img).unsqueeze(0)
for j in range(len(gt_track_ids)):
crop = data_utils.crop_bbox(img, bboxes[j])
instances.append(
Instance(
gt_track_id=gt_track_ids[j],
pred_track_id=-1,
bbox=bboxes[j],
crop=crop,
)
)
if self.mode == "train":
np.random.shuffle(instances)
frames.append(
Frame(
video_id=label_idx,
frame_id=i,
img_shape=img.shape,
instances=instances,
)
)
return frames
__init__(raw_images, gt_images, padding=5, crop_size=20, chunk=False, clip_length=10, mode='train', augmentations=None, n_chunks=1.0, seed=None, gt_list=None)
¶
Initialize CellTrackingDataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
raw_images |
list[list[str]]
|
paths to raw microscopy images |
required |
gt_images |
list[list[str]]
|
paths to gt label images |
required |
padding |
int
|
amount of padding around object crops |
5
|
crop_size |
int
|
the size of the object crops |
20
|
chunk |
bool
|
whether or not to chunk the dataset into batches |
False
|
clip_length |
int
|
the number of frames in each chunk |
10
|
mode |
str
|
|
'train'
|
augmentations |
Optional[dict]
|
An optional dict mapping augmentations to parameters. The keys should map directly to augmentation classes in albumentations. Example: augs = { 'Rotate': {'limit': [-90, 90]}, 'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0}, 'RandomContrast': {'limit': 0.2} } |
None
|
n_chunks |
Union[int, float]
|
Number of chunks to subsample from. Can either a fraction of the dataset (ie (0,1.0]) or number of chunks |
1.0
|
seed |
int
|
set a seed for reproducibility |
None
|
gt_list |
list[str]
|
An optional path to .txt file containing gt ids stored in cell tracking challenge format: "track_id", "start_frame", "end_frame", "parent_id" |
None
|
Source code in dreem/datasets/cell_tracking_dataset.py
def __init__(
self,
raw_images: list[list[str]],
gt_images: list[list[str]],
padding: int = 5,
crop_size: int = 20,
chunk: bool = False,
clip_length: int = 10,
mode: str = "train",
augmentations: Optional[dict] = None,
n_chunks: Union[int, float] = 1.0,
seed: int = None,
gt_list: list[str] = None,
):
"""Initialize CellTrackingDataset.
Args:
raw_images: paths to raw microscopy images
gt_images: paths to gt label images
padding: amount of padding around object crops
crop_size: the size of the object crops
chunk: whether or not to chunk the dataset into batches
clip_length: the number of frames in each chunk
mode: `train` or `val`. Determines whether this dataset is used for
training or validation. Currently doesn't affect dataset logic
augmentations: An optional dict mapping augmentations to parameters. The keys
should map directly to augmentation classes in albumentations. Example:
augs = {
'Rotate': {'limit': [-90, 90]},
'GaussianBlur': {'blur_limit': (3, 7), 'sigma_limit': 0},
'RandomContrast': {'limit': 0.2}
}
n_chunks: Number of chunks to subsample from.
Can either a fraction of the dataset (ie (0,1.0]) or number of chunks
seed: set a seed for reproducibility
gt_list: An optional path to .txt file containing gt ids stored in cell
tracking challenge format: "track_id", "start_frame",
"end_frame", "parent_id"
"""
super().__init__(
gt_images,
raw_images,
padding,
crop_size,
chunk,
clip_length,
mode,
augmentations,
n_chunks,
seed,
gt_list,
)
self.videos = raw_images
self.labels = gt_images
self.chunk = chunk
self.clip_length = clip_length
self.crop_size = crop_size
self.padding = padding
self.mode = mode.lower()
self.n_chunks = n_chunks
self.seed = seed
# if self.seed is not None:
# np.random.seed(self.seed)
if augmentations and self.mode == "train":
self.augmentations = data_utils.build_augmentations(augmentations)
else:
self.augmentations = None
if gt_list is not None:
self.gt_list = [
pd.read_csv(
gtf,
delimiter=" ",
header=None,
names=["track_id", "start_frame", "end_frame", "parent_id"],
)
for gtf in gt_list
]
else:
self.gt_list = None
self.frame_idx = [torch.arange(len(image)) for image in self.labels]
# Method in BaseDataset. Creates label_idx and chunked_frame_idx to be
# used in call to get_instances()
self.create_chunks()
get_indices(idx)
¶
Retrieve label and frame indices given batch index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
idx |
int
|
the index of the batch. |
required |
Returns:
Type | Description |
---|---|
tuple
|
the label and frame indices corresponding to a batch, |
Source code in dreem/datasets/cell_tracking_dataset.py
get_instances(label_idx, frame_idx)
¶
Get an element of the dataset.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
label_idx |
List[int]
|
index of the labels |
required |
frame_idx |
List[int]
|
index of the frames |
required |
Returns:
Type | Description |
---|---|
List[Frame]
|
a list of Frame objects containing frame metadata and Instance Objects.
See |
Source code in dreem/datasets/cell_tracking_dataset.py
def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> List[Frame]:
"""Get an element of the dataset.
Args:
label_idx: index of the labels
frame_idx: index of the frames
Returns:
a list of Frame objects containing frame metadata and Instance Objects.
See `dreem.io.data_structures` for more info.
"""
image = self.videos[label_idx]
gt = self.labels[label_idx]
if self.gt_list is not None:
gt_list = self.gt_list[label_idx]
else:
gt_list = None
frames = []
for i in frame_idx:
instances, gt_track_ids, centroids, bboxes = [], [], [], []
i = int(i)
img = image[i]
gt_sec = gt[i]
img = np.array(Image.open(img))
gt_sec = np.array(Image.open(gt_sec))
if img.dtype == np.uint16:
img = ((img - img.min()) * (1 / (img.max() - img.min()) * 255)).astype(
np.uint8
)
if gt_list is None:
unique_instances = np.unique(gt_sec)
else:
unique_instances = gt_list["track_id"].unique()
for instance in unique_instances:
# not all instances are in the frame, and they also label the
# background instance as zero
if instance in gt_sec and instance != 0:
mask = gt_sec == instance
center_of_mass = measurements.center_of_mass(mask)
# scipy returns yx
x, y = center_of_mass[::-1]
bbox = data_utils.pad_bbox(
data_utils.get_bbox([int(x), int(y)], self.crop_size),
padding=self.padding,
)
gt_track_ids.append(int(instance))
centroids.append([x, y])
bboxes.append(bbox)
# albumentations wants (spatial, channels), ensure correct dims
if self.augmentations is not None:
for transform in self.augmentations:
# for occlusion simulation, can remove if we don't want
if isinstance(transform, A.CoarseDropout):
transform.fill_value = random.randint(0, 255)
augmented = self.augmentations(
image=img,
keypoints=np.vstack(centroids),
)
img, centroids = augmented["image"], augmented["keypoints"]
img = torch.Tensor(img).unsqueeze(0)
for j in range(len(gt_track_ids)):
crop = data_utils.crop_bbox(img, bboxes[j])
instances.append(
Instance(
gt_track_id=gt_track_ids[j],
pred_track_id=-1,
bbox=bboxes[j],
crop=crop,
)
)
if self.mode == "train":
np.random.shuffle(instances)
frames.append(
Frame(
video_id=label_idx,
frame_id=i,
img_shape=img.shape,
instances=instances,
)
)
return frames