Skip to content

Dmd2 for flux2-klein-base-4B#1503

Open
yjy415 wants to merge 4 commits into
modelscope:mainfrom
yjy415:dmd2-flux2-klein-4B
Open

Dmd2 for flux2-klein-base-4B#1503
yjy415 wants to merge 4 commits into
modelscope:mainfrom
yjy415:dmd2-flux2-klein-4B

Conversation

@yjy415

@yjy415 yjy415 commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

No description provided.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces DMD2 (Distribution Matching Distillation 2) training support for Flux2, adding configuration, loss functions, a discriminator, and training scripts, as well as extending the Flux2 image pipeline to support intermediate feature extraction. The review feedback highlights critical issues regarding PyTorch DDP compatibility, specifically pointing out that accessing custom attributes on DDP-wrapped models will raise errors and that dynamically toggling requires_grad during training breaks DDP bucket synchronization. Additionally, the feedback identifies a 1000x scaling discrepancy in the timestep embedding during feature extraction and a potential crash in the group-normalization channel division logic.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +728 to +736
model = prepared[0]
prepared_tail = list(prepared[1:])
optimizers["student"] = prepared_tail.pop(0)
optimizers["fake_score"] = prepared_tail.pop(0)
optimizers["student_scheduler"] = prepared_tail.pop(0)
optimizers["fake_score_scheduler"] = prepared_tail.pop(0)
if model.discriminator is not None:
optimizers["discriminator"] = prepared_tail.pop(0)
optimizers["discriminator_scheduler"] = prepared_tail.pop(0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

After accelerator.prepare, the model is wrapped in a DistributedDataParallel (DDP) container. Accessing custom attributes like model.discriminator directly on the wrapped model will raise an AttributeError under DDP. We should unwrap the model using accelerator.unwrap_model(model) before accessing its custom attributes.

Suggested change
model = prepared[0]
prepared_tail = list(prepared[1:])
optimizers["student"] = prepared_tail.pop(0)
optimizers["fake_score"] = prepared_tail.pop(0)
optimizers["student_scheduler"] = prepared_tail.pop(0)
optimizers["fake_score_scheduler"] = prepared_tail.pop(0)
if model.discriminator is not None:
optimizers["discriminator"] = prepared_tail.pop(0)
optimizers["discriminator_scheduler"] = prepared_tail.pop(0)
model = prepared[0]
unwrapped_model = accelerator.unwrap_model(model)
prepared_tail = list(prepared[1:])
optimizers["student"] = prepared_tail.pop(0)
optimizers["fake_score"] = prepared_tail.pop(0)
optimizers["student_scheduler"] = prepared_tail.pop(0)
optimizers["fake_score_scheduler"] = prepared_tail.pop(0)
if unwrapped_model.discriminator is not None:
optimizers["discriminator"] = prepared_tail.pop(0)
optimizers["discriminator_scheduler"] = prepared_tail.pop(0)

Comment on lines +751 to +752
if iteration % config.student_update_freq == 0 and config.student_grad_clip_norm is not None and config.student_grad_clip_norm > 0:
accelerator.clip_grad_norm_(_trainable_params(model.pipe.dit), config.student_grad_clip_norm)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Accessing model.pipe directly on the DDP-wrapped model will raise an AttributeError. We should use the unwrapped model reference to access the underlying student pipeline.

Suggested change
if iteration % config.student_update_freq == 0 and config.student_grad_clip_norm is not None and config.student_grad_clip_norm > 0:
accelerator.clip_grad_norm_(_trainable_params(model.pipe.dit), config.student_grad_clip_norm)
if iteration % config.student_update_freq == 0 and config.student_grad_clip_norm is not None and config.student_grad_clip_norm > 0:
accelerator.clip_grad_norm_(_trainable_params(unwrapped_model.pipe.dit), config.student_grad_clip_norm)

Comment on lines +680 to +682
timestep = timestep.to(latents.dtype)
guidance = None if embedded_guidance is None else embedded_guidance.to(latents.dtype) * 1000
temb = dit.time_guidance_embed(timestep, guidance)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In the feature-return path of model_fn_flux2, timestep is passed directly to dit.time_guidance_embed without being divided by 1000, whereas in the standard path it is divided by 1000. This causes a 1000x scaling discrepancy in the time embeddings, leading to incorrect feature extraction. We should divide timestep by 1000 before passing it to time_guidance_embed.

Suggested change
timestep = timestep.to(latents.dtype)
guidance = None if embedded_guidance is None else embedded_guidance.to(latents.dtype) * 1000
temb = dit.time_guidance_embed(timestep, guidance)
timestep = timestep.to(latents.dtype) / 1000
guidance = None if embedded_guidance is None else embedded_guidance.to(latents.dtype) * 1000
temb = dit.time_guidance_embed(timestep, guidance)

Comment on lines +210 to +221
def set_dmd2_train_phase(module, student_phase: bool):
module.pipe.dit.train(student_phase)
for name, param in module.pipe.dit.named_parameters():
param.requires_grad = student_phase and name in module._dmd2_student_param_names

module.fake_score.train(not student_phase)
for name, param in module.fake_score.named_parameters():
param.requires_grad = (not student_phase) and name in module._dmd2_fake_score_param_names

if module.discriminator is not None:
module.discriminator.train(not student_phase)
module.discriminator.requires_grad_(not student_phase)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Dynamically toggling requires_grad on and off during the training loop (inside set_dmd2_train_phase) breaks PyTorch DDP's bucket synchronization mechanism, leading to silent failures or runtime errors. Instead, we should keep requires_grad=True statically for all trainable parameters and only toggle .train() / .eval() mode on the submodules.

Suggested change
def set_dmd2_train_phase(module, student_phase: bool):
module.pipe.dit.train(student_phase)
for name, param in module.pipe.dit.named_parameters():
param.requires_grad = student_phase and name in module._dmd2_student_param_names
module.fake_score.train(not student_phase)
for name, param in module.fake_score.named_parameters():
param.requires_grad = (not student_phase) and name in module._dmd2_fake_score_param_names
if module.discriminator is not None:
module.discriminator.train(not student_phase)
module.discriminator.requires_grad_(not student_phase)
def set_dmd2_train_phase(module, student_phase: bool):
module.pipe.dit.train(student_phase)
module.fake_score.train(not student_phase)
if module.discriminator is not None:
module.discriminator.train(not student_phase)

Comment on lines +99 to +107
def _get_optimal_groups(num_channels):
if num_channels <= 32:
groups = max(1, num_channels // 4)
else:
groups = 32
while groups > 1 and num_channels % groups != 0:
groups -= 1
assert num_channels % groups == 0, f"{num_channels} not divisible by {groups}"
return groups

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _get_optimal_groups, if num_channels <= 32 and is not divisible by num_channels // 4 (e.g., num_channels = 14), the assertion assert num_channels % groups == 0 will fail, causing a crash. We should use a more robust logic to find the largest valid divisor of num_channels from a standard set of group sizes.

def _get_optimal_groups(num_channels):
    for groups in [32, 16, 8, 4, 2, 1]:
        if num_channels >= groups and num_channels % groups == 0:
            return groups
    return 1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant