import gradio as gr from transformers import AutoTokenizer, AutoModelForMaskedLM import torch # 加载模型和分词器 model_name = "InstaDeepAI/agro-nucleotide-transformer-1b" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForMaskedLM.from_pretrained(model_name) # 推理函数 def predict(sequence): # 编码 tokens = tokenizer(sequence, return_tensors="pt", padding=True) with torch.no_grad(): output = model(**tokens) # 提取 logits(每个 token 的预测概率) logits = output.logits[0] predicted_token_ids = torch.argmax(logits, dim=-1) predicted_sequence = tokenizer.decode(predicted_token_ids) return predicted_sequence # Gradio 网页界面 demo = gr.Interface( fn=predict, inputs=gr.Textbox(lines=2, placeholder="输入DNA序列,例如:ATATACGGCCGNC"), outputs="text", title="AgroNT 植物DNA语言模型", description="使用 AgroNT 模型(1B参数)对DNA序列进行语言建模预测" ) demo.launch()