Skip to content

Add Single Stage Training support to core modules#5

Open
shun31y wants to merge 5 commits into
mainfrom
feature/single_stage_training
Open

Add Single Stage Training support to core modules#5
shun31y wants to merge 5 commits into
mainfrom
feature/single_stage_training

Conversation

@shun31y

@shun31y shun31y commented Jul 15, 2025

Copy link
Copy Markdown

Why

Train end-to-end without the MaskGIT pseudo-code.

What

Add Single Stage Training (based on TA-TiTok).

How

  • LPIPS support
  • Clustering-VQ support
  • Adjust output tensor shape
  • Add ReconstructionLossSingleStage to Loss module

@shun31y shun31y requested review from Copilot and ensan-hcl July 15, 2025 01:50
@shun31y shun31y self-assigned this Jul 15, 2025

This comment was marked as outdated.

@shun31y shun31y requested a review from Copilot July 16, 2025 12:25

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds configuration flags and modules to support single-stage end-to-end training, clustering-based VQ, and enhanced perceptual loss options.

  • Introduces is_single_training flag and new ReconstructionLoss_Single_Stage to unify stage-1 and stage-2 training.
  • Adds clustering_vq support and DiagonalGaussianDistribution in VectorQuantizer.
  • Extends perceptual loss to combine LPIPS and ConvNeXt and provides 3D perceptual loss implementations.

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
utils/train_utils.py Imports and selects ReconstructionLoss_Single_Stage based on is_single_training
modeling/titok.py Adds is_single_training, clustering_vq, and only_finetune_decoder flags; updates encoder/decoder logic
modeling/quantizer/quantizer.py Introduces clustering_vq logic in forward and adds DiagonalGaussianDistribution
modeling/modules/perceptual_loss.py Implements combined LPIPS+ConvNeXt and 3D perceptual loss classes
modeling/modules/lpips.py Adds LPIPS implementation with model download and VGG backbone
modeling/modules/losses.py Adds ReconstructionLoss_Single_Stage subclass
modeling/modules/blocks.py Updates TiTok encoder/decoder reshape paths based on is_single_training
modeling/modules/init.py Exposes ReconstructionLoss_Single_Stage in package exports
Comments suppressed due to low confidence (5)

modeling/titok.py:80

  • This comment is no longer accurate after adding is_single_training and only_finetune_decoder. Update it to reflect the new flag logic or remove it.
        # This should be False for stage1 and True for stage2.

modeling/quantizer/quantizer.py:131

  • The new DiagonalGaussianDistribution class lacks unit tests. Consider adding tests for sample, mode, and kl methods to ensure correct behavior.
class DiagonalGaussianDistribution(object):

modeling/modules/lpips.py:4

  • Please replace (year) with the actual year or remove placeholder text to complete the license header.
All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates. 

modeling/quantizer/quantizer.py:80

  • The variables min_encoding_indices and d used in the clustering_vq block are not defined in this scope, which will cause a NameError. Ensure that these are computed earlier in forward or passed into this block before use.
                encoding_indices = gather(min_encoding_indices)

modeling/modules/perceptual_loss.py:543

  • Index 8 is out of range for a tensor with num_frames = 8 (valid indices 0–7). Change to a valid frame index (e.g., 7) or use num_frames - 1.
    target[:, :, 8] = torch.rand(2, 3, w, h).clamp(0, 1)

if "lpips" in model_name and "convnext_s" in model_name:
loss_config = model_name.split('-')[-2:]
self.loss_weight_lpips, self.loss_weight_convnext = float(loss_config[0]), float(loss_config[1])
print(f"self.loss_weight_lpips, self.loss_weight_convnext: {self.loss_weight_lpips}, {self.loss_weight_convnext}")

Copilot AI Jul 16, 2025

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Using print inside __init__ can clutter logs; consider using a logger or removing this debug statement.

Suggested change
print(f"self.loss_weight_lpips, self.loss_weight_convnext: {self.loss_weight_lpips}, {self.loss_weight_convnext}")
logger.debug(f"self.loss_weight_lpips, self.loss_weight_convnext: {self.loss_weight_lpips}, {self.loss_weight_convnext}")

Copilot uses AI. Check for mistakes.
@shun31y shun31y requested a review from kentosasaki-jp August 18, 2025 07:52
@ensan-hcl

Copy link
Copy Markdown
Member

@shun31y Can you upload config file?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants