Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/diffusers/loaders/lora_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.peft_utils import _create_lora_config
from ..utils.peft_utils import _create_lora_config, _validate_lora_weight_compatibility
from ..utils.state_dict_utils import _load_sft_state_dict_metadata


Expand Down Expand Up @@ -393,6 +393,8 @@ def _load_lora_into_text_encoder(
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)

_validate_lora_weight_compatibility(text_encoder, state_dict, adapter_name=adapter_name)

# <Unsafe code
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
_pipeline
Expand Down
8 changes: 7 additions & 1 deletion src/diffusers/loaders/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.peft_utils import _create_lora_config, _maybe_warn_for_unhandled_keys
from ..utils.peft_utils import (
_create_lora_config,
_maybe_warn_for_unhandled_keys,
_validate_lora_weight_compatibility,
)
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
from .unet_loader_utils import _maybe_expand_lora_scales

Expand Down Expand Up @@ -264,6 +268,8 @@ def load_lora_adapter(
lora_config.modules_to_save = lora_config.exclude_modules
lora_config.exclude_modules = None

_validate_lora_weight_compatibility(self, state_dict, adapter_name=adapter_name)

# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to `_pipeline`.
Expand Down
103 changes: 103 additions & 0 deletions src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,109 @@ def _maybe_raise_error_for_ambiguous_keys(config):
)


def _get_lora_base_module(module):
if hasattr(module, "base_layer"):
return module.base_layer
return module


def _get_lora_base_layer_info(module):
base = _get_lora_base_module(module)
if isinstance(base, torch.nn.Linear):
return "linear", base.in_features, base.out_features, base.weight.ndim
if isinstance(base, torch.nn.Conv2d):
return "conv2d", base.in_channels, base.out_channels, base.weight.ndim
return None, None, None, None


def _validate_lora_weight_compatibility(model, state_dict, adapter_name=None, max_mismatches=5):
"""
Validate that LoRA checkpoint tensors match the target model's base layer shapes before adapter injection.

Raises:
ValueError: If any LoRA weight has incompatible rank, feature dimensions, or tensor rank (ndim) with the
corresponding base module (e.g. SD 1.5 Conv LoRA loaded into an SDXL Linear target).
"""
mismatches = []

for key, tensor in state_dict.items():
if ".lora_A." not in key and ".lora_B." not in key:
continue
if not key.endswith(".weight"):
continue

if ".lora_A." in key:
module_name = key.split(".lora_A.")[0]
lora_side = "lora_A"
else:
module_name = key.split(".lora_B.")[0]
lora_side = "lora_B"

try:
module = model.get_submodule(module_name)
except AttributeError:
continue

layer_type, in_dim, out_dim, expected_ndim = _get_lora_base_layer_info(module)
if layer_type is None:
continue

if lora_side == "lora_A":
if tensor.ndim != expected_ndim:
mismatches.append(
(
key,
f"checkpoint ndim={tensor.ndim} {tuple(tensor.shape)} vs model expects "
f"{layer_type} ndim={expected_ndim}",
)
)
continue
checkpoint_in_dim = tensor.shape[1]
if checkpoint_in_dim != in_dim:
mismatches.append(
(
key,
f"checkpoint in-dim={checkpoint_in_dim} {tuple(tensor.shape)} vs model expects "
f"{layer_type} in-dim={in_dim}",
)
)
else:
if tensor.ndim != expected_ndim:
mismatches.append(
(
key,
f"checkpoint ndim={tensor.ndim} {tuple(tensor.shape)} vs model expects "
f"{layer_type} ndim={expected_ndim}",
)
)
continue
checkpoint_out_dim = tensor.shape[0]
if checkpoint_out_dim != out_dim:
mismatches.append(
(
key,
f"checkpoint out-dim={checkpoint_out_dim} {tuple(tensor.shape)} vs model expects "
f"{layer_type} out-dim={out_dim}",
)
)

if len(mismatches) >= max_mismatches:
break

if mismatches:
details = "\n".join(f" - {key}: {reason}" for key, reason in mismatches)
extra = ""
if len(mismatches) >= max_mismatches:
extra = f"\n ... (showing first {max_mismatches} mismatches)"
adapter_msg = f" for adapter '{adapter_name}'" if adapter_name is not None else ""
raise ValueError(
"The loaded LoRA weights are incompatible with the current model dimensions"
f"{adapter_msg}. Please check the model architecture version "
"(for example, an SD 1.5 LoRA cannot be loaded into an SDXL pipeline).\n"
f"Found shape mismatch(es):\n{details}{extra}"
)


def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
warn_msg = ""
if incompatible_keys is not None:
Expand Down
57 changes: 57 additions & 0 deletions tests/lora/test_lora_layers_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import numpy as np
import torch
import torch.nn as nn
from packaging import version
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer

Expand Down Expand Up @@ -55,6 +56,34 @@
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set, state_dicts_almost_equal # noqa: E402


def _build_sd15_incompatible_lora_state_dict(unet, rank=32):
"""
Build a minimal LoRA state dict that mimics SD 1.5 checkpoint shapes when loaded into an SDXL UNet.

Includes:
- 4D Conv-style LoRA tensors targeting a Linear `proj_in` layer (ndim mismatch).
- 768-dim cross-attention LoRA tensors targeting SDXL's `attn2.to_k` (feature dim mismatch).
"""
state_dict = {}

for name, module in unet.named_modules():
if name.endswith("proj_in") and isinstance(module, nn.Linear):
state_dict[f"unet.{name}.lora_A.weight"] = torch.zeros(rank, 640, 1, 1)
state_dict[f"unet.{name}.lora_B.weight"] = torch.zeros(module.out_features, rank, 1, 1)
break

for name, module in unet.named_modules():
if name.endswith("attn2.to_k") and isinstance(module, nn.Linear):
state_dict[f"unet.{name}.lora_A.weight"] = torch.zeros(rank, 768)
state_dict[f"unet.{name}.lora_B.weight"] = torch.zeros(module.out_features, rank)
break

if len(state_dict) < 4:
raise RuntimeError("Could not find SDXL UNet modules needed for incompatible LoRA regression test.")

return state_dict


if is_accelerate_available():
from accelerate.utils import release_memory

Expand Down Expand Up @@ -152,6 +181,34 @@ def test_lora_scale_kwargs_match_fusion(self):

super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol)

def test_load_sd15_lora_into_sdxl_raises_incompatibility_error(self):
"""
Regression test for https://github.com/huggingface/diffusers/issues/11286

Loading an SD 1.5 LoRA into an SDXL pipeline must fail early with a clear ValueError before PEFT injects
adapter shells into the UNet.
"""
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)

incompatible_lora_state_dict = _build_sd15_incompatible_lora_state_dict(pipe.unet)

with self.assertRaises(ValueError) as err_context:
pipe.load_lora_weights(incompatible_lora_state_dict)

error_message = str(err_context.exception)
self.assertIn(
"The loaded LoRA weights are incompatible with the current model dimensions",
error_message,
)
self.assertIn("Please check the model architecture version", error_message)
self.assertIn("Found shape mismatch(es):", error_message)
self.assertIn("proj_in.lora_A.weight", error_message)
self.assertIn("attn2.to_k.lora_A.weight", error_message)

self.assertFalse(check_if_lora_correctly_set(pipe.unet))
self.assertFalse(getattr(pipe.unet, "_hf_peft_config_loaded", False))


@slow
@nightly
Expand Down
Loading