Skip to content

Fixed-m softmax for pure-Ulysses WAN 2.2 attention (+ CFG batch-fold fix)#436

Open
syhuang22 wants to merge 1 commit into
AI-Hypercomputer:mainfrom
syhuang22:fixed-m-ulysses
Open

Fixed-m softmax for pure-Ulysses WAN 2.2 attention (+ CFG batch-fold fix)#436
syhuang22 wants to merge 1 commit into
AI-Hypercomputer:mainfrom
syhuang22:fixed-m-ulysses

Conversation

@syhuang22

@syhuang22 syhuang22 commented Jul 2, 2026

Copy link
Copy Markdown
Collaborator

Fixed-m softmax for pure-Ulysses WAN 2.2 attention (+ CFG batch-fold fix)

Brings the fixed-m softmax fast path (as productionized in
DiffusionServing PR#28) to maxdiffusion's pure-Ulysses ulysses_custom
attention, and fixes a 2x CFG-batch layout regression that previously made
pure Ulysses at CP8 uncompetitive.

Results — WAN2.2-A14B, 720p, 81 frames, CFG on, v7x-8 (CP8 pure Ulysses)

config s/step (denoise) vs baseline
before this PR (cp8, CFG) 7.00
+ CFG batch-fold fix 3.43 −51%
+ tile bq8448/bkv1024 3.10 −10%
+ fixed-m (ulysses_custom_fixed_m) 2.95 −5%

What's in the change

  1. CFG batch-fold fix (attention_flax.py): at CFG batch=2 XLA placed
    the size-2 batch dim in the tile sublanes (T(2,128) layout) on the
    Ulysses all-to-all path, quadrupling the cost of every op touching those
    tensors inside the scanned layers. Folding the batch into the heads axis
    around the exchange ([B,H,S,D] -> [1,B*H,S,D], identity math — each
    (batch, head) pair is an independent attention problem) restores the
    healthy batch=1 compilation: 7.00 -> 3.43 s/step.
  2. fixed-m softmax (ulysses_custom_fixed_m): eligible heads skip the
    per-block running-max update (the VPU-bound part of online softmax);
    outlier heads fall back to online softmax via a per-head Cauchy-Schwarz
    bound gate (k-smoothing keeps the bound flush-free). The fold also makes
    the gate per-(batch, head) problem, i.e. strictly tighter.

fixed-m vs online — kernel-level equivalence

Same q/k/v through both paths (production config: base2 exp, k-smoothing,
per-head mk gate), post-fold CP8 shape (10 head-problems, seq 75600,
head_dim 128):

Metric Value
Absolute max error 2.441e-04
Mean absolute error (MAE) 1.320e-05
Mean error (bias) −3.057e-08 ≈ 0
RMSE 2.120e-05
Relative-L2 3.491e-03

Relative-L2 sits at the bf16 output quantization floor (2^-8 ≈ 3.9e-3):
the fixed-m/online difference is within the output dtype's inherent
precision. Forced-all (every head through the fixed-m fast path, worst
case) gives identical metrics. Matches the DiffusionServing PR#28 gate
(rel-L2 2.95e-3, measured there on f32 outputs).

Tile notes (v7, seq 75600)

  • 106-combo 4-D sweep (block_q x block_kv x block_kv_compute x block_kv_compute_in): plain optimum 6912/2048/2048/2048; fixed-m
    optimum shifts MXU-bound (9472/1024/1024/512) because the VPU pressure
    is gone.
  • bkv=1024-family tiles lose ~2% at kernel level but win ~10% end-to-end:
    the smaller f32 score tile (34.6 MB vs 48.2 MB) leaves VMEM headroom for
    the latency-hiding scheduler to overlap the scan body. Tile choices must
    be validated end-to-end.

@syhuang22 syhuang22 requested a review from entrpn as a code owner July 2, 2026 01:38
@syhuang22 syhuang22 requested review from Perseus14 and csgoogle July 2, 2026 01:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant