zhijie3's picture
Update app.py
b788255 verified
import gradio as gr
import spaces
import torch
import json
import re
import os
from diffusers import QwenImagePipeline
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
SYSTEM_PROMPT = """You are a Prompt Optimizer specializing in image generation models (e.g., MidJourney, Stable Diffusion). Your core task is to rewrite user-provided prompts into highly clear, easy-to-render versions.
When rewriting, prioritize the following principles:
1. Start from the user's prompt, do reasoning step by step to analyze the object or scene they want to generate.
2. Focus on describing the final visual appearance of the scene. Clarify elements like the main subject’s shape, color, and state.
3. If you are confident about what the user wants to generate, directly point it out in your explanation and the final revised prompt.
4. If technical concepts are necessary but difficult for ordinary users to understand, translate them into intuitive visual descriptions.
5. Ensure the final revised prompt is consistent with the user's intent.
After receiving the user’s prompt that needs rewriting, first explain your reasoning for optimization. Then, output the final revised prompt in the fixed format of "Revised Prompt:\n". Where the specific revised content is filled in the next line.
Prompt:
"""
hf_token = os.environ.get("HF_TOKEN")
device = "cuda"
dtype = torch.bfloat16
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
qwen = Qwen2_5_VLForConditionalGeneration.from_pretrained("zhijie3/think-then-generate", subfolder="text_encoder", token=hf_token)
pipe = QwenImagePipeline.from_pretrained("zhijie3/think-then-generate", token=hf_token, torch_dtype=dtype)
pipe = pipe.to(device)
negative_prompt = " "
def extract_prompt(text: str) -> str:
m = re.search(r"Revised Prompt:\n(.*)", text, re.DOTALL)
if not m:
m = re.search(r"Revised Prompt:(.*)", text, re.DOTALL)
return m.group(1).strip() if m else text.strip()
@spaces.GPU
def predict(prompt):
with torch.inference_mode():
messages = [
{
"role": "user",
"content": SYSTEM_PROMPT + prompt
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[text],
padding=True,
return_tensors="pt",
)
inputs = inputs.to(device)
generated_ids = pipe.text_encoder.generate(**inputs, do_sample=False, max_new_tokens=1024)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
true_prompt = extract_prompt(output_text[0])
image = pipe(
prompt = true_prompt,
negative_prompt = negative_prompt,
width = 1024,
height = 1024,
num_inference_steps = 50,
true_cfg_scale = 4.0,
).images[0]
return image
# 3. Build Gradio Interface
examples = [
["A multi-panel illustration showing the story of playing the lute to a cow, with clear steps from performing music passionately to the cow remaining completely uninterested."],
["A multi-panel illustration showing the story of marking the boat to find a sword, with clear steps from dropping the sword to carving a mark on the boat."],
["A multi-panel illustration showing the story of fixing the sheep pen after losing sheep, with clear steps from discovering the loss to repairing the fence too late."]
]
with gr.Blocks(theme=gr.themes.Soft(), title="Think-Then-Generate") as demo:
gr.Markdown("""
# ✨ Think-Then-Generate: Evolutionary Reasoning for Image Generation
### 🚀 The Paradigm Shift
Traditional text-to-image models often treat the text encoder as a frozen dictionary, mapping words to pixels without truly understanding the **intent** or **implicit semantics** behind a prompt.
### 🧠 How it Works
To break this limitation, we introduce the **Think-Then-Generate** paradigm. Before a single pixel is drawn, the model "thinks" through the instruction:
1. **Chain-of-Thought (CoT) Reasoning**: By fine-tuning **Qwen2.5-VL**, we activate its latent world knowledge. The model reasons about the scene and objects before generating an "optimized prompt."
2. **Dual-GRPO Reinforcement Learning**: A collaborative RL strategy where the LLM encoder and DiT generator evolve together. The LLM learns to produce better instructions, while the DiT enhances its rendering capability based on visual feedback.
3. **Bridging Logic and Vision**: The optimized prompt serves as a semantic bridge, ensuring the final generation is a deep realization of user intent.
""")
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(
lines=5,
placeholder="Describe what you want to create...",
label="Input Prompt ✍️"
)
generate_btn = gr.Button("Generate ✨", variant="primary")
gr.Examples(
examples=examples,
inputs=prompt_input,
label="Try these examples 👇"
)
with gr.Column(scale=1):
image_output = gr.Image(label="Generated Result 🖼️", interactive=False)
gr.Markdown("""
> **Note**: This is a research preview. The model first reasons about your prompt to optimize the visual description.
""")
# When the model is not fully loaded, predict returns the prompt string.
# To avoid Gradio Image component errors, we handle the output.
def ui_predict(prompt):
result = predict(prompt)
if isinstance(result, str):
# Return a simple message or placeholder if it's just text
return None
return result
generate_btn.click(
fn=ui_predict,
inputs=prompt_input,
outputs=image_output
)
# 4. Launch Service
if __name__ == "__main__":
demo.launch()