Add MagCache inference acceleration for Wan2.2 (T2V + I2V)#433
Add MagCache inference acceleration for Wan2.2 (T2V + I2V)#433HadarIngonyama wants to merge 1 commit into
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Perseus14
left a comment
There was a problem hiding this comment.
Thanks for the PR. I have added some comments. PTAL!
Please run a manual linting test.
pip install pylint pyink==23.10.0 pytype==2024.2.27
pyink src/maxdiffusion --check --diff --color --pyink-indentation=2 --line-length=125
Additionally could you also squash the commits?
944f130 to
4669443
Compare
|
LGTM! I was able to reproduce the results for both T2V and I2V. @prishajain1 Could you also take a look? @HadarIngonyama Please rebase with main. Your current PR includes linting changes to ltx2 lora, which is fixed in main. |
mbohlool
left a comment
There was a problem hiding this comment.
two minor comments, otherwise looks good.
| use_cfg_cache: bool = False, | ||
| use_sen_cache: bool = False, | ||
| use_kv_cache: bool = False, | ||
| use_magcache: bool = False, |
There was a problem hiding this comment.
default values in next 4 lines are inconsistant with wan_pipeline_2_2.py, can you use the same default values please?
|
|
||
| cache_count = 0 | ||
| for step in range(num_inference_steps): | ||
| t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] |
There was a problem hiding this comment.
can you create this outside of the loop and access its index inside of the loop, the same as wan_pipeline_2_2.py?
Port MagCache acceleration to the Wan 2.2 dual-transformer pipelines. - wan_pipeline_2_2.py / wan_pipeline_i2v_2p2.py: MagCache skip path for the dual-transformer loop — per-phase forced-compute (retention) zones, residual reset at the high->low boundary, and a single interleaved mag_ratios_base curve spanning both phases (indexed by global step). I2V additionally handles the image condition (concat with latents + BFHWC<->BCFHW transposes). - generate_wan.py: pass use_magcache / magcache_thresh / magcache_K / retention_ratio through to the 2.2 pipelines. - base_wan_27b.yml (T2V): flow_shift=12.0 + MagCache params + official A14B mag_ratios_base. base_wan_i2v_27b.yml (I2V): boundary_ratio=0.900, flow_shift=5.0 + official I2V-A14B mag_ratios_base. - tests: wan2_2_magcache_test.py (host-side validation/schedule/core tests + a TPU-only end-to-end smoke test). - README: document MagCache for Wan 2.2 (settings, speedup, SSIM/PSNR).
4669443 to
bf91020
Compare
Add MagCache inference acceleration for Wan2.2 (T2V + I2V)
Summary
This PR adds MagCache support to the Wan2.2 dual-transformer pipelines (both T2V and I2V), extending the existing Wan2.1 T2V MagCache support. MagCache skips the transformer blocks and reuses the cached block residual when the accumulated magnitude-ratio error stays below a threshold, using a precalibrated per-step
mag_ratios_basecurve so the skip schedule is deterministic (no data-dependent control flow, TPU/JIT friendly).Measured speedups vs the dense render: ~1.82× for T2V and ~1.75× for I2V, with visually near-indistinguishable output.
What's included
wan_pipeline_2_2.py): MagCache skip path for the dual transformer — a single interleavedmag_ratios_basecurve spanning both the high-noise and low-noise phases, a per-phase forced-compute (retention) zone, and an explicit cached-residual reset at the high→low transformer boundary.wan_pipeline_i2v_2p2.py): the same skip path adapted for the image-conditioned pipeline (image condition concatenated with the latents, with the required BFHWC↔BCFHW transposes).generate_wan.py: threadsuse_magcache/magcache_thresh/magcache_K/retention_ratiothrough to both 2.2 pipelines.base_wan_27b.yml(T2V): MagCache params + officialmag_ratios_base, andflow_shiftdefaulted to 12.0 (see note below).base_wan_i2v_27b.yml(I2V): MagCache params + official I2V-A14Bmag_ratios_base, withboundary_ratio=0.900to align the high→low switch with the curve (flow_shiftstays at the I2V default of 5.0).wan2_2_magcache_test.py): host-side validation/schedule/core tests plus a TPU-only end-to-end smoke test.Important:
flow_shiftalignmentmag_ratios_baseis calibrated against where the high→low noise boundary lands, whichflow_shiftcontrols. Wan2.2 T2V requiresflow_shift=12.0(the official A14B sampling shift) — the previous default of5.0moved the boundary several steps out of phase, so MagCache skipped at the wrong steps and quality dropped. This PR sets the correct default, which also fixes the off-spec dense baseline. For I2V the official shift is5.0, paired withboundary_ratio=0.900.Results
Measured on a v7x (720×1280, 81 frames, 40 steps), reference = dense (
use_magcache=False) render with the same seed/config:flow_shift=12.0,thresh=0.04,K=2flow_shift=5.0,boundary_ratio=0.900,thresh=0.06,K=2The reference-based metrics mostly reflect trajectory divergence — caching nudges the sampler onto a different but equally plausible sample — rather than visible degradation; cached clips are visually hard to tell apart from dense. I2V scores higher because the image conditioning anchors the trajectory. Recalibrating
mag_ratios_basefor a specific dtype/attention kernel can tighten the metric gap further.Usage
MagCache is one of several mutually-exclusive caching strategies (CFG Cache, SenCache, MagCache) — enable only one at a time.
Testing
wan2_2_magcache_test.pyhost-side tests pass (schedule/core logic).