Fix prior preservation shape mismatch in TimestepEmbedding and PixArtAlphaTextProjection#13947
Open
Liauuu wants to merge 2 commits into
Open
Fix prior preservation shape mismatch in TimestepEmbedding and PixArtAlphaTextProjection#13947Liauuu wants to merge 2 commits into
Liauuu wants to merge 2 commits into
Conversation
…s feature dim When --with_prior_preservation is enabled during FLUX DreamBooth training, pooled text projections can arrive with horizontally concatenated features (e.g. [2, 1536] instead of [4, 768]), causing a RuntimeError in PixArtAlphaTextProjection and TimestepEmbedding linear layers. Add a shared _unstack_doubled_features helper that detects a last-dimension exactly 2x in_features, splits on dim=-1, and re-stacks on dim=0 before the linear projections. Normal inputs pass through unchanged. Fixes huggingface#12494 Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR fixes a
RuntimeError: mat1 and mat2 shapes cannot be multipliedthat occurs during FLUX DreamBooth training when--with_prior_preservationis enabled (reported in #12494).The failure happens inside
PixArtAlphaTextProjection.linear_1(viaCombinedTimestepGuidanceTextProjEmbeddings/CombinedTimestepTextProjEmbeddingsin the FLUX transformer), with shapes such as(2×1536)passed to a layer expecting(N×768).Root cause
Prior preservation training concatenates instance and class samples so the transformer can process both in a single forward pass. In affected code paths, pooled text embeddings are sometimes concatenated along the feature dimension (
dim=-1) instead of the batch dimension (dim=0).For example:
[batch_size * 2, 768](instance + class stacked on batch)[batch_size, 1536](instance + class concatenated horizontally)Because
PixArtAlphaTextProjectionandTimestepEmbeddingdefinelinear_1within_features=768(or the configured channel size), a last dimension of1536(768 × 2) triggers the matmul failure:This matches the stack trace in #12494 from
train_dreambooth_lora_flux_advanced.pywith--with_prior_preservation.Solution
Introduce a small shared helper,
_unstack_doubled_features, insrc/diffusers/models/embeddings.py:Apply it at the start of:
TimestepEmbedding.forward— onsample, and onconditionwhencond_projis usedPixArtAlphaTextProjection.forward— oncaption(pooled projections)When the last dimension is exactly
2 × in_features, the helper splits ondim=-1and re-stacks ondim=0, converting[B, 2F] → [2B, F]before the linear layers run. Downstream modules then receive batch-doubled embeddings as intended for prior preservation.When inputs already have the correct shape (
[B, F]), the helper is a no-op, so normal inference and training without prior preservation are unchanged.Changes
src/diffusers/models/embeddings.py_unstack_doubled_featuresTimestepEmbedding.forwardandPixArtAlphaTextProjection.forwardTest plan
train_dreambooth_lora_flux_advanced.py(ortrain_dreambooth_lora_flux.py) with--with_prior_preservation,--class_data_dir, and--class_prompt; confirm training starts without the(2×1536)matmul error--with_prior_preservation; confirm behavior is unchangedFixes #12494