Dmd2 for flux2-klein-base-4B#1503
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces DMD2 (Distribution Matching Distillation 2) training support for Flux2, adding configuration, loss functions, a discriminator, and training scripts, as well as extending the Flux2 image pipeline to support intermediate feature extraction. The review feedback highlights critical issues regarding PyTorch DDP compatibility, specifically pointing out that accessing custom attributes on DDP-wrapped models will raise errors and that dynamically toggling requires_grad during training breaks DDP bucket synchronization. Additionally, the feedback identifies a 1000x scaling discrepancy in the timestep embedding during feature extraction and a potential crash in the group-normalization channel division logic.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| model = prepared[0] | ||
| prepared_tail = list(prepared[1:]) | ||
| optimizers["student"] = prepared_tail.pop(0) | ||
| optimizers["fake_score"] = prepared_tail.pop(0) | ||
| optimizers["student_scheduler"] = prepared_tail.pop(0) | ||
| optimizers["fake_score_scheduler"] = prepared_tail.pop(0) | ||
| if model.discriminator is not None: | ||
| optimizers["discriminator"] = prepared_tail.pop(0) | ||
| optimizers["discriminator_scheduler"] = prepared_tail.pop(0) |
There was a problem hiding this comment.
After accelerator.prepare, the model is wrapped in a DistributedDataParallel (DDP) container. Accessing custom attributes like model.discriminator directly on the wrapped model will raise an AttributeError under DDP. We should unwrap the model using accelerator.unwrap_model(model) before accessing its custom attributes.
| model = prepared[0] | |
| prepared_tail = list(prepared[1:]) | |
| optimizers["student"] = prepared_tail.pop(0) | |
| optimizers["fake_score"] = prepared_tail.pop(0) | |
| optimizers["student_scheduler"] = prepared_tail.pop(0) | |
| optimizers["fake_score_scheduler"] = prepared_tail.pop(0) | |
| if model.discriminator is not None: | |
| optimizers["discriminator"] = prepared_tail.pop(0) | |
| optimizers["discriminator_scheduler"] = prepared_tail.pop(0) | |
| model = prepared[0] | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| prepared_tail = list(prepared[1:]) | |
| optimizers["student"] = prepared_tail.pop(0) | |
| optimizers["fake_score"] = prepared_tail.pop(0) | |
| optimizers["student_scheduler"] = prepared_tail.pop(0) | |
| optimizers["fake_score_scheduler"] = prepared_tail.pop(0) | |
| if unwrapped_model.discriminator is not None: | |
| optimizers["discriminator"] = prepared_tail.pop(0) | |
| optimizers["discriminator_scheduler"] = prepared_tail.pop(0) |
| if iteration % config.student_update_freq == 0 and config.student_grad_clip_norm is not None and config.student_grad_clip_norm > 0: | ||
| accelerator.clip_grad_norm_(_trainable_params(model.pipe.dit), config.student_grad_clip_norm) |
There was a problem hiding this comment.
Accessing model.pipe directly on the DDP-wrapped model will raise an AttributeError. We should use the unwrapped model reference to access the underlying student pipeline.
| if iteration % config.student_update_freq == 0 and config.student_grad_clip_norm is not None and config.student_grad_clip_norm > 0: | |
| accelerator.clip_grad_norm_(_trainable_params(model.pipe.dit), config.student_grad_clip_norm) | |
| if iteration % config.student_update_freq == 0 and config.student_grad_clip_norm is not None and config.student_grad_clip_norm > 0: | |
| accelerator.clip_grad_norm_(_trainable_params(unwrapped_model.pipe.dit), config.student_grad_clip_norm) |
| timestep = timestep.to(latents.dtype) | ||
| guidance = None if embedded_guidance is None else embedded_guidance.to(latents.dtype) * 1000 | ||
| temb = dit.time_guidance_embed(timestep, guidance) |
There was a problem hiding this comment.
In the feature-return path of model_fn_flux2, timestep is passed directly to dit.time_guidance_embed without being divided by 1000, whereas in the standard path it is divided by 1000. This causes a 1000x scaling discrepancy in the time embeddings, leading to incorrect feature extraction. We should divide timestep by 1000 before passing it to time_guidance_embed.
| timestep = timestep.to(latents.dtype) | |
| guidance = None if embedded_guidance is None else embedded_guidance.to(latents.dtype) * 1000 | |
| temb = dit.time_guidance_embed(timestep, guidance) | |
| timestep = timestep.to(latents.dtype) / 1000 | |
| guidance = None if embedded_guidance is None else embedded_guidance.to(latents.dtype) * 1000 | |
| temb = dit.time_guidance_embed(timestep, guidance) |
| def set_dmd2_train_phase(module, student_phase: bool): | ||
| module.pipe.dit.train(student_phase) | ||
| for name, param in module.pipe.dit.named_parameters(): | ||
| param.requires_grad = student_phase and name in module._dmd2_student_param_names | ||
|
|
||
| module.fake_score.train(not student_phase) | ||
| for name, param in module.fake_score.named_parameters(): | ||
| param.requires_grad = (not student_phase) and name in module._dmd2_fake_score_param_names | ||
|
|
||
| if module.discriminator is not None: | ||
| module.discriminator.train(not student_phase) | ||
| module.discriminator.requires_grad_(not student_phase) |
There was a problem hiding this comment.
Dynamically toggling requires_grad on and off during the training loop (inside set_dmd2_train_phase) breaks PyTorch DDP's bucket synchronization mechanism, leading to silent failures or runtime errors. Instead, we should keep requires_grad=True statically for all trainable parameters and only toggle .train() / .eval() mode on the submodules.
| def set_dmd2_train_phase(module, student_phase: bool): | |
| module.pipe.dit.train(student_phase) | |
| for name, param in module.pipe.dit.named_parameters(): | |
| param.requires_grad = student_phase and name in module._dmd2_student_param_names | |
| module.fake_score.train(not student_phase) | |
| for name, param in module.fake_score.named_parameters(): | |
| param.requires_grad = (not student_phase) and name in module._dmd2_fake_score_param_names | |
| if module.discriminator is not None: | |
| module.discriminator.train(not student_phase) | |
| module.discriminator.requires_grad_(not student_phase) | |
| def set_dmd2_train_phase(module, student_phase: bool): | |
| module.pipe.dit.train(student_phase) | |
| module.fake_score.train(not student_phase) | |
| if module.discriminator is not None: | |
| module.discriminator.train(not student_phase) |
| def _get_optimal_groups(num_channels): | ||
| if num_channels <= 32: | ||
| groups = max(1, num_channels // 4) | ||
| else: | ||
| groups = 32 | ||
| while groups > 1 and num_channels % groups != 0: | ||
| groups -= 1 | ||
| assert num_channels % groups == 0, f"{num_channels} not divisible by {groups}" | ||
| return groups |
There was a problem hiding this comment.
In _get_optimal_groups, if num_channels <= 32 and is not divisible by num_channels // 4 (e.g., num_channels = 14), the assertion assert num_channels % groups == 0 will fail, causing a crash. We should use a more robust logic to find the largest valid divisor of num_channels from a standard set of group sizes.
def _get_optimal_groups(num_channels):
for groups in [32, 16, 8, 4, 2, 1]:
if num_channels >= groups and num_channels % groups == 0:
return groups
return 1
No description provided.