Skip to content

Support Stable Diffusion and Stable Diffusion XL#1410

Open
mi804 wants to merge 8 commits intomainfrom
sd
Open

Support Stable Diffusion and Stable Diffusion XL#1410
mi804 wants to merge 8 commits intomainfrom
sd

Conversation

@mi804
Copy link
Copy Markdown
Collaborator

@mi804 mi804 commented Apr 24, 2026

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Stable Diffusion v1.5 and SDXL, including model definitions, pipelines, and training scripts. The review identified several critical issues: the DDIMScheduler fails to support batched timesteps and has device mismatch issues, the VAE Downsample2D implementation has an incorrect padding configuration, and the training modules incorrectly use Flow Matching loss functions for standard diffusion models. Additionally, the scheduler's performance can be improved by keeping alphas_cumprod as a tensor.



def step(self, model_output, timestep, sample, to_final=False):
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Indexing self.alphas_cumprod by converting the timestep tensor to a list and taking the first element (int(timestep.flatten().tolist()[0])) does not support batched timesteps. During training, timestep is typically a 1D tensor with different values for each sample. This implementation incorrectly applies the same noise level to the entire batch based on the first sample's timestep. It should be updated to use tensor indexing: self.alphas_cumprod[timestep.long()]. This issue also exists in lines 84, 90, 91, 100, and 101.

alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
timestep_id = torch.argmin((self.timesteps - timestep).abs())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line will cause a RuntimeError due to a device mismatch if timestep is on a GPU (as it is in the pipeline) while self.timesteps is on the CPU (the default for torch.Tensor). You should ensure self.timesteps is moved to the same device as the input timestep.

"""
def __init__(self, in_channels, out_channels, padding=1):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The padding argument passed to the constructor is ignored because nn.Conv2d is hardcoded with padding=0. This causes the default padding=1 to behave like padding=0 since the forward method (line 404) only applies manual padding when self.padding == 0. This is inconsistent with the Downsample2D implementation in stable_diffusion_unet.py.

Suggested change
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=padding)

self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"direct_distill:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Stable Diffusion v1.5 is not a Flow Matching model. It is a latent diffusion model typically trained with epsilon or v-prediction loss. Using FlowMatchSFTLoss is likely incorrect for this model architecture as it may use a different interpolation scheme (linear vs SNR-based).

self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"direct_distill:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Stable Diffusion XL is not a Flow Matching model. Using FlowMatchSFTLoss is likely incorrect for this model architecture.

self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0)
if rescale_zero_terminal_snr:
self.alphas_cumprod = self.rescale_zero_terminal_snr(self.alphas_cumprod)
self.alphas_cumprod = self.alphas_cumprod.tolist()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Converting alphas_cumprod to a list forces all subsequent operations (like math.sqrt and indexing in step, add_noise, etc.) to happen on the CPU or via slow scalar transfers. Keeping it as a torch.Tensor and moving it to the model's device would be significantly more efficient and allow for batched operations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants