diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 5b5579664b55..8b8fe971e156 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -46,7 +46,7 @@ set_adapter_layers, set_weights_and_activate_adapters, ) -from ..utils.peft_utils import _create_lora_config +from ..utils.peft_utils import _create_lora_config, _validate_lora_weight_compatibility from ..utils.state_dict_utils import _load_sft_state_dict_metadata @@ -393,6 +393,8 @@ def _load_lora_into_text_encoder( if adapter_name is None: adapter_name = get_adapter_name(text_encoder) + _validate_lora_weight_compatibility(text_encoder, state_dict, adapter_name=adapter_name) + # = max_mismatches: + break + + if mismatches: + details = "\n".join(f" - {key}: {reason}" for key, reason in mismatches) + extra = "" + if len(mismatches) >= max_mismatches: + extra = f"\n ... (showing first {max_mismatches} mismatches)" + adapter_msg = f" for adapter '{adapter_name}'" if adapter_name is not None else "" + raise ValueError( + "The loaded LoRA weights are incompatible with the current model dimensions" + f"{adapter_msg}. Please check the model architecture version " + "(for example, an SD 1.5 LoRA cannot be loaded into an SDXL pipeline).\n" + f"Found shape mismatch(es):\n{details}{extra}" + ) + + def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name): warn_msg = "" if incompatible_keys is not None: diff --git a/tests/lora/test_lora_layers_sdxl.py b/tests/lora/test_lora_layers_sdxl.py index ac1d65abdaa7..678162621e5a 100644 --- a/tests/lora/test_lora_layers_sdxl.py +++ b/tests/lora/test_lora_layers_sdxl.py @@ -21,6 +21,7 @@ import numpy as np import torch +import torch.nn as nn from packaging import version from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer @@ -55,6 +56,34 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set, state_dicts_almost_equal # noqa: E402 +def _build_sd15_incompatible_lora_state_dict(unet, rank=32): + """ + Build a minimal LoRA state dict that mimics SD 1.5 checkpoint shapes when loaded into an SDXL UNet. + + Includes: + - 4D Conv-style LoRA tensors targeting a Linear `proj_in` layer (ndim mismatch). + - 768-dim cross-attention LoRA tensors targeting SDXL's `attn2.to_k` (feature dim mismatch). + """ + state_dict = {} + + for name, module in unet.named_modules(): + if name.endswith("proj_in") and isinstance(module, nn.Linear): + state_dict[f"unet.{name}.lora_A.weight"] = torch.zeros(rank, 640, 1, 1) + state_dict[f"unet.{name}.lora_B.weight"] = torch.zeros(module.out_features, rank, 1, 1) + break + + for name, module in unet.named_modules(): + if name.endswith("attn2.to_k") and isinstance(module, nn.Linear): + state_dict[f"unet.{name}.lora_A.weight"] = torch.zeros(rank, 768) + state_dict[f"unet.{name}.lora_B.weight"] = torch.zeros(module.out_features, rank) + break + + if len(state_dict) < 4: + raise RuntimeError("Could not find SDXL UNet modules needed for incompatible LoRA regression test.") + + return state_dict + + if is_accelerate_available(): from accelerate.utils import release_memory @@ -152,6 +181,34 @@ def test_lora_scale_kwargs_match_fusion(self): super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol) + def test_load_sd15_lora_into_sdxl_raises_incompatibility_error(self): + """ + Regression test for https://github.com/huggingface/diffusers/issues/11286 + + Loading an SD 1.5 LoRA into an SDXL pipeline must fail early with a clear ValueError before PEFT injects + adapter shells into the UNet. + """ + components, _, _ = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + incompatible_lora_state_dict = _build_sd15_incompatible_lora_state_dict(pipe.unet) + + with self.assertRaises(ValueError) as err_context: + pipe.load_lora_weights(incompatible_lora_state_dict) + + error_message = str(err_context.exception) + self.assertIn( + "The loaded LoRA weights are incompatible with the current model dimensions", + error_message, + ) + self.assertIn("Please check the model architecture version", error_message) + self.assertIn("Found shape mismatch(es):", error_message) + self.assertIn("proj_in.lora_A.weight", error_message) + self.assertIn("attn2.to_k.lora_A.weight", error_message) + + self.assertFalse(check_if_lora_correctly_set(pipe.unet)) + self.assertFalse(getattr(pipe.unet, "_hf_peft_config_loaded", False)) + @slow @nightly