Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modeling/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .base_model import BaseModel
from .ema_model import EMAModel
from .losses import ReconstructionLoss_Stage1, ReconstructionLoss_Stage2, MLMLoss
from .losses import ReconstructionLoss_Stage1, ReconstructionLoss_Stage2, ReconstructionLoss_Single_Stage, MLMLoss
from .blocks import TiTokEncoder, TiTokDecoder, UViTBlock
from .maskgit_vqgan import Decoder as Pixel_Decoder
from .maskgit_vqgan import VectorQuantizer as Pixel_Quantizer
35 changes: 25 additions & 10 deletions modeling/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from collections import OrderedDict
import einops
from typing import Optional
from einops.layers.torch import Rearrange


class ResidualAttentionBlock(nn.Module):
Expand Down Expand Up @@ -214,6 +215,7 @@ def __init__(self, config):
self.model_size = config.model.vq_model.vit_enc_model_size
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
self.token_size = config.model.vq_model.token_size
self.is_legacy = not config.model.vq_model.is_single_training

self.width = {
"small": 512,
Expand Down Expand Up @@ -273,10 +275,15 @@ def forward(self, pixel_values, latent_tokens, needs_width_reduction=True):

latent_tokens = x[:, 1+self.grid_size**2:]
latent_tokens = self.ln_post(latent_tokens)
if needs_width_reduction:

if self.is_legacy:
latent_tokens = latent_tokens.reshape(batch_size, self.width, self.num_latent_tokens, 1)
latent_tokens = self.conv_out(latent_tokens)
latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, self.num_latent_tokens)
else:
latent_tokens = latent_tokens.reshape(batch_size, self.num_latent_tokens, self.width, 1).permute(0, 2, 1, 3)

latent_tokens = self.conv_out(latent_tokens)
latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, self.num_latent_tokens)

return latent_tokens


Expand All @@ -291,6 +298,7 @@ def __init__(self, config):
self.model_size = config.model.vq_model.vit_dec_model_size
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
self.token_size = config.model.vq_model.token_size
self.is_legacy = not config.model.vq_model.is_single_training
self.width = {
"small": 512,
"base": 768,
Expand Down Expand Up @@ -325,13 +333,20 @@ def __init__(self, config):
))
self.ln_post = nn.LayerNorm(self.width)

self.ffn = nn.Sequential(
nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
nn.Tanh(),
nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
)
self.conv_out = nn.Identity()

if self.is_legacy:
self.ffn = nn.Sequential(
nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
nn.Tanh(),
nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
)
self.conv_out = nn.Identity()
else:
self.ffn = nn.Sequential(
nn.Conv2d(self.width, self.patch_size * self.patch_size *3, 1, padding=0, bias=True),
Rearrange("b (p1 p2 c) h w -> b c (h p1) (w p2)",
p1=self.patch_size, p2=self.patch_size))
self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)

def forward(self, z_quantized):
N, C, H, W = z_quantized.shape
if self.strict_length_assertion:
Expand Down
71 changes: 71 additions & 0 deletions modeling/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,74 @@ def forward(self, inputs: torch.Tensor, targets: torch.Tensor,
# we only compute correct tokens on masked tokens
correct_tokens = ((torch.argmax(inputs, dim=1) == targets) * weights).sum(dim=1) / (weights.sum(1) + 1e-8)
return loss, {"loss": loss, "correct_tokens": correct_tokens.mean()}


class ReconstructionLoss_Single_Stage(ReconstructionLoss_Stage2):
def __init__(self, config):
super().__init__(config) # initialise Stage‑2 machinery

def _forward_generator(
self,
inputs: torch.Tensor,
reconstructions: torch.Tensor,
extra_result_dict: Mapping[Text, torch.Tensor],
global_step,
) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
# ------------------------------------------------------------------
# reconstruction + perceptual
# ------------------------------------------------------------------
if self.reconstruction_loss == "l1":
recon_loss = F.l1_loss(inputs, reconstructions, reduction="mean")
elif self.reconstruction_loss == "l2":
recon_loss = F.mse_loss(inputs, reconstructions, reduction="mean")
else:
raise ValueError(
f"Unsupported reconstruction_loss '{self.reconstruction_loss}'."
)

recon_loss = recon_loss * self.reconstruction_weight
perceptual_loss = self.perceptual_loss(inputs, reconstructions).mean()

discriminator_factor = (
self.discriminator_factor
if self.should_discriminator_be_trained(global_step)
else 0.0
)
g_gan_loss = torch.zeros((), device=inputs.device)
if discriminator_factor > 0.0 and self.discriminator_weight > 0.0:
for p in self.discriminator.parameters():
p.requires_grad = False
g_gan_loss = -self.discriminator(reconstructions).mean()

d_weight = self.discriminator_weight

q_loss = extra_result_dict["quantizer_loss"]
bottleneck_term = self.quantizer_weight * q_loss
# prepare log items
bottleneck_log = {
"quantizer_loss": bottleneck_term.detach(),
"commitment_loss": extra_result_dict["commitment_loss"].detach(),
"codebook_loss": extra_result_dict["codebook_loss"].detach(),
}

total_loss = (
recon_loss
+ self.perceptual_weight * perceptual_loss
+ bottleneck_term
+ d_weight * discriminator_factor * g_gan_loss
)

loss_dict = {
"total_loss": total_loss.detach(),
"reconstruction_loss": recon_loss.detach(),
"perceptual_loss": (self.perceptual_weight * perceptual_loss).detach(),
"weighted_gan_loss": (
d_weight * discriminator_factor * g_gan_loss
).detach(),
"discriminator_factor": torch.as_tensor(discriminator_factor),
"d_weight": d_weight,
"gan_loss": g_gan_loss.detach(),
**bottleneck_log,
}

return total_loss, loss_dict
184 changes: 184 additions & 0 deletions modeling/modules/lpips.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""This file contains code for LPIPS.

This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.

Reference:
https://github.com/richzhang/PerceptualSimilarity/
https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/lpips.py
https://github.com/CompVis/taming-transformers/blob/master/taming/util.py
"""

import os
import hashlib
import requests
from collections import namedtuple
from tqdm import tqdm

import torch
import torch.nn as nn

from torchvision import models

_LPIPS_MEAN = [-0.030, -0.088, -0.188]
_LPIPS_STD = [0.458, 0.448, 0.450]


URL_MAP = {
"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
}

CKPT_MAP = {
"vgg_lpips": "vgg.pth"
}

MD5_MAP = {
"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
}


def download(url, local_path, chunk_size=1024):
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
with requests.get(url, stream=True) as r:
total_size = int(r.headers.get("content-length", 0))
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
with open(local_path, "wb") as f:
for data in r.iter_content(chunk_size=chunk_size):
if data:
f.write(data)
pbar.update(chunk_size)


def md5_hash(path):
with open(path, "rb") as f:
content = f.read()
return hashlib.md5(content).hexdigest()


def get_ckpt_path(name, root, check=False):
assert name in URL_MAP
path = os.path.join(root, CKPT_MAP[name])
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
download(URL_MAP[name], path)
md5 = md5_hash(path)
assert md5 == MD5_MAP[name], md5
return path


class LPIPS(nn.Module):
# Learned perceptual metric.
def __init__(self, use_dropout=True):
super().__init__()
self.scaling_layer = ScalingLayer()
self.chns = [64, 128, 256, 512, 512] # vg16 features
self.net = vgg16(pretrained=True, requires_grad=False)
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.load_pretrained()
for param in self.parameters():
param.requires_grad = False

def load_pretrained(self):
workspace = os.environ.get('WORKSPACE', '')
VGG_PATH = get_ckpt_path("vgg_lpips", os.path.join(workspace, "models/vgg_lpips.pth"), check=True)
self.load_state_dict(torch.load(VGG_PATH, map_location=torch.device("cpu")), strict=False)

def forward(self, input, target):
# Notably, the LPIPS w/ pre-trained weights expect the input in the range of [-1, 1].
# However, our codebase assumes all inputs are in range of [0, 1], and thus a scaling is needed.
input = input * 2. - 1.
target = target * 2. - 1.
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
outs0, outs1 = self.net(in0_input), self.net(in1_input)
feats0, feats1, diffs = {}, {}, {}
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
for kk in range(len(self.chns)):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
val = res[0]
for l in range(1, len(self.chns)):
val += res[l]
return val


class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer("shift", torch.Tensor(_LPIPS_MEAN)[None, :, None, None])
self.register_buffer("scale", torch.Tensor(_LPIPS_STD)[None, :, None, None])

def forward(self, inp):
return (inp - self.shift) / self.scale


class NetLinLayer(nn.Module):
"""A single linear layer which does a 1x1 conv."""

def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = (
[
nn.Dropout(),
]
if (use_dropout)
else []
)
layers += [
nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),
]
self.model = nn.Sequential(*layers)


class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False

def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out


def normalize_tensor(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
return x / (norm_factor + eps)


def spatial_average(x, keepdim=True):
return x.mean([2, 3], keepdim=keepdim)
Loading