diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 41b0f689d9a4..0b9fb2505a2f 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -198,15 +198,21 @@ def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]: return tuples gen = parameter._named_members(get_members_fn=find_tensor_attributes) - last_tuple = None - for tuple in gen: - last_tuple = tuple - if tuple[1].is_floating_point(): - return tuple[1].dtype - - if last_tuple is not None: - # fallback to the last dtype - return last_tuple[1].dtype + last_t = None + for t in gen: + last_t = t + if t[1].is_floating_point(): + return t[1].dtype + + if last_t is not None: + # fallback to the last dtype found via __dict__ inspection + return last_t[1].dtype + + raise ValueError( + f"Could not determine the dtype of {parameter.__class__.__name__}: no parameters, buffers, or tensor " + "attributes were found. If you are using nn.DataParallel, make sure the module is moved to a device " + "before wrapping it (e.g. model.to('cuda') before DataParallel(model))." + ) @contextmanager diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index 7b207f782079..1ffdcfdab064 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -128,6 +128,15 @@ def __init__( timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) sigmas = timesteps / num_train_timesteps + + # Store sigma_min / sigma_max from the *unshifted* linear schedule so that + # set_timesteps can use them as the raw [0, 1] bounds when regenerating the + # sigma grid. If they were stored after shifting, set_timesteps would feed + # already-shifted values back through the shift formula a second time, + # producing a doubly-shifted (and therefore wrong) sigma schedule (#13243). + self.sigma_min = sigmas[-1].item() + self.sigma_max = sigmas[0].item() + if not use_dynamic_shifting: # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) @@ -140,8 +149,6 @@ def __init__( self._shift = shift self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication - self.sigma_min = self.sigmas[-1].item() - self.sigma_max = self.sigmas[0].item() @property def shift(self): diff --git a/tests/schedulers/test_scheduler_flow_map_euler_discrete.py b/tests/schedulers/test_scheduler_flow_map_euler_discrete.py index aca680746a1f..8c6dc686e257 100644 --- a/tests/schedulers/test_scheduler_flow_map_euler_discrete.py +++ b/tests/schedulers/test_scheduler_flow_map_euler_discrete.py @@ -187,3 +187,30 @@ def test_scale_noise_endpoints(self): torch.testing.assert_close(scheduler.scale_noise(sample, zero_t, noise), sample) full_t = torch.tensor([float(scheduler.config.num_train_timesteps)]) torch.testing.assert_close(scheduler.scale_noise(sample, full_t, noise), noise) + + def test_set_timesteps_no_double_shift(self): + """set_timesteps must not apply the shift formula twice (regression #13243). + + When sigma_min/sigma_max were stored *after* shifting in __init__, calling + set_timesteps fed already-shifted values back through the shift formula a + second time. After the fix the schedule produced by set_timesteps must be + identical to the one built in __init__ for the same number of steps. + """ + shift = 3.0 + n = 1000 + scheduler = self.scheduler_class(**self.get_default_config(shift=shift)) + + # The sigmas stored in __init__ — these are the ground-truth shifted values. + init_sigmas = scheduler.sigmas[:-1] # drop terminal 0 added by set_timesteps + + scheduler.set_timesteps(num_inference_steps=n) + inferred_sigmas = scheduler.sigmas[:-1] + + self.assertEqual(len(init_sigmas), len(inferred_sigmas)) + for i, (s_init, s_infer) in enumerate(zip(init_sigmas, inferred_sigmas)): + self.assertAlmostEqual( + s_init.item(), + s_infer.item(), + places=5, + msg=f"sigma mismatch at index {i}: init={s_init:.6f} vs set_timesteps={s_infer:.6f}", + )