From 37d10d83085f3691e06fc79592e7bb163b2bdcd0 Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Sun, 14 Jun 2026 15:12:56 -0700 Subject: [PATCH 1/3] fix(flux): tighten check_inputs validation for negative embeds, controlnet img2img dims, redux scales --- src/diffusers/pipelines/flux/pipeline_flux.py | 8 ++++++++ .../pipeline_flux_controlnet_image_to_image.py | 2 +- .../pipelines/flux/pipeline_flux_kontext.py | 8 ++++++++ .../pipelines/flux/pipeline_flux_prior_redux.py | 16 +++++++++++----- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index e125924adf7f..b57e4d44dd18 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -491,6 +491,14 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 65b2072a7746..61c9da0c9496 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -454,7 +454,7 @@ def check_inputs( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0: + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: logger.warning( f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index efddc6cea139..d94003ae944c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -542,6 +542,14 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." diff --git a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py index 94c7bcc80782..f173fdef88c6 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_prior_redux.py @@ -172,12 +172,18 @@ def check_inputs( raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) - if isinstance(prompt_embeds_scale, list) and ( - isinstance(image, list) and len(prompt_embeds_scale) != len(image) + image_batch_size = ( + image.shape[0] if isinstance(image, torch.Tensor) else len(image) if isinstance(image, list) else 1 + ) + for scale_name, scale in ( + ("prompt_embeds_scale", prompt_embeds_scale), + ("pooled_prompt_embeds_scale", pooled_prompt_embeds_scale), ): - raise ValueError( - f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images" - ) + if isinstance(scale, list) and len(scale) != image_batch_size: + raise ValueError( + f"number of weights in `{scale_name}` must be equal to number of images, but " + f"{len(scale)} weights were provided and {image_batch_size} images" + ) def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype From aa4be1c6dd55f332f735bf12a07fa4f5f0793c2a Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Mon, 15 Jun 2026 10:47:15 -0700 Subject: [PATCH 2/3] fix(flux): gate negative prompt embeds shape check on do_true_cfg --- src/diffusers/pipelines/flux/pipeline_flux.py | 15 +++++++-------- .../pipelines/flux/pipeline_flux_kontext.py | 15 +++++++-------- tests/pipelines/flux/test_pipeline_flux.py | 19 +++++++++++++++++++ 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index b57e4d44dd18..34cbf0faa667 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -491,14 +491,6 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." @@ -829,6 +821,13 @@ def __call__( negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + if do_true_cfg and prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) ( prompt_embeds, pooled_prompt_embeds, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index d94003ae944c..e32bfecfcdad 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -542,14 +542,6 @@ def check_inputs( f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." @@ -957,6 +949,13 @@ def __call__( negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + if do_true_cfg and prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) ( prompt_embeds, pooled_prompt_embeds, diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 13336f0cde9b..ef51553b97de 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -234,6 +234,25 @@ def test_flux_true_cfg(self): np.allclose(no_true_cfg_out, true_cfg_out), "Outputs should be different when true_cfg_scale is set." ) + def test_flux_negative_embeds_shape_check(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + base_inputs = { + "prompt_embeds": torch.randn(1, 4, 32, device=torch_device), + "pooled_prompt_embeds": torch.randn(1, 32, device=torch_device), + "negative_prompt_embeds": torch.randn(1, 5, 32, device=torch_device), + "negative_pooled_prompt_embeds": torch.randn(1, 32, device=torch_device), + "height": 16, + "width": 16, + "num_inference_steps": 1, + "output_type": "latent", + } + + with self.assertRaises(ValueError): + pipe(**base_inputs, true_cfg_scale=2.0, generator=torch.manual_seed(0)) + + pipe(**base_inputs, true_cfg_scale=1.0, generator=torch.manual_seed(0)) + @nightly @require_big_accelerator From 4c0f0e3ffc19b1ee25be46afc1f540767b38edbc Mon Sep 17 00:00:00 2001 From: Akshan Krithick Date: Mon, 15 Jun 2026 21:10:43 -0700 Subject: [PATCH 3/3] fix(flux): assert error message in negative-embed shape check test --- tests/pipelines/flux/test_pipeline_flux.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index ef51553b97de..cdc2974b2b54 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -248,7 +248,7 @@ def test_flux_negative_embeds_shape_check(self): "output_type": "latent", } - with self.assertRaises(ValueError): + with self.assertRaisesRegex(ValueError, "must have the same shape when passed directly"): pipe(**base_inputs, true_cfg_scale=2.0, generator=torch.manual_seed(0)) pipe(**base_inputs, true_cfg_scale=1.0, generator=torch.manual_seed(0))