Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ def get_git_commit_hash():
jax.config.update("jax_use_shardy_partitioner", True)


def call_pipeline(config, pipeline, prompt, negative_prompt):
def call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=None):
model_key = config.model_name
model_type = config.model_type
if num_inference_steps is None:
num_inference_steps = config.num_inference_steps
if model_type == "I2V":
image = load_image(config.image_url)
if model_key == WAN2_1:
Expand All @@ -102,7 +104,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
num_inference_steps=num_inference_steps,
guidance_scale=config.guidance_scale,
use_magcache=config.use_magcache,
magcache_thresh=config.magcache_thresh,
Expand All @@ -118,7 +120,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
num_inference_steps=num_inference_steps,
guidance_scale_low=config.guidance_scale_low,
guidance_scale_high=config.guidance_scale_high,
use_cfg_cache=config.use_cfg_cache,
Expand All @@ -135,7 +137,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
num_inference_steps=num_inference_steps,
guidance_scale=config.guidance_scale,
use_cfg_cache=config.use_cfg_cache,
use_magcache=config.use_magcache,
Expand All @@ -151,7 +153,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
height=config.height,
width=config.width,
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
num_inference_steps=num_inference_steps,
guidance_scale_low=config.guidance_scale_low,
guidance_scale_high=config.guidance_scale_high,
use_cfg_cache=config.use_cfg_cache,
Expand Down Expand Up @@ -282,9 +284,20 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log(
f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}"
)
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
# Warmup with 2 denoising steps instead of a full run: step 0 runs the
# high-noise transformer and step 1 crosses the boundary to the low-noise
# one (WAN 2.2), so every executable of the full run (both transformers,
# text encoder, VAE decode) gets compiled at a fraction of the cost. The
# step count only changes the Python loop trip count, not traced shapes.
warmup_steps = min(2, config.num_inference_steps)
max_logging.log(f"Compile warmup: {warmup_steps} denoising steps")
videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=warmup_steps)
if isinstance(videos, tuple):
videos = videos[0]
videos, warmup_trace = videos
max_logging.log(
"Warmup breakdown: "
+ ", ".join(f"{stage}={seconds:.1f}s" for stage, seconds in warmup_trace.items())
)

max_logging.log("===================== Model details =======================")
max_logging.log(f"model name: {config.model_name}")
Expand All @@ -299,13 +312,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log(f"compile_time: {compile_time}")
if writer and jax.process_index() == 0:
writer.add_scalar("inference/compile_time", compile_time, global_step=0)
saved_video_path = []
for i in range(len(videos)):
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
export_to_video(videos[i], video_path, fps=config.fps)
saved_video_path.append(video_path)
if config.output_dir.startswith("gs://"):
upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path)

s0 = time.perf_counter()
outputs = call_pipeline(config, pipeline, prompt, negative_prompt)
Expand All @@ -315,6 +321,13 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
videos = outputs
trace = {}
generation_time = time.perf_counter() - s0
saved_video_path = []
for i in range(len(videos)):
video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4"
export_to_video(videos[i], video_path, fps=config.fps)
saved_video_path.append(video_path)
if config.output_dir.startswith("gs://"):
upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path)
max_logging.log(f"generation_time: {generation_time}")
if writer and jax.process_index() == 0:
writer.add_scalar("inference/generation_time", generation_time, global_step=0)
Expand Down
178 changes: 129 additions & 49 deletions src/maxdiffusion/models/wan/wan_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,15 @@
limitations under the License.
"""

import os
import concurrent.futures
import json
import os
import threading
import time
from typing import Callable, Optional

import ml_dtypes
import numpy as np
import torch
import jax
import jax.numpy as jnp
Expand All @@ -28,6 +35,11 @@
CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid"
WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX"

# WAN 2.2 transformer and transformer_2 have byte-identical index.json files,
# i.e. ONE blob in the HF hub cache. hf_hub revalidates and rewrites cached
# blobs, so parallel transformer loads must not resolve metadata concurrently.
_HF_METADATA_LOCK = threading.Lock()


def _tuple_str_to_int(in_tuple):
out_list = []
Expand Down Expand Up @@ -273,17 +285,29 @@ def load_wan_transformer(
num_layers: int = 40,
scan_layers: bool = True,
subfolder: str = "",
cast_dtype_fn: Optional[Callable] = None,
):
if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers)
elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers)
else:
return load_base_wan_transformer(
pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers, subfolder
pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers, subfolder, cast_dtype_fn
)


def _torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
"""Converts a CPU torch tensor to numpy without copying or upcasting.

bfloat16 has no native numpy dtype, so it is reinterpreted through uint16
into ml_dtypes.bfloat16 (bit-identical, zero-copy).
"""
if tensor.dtype == torch.bfloat16:
return tensor.view(torch.uint16).numpy().view(ml_dtypes.bfloat16)
return tensor.numpy()


def load_base_wan_transformer(
pretrained_model_name_or_path: str,
eval_shapes: dict,
Expand All @@ -292,8 +316,25 @@ def load_base_wan_transformer(
num_layers: int = 40,
scan_layers: bool = True,
subfolder: str = "",
cast_dtype_fn: Optional[Callable] = None,
):
device = jax.local_devices(backend=device)[0]
"""Loads WAN transformer weights from diffusers safetensors shards.

Fast path compared to the historical implementation:
- tensors are read zero-copy from the safetensors mmap (no bf16->f32
round trip through torch.float()),
- scanned block weights are written in place into one preallocated
(num_layers, ...) numpy buffer per param (the old jnp
``at[block].set`` rebuilt the full stacked array once per layer,
i.e. O(num_layers^2) copies),
- the optional ``cast_dtype_fn(flax_key) -> np.dtype`` casts each param
to its final dtype during this single copy, so no later full-tree
cast pass is needed,
- shard files are converted in parallel threads (numpy copies release
the GIL).
Returns a nested dict of numpy arrays (host memory).
"""
del device # weights stay in plain host numpy until device_put by the caller
filename = "diffusion_pytorch_model.safetensors.index.json"
local_files = False
if os.path.isdir(pretrained_model_name_or_path):
Expand All @@ -303,55 +344,94 @@ def load_base_wan_transformer(
local_files = True
elif hf_download:
# download the index file for sharded models.
index_file_path = hf_hub_download(
pretrained_model_name_or_path,
subfolder=subfolder,
filename=filename,
)
with jax.default_device(device):
# open the index file.
with open(index_file_path, "r") as f:
index_dict = json.load(f)
model_files = set()
for key in index_dict["weight_map"].keys():
model_files.add(index_dict["weight_map"][key])
with _HF_METADATA_LOCK:
index_file_path = hf_hub_download(
pretrained_model_name_or_path,
subfolder=subfolder,
filename=filename,
)
t_start = time.perf_counter()
with open(index_file_path, "r") as f:
index_dict = json.load(f)
model_files = sorted(set(index_dict["weight_map"].values()))

# turn all block numbers to strings just for matching weights.
# Later they will be turned back to ints.
random_flax_state_dict = _build_random_flax_state_dict(eval_shapes)
flax_state_dict = {}
dict_lock = threading.Lock()

def resolve_shard_path(model_file):
if local_files:
return os.path.join(pretrained_model_name_or_path, subfolder, model_file)
return hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)

def convert_chunk(ckpt_shard_path, chunk_keys):
# Each task opens its own handle: safetensors mmap open is cheap and
# per-thread handles avoid serializing get_tensor calls.
with safe_open(ckpt_shard_path, framework="pt") as f:
for pt_key in chunk_keys:
tensor = _torch_tensor_to_numpy(f.get_tensor(pt_key))
renamed_pt_key = rename_key(pt_key)
renamed_pt_key = _rename_common_wan_transformer_key(renamed_pt_key)
pt_tuple_key = tuple(renamed_pt_key.split("."))

model_files = list(model_files)
tensors = {}
for model_file in model_files:
if local_files:
ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file)
else:
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
# now get all the filenames for the model that need downloading
max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}")
block_index = None
if scan_layers and len(pt_tuple_key) >= 2 and pt_tuple_key[0] == "blocks":
block_index = int(pt_tuple_key[1])
pt_tuple_key = ("blocks",) + pt_tuple_key[2:]

if ckpt_shard_path is not None:
with safe_open(ckpt_shard_path, framework="pt") as f:
for k in f.keys():
tensors[k] = torch2jax(f.get_tensor(k))
flax_state_dict = {}
cpu = jax.local_devices(backend="cpu")[0]
# turn all block numbers to strings just for matching weights.
# Later they will be turned back to ints.
random_flax_state_dict = _build_random_flax_state_dict(eval_shapes)
for pt_key, tensor in tensors.items():
# The diffusers implementation explicitly describes this key in keys to be ignored.
if "norm_added_q" in pt_key:
continue
renamed_pt_key = rename_key(pt_key)
renamed_pt_key = _rename_common_wan_transformer_key(renamed_pt_key)
pt_tuple_key = tuple(renamed_pt_key.split("."))
flax_key, flax_tensor = get_key_and_value(
pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers
)
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
# rename_key_and_reshape_tensor only reindexes/transposes views; the
# single real copy happens on assignment into the target buffer below.
flax_key, flax_tensor = rename_key_and_reshape_tensor(
pt_tuple_key, tensor, random_flax_state_dict, scan_layers
)
flax_key = rename_for_nnx(flax_key)
flax_key = _tuple_str_to_int(flax_key)

validate_flax_state_dict(eval_shapes, flax_state_dict)
flax_state_dict = unflatten_dict(flax_state_dict)
del tensors
jax.clear_caches()
return flax_state_dict
if block_index is not None:
with dict_lock:
stacked = flax_state_dict.get(flax_key)
if stacked is None:
stacked_dtype = cast_dtype_fn(flax_key) if cast_dtype_fn else flax_tensor.dtype
stacked = np.empty((num_layers,) + flax_tensor.shape, dtype=stacked_dtype)
flax_state_dict[flax_key] = stacked
# Rows are disjoint per block, so concurrent writes need no lock.
# This assignment fuses transpose + dtype cast (RTNE, matching XLA
# convert semantics) into one pass.
stacked[block_index] = flax_tensor
else:
target_dtype = cast_dtype_fn(flax_key) if cast_dtype_fn else flax_tensor.dtype
# Copy (never keep a view) so nothing references the shard mmap.
value = np.array(flax_tensor, dtype=target_dtype, copy=True, order="C")
with dict_lock:
flax_state_dict[flax_key] = value

# Chunk keys per shard so conversion parallelizes across tensors, not just
# across the ~12 shard files. norm_added_q is explicitly ignored by the
# diffusers implementation.
chunk_size = 16
tasks = []
for model_file in model_files:
ckpt_shard_path = resolve_shard_path(model_file)
with safe_open(ckpt_shard_path, framework="pt") as f:
shard_keys = [k for k in f.keys() if "norm_added_q" not in k]
for i in range(0, len(shard_keys), chunk_size):
tasks.append((ckpt_shard_path, shard_keys[i : i + chunk_size]))
max_logging.log(
f"Load and port {pretrained_model_name_or_path} {subfolder}: {len(model_files)} shards, {len(tasks)} chunks"
)
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
futures = [executor.submit(convert_chunk, path, keys) for path, keys in tasks]
for future in concurrent.futures.as_completed(futures):
future.result() # re-raise conversion errors

validate_flax_state_dict(eval_shapes, flax_state_dict)
flax_state_dict = unflatten_dict(flax_state_dict)
max_logging.log(
f"Converted {subfolder or 'transformer'} weights to host arrays in {time.perf_counter() - t_start:.1f}s"
)
return flax_state_dict


def _is_motion_encoder_custom_weight(pt_key: str) -> bool:
Expand Down
Loading
Loading