# app.py import os import io import gradio as gr import spaces import torch from PIL import Image from fastapi import FastAPI, UploadFile, File from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText from pdf2image import convert_from_path # ----------------------- # Model + processor setup # ----------------------- MODEL_ID = "nanonets/Nanonets-OCR2-3B" model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, dtype="auto", # <- replaces deprecated torch_dtype device_map="auto", ) model.eval() tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) processor = AutoProcessor.from_pretrained(MODEL_ID) # A gentle prompt that nudges markdown structure OCR_PROMPT = ( "Extract the text from the document image and return clean Markdown.\n" "- Use # ## ### for headings\n" "- Use |A|B| tables\n" "- Bullets with '-' or '*'\n" "- **bold** and *italic*\n" "- Equations: readable inline text\n" "- For images: [Image: short description]\n" "- If page number is visible, append [Page X] at the end." ) def _poppler_path() -> dict: """ Optional: on some environments (e.g., Windows or custom containers) you must point pdf2image to Poppler's binaries. Set POPPLER_PATH env var if needed. """ path = os.getenv("POPPLER_PATH") return {"poppler_path": path} if path else {} def _file_to_path(v) -> str: """ Gradio v4's File returns a FileData dict-like object with `.path`, or a NamedTemporaryFile-like `.name`. Normalize to a filesystem path. """ if v is None: return None # gradio.FileData if hasattr(v, "path") and v.path: return v.path # dict style if isinstance(v, dict) and v.get("path"): return v["path"] # tempfile wrapper if hasattr(v, "name"): return v.name # already a string path if isinstance(v, str): return v raise ValueError("Unsupported file input type for PDF.") @spaces.GPU(duration=120) def ocr_page_with_nanonets(pil_image: Image.Image, max_new_tokens: int = 4096) -> str: """ Run Nanonets-OCR2-3B on a single PIL image page and return Markdown text. """ messages = [ {"role": "system", "content": "You extract text from documents and format it in clean Markdown."}, {"role": "user", "content": [ {"type": "image", "image": pil_image}, # pass the PIL image, not a file:// URI {"type": "text", "text": OCR_PROMPT}, ]}, ] # Build inputs using the chat template text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor( text=[text], images=[pil_image], padding=True, return_tensors="pt", ) inputs = {k: v.to(model.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} # Deterministic, long generations are fine for OCR-style tasks outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, repetition_penalty=1.0, # helps table fidelity per community tips ) # Slice off the prompt tokens before decoding prompt_len = inputs["input_ids"].shape[-1] gen_only = outputs[:, prompt_len:] decoded = processor.batch_decode(gen_only, skip_special_tokens=True, clean_up_tokenization_spaces=True) return decoded[0].strip() def process_pdf_impl(pdf_path: str, max_tokens: int) -> str: """ Convert a PDF file at pdf_path into Markdown by running OCR per page. """ if not pdf_path: return "Please upload a PDF file first." # Convert PDF -> list[PIL.Image] pages = convert_from_path(pdf_path, dpi=220, **_poppler_path()) if not pages: return "No pages found in the PDF." all_md = [] for i, page_img in enumerate(pages, start=1): page_text = ocr_page_with_nanonets(page_img, max_new_tokens=max_tokens) if len(pages) > 1: all_md.append(f"# Page {i}\n\n{page_text}") else: all_md.append(page_text) return "\n\n---\n\n".join(all_md) # ----------------------- # Gradio UI (optional) # ----------------------- with gr.Blocks(title="PDF → Markdown (Nanonets OCR2)") as demo: gr.Markdown( """ # 📄 PDF → Markdown (Nanonets OCR2) Upload a PDF and get back structured Markdown (tables, lists, headings). """ ) with gr.Row(): with gr.Column(): file_input = gr.File(label="Upload PDF", file_types=[".pdf"], file_count="single") max_tokens_slider = gr.Slider( label="Maximum tokens", minimum=500, maximum=15000, value=4096, step=500 ) run_btn = gr.Button("🚀 Convert to Markdown", variant="primary") with gr.Column(): output_text = gr.Textbox( label="Markdown Output", lines=28, max_lines=36, interactive=False, show_copy_button=True, ) def _gr_process(file_obj, max_tokens): try: path = _file_to_path(file_obj) return process_pdf_impl(path, int(max_tokens)) except Exception as e: return f"Error processing PDF: {e}" run_btn.click(_gr_process, inputs=[file_input, max_tokens_slider], outputs=output_text) # ----------------------- # FastAPI endpoint # ----------------------- app = FastAPI() @app.post("/ocr") async def ocr_endpoint(file: UploadFile = File(...), max_tokens: int = 4096): """ POST /ocr with multipart/form-data: - file: PDF file - max_tokens (optional): int Returns: {"markdown": "..."} """ # Save to a temp path tmp_path = "/tmp/input.pdf" with open(tmp_path, "wb") as f: f.write(await file.read()) md = process_pdf_impl(tmp_path, max_tokens=max_tokens) return {"markdown": md} # Mount Gradio UI at root (optional) app = gr.mount_gradio_app(app, demo, path="/")