Skip to content

Speed up WAN 2.2 checkpoint load and compile (no output change)#438

Open
syhuang22 wants to merge 1 commit into
AI-Hypercomputer:mainfrom
syhuang22:reduce-load-compile-clean
Open

Speed up WAN 2.2 checkpoint load and compile (no output change)#438
syhuang22 wants to merge 1 commit into
AI-Hypercomputer:mainfrom
syhuang22:reduce-load-compile-clean

Conversation

@syhuang22

Copy link
Copy Markdown
Collaborator

Summary

Cuts WAN 2.2 inference startup dramatically without changing any generated output. On v7x-8 (dp2-cp4 2D-ring, 720p / 81f / 40 steps):

Phase Before After
Load (checkpoint) 141s (warm) / 585s (cold) ~53s
Compile (first full run) 144s ~21s
Denoise 109.8s 109.1s (unchanged)

The generated video is byte-for-byte identical to the pre-change output (same seed, same md5), so this is a pure startup-latency change with no numerical impact.

What changed

1. Rewrote the transformer weight loader (models/wan/wan_utils.py)
The diffusers checkpoints are FP32. The previous loader upcast every tensor through torch.float() and, for scanned layers, rebuilt the full stacked (num_layers, ...) array once per layer via jnp.ndarray.at[block].set(...) — an O(num_layers²) copy chain (~300s per 14B transformer). The new loader:

  • reads tensors zero-copy from the safetensors mmap (bf16 via a uint16 view into ml_dtypes.bfloat16, no FP32 round-trip),
  • pre-allocates one (num_layers, ...) numpy buffer per param and writes each layer's row in place,
  • fuses the final dtype cast (weights_dtype, with the same norm/embedder FP32 exclusions) into that single copy,
  • converts shards concurrently in a thread pool.

Result: ~7s per transformer, output verified bitwise-identical to the old loader across the full weight tree.

2. Staged host→device transfer (pipelines/wan/wan_pipeline.py)
Replicated params are the bulk of the bytes. A direct replicated device_put broadcasts the same bytes over every device's (slow) host link. Instead, large replicated params are staged sharded along dim 0 (each device receives only 1/N of the bytes over PCIe) and replicated on-device via an identity jit over ICI. Measured ~4.6× faster for the replicated weight set (45.8s → ~10s).

3. Parallel component loading (pipelines/wan/wan_pipeline_2_2.py, wan_pipeline_i2v_2p2.py)
VAE / tokenizer / text-encoder / scheduler load on a background thread, and the two 14B transformers load concurrently, so the small components are hidden behind transformer conversion. A shared-blob metadata lock is added because WAN 2.2's transformer and transformer_2 have byte-identical config.json / index.json (one blob in the HF hub cache), which hf_hub may rewrite on revalidation — concurrent readers could otherwise observe a half-written file.

4. Two-step compile warmup (generate_wan.py)
The first (warmup) call now runs only 2 denoising steps instead of a full run. Step 0 uses the high-noise transformer and step 1 crosses the boundary to the low-noise one, so every executable of the full run (both transformers, text encoder, VAE decode) still gets compiled — the step count only changes the Python loop trip count, not any traced shape. The exported video now comes from the timed full run rather than the warmup. When compile_text_encoder is set, the text encoder's torch.compile graph is also warmed during loading (behind the load thread) instead of inside the first pipeline call.

Verification

  • New loader output is bitwise-equal to the previous loader over the full transformer weight tree (checked twice).
  • Staged device_put output verified bitwise on TPU.
  • Denoise time 109.1s across 5 consecutive full runs (baseline 109.8s — no regression).
  • Final generated video md5-identical to the pre-change output.
  • Wan checkpointer / kv-cache / vace unit tests pass.

Independence

Touches only the WAN load/compile/orchestration path; no attention-kernel files. Disjoint from #436 (fixed-m attention) — no shared files, no shared symbols.

…ICI-staged device_put, parallel component loads, 2-step warmup
@syhuang22 syhuang22 requested a review from entrpn as a code owner July 3, 2026 07:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant