diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 5fb666a4d42c..219ff3587df4 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1512,7 +1512,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Handle class prompt for prior-preservation. if args.with_prior_preservation: if not args.train_text_encoder: - class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings( + class_prompt_hidden_states, class_pooled_prompt_embeds, _ = compute_text_embeddings( args.class_prompt, text_encoders, tokenizers ) @@ -1533,7 +1533,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if args.with_prior_preservation: prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0) - text_ids = torch.cat([text_ids, class_text_ids], dim=0) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) # we need to tokenize and encode the batch prompts on all training steps else: