Skip to content
Merged
461 changes: 461 additions & 0 deletions examples/4.data_augmentation_example.ipynb

Large diffs are not rendered by default.

240 changes: 240 additions & 0 deletions examples/nbconverted/4.data_augmentation_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
#!/usr/bin/env python
# coding: utf-8

# # Examples for incorporating monai image augmentation suite for training

# ## Dependencies

# In[ ]:


import re
import pathlib
from typing import List

import pandas as pd
from monai.transforms import (
Compose,
EnsureTyped,
RandFlipd,
RandRotate90d,
RandAffined,
RandGaussianNoised,
RandGaussianSmoothd,
RandAdjustContrastd
)

from virtual_stain_flow.datasets.base_dataset import BaseImageDataset
from virtual_stain_flow.datasets.crop_dataset import CropImageDataset
from virtual_stain_flow.datasets.monai_aug_adapter_dataset import MonaiAdapter
from virtual_stain_flow.transforms.normalizations import MaxScaleNormalize
from virtual_stain_flow.evaluation.visualization import plot_dataset_grid


# ## Pathing and Additional utils

# In[ ]:


DATA_PATH = pathlib.Path("/YOUR/DATA/PATH/") # Change to where the download_data script outputs data

# Sanity check for data existence
if not DATA_PATH.exists() or not DATA_PATH.is_dir():
raise FileNotFoundError(f"Data path {DATA_PATH} does not exist or is not a directory.")

# Matches filenames like:
# r01c01f01p01-ch1sk1fk1fl1.tiff
FIELD_RE = re.compile(
r"(r\d{2}c\d{2}f\d{2}p01)-ch(\d+)sk1fk1fl1\.tiff$"
)

def _collect_field_prefixes(
plate_dir: pathlib.Path,
max_fields: int = 16,
) -> List[str]:
"""
Scan a JUMP CPJUMP1 plate directory and collect distinct field prefixes.
Expects image filename like:
r01c01f01p01-ch1sk1fk1fl1.tiff
Comment thread
wli51 marked this conversation as resolved.
"""
prefixes: List[str] = []
for path in sorted(plate_dir.glob("*.tiff")):
m = FIELD_RE.match(path.name)
if not m:
continue
prefix = m.group(1) # e.g. "r01c01f01p01"
if prefix not in prefixes:
prefixes.append(prefix)
if len(prefixes) >= max_fields:
break
return prefixes

def build_file_index(
plate_dir: pathlib.Path,
max_fields: int = 16,
) -> pd.DataFrame:
"""
Helper function to build a file index that specifies
the relationship of images across channels and field/fovs.
The result can directly be supplied to BaseImageDataset to create a
dataset with the correct image pairs.
"""

fields = _collect_field_prefixes(
plate_dir,
max_fields=max_fields,
)

file_index_list = []
for field in fields:
sample = {}
for chan in DATA_PATH.glob(f"**/{field}*.tiff"):
Comment thread
wli51 marked this conversation as resolved.
match = FIELD_RE.match(chan.name)
if match and match.groups()[1]:
sample[f"ch{match.groups()[1]}"] = str(chan)

file_index_list.append(sample)

file_index = pd.DataFrame(file_index_list)
file_index.dropna(how='all', inplace=True)
if file_index.empty:
raise ValueError(f"No files found in {plate_dir} matching the expected pattern.")

return file_index.loc[:, sorted(file_index.columns)]


# In[3]:


# For stable wGAN, we don't want the dataset to be too small that the discriminator
# quickly memorizes the set and overpowers the generator.
# So here a bigger, 2048 FOV subset of CJUMP1 (BF and Hoechst channel) is used as demo dataset
# See https://github.com/jump-cellpainting/2024_Chandrasekaran_NatureMethods_CPJUMP1 for details
file_index = build_file_index(DATA_PATH, max_fields=64)
print(file_index.head())


# ## Create dataset from CPJUMP1 and take center crops

# In[4]:


# Create a dataset with Brightfield as input and Hoechst as target
# See https://github.com/jump-cellpainting/2024_Chandrasekaran_NatureMethods_CPJUMP1
# for which channel codes correspond to which channel
dataset = BaseImageDataset(
file_index=file_index,
check_exists=True,
pil_image_mode="I;16",
input_channel_keys=["ch7"],
target_channel_keys=["ch5"],
)
print(f"Dataset length: {len(dataset)}")
print(
f"Input channels: {dataset.input_channel_keys}, target channels: {dataset._target_channel_keys}"
)

cropped_dataset = CropImageDataset.from_base_dataset(
dataset,
crop_size=128,
transforms=MaxScaleNormalize(
normalization_factor='16bit'
)
)
plot_dataset_grid(
dataset=cropped_dataset,
indices=[0],
wspace=0.025,
hspace=0.05
)


# ## Transforamtion example
Comment thread
wli51 marked this conversation as resolved.
Outdated

# In[5]:


monai_transform = Compose([
EnsureTyped(keys=["input", "target"]),
RandFlipd(keys=["input", "target"], prob=0.5, spatial_axis=0),
RandFlipd(keys=["input", "target"], prob=0.5, spatial_axis=1),
RandRotate90d(keys=["input", "target"], prob=0.5, max_k=3),
RandAffined(
keys=["input", "target"],
prob=0.7,
rotate_range=(0.0, 0.0, 0.15),
translate_range=(0, 0), # no translate
scale_range=(0.0, 0.0), # no scale
padding_mode="border",
),
RandGaussianSmoothd(
keys=["input"],
prob=0.2,
sigma_x=(0.25, 0.5), # more aggressive smoothing to simulate out-of-focus
sigma_y=(0.25, 0.5),
),
RandAdjustContrastd(
keys=["input"],
prob=0.2,
gamma=(0.95, 1.05), # small variation to avoid unrealistic contrast change
invert_image=False,
retain_stats=True,
),
RandGaussianNoised(
keys=["input"],
prob=0.2,
mean=0.0, # no bias
std=1e-4, # subtle salt and pepper
),
])

augmented_dataset = MonaiAdapter(cropped_dataset, transform=monai_transform)


# ## Visualize the same augmented dataset multiple times to see effects of augmentation
# Note that augmentation is only applied to the crop and the shown full FOV is always un-augmented

# In[6]:


for i in range(5):
plot_dataset_grid(
dataset=augmented_dataset,
indices=[0], # only first sample to better see difference
wspace=0.025,
hspace=0.05
)


# ## Use `MonaiAdapter` for training as would with any image dataset
#
# e.g.
# ```python
# ...
#
# # Make train loader from augmented adataset
# train_loader = DataLoader(
# augmented_dataset,
# batch_size=batch_size,
# shuffle=True,
# )
# ...
#
# # feed to trainer
# trainer = SingleGeneratorTrainer(
# model=...,
# optimizer=...,
# losses=...,
# loss_weights=...,
# device='cuda',
# train_loader=train_loader
# )
Comment thread
wli51 marked this conversation as resolved.
#
# # optionally, if want to use plot prediction callback
# plot_callback = PlotPredictionCallback(
# name="...",
# dataset=crop_dataset, # non-augmented dataset recommended for consistentcy
Comment thread
wli51 marked this conversation as resolved.
Outdated
# # but augment datasets also work here
# ...
# )
# ```
43 changes: 43 additions & 0 deletions src/virtual_stain_flow/datasets/base_wrapper_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
base_wrapper_dataset.py

Defines a simple BaseWrapperDataset scheme that can wraps any BaseImageDataset
and forwards all method calls to it.
"""

from abc import ABC, abstractmethod
from typing import Union

from .base_dataset import BaseImageDataset
from .crop_dataset import CropImageDataset

class BaseWrapperDataset(ABC):

def __init__(
self,
dataset: Union[BaseImageDataset, CropImageDataset]
):
self._dataset = dataset
# optionally do something to the dataset

def __len__(self):
return len(self._dataset)

@abstractmethod
def __getitem__(self, idx):
# retrieve images from dataset
input, target = self._dataset[idx]

# do something to the input and target here
# (e.g. apply transformations, generate crops, cache in RAM, etc.)

return input, target

@property
def original(self) -> Union[BaseImageDataset, CropImageDataset]:
"""
Access the original underlying dataset for metadata etc.
"""
if isinstance(self._dataset, BaseWrapperDataset):
return self._dataset.original
return self._dataset
41 changes: 41 additions & 0 deletions src/virtual_stain_flow/datasets/monai_aug_adapter_dataset.py
Comment thread
wli51 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
monai_aug_adapter_dataset.py
"""

from monai.transforms import (
Compose
)

from .base_dataset import BaseImageDataset
from .base_wrapper_dataset import BaseWrapperDataset

class MonaiAdapter(BaseWrapperDataset):
"""
Adapter dataset to wrap any BaseImageDataset and return samples
in dictionary format compatible with MONAI transforms and pipelines.
Specifically, each sample is returned as a dictionary with keys "input" and "target",
containing the input and target tensors respectively, then ran through
MONAI transforms if provided, adn finally returned back as a tuple of
Comment thread
wli51 marked this conversation as resolved.
Outdated
(input, target) tensors. It would be meaningless to use this adapter
without any MONAI transforms as the data just gets wrapped and unwrapped.
"""
def __init__(
self,
base_dataset: BaseImageDataset,
transform: Compose | None = None
):
super().__init__(base_dataset)
self._transform = transform

def __len__(self):
return len(self._dataset)

def __getitem__(self, idx):
x, y = self._dataset[idx]

sample = {"input": x, "target": y}

if self._transform is not None:
sample = self._transform(sample)

return sample["input"], sample["target"]
26 changes: 22 additions & 4 deletions src/virtual_stain_flow/evaluation/evaluation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from ..datasets.base_dataset import BaseImageDataset
from ..datasets.crop_dataset import CropImageDataset
from virtual_stain_flow.datasets.base_wrapper_dataset import BaseWrapperDataset


def extract_samples_from_dataset(
dataset: Union[BaseImageDataset, CropImageDataset],
dataset: Union[BaseImageDataset, CropImageDataset, BaseWrapperDataset],
indices: List[int],
) -> Tuple[
List[np.ndarray],
Expand All @@ -33,7 +34,20 @@ def extract_samples_from_dataset(
or None for BaseImageDataset.
- patch_coords: List of (x, y) tuples for CropImageDataset, or None for BaseImageDataset.
"""
is_crop_dataset = isinstance(dataset, CropImageDataset)
is_wrapper_dataset = False
if isinstance(dataset, BaseWrapperDataset):
is_crop_dataset = isinstance(dataset.original, CropImageDataset)
is_wrapper_dataset = True
elif isinstance(dataset, CropImageDataset):
is_crop_dataset = True
elif isinstance(dataset, BaseImageDataset):
is_crop_dataset = False
else:
raise ValueError(
"Unsupported dataset type. Expected BaseImageDataset, CropImageDataset, or BaseWrapperDataset.")

if max(indices) >= len(dataset):
raise IndexError(f"Index out of range. Dataset length: {len(dataset)}, max index requested: {max(indices)}")
Comment thread
wli51 marked this conversation as resolved.
Outdated

inputs: List[np.ndarray] = []
targets: List[np.ndarray] = []
Expand All @@ -58,8 +72,12 @@ def extract_samples_from_dataset(
if is_crop_dataset:
# Access the original uncropped image and crop coordinates
# These are populated after __getitem__ is called
raw_images.append(dataset.original_input_image[0])
patch_coords.append((dataset.crop_info.x, dataset.crop_info.y))
if is_wrapper_dataset:
raw_images.append(dataset.original.original_input_image[0])
patch_coords.append((dataset.original.crop_info.x, dataset.original.crop_info.y))
else:
raw_images.append(dataset.original_input_image[0])
patch_coords.append((dataset.crop_info.x, dataset.crop_info.y))

return inputs, targets, raw_images, patch_coords

Expand Down
3 changes: 2 additions & 1 deletion src/virtual_stain_flow/evaluation/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ..datasets.base_dataset import BaseImageDataset
from ..datasets.crop_dataset import CropImageDataset
from ..datasets.base_wrapper_dataset import BaseWrapperDataset
from .evaluation_utils import evaluate_per_image_metric, extract_samples_from_dataset
from .predict_utils import predict_image

Expand Down Expand Up @@ -199,7 +200,7 @@ def plot_predictions_grid(


def plot_dataset_grid(
dataset: Union[BaseImageDataset, CropImageDataset],
dataset: Union[BaseImageDataset, CropImageDataset, BaseWrapperDataset],
indices: List[int],
save_path: Optional[str] = None,
**kwargs,
Expand Down
Loading