Fixed-m softmax for pure-Ulysses WAN 2.2 attention (+ CFG batch-fold fix)#436
Open
syhuang22 wants to merge 1 commit into
Open
Fixed-m softmax for pure-Ulysses WAN 2.2 attention (+ CFG batch-fold fix)#436syhuang22 wants to merge 1 commit into
syhuang22 wants to merge 1 commit into
Conversation
…xed_m) with CFG batch-fold fix
2ef81ea to
9c9b707
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixed-m softmax for pure-Ulysses WAN 2.2 attention (+ CFG batch-fold fix)
Brings the fixed-m softmax fast path (as productionized in
DiffusionServing PR#28) to maxdiffusion's pure-Ulysses
ulysses_customattention, and fixes a 2x CFG-batch layout regression that previously made
pure Ulysses at CP8 uncompetitive.
Results — WAN2.2-A14B, 720p, 81 frames, CFG on, v7x-8 (CP8 pure Ulysses)
bq8448/bkv1024ulysses_custom_fixed_m)What's in the change
attention_flax.py): at CFG batch=2 XLA placedthe size-2 batch dim in the tile sublanes (
T(2,128)layout) on theUlysses all-to-all path, quadrupling the cost of every op touching those
tensors inside the scanned layers. Folding the batch into the heads axis
around the exchange (
[B,H,S,D] -> [1,B*H,S,D], identity math — each(batch, head) pair is an independent attention problem) restores the
healthy batch=1 compilation: 7.00 -> 3.43 s/step.
ulysses_custom_fixed_m): eligible heads skip theper-block running-max update (the VPU-bound part of online softmax);
outlier heads fall back to online softmax via a per-head Cauchy-Schwarz
bound gate (k-smoothing keeps the bound flush-free). The fold also makes
the gate per-(batch, head) problem, i.e. strictly tighter.
fixed-m vs online — kernel-level equivalence
Same q/k/v through both paths (production config: base2 exp, k-smoothing,
per-head mk gate), post-fold CP8 shape (10 head-problems, seq 75600,
head_dim 128):
Relative-L2 sits at the bf16 output quantization floor (2^-8 ≈ 3.9e-3):
the fixed-m/online difference is within the output dtype's inherent
precision. Forced-all (every head through the fixed-m fast path, worst
case) gives identical metrics. Matches the DiffusionServing PR#28 gate
(rel-L2 2.95e-3, measured there on f32 outputs).
Tile notes (v7, seq 75600)
block_q x block_kv x block_kv_compute x block_kv_compute_in): plain optimum6912/2048/2048/2048; fixed-moptimum shifts MXU-bound (
9472/1024/1024/512) because the VPU pressureis gone.
bkv=1024-family tiles lose ~2% at kernel level but win ~10% end-to-end:the smaller f32 score tile (34.6 MB vs 48.2 MB) leaves VMEM headroom for
the latency-hiding scheduler to overlap the scan body. Tile choices must
be validated end-to-end.