Skip to content

Fix prior preservation shape mismatch in TimestepEmbedding and PixArtAlphaTextProjection#13947

Open
Liauuu wants to merge 2 commits into
huggingface:mainfrom
Liauuu:fix-flux-mat-mul
Open

Fix prior preservation shape mismatch in TimestepEmbedding and PixArtAlphaTextProjection#13947
Liauuu wants to merge 2 commits into
huggingface:mainfrom
Liauuu:fix-flux-mat-mul

Conversation

@Liauuu

@Liauuu Liauuu commented Jun 14, 2026

Copy link
Copy Markdown

Summary

This PR fixes a RuntimeError: mat1 and mat2 shapes cannot be multiplied that occurs during FLUX DreamBooth training when --with_prior_preservation is enabled (reported in #12494).

The failure happens inside PixArtAlphaTextProjection.linear_1 (via CombinedTimestepGuidanceTextProjEmbeddings / CombinedTimestepTextProjEmbeddings in the FLUX transformer), with shapes such as (2×1536) passed to a layer expecting (N×768).

Root cause

Prior preservation training concatenates instance and class samples so the transformer can process both in a single forward pass. In affected code paths, pooled text embeddings are sometimes concatenated along the feature dimension (dim=-1) instead of the batch dimension (dim=0).

For example:

  • Expected: [batch_size * 2, 768] (instance + class stacked on batch)
  • Observed: [batch_size, 1536] (instance + class concatenated horizontally)

Because PixArtAlphaTextProjection and TimestepEmbedding define linear_1 with in_features=768 (or the configured channel size), a last dimension of 1536 (768 × 2) triggers the matmul failure:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2×1536 and 768×3072)

This matches the stack trace in #12494 from train_dreambooth_lora_flux_advanced.py with --with_prior_preservation.

Solution

Introduce a small shared helper, _unstack_doubled_features, in src/diffusers/models/embeddings.py:

def _unstack_doubled_features(tensor, expected_features):
    if tensor.shape[-1] == expected_features * 2:
        first, second = tensor.chunk(2, dim=-1)
        return torch.cat([first, second], dim=0)
    return tensor

Apply it at the start of:

  • TimestepEmbedding.forward — on sample, and on condition when cond_proj is used
  • PixArtAlphaTextProjection.forward — on caption (pooled projections)

When the last dimension is exactly 2 × in_features, the helper splits on dim=-1 and re-stacks on dim=0, converting [B, 2F] → [2B, F] before the linear layers run. Downstream modules then receive batch-doubled embeddings as intended for prior preservation.

When inputs already have the correct shape ([B, F]), the helper is a no-op, so normal inference and training without prior preservation are unchanged.

Changes

  • src/diffusers/models/embeddings.py
    • Add _unstack_doubled_features
    • Call it from TimestepEmbedding.forward and PixArtAlphaTextProjection.forward

Test plan

Fixes #12494

Liauuu and others added 2 commits June 14, 2026 21:00
…s feature dim

When --with_prior_preservation is enabled during FLUX DreamBooth training,
pooled text projections can arrive with horizontally concatenated features
(e.g. [2, 1536] instead of [4, 768]), causing a RuntimeError in
PixArtAlphaTextProjection and TimestepEmbedding linear layers.

Add a shared _unstack_doubled_features helper that detects a last-dimension
exactly 2x in_features, splits on dim=-1, and re-stacks on dim=0 before
the linear projections. Normal inputs pass through unchanged.

Fixes huggingface#12494

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@Liauuu Liauuu force-pushed the fix-flux-mat-mul branch from a8aecf0 to eec51dc Compare June 14, 2026 12:00
@github-actions github-actions Bot added size/S PR with diff < 50 LOC and removed lora tests utils size/M PR with diff < 200 LOC labels Jun 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

fixes-issue models size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

i think there is something wrong with new/latest scripts. RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x1536 and 768x3072)

1 participant