-
Notifications
You must be signed in to change notification settings - Fork 1
Dev gan trainer #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev gan trainer #24
Changes from 5 commits
d16d917
cea1cc3
123ece2
b2aae07
668a213
adcbb1f
f53527c
6ba51f4
146d238
c91d2c2
f053713
8e6b630
83d35d1
b4b6edb
fd312b9
df370b9
21f0d85
78876a0
68f0ca2
7fb9693
f068af2
c048d40
afb60d8
1ef9dad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
wli51 marked this conversation as resolved.
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| """ | ||
| orchestrators.py | ||
|
|
||
| Collection of orchestrator classes that manages training flow | ||
| for complex models involving multiple components, such as GANs. | ||
|
|
||
| This is constrasted with ForwardGroup classes, which handle | ||
| the forward pass and optimization of single model components. | ||
| The addition of orchestrators helps keep ForwardGroup classes simple. | ||
|
|
||
| Internally, an orchestrator manages multiple ForwardGroups | ||
| and defines coordinated training steps that involve forward passes | ||
| through several components in a specific sequence. | ||
| """ | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Callable, Optional | ||
|
|
||
| import torch | ||
| from torch import optim | ||
|
|
||
| from .forward_groups import GeneratorForwardGroup, DiscriminatorForwardGroup | ||
| from .context import Context | ||
| from .names import INPUTS, TARGETS, PREDS | ||
|
|
||
|
|
||
| @dataclass | ||
| class OrchestratedStep: | ||
| """ | ||
| Thin wrapper around orchestrator methods to present step-like objects | ||
| to trainers with the same interface as ForwardGroups, exposing: | ||
| - __call__(train=..., **batch) for forward pass, and | ||
| - .step() to step the optimizer | ||
| """ | ||
|
|
||
| forward_fn: Callable[..., Context] | ||
| optimizer: Optional[optim.Optimizer] = None | ||
|
|
||
| def __call__(self, train: bool, **batch) -> Context: | ||
| return self.forward_fn(train=train, **batch) | ||
|
|
||
| def step(self) -> None: | ||
| if self.optimizer is not None: | ||
| self.optimizer.step() | ||
|
|
||
|
|
||
| class GANOrchestrator: | ||
| """ | ||
| Orchestrator for a GAN-style setup with separate generator and discriminator | ||
| training steps. | ||
|
|
||
| Stores GeneratorForwardGroup and a DiscriminatorForwardGroup: | ||
| The GeneratorForwardGroup and DiscriminatorForwardGroup are the | ||
| simplified building blocks that conducts exclusively the forward pass of | ||
| either generator or discriminator). | ||
| The Orchestrator._discriminator_forward and Orchestrator._generator_forward | ||
| methods is uses these simple forward groups to build more complex steps that | ||
| enable GAN training, which requires a coordinated forward pass through both | ||
| the discriminator and generator. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| generator_fg: GeneratorForwardGroup, | ||
| discriminator_fg: DiscriminatorForwardGroup, | ||
| ): | ||
| """ | ||
| Initialize from already-constructed forward groups. | ||
|
|
||
| This keeps concerns separated: forward groups manage single-module | ||
| behavior; the orchestrator manages their composition. | ||
| """ | ||
| # simple forward group storage | ||
| self._gen_fg = generator_fg | ||
| self._disc_fg = discriminator_fg | ||
|
|
||
| # Public step-like objects that trainers can use directly | ||
| self.discriminator_step = OrchestratedStep( | ||
| forward_fn=self._discriminator_forward, | ||
| optimizer=self._disc_fg.optimizer, | ||
| ) | ||
| self.generator_step = OrchestratedStep( | ||
| forward_fn=self._generator_forward, | ||
| optimizer=self._gen_fg.optimizer, | ||
| ) | ||
|
|
||
| def _build_real_fake_contexts( | ||
| self, | ||
| train: bool, | ||
| gen_ctx: Context, | ||
| ) -> Context: | ||
| """ | ||
| Given a generator context containing inputs / targets / preds, | ||
| generates the real and fake stacks by concatenating the true | ||
| input with the true target or predicted target along the | ||
| channel dimension. The result stacks serve as direct inputs | ||
| to the discriminator. | ||
|
|
||
| The discriminator is then run on both stacks to produce | ||
| outputs scores of if it thinks the provided stack is real. | ||
|
|
||
| :param train: Whether the model is in training mode. | ||
| :param gen_ctx: The Context produced by the generator forward pass, | ||
| containing at least INPUTS, TARGETS, and PREDS tensors. | ||
| :return: A merged Context containing outputs from both | ||
| the real and fake discriminator passes, as well as | ||
| the original generator context. | ||
| """ | ||
| # Stack along channel dim: [inputs, targets] vs [inputs, preds] | ||
| real_stack = torch.cat([gen_ctx[INPUTS], gen_ctx[TARGETS]], dim=1) | ||
| fake_stack = torch.cat([gen_ctx[INPUTS], gen_ctx[PREDS]], dim=1) | ||
|
|
||
| # Real batch: D(x, y_true) | ||
| ctx_real = self._disc_fg(train=train, stack=real_stack) | ||
| ctx_real["real_stack"] = ctx_real.pop("stack") | ||
| ctx_real["p_real_as_real"] = ctx_real.pop("p") | ||
|
|
||
| # Fake batch: D(x, y_fake) | ||
| ctx_fake = self._disc_fg(train=train, stack=fake_stack) | ||
| ctx_fake["fake_stack"] = ctx_fake.pop("stack") | ||
| ctx_fake["p_fake_as_real"] = ctx_fake.pop("p") | ||
|
|
||
| # Merge: real info, fake info, and generator info | ||
| return ctx_real | ctx_fake | gen_ctx | ||
|
|
||
| def _discriminator_forward( | ||
| self, | ||
| train: bool, | ||
| inputs: torch.Tensor, | ||
| targets: torch.Tensor, | ||
| ) -> Context: | ||
| """ | ||
| Forward step to train only the discriminator. | ||
|
|
||
| :param train: Whether the model is in training mode. | ||
| :param inputs: The input tensor for the models. | ||
| :param targets: The target tensor for the models. | ||
| :return: A Context containing discriminator outputs for both | ||
| real and fake stacks, as well as the original generator context. | ||
| """ | ||
| # Generator is always eval for discriminator updates | ||
| gen_ctx = self._gen_fg(train=False, inputs=inputs, targets=targets) | ||
|
|
||
| return self._build_real_fake_contexts(train=train, gen_ctx=gen_ctx) | ||
|
|
||
| def _generator_forward( | ||
| self, | ||
| train: bool, | ||
| inputs: torch.Tensor, | ||
| targets: torch.Tensor, | ||
| ) -> Context: | ||
| """ | ||
| Forward step to train only the generator. | ||
|
|
||
| :param train: Whether the model is in training mode. | ||
| :param inputs: The input tensor for the models. | ||
| :param targets: The target tensor for the models. | ||
| :return: A Context containing generator outputs plus | ||
| p_fake_as_real from the discriminator. | ||
| """ | ||
|
|
||
| # Generate predictions and then run discriminator on fake stack | ||
| gen_ctx = self._gen_fg(train=train,inputs=inputs,targets=targets) | ||
| fake_stack = torch.cat([gen_ctx[INPUTS], gen_ctx[PREDS]], dim=1) | ||
| disc_ctx = self._disc_fg(train=train, stack=fake_stack) | ||
|
|
||
| # Attach discriminator score to the generator context and return. | ||
| return gen_ctx.add(p_fake_as_real=disc_ctx["p"]) | ||
|
|
||
| @property | ||
| def generator_forward_group(self) -> GeneratorForwardGroup: | ||
| return self._gen_fg | ||
|
|
||
| @property | ||
| def discriminator_forward_group(self) -> DiscriminatorForwardGroup: | ||
| return self._disc_fg |
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I double checked and found the tests all pass, nice job. Given the great stuff going on with testing you might consider making testing a part of automated jobs with a GitHub Actions job. The benefit here is that it'd provide you and reviewers confidence about the changes within the context of the pull request (similar to how the pre-commit checks operate).
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I will save the CI test integration for a separate future PR as a careful pass of the test suite to mark and ensure skipping of cuda dependent tests and add cpu alternatives to tests that are only written for cuda as cuda enabled gpu is not offered on the free github action runners (unless I am wrong). |
Uh oh!
There was an error while loading. Please reload this page.