Skip to content

fix(pipeline): preserve original pipeline dtype in from_pipe() by default#13964

Open
Liauuu wants to merge 1 commit into
huggingface:mainfrom
Liauuu:pr5
Open

fix(pipeline): preserve original pipeline dtype in from_pipe() by default#13964
Liauuu wants to merge 1 commit into
huggingface:mainfrom
Liauuu:pr5

Conversation

@Liauuu

@Liauuu Liauuu commented Jun 15, 2026

Copy link
Copy Markdown

Description

Fixes #12754

Currently, from_pipe() forces the newly created pipeline (and its shared components) to float32 by default unless torch_dtype is explicitly provided. This causes unexpected VRAM spikes, CUDA OOMs, and introduces a regression in developer experience (DX) when working with float16 or bfloat16 pipelines.

This PR modifies from_pipe() to automatically infer and inherit the source pipeline's dtype if the user leaves torch_dtype=None.

Safety & Edge Case Defense

To perfectly guard against the edge case where shared components might have mixed dtypes:

  1. It scans all torch.nn.Module components inside pipeline.components.values().
  2. It only inherits and overrides torch_dtype if and only if all components share the exact same dtype (len(dtypes) == 1).
  3. If there is a rare mixture of dtypes, it safely leaves torch_dtype = None and skips casting, maintaining the exact original state of the components without breaking any shared weight links.

Testing

  • Added a comprehensive mock unit test in tests/pipelines/test_pipeline_utils.py checking that from_pipe correctly inherits float16 and shares the same object reference without unwanted mutations.
  • The unit test passes successfully locally:
tests/pipelines/test_pipeline_utils.py::FromPipeDtypeTests::test_from_pipe_preserves_dtype_by_default PASSED [100%]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

from_pipe converts pipelines to float32 by default

1 participant