Speed up WAN 2.2 checkpoint load and compile (no output change)#438
Open
syhuang22 wants to merge 1 commit into
Open
Speed up WAN 2.2 checkpoint load and compile (no output change)#438syhuang22 wants to merge 1 commit into
syhuang22 wants to merge 1 commit into
Conversation
…ICI-staged device_put, parallel component loads, 2-step warmup
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Cuts WAN 2.2 inference startup dramatically without changing any generated output. On v7x-8 (dp2-cp4 2D-ring, 720p / 81f / 40 steps):
The generated video is byte-for-byte identical to the pre-change output (same seed, same md5), so this is a pure startup-latency change with no numerical impact.
What changed
1. Rewrote the transformer weight loader (
models/wan/wan_utils.py)The diffusers checkpoints are FP32. The previous loader upcast every tensor through
torch.float()and, for scanned layers, rebuilt the full stacked(num_layers, ...)array once per layer viajnp.ndarray.at[block].set(...)— an O(num_layers²) copy chain (~300s per 14B transformer). The new loader:uint16view intoml_dtypes.bfloat16, no FP32 round-trip),(num_layers, ...)numpy buffer per param and writes each layer's row in place,weights_dtype, with the same norm/embedder FP32 exclusions) into that single copy,Result: ~7s per transformer, output verified bitwise-identical to the old loader across the full weight tree.
2. Staged host→device transfer (
pipelines/wan/wan_pipeline.py)Replicated params are the bulk of the bytes. A direct replicated
device_putbroadcasts the same bytes over every device's (slow) host link. Instead, large replicated params are staged sharded along dim 0 (each device receives only 1/N of the bytes over PCIe) and replicated on-device via an identityjitover ICI. Measured ~4.6× faster for the replicated weight set (45.8s → ~10s).3. Parallel component loading (
pipelines/wan/wan_pipeline_2_2.py,wan_pipeline_i2v_2p2.py)VAE / tokenizer / text-encoder / scheduler load on a background thread, and the two 14B transformers load concurrently, so the small components are hidden behind transformer conversion. A shared-blob metadata lock is added because WAN 2.2's
transformerandtransformer_2have byte-identicalconfig.json/index.json(one blob in the HF hub cache), whichhf_hubmay rewrite on revalidation — concurrent readers could otherwise observe a half-written file.4. Two-step compile warmup (
generate_wan.py)The first (warmup) call now runs only 2 denoising steps instead of a full run. Step 0 uses the high-noise transformer and step 1 crosses the boundary to the low-noise one, so every executable of the full run (both transformers, text encoder, VAE decode) still gets compiled — the step count only changes the Python loop trip count, not any traced shape. The exported video now comes from the timed full run rather than the warmup. When
compile_text_encoderis set, the text encoder'storch.compilegraph is also warmed during loading (behind the load thread) instead of inside the first pipeline call.Verification
device_putoutput verified bitwise on TPU.Independence
Touches only the WAN load/compile/orchestration path; no attention-kernel files. Disjoint from #436 (fixed-m attention) — no shared files, no shared symbols.