diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index a30c4d633..ac2e0aa5c 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -89,9 +89,11 @@ def get_git_commit_hash(): jax.config.update("jax_use_shardy_partitioner", True) -def call_pipeline(config, pipeline, prompt, negative_prompt): +def call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=None): model_key = config.model_name model_type = config.model_type + if num_inference_steps is None: + num_inference_steps = config.num_inference_steps if model_type == "I2V": image = load_image(config.image_url) if model_key == WAN2_1: @@ -102,7 +104,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): height=config.height, width=config.width, num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, + num_inference_steps=num_inference_steps, guidance_scale=config.guidance_scale, use_magcache=config.use_magcache, magcache_thresh=config.magcache_thresh, @@ -118,7 +120,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): height=config.height, width=config.width, num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, + num_inference_steps=num_inference_steps, guidance_scale_low=config.guidance_scale_low, guidance_scale_high=config.guidance_scale_high, use_cfg_cache=config.use_cfg_cache, @@ -135,7 +137,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): height=config.height, width=config.width, num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, + num_inference_steps=num_inference_steps, guidance_scale=config.guidance_scale, use_cfg_cache=config.use_cfg_cache, use_magcache=config.use_magcache, @@ -151,7 +153,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): height=config.height, width=config.width, num_frames=config.num_frames, - num_inference_steps=config.num_inference_steps, + num_inference_steps=num_inference_steps, guidance_scale_low=config.guidance_scale_low, guidance_scale_high=config.guidance_scale_high, use_cfg_cache=config.use_cfg_cache, @@ -282,9 +284,20 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): max_logging.log( f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" ) - videos = call_pipeline(config, pipeline, prompt, negative_prompt) + # Warmup with 2 denoising steps instead of a full run: step 0 runs the + # high-noise transformer and step 1 crosses the boundary to the low-noise + # one (WAN 2.2), so every executable of the full run (both transformers, + # text encoder, VAE decode) gets compiled at a fraction of the cost. The + # step count only changes the Python loop trip count, not traced shapes. + warmup_steps = min(2, config.num_inference_steps) + max_logging.log(f"Compile warmup: {warmup_steps} denoising steps") + videos = call_pipeline(config, pipeline, prompt, negative_prompt, num_inference_steps=warmup_steps) if isinstance(videos, tuple): - videos = videos[0] + videos, warmup_trace = videos + max_logging.log( + "Warmup breakdown: " + + ", ".join(f"{stage}={seconds:.1f}s" for stage, seconds in warmup_trace.items()) + ) max_logging.log("===================== Model details =======================") max_logging.log(f"model name: {config.model_name}") @@ -299,13 +312,6 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): max_logging.log(f"compile_time: {compile_time}") if writer and jax.process_index() == 0: writer.add_scalar("inference/compile_time", compile_time, global_step=0) - saved_video_path = [] - for i in range(len(videos)): - video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4" - export_to_video(videos[i], video_path, fps=config.fps) - saved_video_path.append(video_path) - if config.output_dir.startswith("gs://"): - upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path) s0 = time.perf_counter() outputs = call_pipeline(config, pipeline, prompt, negative_prompt) @@ -315,6 +321,13 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): videos = outputs trace = {} generation_time = time.perf_counter() - s0 + saved_video_path = [] + for i in range(len(videos)): + video_path = f"{filename_prefix}wan_output_{config.seed}_{i}.mp4" + export_to_video(videos[i], video_path, fps=config.fps) + saved_video_path.append(video_path) + if config.output_dir.startswith("gs://"): + upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path) max_logging.log(f"generation_time: {generation_time}") if writer and jax.process_index() == 0: writer.add_scalar("inference/generation_time", generation_time, global_step=0) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 8182dee05..2904f15dc 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -14,8 +14,15 @@ limitations under the License. """ -import os +import concurrent.futures import json +import os +import threading +import time +from typing import Callable, Optional + +import ml_dtypes +import numpy as np import torch import jax import jax.numpy as jnp @@ -28,6 +35,11 @@ CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid" WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX" +# WAN 2.2 transformer and transformer_2 have byte-identical index.json files, +# i.e. ONE blob in the HF hub cache. hf_hub revalidates and rewrites cached +# blobs, so parallel transformer loads must not resolve metadata concurrently. +_HF_METADATA_LOCK = threading.Lock() + def _tuple_str_to_int(in_tuple): out_list = [] @@ -273,6 +285,7 @@ def load_wan_transformer( num_layers: int = 40, scan_layers: bool = True, subfolder: str = "", + cast_dtype_fn: Optional[Callable] = None, ): if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers) @@ -280,10 +293,21 @@ def load_wan_transformer( return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers) else: return load_base_wan_transformer( - pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers, subfolder + pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers, subfolder, cast_dtype_fn ) +def _torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: + """Converts a CPU torch tensor to numpy without copying or upcasting. + + bfloat16 has no native numpy dtype, so it is reinterpreted through uint16 + into ml_dtypes.bfloat16 (bit-identical, zero-copy). + """ + if tensor.dtype == torch.bfloat16: + return tensor.view(torch.uint16).numpy().view(ml_dtypes.bfloat16) + return tensor.numpy() + + def load_base_wan_transformer( pretrained_model_name_or_path: str, eval_shapes: dict, @@ -292,8 +316,25 @@ def load_base_wan_transformer( num_layers: int = 40, scan_layers: bool = True, subfolder: str = "", + cast_dtype_fn: Optional[Callable] = None, ): - device = jax.local_devices(backend=device)[0] + """Loads WAN transformer weights from diffusers safetensors shards. + + Fast path compared to the historical implementation: + - tensors are read zero-copy from the safetensors mmap (no bf16->f32 + round trip through torch.float()), + - scanned block weights are written in place into one preallocated + (num_layers, ...) numpy buffer per param (the old jnp + ``at[block].set`` rebuilt the full stacked array once per layer, + i.e. O(num_layers^2) copies), + - the optional ``cast_dtype_fn(flax_key) -> np.dtype`` casts each param + to its final dtype during this single copy, so no later full-tree + cast pass is needed, + - shard files are converted in parallel threads (numpy copies release + the GIL). + Returns a nested dict of numpy arrays (host memory). + """ + del device # weights stay in plain host numpy until device_put by the caller filename = "diffusion_pytorch_model.safetensors.index.json" local_files = False if os.path.isdir(pretrained_model_name_or_path): @@ -303,55 +344,94 @@ def load_base_wan_transformer( local_files = True elif hf_download: # download the index file for sharded models. - index_file_path = hf_hub_download( - pretrained_model_name_or_path, - subfolder=subfolder, - filename=filename, - ) - with jax.default_device(device): - # open the index file. - with open(index_file_path, "r") as f: - index_dict = json.load(f) - model_files = set() - for key in index_dict["weight_map"].keys(): - model_files.add(index_dict["weight_map"][key]) + with _HF_METADATA_LOCK: + index_file_path = hf_hub_download( + pretrained_model_name_or_path, + subfolder=subfolder, + filename=filename, + ) + t_start = time.perf_counter() + with open(index_file_path, "r") as f: + index_dict = json.load(f) + model_files = sorted(set(index_dict["weight_map"].values())) + + # turn all block numbers to strings just for matching weights. + # Later they will be turned back to ints. + random_flax_state_dict = _build_random_flax_state_dict(eval_shapes) + flax_state_dict = {} + dict_lock = threading.Lock() + + def resolve_shard_path(model_file): + if local_files: + return os.path.join(pretrained_model_name_or_path, subfolder, model_file) + return hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file) + + def convert_chunk(ckpt_shard_path, chunk_keys): + # Each task opens its own handle: safetensors mmap open is cheap and + # per-thread handles avoid serializing get_tensor calls. + with safe_open(ckpt_shard_path, framework="pt") as f: + for pt_key in chunk_keys: + tensor = _torch_tensor_to_numpy(f.get_tensor(pt_key)) + renamed_pt_key = rename_key(pt_key) + renamed_pt_key = _rename_common_wan_transformer_key(renamed_pt_key) + pt_tuple_key = tuple(renamed_pt_key.split(".")) - model_files = list(model_files) - tensors = {} - for model_file in model_files: - if local_files: - ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file) - else: - ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file) - # now get all the filenames for the model that need downloading - max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}") + block_index = None + if scan_layers and len(pt_tuple_key) >= 2 and pt_tuple_key[0] == "blocks": + block_index = int(pt_tuple_key[1]) + pt_tuple_key = ("blocks",) + pt_tuple_key[2:] - if ckpt_shard_path is not None: - with safe_open(ckpt_shard_path, framework="pt") as f: - for k in f.keys(): - tensors[k] = torch2jax(f.get_tensor(k)) - flax_state_dict = {} - cpu = jax.local_devices(backend="cpu")[0] - # turn all block numbers to strings just for matching weights. - # Later they will be turned back to ints. - random_flax_state_dict = _build_random_flax_state_dict(eval_shapes) - for pt_key, tensor in tensors.items(): - # The diffusers implementation explicitly describes this key in keys to be ignored. - if "norm_added_q" in pt_key: - continue - renamed_pt_key = rename_key(pt_key) - renamed_pt_key = _rename_common_wan_transformer_key(renamed_pt_key) - pt_tuple_key = tuple(renamed_pt_key.split(".")) - flax_key, flax_tensor = get_key_and_value( - pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers - ) - flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + # rename_key_and_reshape_tensor only reindexes/transposes views; the + # single real copy happens on assignment into the target buffer below. + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, tensor, random_flax_state_dict, scan_layers + ) + flax_key = rename_for_nnx(flax_key) + flax_key = _tuple_str_to_int(flax_key) - validate_flax_state_dict(eval_shapes, flax_state_dict) - flax_state_dict = unflatten_dict(flax_state_dict) - del tensors - jax.clear_caches() - return flax_state_dict + if block_index is not None: + with dict_lock: + stacked = flax_state_dict.get(flax_key) + if stacked is None: + stacked_dtype = cast_dtype_fn(flax_key) if cast_dtype_fn else flax_tensor.dtype + stacked = np.empty((num_layers,) + flax_tensor.shape, dtype=stacked_dtype) + flax_state_dict[flax_key] = stacked + # Rows are disjoint per block, so concurrent writes need no lock. + # This assignment fuses transpose + dtype cast (RTNE, matching XLA + # convert semantics) into one pass. + stacked[block_index] = flax_tensor + else: + target_dtype = cast_dtype_fn(flax_key) if cast_dtype_fn else flax_tensor.dtype + # Copy (never keep a view) so nothing references the shard mmap. + value = np.array(flax_tensor, dtype=target_dtype, copy=True, order="C") + with dict_lock: + flax_state_dict[flax_key] = value + + # Chunk keys per shard so conversion parallelizes across tensors, not just + # across the ~12 shard files. norm_added_q is explicitly ignored by the + # diffusers implementation. + chunk_size = 16 + tasks = [] + for model_file in model_files: + ckpt_shard_path = resolve_shard_path(model_file) + with safe_open(ckpt_shard_path, framework="pt") as f: + shard_keys = [k for k in f.keys() if "norm_added_q" not in k] + for i in range(0, len(shard_keys), chunk_size): + tasks.append((ckpt_shard_path, shard_keys[i : i + chunk_size])) + max_logging.log( + f"Load and port {pretrained_model_name_or_path} {subfolder}: {len(model_files)} shards, {len(tasks)} chunks" + ) + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: + futures = [executor.submit(convert_chunk, path, keys) for path, keys in tasks] + for future in concurrent.futures.as_completed(futures): + future.result() # re-raise conversion errors + + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + max_logging.log( + f"Converted {subfolder or 'transformer'} weights to host arrays in {time.perf_counter() - t_start:.1f}s" + ) + return flax_state_dict def _is_motion_encoder_custom_weight(pt_key: str) -> bool: diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 8b0493ed3..d56171ff6 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -20,6 +20,7 @@ import math import jax import jax.numpy as jnp +import threading import time from jax.sharding import Mesh, NamedSharding, PartitionSpec as P import flax @@ -61,25 +62,50 @@ } +# The two WAN 2.2 transformers share identical config.json contents, i.e. +# ONE blob file in the HF hub cache. hf_hub revalidates and rewrites cached +# blobs, so concurrent load_config calls from the parallel transformer loads +# can read a half-written file. Serialize metadata resolution. +_HF_METADATA_LOCK = threading.Lock() + +# Params whose path matches any of these keywords are kept in float32 by +# cast_with_exclusion / _final_param_dtype regardless of weights_dtype. +_CAST_EXCLUSION_KEYWORDS = ( + "norm", # For all LayerNorm/GroupNorm layers + "condition_embedder", # The entire time/text conditioning module + "scale_shift_table", # Catches both the final and the AdaLN tables +) + + +def _is_cast_excluded(path_str: str) -> bool: + return any(keyword in path_str.lower() for keyword in _CAST_EXCLUSION_KEYWORDS) + + +def _final_param_dtype(flax_key: tuple, dtype_to_cast) -> np.dtype: + """Final dtype for a param addressed by a flat key tuple (loader-side twin + of cast_with_exclusion, so weights are cast once at read time).""" + path_str = ".".join(str(k) for k in flax_key) + if _is_cast_excluded(path_str): + return np.dtype(jnp.float32) + return np.dtype(dtype_to_cast) + + def cast_with_exclusion(path, x, dtype_to_cast): """ Casts arrays to dtype_to_cast, but keeps params from any 'norm' layer in float32. """ - - exclusion_keywords = [ - "norm", # For all LayerNorm/GroupNorm layers - "condition_embedder", # The entire time/text conditioning module - "scale_shift_table", # Catches both the final and the AdaLN tables - ] - path_str = ".".join(str(k.key) if isinstance(k, jax.tree_util.DictKey) else str(k) for k in path) - if any(keyword in path_str.lower() for keyword in exclusion_keywords): + if _is_cast_excluded(path_str): # Keep LayerNorm/GroupNorm weights and biases in full precision - return x.astype(jnp.float32) + target_dtype = jnp.float32 else: # Cast everything else to dtype_to_cast - return x.astype(dtype_to_cast) + target_dtype = dtype_to_cast + if x.dtype == np.dtype(target_dtype): + # Already final (e.g. cast during weight loading) - avoid a full copy. + return x + return x.astype(target_dtype) def basic_clean(text): @@ -148,7 +174,8 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): if restored_checkpoint: wan_config = restored_checkpoint["wan_config"] else: - wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder) + with _HF_METADATA_LOCK: + wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder) if config.model_type == "I2V": # WAN 2.1 I2V uses image embeddings via CLIP encoder (image_dim and added_kv_proj_dim are set) # WAN 2.2 I2V uses VAE-encoded latent conditioning (image_dim and added_kv_proj_dim are None in the transformer config) @@ -206,12 +233,17 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): num_layers=wan_config["num_layers"], scan_layers=config.scan_layers, subfolder=subfolder, + cast_dtype_fn=partial(_final_param_dtype, dtype_to_cast=config.weights_dtype), ) + # No-op (returns leaves unchanged) when the loader already cast to the + # final dtypes; still needed for restored orbax checkpoints. params = jax.tree_util.tree_map_with_path( lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params, ) + t_put_start = time.perf_counter() + put_specs = [] for path, val in flax.traverse_util.flatten_dict(params).items(): if restored_checkpoint: if path[-1] == "value": @@ -223,15 +255,61 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): except Exception: pass - sharding = logical_state_sharding[path].value - try: - state[path].value = device_put_replicated(val, sharding) - except Exception as e: - max_logging.log(f"Failed to device_put_replicated {path}: {e}") - max_logging.log(f"Trying to use process_allgather for {path}") - val_on_host = jax.experimental.multihost_utils.process_allgather(val, tiled=True) - state[path].value = device_put_replicated(val_on_host, sharding) - del val_on_host + put_specs.append((path, val, logical_state_sharding[path].value)) + + if jax.process_count() == 1: + # Replicated params are the bulk of the bytes; a direct device_put + # broadcasts the same bytes over every device's PCIe stream (~2GB/s + # each). Instead, stage them sharded along dim0 (each device receives + # only 1/n of the bytes over PCIe) and replicate on-device through ICI, + # which is an order of magnitude faster than host links. + n_devices = mesh.devices.size + dim0_sharding = NamedSharding(mesh, P(mesh.axis_names)) + + def stages_via_ici(val, sharding) -> bool: + return ( + sharding.is_fully_replicated + and val.ndim > 0 + and val.shape[0] % n_devices == 0 + and val.nbytes >= 1 << 26 # 64MB: below this, staging overhead wins + ) + + staged_ids = [i for i, (_, val, sharding) in enumerate(put_specs) if stages_via_ici(val, sharding)] + direct_ids = [i for i in range(len(put_specs)) if i not in set(staged_ids)] + + put_arrays = [None] * len(put_specs) + if staged_ids: + staged = jax.device_put( + [put_specs[i][1] for i in staged_ids], [dim0_sharding] * len(staged_ids) + ) + # out_shardings must be the exact target sharding objects (not an + # equivalent P()): downstream jit cache keys include arg shardings, so + # a different-but-equivalent spec would force a full recompile. + replicate_fn = jax.jit( + lambda xs: xs, out_shardings=[put_specs[i][2] for i in staged_ids] + ) + for i, replicated in zip(staged_ids, replicate_fn(staged)): + put_arrays[i] = replicated + if direct_ids: + for i, put_array in zip( + direct_ids, + jax.device_put([put_specs[i][1] for i in direct_ids], [put_specs[i][2] for i in direct_ids]), + ): + put_arrays[i] = put_array + for (path, _, _), put_array in zip(put_specs, put_arrays): + state[path].value = put_array + else: + for path, val, sharding in put_specs: + try: + state[path].value = device_put_replicated(val, sharding) + except Exception as e: + max_logging.log(f"Failed to device_put_replicated {path}: {e}") + max_logging.log(f"Trying to use process_allgather for {path}") + val_on_host = jax.experimental.multihost_utils.process_allgather(val, tiled=True) + state[path].value = device_put_replicated(val_on_host, sharding) + del val_on_host + jax.block_until_ready([state[path].value for path, _, _ in put_specs]) + max_logging.log(f"Transformer {subfolder or 'transformer'} weights on device in {time.perf_counter() - t_put_start:.1f}s") state = nnx.from_flat_state(state) wan_transformer = nnx.merge(graphdef, state, rest_of_state) @@ -794,6 +872,8 @@ def _create_common_components( max_logging.log("Loading Tokenizer and Text Encoder") components["tokenizer"] = cls.load_tokenizer(config=config) components["text_encoder"] = cls.load_text_encoder(config=config) + if getattr(config, "compile_text_encoder", False): + cls._warm_text_encoder(config, components["tokenizer"], components["text_encoder"]) if cls._needs_image_encoder(config, i2v=i2v): ( components["image_processor"], @@ -805,6 +885,37 @@ def _create_common_components( return components + @classmethod + def _warm_text_encoder(cls, config, tokenizer, text_encoder) -> None: + """Runs one dummy forward through the torch.compile'd text encoder. + + torch.compile pays its (~30s CPU) inductor compilation on the first + call; doing it here means it happens during weight loading (hidden + behind the transformer conversion when loading runs in a background + thread) instead of inside the first pipeline call. The dummy batch + matches the shapes encode_prompt will use, so no recompilation later. + """ + t_start = time.perf_counter() + batch_size = int(getattr(config, "global_batch_size_to_train_on", 1)) + if getattr(config, "use_batched_text_encoder", False): + # encode_prompt batches prompt + negative prompt into one call. + batch_size *= 2 + dummy_inputs = tokenizer( + [""] * batch_size, + padding="max_length", + max_length=getattr(config, "max_sequence_length", 512), + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + # Deliberately NOT under torch.no_grad(): grad mode is a dynamo guard, + # and the pipeline's encode call runs with grad enabled. The warmup must + # compile the exact same graph (also keeps numerics identical to the + # historical encode path). + text_encoder(dummy_inputs.input_ids, dummy_inputs.attention_mask) + max_logging.log(f"Text encoder compile warmup in {time.perf_counter() - t_start:.1f}s") + @classmethod @abstractmethod def _load_and_init( diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 00d11f961..4124cf74e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -16,12 +16,14 @@ from ...models.wan.transformers.transformer_wan import WanModel from typing import List, Union, Optional from ...pyconfig import HyperParameters +import concurrent.futures from functools import partial import time from flax import nnx from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp +from jax.sharding import Mesh import numpy as np from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler from ... import max_utils @@ -55,30 +57,53 @@ def _load_and_init( load_transformer=True, load_scheduler=True, ): - common_components = cls._create_common_components( + # Load VAE/tokenizer/text-encoder/scheduler in a background thread while + # the main thread converts the two 14B transformers: the small components + # are fully hidden behind the transformer conversion time. The mesh/rngs + # built here for the transformers are deterministic duplicates of the + # ones _create_common_components builds (same devices, same seed). + common_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + common_future = common_executor.submit( + cls._create_common_components, config, load_vae=load_vae, load_text_encoder=load_text_encoder, load_scheduler=load_scheduler, ) low_noise_transformer, high_noise_transformer = None, None - if load_transformer: - low_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer_2", - ) - high_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer", - ) + transformer_executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + try: + if load_transformer: + devices_array = max_utils.create_device_mesh(config) + transformer_mesh = Mesh(devices_array, config.mesh_axes) + rngs = nnx.Rngs(jax.random.key(config.seed)) + load_transformer_fn = super().load_transformer + # The two 14B transformers load concurrently: host-side conversion + # shares CPU cores while their device transfers interleave on PCIe. + low_future = transformer_executor.submit( + load_transformer_fn, + devices_array=devices_array, + mesh=transformer_mesh, + rngs=rngs, + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer_2", + ) + high_future = transformer_executor.submit( + load_transformer_fn, + devices_array=devices_array, + mesh=transformer_mesh, + rngs=rngs, + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", + ) + low_noise_transformer = low_future.result() + high_noise_transformer = high_future.result() + common_components = common_future.result() + finally: + transformer_executor.shutdown(wait=True) + common_executor.shutdown(wait=True) pipeline = cls( tokenizer=common_components["tokenizer"], diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index f071c231f..503f8f78b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -18,6 +18,7 @@ from ...models.wan.transformers.transformer_wan import WanModel from typing import List, Union, Optional, Tuple from ...pyconfig import HyperParameters +import concurrent.futures from functools import partial from flax import nnx from flax.linen import partitioning as nn_partitioning @@ -25,7 +26,7 @@ import jax.numpy as jnp import numpy as np import time -from jax.sharding import NamedSharding, PartitionSpec as P +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler from ... import max_utils @@ -57,7 +58,12 @@ def _load_and_init( load_transformer=True, load_scheduler=True, ): - common_components = cls._create_common_components( + # Same overlap as WanPipeline2_2._load_and_init: VAE/text-encoder/ + # scheduler load in a background thread while the main thread converts + # the two 14B transformers (mesh/rngs are deterministic duplicates). + common_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + common_future = common_executor.submit( + cls._create_common_components, config, load_vae=load_vae, load_text_encoder=load_text_encoder, @@ -65,23 +71,39 @@ def _load_and_init( i2v=True, ) low_noise_transformer, high_noise_transformer = None, None - if load_transformer: - high_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer", - ) - low_noise_transformer = super().load_transformer( - devices_array=common_components["devices_array"], - mesh=common_components["mesh"], - rngs=common_components["rngs"], - config=config, - restored_checkpoint=restored_checkpoint, - subfolder="transformer_2", - ) + transformer_executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + try: + if load_transformer: + devices_array = max_utils.create_device_mesh(config) + transformer_mesh = Mesh(devices_array, config.mesh_axes) + rngs = nnx.Rngs(jax.random.key(config.seed)) + load_transformer_fn = super().load_transformer + # The two 14B transformers load concurrently: host-side conversion + # shares CPU cores while their device transfers interleave on PCIe. + high_future = transformer_executor.submit( + load_transformer_fn, + devices_array=devices_array, + mesh=transformer_mesh, + rngs=rngs, + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer", + ) + low_future = transformer_executor.submit( + load_transformer_fn, + devices_array=devices_array, + mesh=transformer_mesh, + rngs=rngs, + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer_2", + ) + high_noise_transformer = high_future.result() + low_noise_transformer = low_future.result() + common_components = common_future.result() + finally: + transformer_executor.shutdown(wait=True) + common_executor.shutdown(wait=True) pipeline = cls( tokenizer=common_components["tokenizer"],