Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for Stable Diffusion v1.5 and SDXL, including model definitions, pipelines, and training scripts. The review identified several critical issues: the DDIMScheduler fails to support batched timesteps and has device mismatch issues, the VAE Downsample2D implementation has an incorrect padding configuration, and the training modules incorrectly use Flow Matching loss functions for standard diffusion models. Additionally, the scheduler's performance can be improved by keeping alphas_cumprod as a tensor.
|
|
||
|
|
||
| def step(self, model_output, timestep, sample, to_final=False): | ||
| alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])] |
There was a problem hiding this comment.
Indexing self.alphas_cumprod by converting the timestep tensor to a list and taking the first element (int(timestep.flatten().tolist()[0])) does not support batched timesteps. During training, timestep is typically a 1D tensor with different values for each sample. This implementation incorrectly applies the same noise level to the entire batch based on the first sample's timestep. It should be updated to use tensor indexing: self.alphas_cumprod[timestep.long()]. This issue also exists in lines 84, 90, 91, 100, and 101.
| alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])] | ||
| if isinstance(timestep, torch.Tensor): | ||
| timestep = timestep.cpu() | ||
| timestep_id = torch.argmin((self.timesteps - timestep).abs()) |
There was a problem hiding this comment.
| """ | ||
| def __init__(self, in_channels, out_channels, padding=1): | ||
| super().__init__() | ||
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) |
There was a problem hiding this comment.
The padding argument passed to the constructor is ignored because nn.Conv2d is hardcoded with padding=0. This causes the default padding=1 to behave like padding=0 since the forward method (line 404) only applies manual padding when self.padding == 0. This is inconsistent with the Downsample2D implementation in stable_diffusion_unet.py.
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding) |
| self.task_to_loss = { | ||
| "sft:data_process": lambda pipe, *args: args, | ||
| "direct_distill:data_process": lambda pipe, *args: args, | ||
| "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), |
There was a problem hiding this comment.
| self.task_to_loss = { | ||
| "sft:data_process": lambda pipe, *args: args, | ||
| "direct_distill:data_process": lambda pipe, *args: args, | ||
| "sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi), |
| self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0) | ||
| if rescale_zero_terminal_snr: | ||
| self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod) | ||
| self.alphas_cumprod = self.alphas_cumprod.tolist() |
There was a problem hiding this comment.
Converting alphas_cumprod to a list forces all subsequent operations (like math.sqrt and indexing in step, add_noise, etc.) to happen on the CPU or via slow scalar transfers. Keeping it as a torch.Tensor and moving it to the model's device would be significantly more efficient and allow for batched operations.
No description provided.