Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d16d917
Add pop method and merge operators to Context class for the convenien…
wli51 Jan 5, 2026
cea1cc3
Add DiscriminatorForwardGroup class for GAN discriminator management
wli51 Jan 5, 2026
123ece2
Add GANOrchestrator and OrchestratedStep class to abstract away the G…
wli51 Jan 5, 2026
b2aae07
Refactor GANOrchestrator methods to streamline context handling and i…
wli51 Jan 5, 2026
668a213
Add minimal testing for GANOrchestrator and implement simple discrimi…
wli51 Jan 5, 2026
adcbb1f
Enhance Context class with reserved key checks and add properties for…
wli51 Jan 7, 2026
f53527c
Refactor DiscriminatorForwardGroup to use constant keys for model and…
wli51 Jan 7, 2026
6ba51f4
Refactor GANOrchestrator to tighten type checks accessing values from…
wli51 Jan 7, 2026
146d238
Refactor AbstractBlock and Stage classes to improve output channel ha…
wli51 Jan 7, 2026
c91d2c2
Refactor type annotations in AbstractForwardGroup and its subclasses …
wli51 Jan 7, 2026
f053713
Add tests for or and ror operations and switch from returnning NotImp…
wli51 Jan 7, 2026
8e6b630
Enhance Context class with type checks for values and update pop meth…
wli51 Jan 7, 2026
83d35d1
Tighten-up typing of the loss and lossgroup modules. Due to the need …
wli51 Jan 8, 2026
b4b6edb
Refactor type annotations in AbstractTrainer class to use Tensor type…
wli51 Jan 8, 2026
fd312b9
Fix return type of from_config method in BaseModel to return BaseMode…
wli51 Jan 8, 2026
df370b9
Refactor train method in TrainerProtocol to accept variable arguments…
wli51 Jan 8, 2026
21f0d85
Tighten up type annotations in BaseModel class for to_config method
wli51 Jan 8, 2026
78876a0
Allow customizable or no output activation function in GlobalDiscrimi…
wli51 Jan 8, 2026
68f0ca2
Add model configuration logging (as artifact) to MlflowLogger
wli51 Jan 8, 2026
7fb9693
Enhance TrainerProtocol with model saving functionality and update Ml…
wli51 Jan 8, 2026
f068af2
Refactor SingleGeneratorTrainer to enhance type annotations for loss_…
wli51 Jan 8, 2026
c048d40
Add LoggingGANTrainer class for enhanced GAN training and evaluation …
wli51 Jan 8, 2026
afb60d8
Add minimal tests for LoggingWGANTrainer train_step and evaluate_step…
wli51 Jan 8, 2026
1ef9dad
Add training example for WGAN with logging and visualization
wli51 Jan 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/virtual_stain_flow/engine/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,36 @@ def items(self):

def keys(self):
return self._store.keys()

def pop(self, key: str, default: ContextValue = None) -> ContextValue:
"""Remove and return the value for key if key is in the context, else default."""
return self._store.pop(key, default)

def __or__(self, other: "Context") -> "Context":
"""
Merge two Context objects using the | operator.
Returns a new Context with items from both contexts.
Items from the right operand (other) take precedence in case of key conflicts.

:param other: Another Context object to merge with.
:return: A new Context object containing items from both contexts.
"""
if not isinstance(other, Context):
return NotImplemented
new_context = Context(**self._store)
new_context.add(**other._store)
return new_context

def __ror__(self, other: "Context") -> "Context":
Comment thread
wli51 marked this conversation as resolved.
"""
Reverse merge (right | operator) for Context objects.
Called when the left operand doesn't support __or__ with Context.

:param other: Another Context object to merge with.
:return: A new Context object containing items from both contexts.
"""
if not isinstance(other, Context):
return NotImplemented
new_context = Context(**other._store)
new_context.add(**self._store)
return new_context
85 changes: 85 additions & 0 deletions src/virtual_stain_flow/engine/forward_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,88 @@ def optimizer(self) -> Optional[optim.Optimizer]:
Convenience property to access the generator optimizer directly.
"""
return self._optimizers[GENERATOR_MODEL]


class DiscriminatorForwardGroup(AbstractForwardGroup):
"""
Forward group for a simple single (GAN/wGAN) discriminator workflow.
The discriminator is assumed to take in a "stack" of input and target
images concatenated along the channel dimension, and output a score/probability.
Relevant context values are input_keys, target_keys, output_keys for a
single-discriminator model, where:
- the forward is called as:
p = discriminator(stack)
- the evaluation is less straightforward, but typically involves
computing losses/metrics based on p and real/fake labels:
metric_value = metric_fn(p, real_or_fake_labels)
or perhaps involving the discrminator model itself for wasserstein distance:
metric_value = metric_fn(discriminator, stack, real_or_fake_labels)
"""

input_keys: Tuple[str, ...] = ("stack",)
target_keys: Tuple[str, ...] = ()
output_keys: Tuple[str, ...] = ("p",)
Comment thread
wli51 marked this conversation as resolved.
Outdated

def __init__(
self,
discriminator: nn.Module,
optimizer: Optional[optim.Optimizer] = None,
device: torch.device = torch.device("cpu"),
):
super().__init__(device=device)

self._models['discriminator'] = discriminator
self._models['discriminator'].to(self.device)
self._optimizers['discriminator'] = optimizer

def __call__(self, train: bool, **inputs: torch.Tensor) -> Context:
"""
Executes the forward pass, managing training/eval modes and optimizer steps.
Subclasses may override this method if needed.

:param train: Whether to run in training mode. Meant to be specified
by the trainer to switch between train/eval modes and determine
whether gradients should be computed.
:param inputs: Keyword arguments of input tensors.
"""

fp_model = self.model
fp_optimizer = self.optimizer

# 1) Stage and validate inputs/targets
ctx = Context(**self._move_tensors(inputs), **{'discriminator': fp_model })
ctx.require(self.input_keys)
ctx.require(self.target_keys)

# 2) Forward, with grad only when training
fp_model.train(mode=train)
train and fp_optimizer is not None and fp_optimizer.zero_grad(set_to_none=True)
with torch.set_grad_enabled(train):
model_inputs = [ctx[k] for k in self.input_keys] # ordered
raw = fp_model(*model_inputs)
y_tuple = self._normalize_outputs(raw)

# 3) Arity check + map outputs to names
if len(y_tuple) != len(self.output_keys):
raise ValueError(
f"Model returned {len(y_tuple)} outputs, "
f"but output_keys expects {len(self.output_keys)}"
)
outputs = {k: v for k, v in zip(self.output_keys, y_tuple)}

# 5) Return enriched context (preds available for losses/metrics)
return ctx.add(**outputs)

@property
def model(self) -> nn.Module:
"""
Convenience property to access the discriminator model directly.
"""
return self._models['discriminator']

@property
def optimizer(self) -> Optional[optim.Optimizer]:
"""
Convenience property to access the discriminator optimizer directly.
"""
return self._optimizers['discriminator']
176 changes: 176 additions & 0 deletions src/virtual_stain_flow/engine/orchestrators.py
Comment thread
wli51 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""
orchestrators.py

Collection of orchestrator classes that manages training flow
for complex models involving multiple components, such as GANs.

This is constrasted with ForwardGroup classes, which handle
the forward pass and optimization of single model components.
The addition of orchestrators helps keep ForwardGroup classes simple.

Internally, an orchestrator manages multiple ForwardGroups
and defines coordinated training steps that involve forward passes
through several components in a specific sequence.
"""

from dataclasses import dataclass
from typing import Callable, Optional

import torch
from torch import optim

from .forward_groups import GeneratorForwardGroup, DiscriminatorForwardGroup
from .context import Context
from .names import INPUTS, TARGETS, PREDS


@dataclass
class OrchestratedStep:
"""
Thin wrapper around orchestrator methods to present step-like objects
to trainers with the same interface as ForwardGroups, exposing:
- __call__(train=..., **batch) for forward pass, and
- .step() to step the optimizer
"""

forward_fn: Callable[..., Context]
optimizer: Optional[optim.Optimizer] = None

def __call__(self, train: bool, **batch) -> Context:
return self.forward_fn(train=train, **batch)

def step(self) -> None:
if self.optimizer is not None:
self.optimizer.step()


class GANOrchestrator:
"""
Orchestrator for a GAN-style setup with separate generator and discriminator
training steps.

Stores GeneratorForwardGroup and a DiscriminatorForwardGroup:
The GeneratorForwardGroup and DiscriminatorForwardGroup are the
simplified building blocks that conducts exclusively the forward pass of
either generator or discriminator).
The Orchestrator._discriminator_forward and Orchestrator._generator_forward
methods is uses these simple forward groups to build more complex steps that
enable GAN training, which requires a coordinated forward pass through both
the discriminator and generator.
"""

def __init__(
self,
generator_fg: GeneratorForwardGroup,
discriminator_fg: DiscriminatorForwardGroup,
):
"""
Initialize from already-constructed forward groups.

This keeps concerns separated: forward groups manage single-module
behavior; the orchestrator manages their composition.
"""
# simple forward group storage
self._gen_fg = generator_fg
self._disc_fg = discriminator_fg

# Public step-like objects that trainers can use directly
self.discriminator_step = OrchestratedStep(
forward_fn=self._discriminator_forward,
optimizer=self._disc_fg.optimizer,
)
self.generator_step = OrchestratedStep(
forward_fn=self._generator_forward,
optimizer=self._gen_fg.optimizer,
)

def _build_real_fake_contexts(
self,
train: bool,
gen_ctx: Context,
) -> Context:
"""
Given a generator context containing inputs / targets / preds,
generates the real and fake stacks by concatenating the true
input with the true target or predicted target along the
channel dimension. The result stacks serve as direct inputs
to the discriminator.

The discriminator is then run on both stacks to produce
outputs scores of if it thinks the provided stack is real.

:param train: Whether the model is in training mode.
:param gen_ctx: The Context produced by the generator forward pass,
containing at least INPUTS, TARGETS, and PREDS tensors.
:return: A merged Context containing outputs from both
the real and fake discriminator passes, as well as
the original generator context.
"""
# Stack along channel dim: [inputs, targets] vs [inputs, preds]
real_stack = torch.cat([gen_ctx[INPUTS], gen_ctx[TARGETS]], dim=1)
fake_stack = torch.cat([gen_ctx[INPUTS], gen_ctx[PREDS]], dim=1)

# Real batch: D(x, y_true)
ctx_real = self._disc_fg(train=train, stack=real_stack)
ctx_real["real_stack"] = ctx_real.pop("stack")
ctx_real["p_real_as_real"] = ctx_real.pop("p")

# Fake batch: D(x, y_fake)
ctx_fake = self._disc_fg(train=train, stack=fake_stack)
ctx_fake["fake_stack"] = ctx_fake.pop("stack")
ctx_fake["p_fake_as_real"] = ctx_fake.pop("p")

# Merge: real info, fake info, and generator info
return ctx_real | ctx_fake | gen_ctx

def _discriminator_forward(
self,
train: bool,
inputs: torch.Tensor,
targets: torch.Tensor,
) -> Context:
"""
Forward step to train only the discriminator.

:param train: Whether the model is in training mode.
:param inputs: The input tensor for the models.
:param targets: The target tensor for the models.
:return: A Context containing discriminator outputs for both
real and fake stacks, as well as the original generator context.
"""
# Generator is always eval for discriminator updates
gen_ctx = self._gen_fg(train=False, inputs=inputs, targets=targets)

return self._build_real_fake_contexts(train=train, gen_ctx=gen_ctx)

def _generator_forward(
self,
train: bool,
inputs: torch.Tensor,
targets: torch.Tensor,
) -> Context:
"""
Forward step to train only the generator.

:param train: Whether the model is in training mode.
:param inputs: The input tensor for the models.
:param targets: The target tensor for the models.
:return: A Context containing generator outputs plus
p_fake_as_real from the discriminator.
"""

# Generate predictions and then run discriminator on fake stack
gen_ctx = self._gen_fg(train=train,inputs=inputs,targets=targets)
fake_stack = torch.cat([gen_ctx[INPUTS], gen_ctx[PREDS]], dim=1)
disc_ctx = self._disc_fg(train=train, stack=fake_stack)

# Attach discriminator score to the generator context and return.
return gen_ctx.add(p_fake_as_real=disc_ctx["p"])

@property
def generator_forward_group(self) -> GeneratorForwardGroup:
return self._gen_fg

@property
def discriminator_forward_group(self) -> DiscriminatorForwardGroup:
return self._disc_fg
65 changes: 65 additions & 0 deletions tests/engine/conftest.py

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked and found the tests all pass, nice job. Given the great stuff going on with testing you might consider making testing a part of automated jobs with a GitHub Actions job. The benefit here is that it'd provide you and reviewers confidence about the changes within the context of the pull request (similar to how the pre-commit checks operate).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I will save the CI test integration for a separate future PR as a careful pass of the test suite to mark and ensure skipping of cuda dependent tests and add cpu alternatives to tests that are only written for cuda as cuda enabled gpu is not offered on the free github action runners (unless I am wrong).

Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,43 @@ def forward(self, x):
return MultiOutputConv()


@pytest.fixture
def simple_discriminator():
"""
Simple discriminator model for GAN testing.
Takes concatenated input/target stack (B, 6, H, W) -> outputs score (B, 1)
Uses conv + global average pooling + linear layer.
"""
class SimpleDiscriminator(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(
in_channels=6, # stacked input + target
out_channels=16,
kernel_size=3,
padding=1,
bias=True
)
self.pool = nn.AdaptiveAvgPool2d(1) # Global average pooling
self.fc = nn.Linear(16, 1) # Output single score

def forward(self, x):
x = self.conv(x)
x = torch.relu(x)
x = self.pool(x) # (B, 16, 1, 1)
x = x.flatten(1) # (B, 16)
x = self.fc(x) # (B, 1)
return x

return SimpleDiscriminator()


@pytest.fixture
def random_stack():
"""Random stack tensor (batch=2, channels=6, height=8, width=8) for discriminator."""
return torch.randn(2, 6, 8, 8)


@pytest.fixture
def sample_inputs():
"""Create sample inputs for loss computation."""
Expand Down Expand Up @@ -208,3 +245,31 @@ def forward_pass_context_eval(forward_group, random_input, random_target, torch_
inputs=random_input.to(torch_device),
targets=random_target.to(torch_device),
)


@pytest.fixture
def disc_optimizer(simple_discriminator):
"""Create an Adam optimizer for the discriminator model."""
import torch.optim as optim
return optim.Adam(simple_discriminator.parameters(), lr=1e-3)


@pytest.fixture
def discriminator_forward_group(simple_discriminator, disc_optimizer, torch_device):
"""Create a DiscriminatorForwardGroup with the simple discriminator and optimizer."""
from virtual_stain_flow.engine.forward_groups import DiscriminatorForwardGroup
return DiscriminatorForwardGroup(
discriminator=simple_discriminator,
optimizer=disc_optimizer,
device=torch_device,
)


@pytest.fixture
def gan_orchestrator(forward_group, discriminator_forward_group):
"""Create a GANOrchestrator with generator and discriminator forward groups."""
from virtual_stain_flow.engine.orchestrators import GANOrchestrator
return GANOrchestrator(
generator_fg=forward_group,
discriminator_fg=discriminator_forward_group,
)
Loading