diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index ac7196735ad8..3866342d9be6 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -144,6 +144,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi ## KandinskyLoraLoaderMixin [[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin +## Ideogram4LoraLoaderMixin + +[[autodoc]] loaders.lora_pipeline.Ideogram4LoraLoaderMixin + ## LoraBaseMixin [[autodoc]] loaders.lora_base.LoraBaseMixin diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 33eeba673a98..2eb1f5cc7a44 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -86,6 +86,7 @@ def text_encoder_attn_modules(text_encoder): "QwenImageLoraLoaderMixin", "ZImageLoraLoaderMixin", "Flux2LoraLoaderMixin", + "Ideogram4LoraLoaderMixin", "ErnieImageLoraLoaderMixin", "CosmosLoraLoaderMixin", ] @@ -128,6 +129,7 @@ def text_encoder_attn_modules(text_encoder): HeliosLoraLoaderMixin, HiDreamImageLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, + Ideogram4LoraLoaderMixin, KandinskyLoraLoaderMixin, LoraLoaderMixin, LTX2LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index bf516abc825f..a29d74024c18 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2883,3 +2883,88 @@ def get_alpha_scales(down_weight, alpha_key): converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()} return converted_state_dict + + +def _convert_non_diffusers_ideogram4_lora_to_diffusers(state_dict): + """ + Convert non-diffusers Ideogram4 LoRA state dict to diffusers format. + + Handles: + - `diffusion_model.` / `conditional_transformer.` prefix removal + - `lora_down`/`lora_up` (kohya) -> `lora_A`/`lora_B`, with `.alpha` folded into the weights + - fused `attention.qkv` -> split `to_q`/`to_k`/`to_v`; `attention.o` -> `to_out.0` + - `feed_forward.w1`/`w2`/`w3` and `adaln_modulation` map one-to-one + """ + for prefix in ("diffusion_model.", "conditional_transformer."): + if any(k.startswith(prefix) for k in state_dict): + state_dict = {k.removeprefix(prefix): v for k, v in state_dict.items()} + break + + is_kohya = any(".lora_down.weight" in k for k in state_dict) + down_suffix = ".lora_down.weight" if is_kohya else ".lora_A.weight" + up_suffix = ".lora_up.weight" if is_kohya else ".lora_B.weight" + + def get_alpha_scales(down_weight, alpha_key): + rank = down_weight.shape[0] + alpha_tensor = state_dict.pop(alpha_key, None) + if alpha_tensor is None: + return 1.0, 1.0 + # LoRA is scaled by `alpha / rank` in the forward pass; split the factor between down and up. + scale = alpha_tensor.item() / rank + scale_down, scale_up = scale, 1.0 + while scale_down * 2 < scale_up: + scale_down *= 2 + scale_up /= 2 + return scale_down, scale_up + + def pull(base): + """Pop the scaled (lora_A, lora_B) pair for a module path, or return None if absent.""" + down_key = base + down_suffix + if down_key not in state_dict: + return None + down = state_dict.pop(down_key) + up = state_dict.pop(base + up_suffix) + scale_down, scale_up = get_alpha_scales(down, base + ".alpha") + return down * scale_down, up * scale_up + + num_layers = 0 + for k in state_dict: + match = re.match(r"layers\.(\d+)\.", k) + if match: + num_layers = max(num_layers, int(match.group(1)) + 1) + + converted_state_dict = {} + for i in range(num_layers): + layer_prefix = f"layers.{i}" + + # Fused qkv -> split to_q / to_k / to_v (shared down/lora_A, chunk up/lora_B in thirds). + qkv = pull(f"{layer_prefix}.attention.qkv") + if qkv is not None: + down, up = qkv + up_q, up_k, up_v = torch.chunk(up, 3, dim=0) + for proj, up_proj in (("to_q", up_q), ("to_k", up_k), ("to_v", up_v)): + converted_state_dict[f"{layer_prefix}.attention.{proj}.lora_A.weight"] = down.clone() + converted_state_dict[f"{layer_prefix}.attention.{proj}.lora_B.weight"] = up_proj.contiguous() + + # attention.o -> attention.to_out.0 + out = pull(f"{layer_prefix}.attention.o") + if out is not None: + down, up = out + converted_state_dict[f"{layer_prefix}.attention.to_out.0.lora_A.weight"] = down + converted_state_dict[f"{layer_prefix}.attention.to_out.0.lora_B.weight"] = up + + # feed_forward.{w1,w2,w3} and adaln_modulation map one-to-one. + for module in ("feed_forward.w1", "feed_forward.w2", "feed_forward.w3", "adaln_modulation"): + pair = pull(f"{layer_prefix}.{module}") + if pair is not None: + down, up = pair + converted_state_dict[f"{layer_prefix}.{module}.lora_A.weight"] = down + converted_state_dict[f"{layer_prefix}.{module}.lora_B.weight"] = up + + if len(state_dict) > 0: + raise ValueError( + f"`state_dict` should be empty at this point but has {sorted(state_dict.keys())}. " + "This may be an unsupported Ideogram4 LoRA layout." + ) + + return {f"transformer.{k}": v for k, v in converted_state_dict.items()} diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 52b2aad174be..0abeba91e983 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -49,6 +49,7 @@ _convert_non_diffusers_anima_lora_to_diffusers, _convert_non_diffusers_flux2_lora_to_diffusers, _convert_non_diffusers_hidream_lora_to_diffusers, + _convert_non_diffusers_ideogram4_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_ltx2_lora_to_diffusers, _convert_non_diffusers_ltxv_lora_to_diffusers, @@ -6018,6 +6019,213 @@ def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class Ideogram4LoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`Ideogram4Transformer2DModel`]. Specific to [`Ideogram4Pipeline`]. + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details. + """ + # Load the main state dict first which has the LoRA layers for either of + # transformer and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + # ai-toolkit (ostris) saves Ideogram4 LoRAs under a `diffusion_model.` prefix with a fused + # `attention.qkv` projection; convert those to the diffusers layout before loading. + is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict) or any( + ".attention.qkv." in k for k in state_dict + ) + if is_non_diffusers_format: + state_dict = _convert_non_diffusers_ideogram4_lora_to_diffusers(state_dict) + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor], + adapter_name: str | None = None, + hotswap: bool = False, + **kwargs, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.") + + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details. + """ + if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights + def save_lora_weights( + cls, + save_directory: str | os.PathLike, + transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata: dict | None = None, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora + def fuse_lora( + self, + components: list[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: list[str] | None = None, + **kwargs, + ): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details. + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora + def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs): + r""" + See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details. + """ + super().unfuse_lora(components=components, **kwargs) + + class ErnieImageLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`ErnieImageTransformer2DModel`]. Specific to [`ErnieImagePipeline`]. diff --git a/src/diffusers/models/transformers/transformer_ideogram4.py b/src/diffusers/models/transformers/transformer_ideogram4.py index 121118e3bd80..595873f06f75 100644 --- a/src/diffusers/models/transformers/transformer_ideogram4.py +++ b/src/diffusers/models/transformers/transformer_ideogram4.py @@ -21,7 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import logging +from ...utils import apply_lora_scale, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn @@ -365,6 +365,7 @@ def __init__( adaln_dim=adaln_dim, ) + @apply_lora_scale("attention_kwargs") def forward( self, hidden_states: torch.Tensor, @@ -373,6 +374,7 @@ def forward( position_ids: torch.Tensor, segment_ids: torch.Tensor, indicator: torch.Tensor, + attention_kwargs: dict | None = None, return_dict: bool = True, ) -> Transformer2DModelOutput | tuple[torch.Tensor]: r""" @@ -391,6 +393,9 @@ def forward( Per-token sample id within a packed batch. Positions sharing a `segment_id` attend to each other. indicator (`torch.Tensor` of shape `(batch_size, sequence_length)`): Per-token role: `LLM_TOKEN_INDICATOR` (text) or `OUTPUT_IMAGE_INDICATOR` (image). + attention_kwargs (`dict`, *optional*): + A kwargs dictionary passed along to the attention processor. A `"scale"` entry scales the LoRA weights + (when the PEFT backend is active). return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.modeling_outputs.Transformer2DModelOutput`] instead of a plain tuple. diff --git a/src/diffusers/pipelines/ideogram4/pipeline_ideogram4.py b/src/diffusers/pipelines/ideogram4/pipeline_ideogram4.py index 61ba4fa43a62..f99bf1335b67 100644 --- a/src/diffusers/pipelines/ideogram4/pipeline_ideogram4.py +++ b/src/diffusers/pipelines/ideogram4/pipeline_ideogram4.py @@ -20,6 +20,7 @@ from transformers.masking_utils import create_causal_mask from ...image_processor import VaeImageProcessor +from ...loaders import Ideogram4LoraLoaderMixin from ...models.autoencoders import AutoencoderKLFlux2 from ...models.transformers.transformer_ideogram4 import ( IMAGE_POSITION_OFFSET, @@ -137,7 +138,7 @@ def _expand_tensor_to_effective_batch( return torch.repeat_interleave(tensor, repeats=repeat_by, dim=0, output_size=tensor.shape[0] * repeat_by) -class Ideogram4Pipeline(DiffusionPipeline): +class Ideogram4Pipeline(DiffusionPipeline, Ideogram4LoraLoaderMixin): r""" Text-to-image pipeline for Ideogram4. @@ -367,9 +368,14 @@ def encode_prompt( attention_mask[b, offset:] = 1 text_position_ids[b, offset:] = torch.arange(n) - token_ids = token_ids.to(device) - attention_mask = attention_mask.to(device) - text_position_ids = text_position_ids.to(device) + # Run the encoder on the device its parameters currently live on, then move the features to the + # pipeline device. encode_prompt calls the text encoder's submodules directly, so under + # enable_model_cpu_offload the onload hook never fires and the weights stay on CPU; honoring their + # actual device avoids a device mismatch on the token embedding. + te_device = self.text_encoder.device + token_ids = token_ids.to(te_device) + attention_mask = attention_mask.to(te_device) + text_position_ids = text_position_ids.to(te_device) # Concatenate the tapped activation-layer hidden states into per-token text features, zeroing padding. selected = self._get_text_encoder_hidden_states( @@ -377,6 +383,7 @@ def encode_prompt( ) text_features = torch.stack(selected, dim=0).permute(1, 2, 3, 0).reshape(batch_size, max_sequence_length, -1) text_features = (text_features * attention_mask.to(text_features.dtype).unsqueeze(-1)).to(torch.float32) + text_features = text_features.to(device) position_ids, segment_ids, indicator = self._prepare_ids( text_lengths, grid_h, grid_w, max_sequence_length, device @@ -417,6 +424,10 @@ def guidance_scale(self) -> float | None: def num_timesteps(self) -> int: return self._num_timesteps + @property + def attention_kwargs(self) -> dict[str, Any] | None: + return self._attention_kwargs + @property def interrupt(self) -> bool: return self._interrupt @@ -485,6 +496,7 @@ def __call__( latents: torch.Tensor | None = None, output_type: str = "pil", return_dict: bool = True, + attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[["Ideogram4Pipeline", int, int, dict[str, Any]], dict[str, Any]] | None = None, callback_on_step_end_tensor_inputs: list[str] = ["latents"], ) -> Ideogram4PipelineOutput | tuple[Any]: @@ -533,6 +545,9 @@ def __call__( One of `"pil"`, `"np"`, `"pt"`, or `"latent"`. return_dict (`bool`, *optional*, defaults to `True`): Whether to return an [`~pipelines.ideogram4.Ideogram4PipelineOutput`]. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary passed along to the attention processor of each transformer. A `"scale"` entry + scales the loaded LoRA weights (e.g. `{"scale": 0.7}`) when the PEFT backend is active. callback_on_step_end (`Callable`, *optional*): Callback invoked at the end of every denoising step. callback_on_step_end_tensor_inputs (`list[str]`, *optional*): @@ -560,6 +575,7 @@ def __call__( device = self._execution_device self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs self._interrupt = False # 0. Optionally rewrite the prompt(s) into Ideogram4's native structured JSON caption. @@ -669,6 +685,7 @@ def __call__( position_ids=position_ids, segment_ids=segment_ids, indicator=indicator, + attention_kwargs=self.attention_kwargs, return_dict=False, )[0] # Velocity (and guidance) is computed in float32 for scheduler precision; the transformers @@ -683,6 +700,7 @@ def __call__( position_ids=neg_position_ids, segment_ids=neg_segment_ids, indicator=neg_indicator, + attention_kwargs=self.attention_kwargs, return_dict=False, )[0].to(torch.float32) diff --git a/tests/lora/test_lora_layers_ideogram4.py b/tests/lora/test_lora_layers_ideogram4.py new file mode 100644 index 000000000000..e38b51eaf8e2 --- /dev/null +++ b/tests/lora/test_lora_layers_ideogram4.py @@ -0,0 +1,274 @@ +# coding=utf-8 +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest + +import numpy as np +import torch +from transformers import Qwen2Tokenizer, Qwen3VLConfig, Qwen3VLModel + +from diffusers import ( + AutoencoderKLFlux2, + FlowMatchEulerDiscreteScheduler, + Ideogram4Pipeline, + Ideogram4Transformer2DModel, +) +from diffusers.pipelines.ideogram4.pipeline_ideogram4 import QWEN3_VL_ACTIVATION_LAYERS + +from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device + + +if is_peft_available(): + from peft import LoraConfig + + +sys.path.append(".") + +from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 + + +# The text conditioning concatenates the hidden states of these Qwen3-VL decoder layers, so the dummy text +# encoder must be deep enough to expose the last tapped layer, and `llm_features_dim` must match the product. +_TEXT_HIDDEN_SIZE = 8 +_NUM_TEXT_LAYERS = max(QWEN3_VL_ACTIVATION_LAYERS) + 1 +_LLM_FEATURES_DIM = len(QWEN3_VL_ACTIVATION_LAYERS) * _TEXT_HIDDEN_SIZE + + +@require_peft_backend +class Ideogram4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): + pipeline_class = Ideogram4Pipeline + scheduler_cls = FlowMatchEulerDiscreteScheduler + scheduler_kwargs = {} + + transformer_kwargs = { + "in_channels": 16, + "num_layers": 2, + "attention_head_dim": 8, + "num_attention_heads": 4, + "intermediate_size": 32, + "adaln_dim": 16, + "llm_features_dim": _LLM_FEATURES_DIM, + "rope_theta": 10_000, + "mrope_section": (2, 1, 1), + "norm_eps": 1e-5, + } + transformer_cls = Ideogram4Transformer2DModel + + vae_kwargs = { + "in_channels": 3, + "out_channels": 3, + "down_block_types": ("DownEncoderBlock2D",), + "up_block_types": ("UpDecoderBlock2D",), + "block_out_channels": (8,), + "layers_per_block": 1, + "latent_channels": 4, + "norm_num_groups": 1, + "sample_size": 32, + "patch_size": (2, 2), + "use_quant_conv": False, + "use_post_quant_conv": False, + } + vae_cls = AutoencoderKLFlux2 + + tokenizer_cls, tokenizer_id = Qwen2Tokenizer, "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + + # Ideogram4's attention uses split q/k/v/out projections in the diffusers transformer. + denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + # The text encoder (Qwen3-VL) is frozen and not LoRA-adapted by the Ideogram4 loader. + supports_text_encoder_loras = False + + @property + def output_shape(self): + return (1, 16, 16, 3) + + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): + # The Ideogram4 pipeline takes a second (unconditional) transformer and a Qwen3-VL text encoder for + # which there is no tiny pretrained checkpoint, so build the components inline rather than relying on + # the base implementation. + scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls + rank = 4 + lora_alpha = rank if lora_alpha is None else lora_alpha + + torch.manual_seed(0) + transformer = self.transformer_cls(**self.transformer_kwargs) + unconditional_transformer = self.transformer_cls(**self.transformer_kwargs) + + torch.manual_seed(0) + vae = self.vae_cls(**self.vae_kwargs) + + torch.manual_seed(0) + text_config = { + "hidden_size": _TEXT_HIDDEN_SIZE, + "num_hidden_layers": _NUM_TEXT_LAYERS, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "intermediate_size": 16, + "head_dim": 8, + "vocab_size": 151936, + "max_position_embeddings": 256, + "rope_theta": 10_000.0, + } + vision_config = { + "hidden_size": 8, + "depth": 2, + "num_heads": 2, + "intermediate_size": 16, + "out_hidden_size": _TEXT_HIDDEN_SIZE, + "patch_size": 14, + } + text_encoder = Qwen3VLModel(Qwen3VLConfig(text_config=text_config, vision_config=vision_config)) + tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id) + + scheduler = scheduler_cls(**self.scheduler_kwargs) + + text_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + init_lora_weights=False, + use_dora=use_dora, + ) + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=self.denoiser_target_modules, + init_lora_weights=False, + use_dora=use_dora, + ) + + pipeline_components = { + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "unconditional_transformer": unconditional_transformer, + } + + return pipeline_components, text_lora_config, denoiser_lora_config + + def get_dummy_inputs(self, with_generator=True): + batch_size = 1 + sequence_length = 32 + num_channels = 4 + sizes = (16, 16) + + generator = torch.manual_seed(0) + noise = floats_tensor((batch_size, num_channels) + sizes) + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) + + pipeline_inputs = { + "prompt": "a dog is dancing", + "num_inference_steps": 2, + "guidance_schedule": [1.0, 1.0], + "height": 16, + "width": 16, + "max_sequence_length": sequence_length, + "output_type": "np", + } + if with_generator: + pipeline_inputs.update({"generator": generator}) + + return noise, input_ids, pipeline_inputs + + # Overridden because the base test's rank-pattern module finder doesn't resolve a module on Ideogram4's + # attention naming; this mirrors the same override other DiT LoRA tests use (e.g. Z-Image). + def test_correct_lora_configs_with_different_ranks(self): + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + original_output = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + + lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + + pipe.transformer.delete_adapters("adapter-1") + + denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer + for name, _ in denoiser.named_modules(): + if "to_k" in name and "attention" in name and "lora" not in name: + module_name_to_rank_update = name.replace(".base_layer.", ".") + break + + # change the rank_pattern + updated_rank = denoiser_lora_config.r * 2 + denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank} + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern + + self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank}) + + lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + + pipe.transformer.delete_adapters("adapter-1") + + # similarly change the alpha_pattern + updated_alpha = denoiser_lora_config.lora_alpha * 2 + denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha} + + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue( + pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha} + ) + + lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)) + self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)) + + @unittest.skip("Not supported in Ideogram4.") + def test_simple_inference_with_text_denoiser_block_scale(self): + pass + + @unittest.skip("Not supported in Ideogram4.") + def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): + pass + + @unittest.skip("Not supported in Ideogram4.") + def test_modify_padding_mode(self): + pass + + # Overridden because the base test probes for `transformer_blocks`/`blocks`/etc. to corrupt a weight, + # but Ideogram4's transformer tower is named `layers` (with `attention.to_q` projections). + def test_lora_fuse_nan(self): + components, _, denoiser_lora_config = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet + denoiser.add_adapter(denoiser_lora_config, "adapter-1") + self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + pipe.transformer.layers[0].attention.to_q.lora_A["adapter-1"].weight += float("inf") + + # with `safe_fusing=True` we should see an Error + with self.assertRaises(ValueError): + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) + + # without we should not see an error, but every image will be black + pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) + out = pipe(**inputs)[0] + + self.assertTrue(np.isnan(out).all())