Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| import json | |
| import re | |
| import os | |
| from diffusers import QwenImagePipeline | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| SYSTEM_PROMPT = """You are a Prompt Optimizer specializing in image generation models (e.g., MidJourney, Stable Diffusion). Your core task is to rewrite user-provided prompts into highly clear, easy-to-render versions. | |
| When rewriting, prioritize the following principles: | |
| 1. Start from the user's prompt, do reasoning step by step to analyze the object or scene they want to generate. | |
| 2. Focus on describing the final visual appearance of the scene. Clarify elements like the main subject’s shape, color, and state. | |
| 3. If you are confident about what the user wants to generate, directly point it out in your explanation and the final revised prompt. | |
| 4. If technical concepts are necessary but difficult for ordinary users to understand, translate them into intuitive visual descriptions. | |
| 5. Ensure the final revised prompt is consistent with the user's intent. | |
| After receiving the user’s prompt that needs rewriting, first explain your reasoning for optimization. Then, output the final revised prompt in the fixed format of "Revised Prompt:\n". Where the specific revised content is filled in the next line. | |
| Prompt: | |
| """ | |
| hf_token = os.environ.get("HF_TOKEN") | |
| device = "cuda" | |
| dtype = torch.bfloat16 | |
| processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") | |
| qwen = Qwen2_5_VLForConditionalGeneration.from_pretrained("zhijie3/think-then-generate", subfolder="text_encoder", token=hf_token) | |
| pipe = QwenImagePipeline.from_pretrained("zhijie3/think-then-generate", token=hf_token, torch_dtype=dtype) | |
| pipe = pipe.to(device) | |
| negative_prompt = " " | |
| def extract_prompt(text: str) -> str: | |
| m = re.search(r"Revised Prompt:\n(.*)", text, re.DOTALL) | |
| if not m: | |
| m = re.search(r"Revised Prompt:(.*)", text, re.DOTALL) | |
| return m.group(1).strip() if m else text.strip() | |
| def predict(prompt): | |
| with torch.inference_mode(): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": SYSTEM_PROMPT + prompt | |
| } | |
| ] | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = processor( | |
| text=[text], | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to(device) | |
| generated_ids = pipe.text_encoder.generate(**inputs, do_sample=False, max_new_tokens=1024) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| true_prompt = extract_prompt(output_text[0]) | |
| image = pipe( | |
| prompt = true_prompt, | |
| negative_prompt = negative_prompt, | |
| width = 1024, | |
| height = 1024, | |
| num_inference_steps = 50, | |
| true_cfg_scale = 4.0, | |
| ).images[0] | |
| return image | |
| # 3. Build Gradio Interface | |
| examples = [ | |
| ["A multi-panel illustration showing the story of playing the lute to a cow, with clear steps from performing music passionately to the cow remaining completely uninterested."], | |
| ["A multi-panel illustration showing the story of marking the boat to find a sword, with clear steps from dropping the sword to carving a mark on the boat."], | |
| ["A multi-panel illustration showing the story of fixing the sheep pen after losing sheep, with clear steps from discovering the loss to repairing the fence too late."] | |
| ] | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Think-Then-Generate") as demo: | |
| gr.Markdown(""" | |
| # ✨ Think-Then-Generate: Evolutionary Reasoning for Image Generation | |
| ### 🚀 The Paradigm Shift | |
| Traditional text-to-image models often treat the text encoder as a frozen dictionary, mapping words to pixels without truly understanding the **intent** or **implicit semantics** behind a prompt. | |
| ### 🧠 How it Works | |
| To break this limitation, we introduce the **Think-Then-Generate** paradigm. Before a single pixel is drawn, the model "thinks" through the instruction: | |
| 1. **Chain-of-Thought (CoT) Reasoning**: By fine-tuning **Qwen2.5-VL**, we activate its latent world knowledge. The model reasons about the scene and objects before generating an "optimized prompt." | |
| 2. **Dual-GRPO Reinforcement Learning**: A collaborative RL strategy where the LLM encoder and DiT generator evolve together. The LLM learns to produce better instructions, while the DiT enhances its rendering capability based on visual feedback. | |
| 3. **Bridging Logic and Vision**: The optimized prompt serves as a semantic bridge, ensuring the final generation is a deep realization of user intent. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox( | |
| lines=5, | |
| placeholder="Describe what you want to create...", | |
| label="Input Prompt ✍️" | |
| ) | |
| generate_btn = gr.Button("Generate ✨", variant="primary") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=prompt_input, | |
| label="Try these examples 👇" | |
| ) | |
| with gr.Column(scale=1): | |
| image_output = gr.Image(label="Generated Result 🖼️", interactive=False) | |
| gr.Markdown(""" | |
| > **Note**: This is a research preview. The model first reasons about your prompt to optimize the visual description. | |
| """) | |
| # When the model is not fully loaded, predict returns the prompt string. | |
| # To avoid Gradio Image component errors, we handle the output. | |
| def ui_predict(prompt): | |
| result = predict(prompt) | |
| if isinstance(result, str): | |
| # Return a simple message or placeholder if it's just text | |
| return None | |
| return result | |
| generate_btn.click( | |
| fn=ui_predict, | |
| inputs=prompt_input, | |
| outputs=image_output | |
| ) | |
| # 4. Launch Service | |
| if __name__ == "__main__": | |
| demo.launch() |