GRPO学習済み Gemma-3 1B 数学推論モデル

このモデルはunsloth/gemma-3-1b-itをGSM8Kデータセットで数学推論タスク用にGRPO学習したLoRAアダプタです。

モデル概要

  • ベースモデル: unsloth/gemma-3-1b-it
  • 学習手法: GRPO (Generalized Reward-based Policy Optimization)
  • データセット: GSM8K (小学校レベルの数学問題)
  • パラメータ数: 約6.5M (ベースモデルの0.65%)
  • 開発者: shimokawatoko
  • ライセンス: apache-2.0

特徴

数学問題に対して以下の形式で構造化された回答を生成します:

<start_working_out>
[段階的な思考プロセス]
<end_working_out>

<SOLUTION>[最終的な数値回答]</SOLUTION>

使用方法

Google Colabで実行(GPU設定をT4にしてください)

!pip install unsloth
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from datasets import load_dataset
import random
from unsloth import FastModel

# デバイス設定
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用デバイス: {device}")

# システムプロンプトと特殊トークンの定義
reasoning_start = "<start_working_out>"
reasoning_end   = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

system_prompt = f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""

def setup_model():
    """GRPO学習済みLoRAアダプタを読み込む"""
    print("GRPO学習済みモデルを読み込み中...")
    base_model_name = "unsloth/gemma-3-1b-it"
    
    # Unslothを使用してベースモデルを読み込み
    model, tokenizer = FastModel.from_pretrained(
        model_name=base_model_name,
        max_seq_length=1024,
        load_in_4bit=False,
        load_in_8bit=False,
        full_finetuning=False,
    )
    
    # LoRAアダプタを適用
    adapter_path = "shimokawatoko/gemma-3-grpo"
    model = PeftModel.from_pretrained(model, adapter_path)
    
    return model, tokenizer

def generate_response(model, tokenizer, question, max_new_tokens=512):
    """質問に対してモデルの回答を生成"""
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question},
    ]
    
    # チャットテンプレートを適用
    text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )
    
    # 入力をトークン化
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # 推論実行
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=1.0,
            top_p=0.95,
            top_k=64,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # 生成されたテキストから入力部分を除去
    response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return response.strip()

def extract_answer_from_response(response):
    """レスポンスから最終回答を抽出"""
    import re
    
    # SOLUTIONタグ内の回答を抽出
    solution_pattern = re.compile(
        rf"{solution_start}(.+?){solution_end}",
        flags=re.MULTILINE | re.DOTALL
    )
    
    match = solution_pattern.search(response)
    if match:
        return match.group(1).strip()
    
    # SOLUTIONタグがない場合は数字を抽出
    number_pattern = re.compile(r"(\d+(?:\.\d+)?)")
    numbers = number_pattern.findall(response)
    if numbers:
        return numbers[-1]  # 最後の数字を返す
    
    return None

def solve_gsm8k_problem(model, tokenizer):
    """GSM8Kからランダムな問題を解く"""
    # GSM8Kテストセットから問題を取得
    dataset = load_dataset("openai/gsm8k", "main", split="test")
    problem = random.choice(dataset)
    
    question = problem["question"]
    correct_answer = problem["answer"].split("####")[1].strip()
    
    print(f"問題: {question}")
    print(f"正解: {correct_answer}")
    print("-" * 60)
    
    # モデルで回答生成
    response = generate_response(model, tokenizer, question)
    predicted_answer = extract_answer_from_response(response)
    
    print(f"モデル回答:\n{response}")
    print("-" * 60)
    print(f"抽出された答え: {predicted_answer}")
    print(f"正解判定: {'✅ 正解' if predicted_answer == correct_answer else '❌ 不正解'}")
    
    return {
        "question": question,
        "correct_answer": correct_answer,
        "response": response,
        "predicted_answer": predicted_answer,
        "is_correct": predicted_answer == correct_answer
    }

# モデルを読み込み
model, tokenizer = setup_model()

# ランダムな問題を解く
print("=== GSM8Kランダム問題を解いてみる ===")
result = solve_gsm8k_problem(model, tokenizer)

# 複数問題を連続で試す
print("\n=== 3つのランダム問題を解いてみる ===")
for i in range(3):
    print(f"\n--- 問題 {i+1} ---")
    solve_gsm8k_problem(model, tokenizer)

必要な環境

  • GPU: 8GB以上推奨(Google Colab T4で動作確認済み)
  • CPU: 16GB RAM以上(GPU使用時)
  • Python: 3.8以上

インストール

pip install unsloth torch transformers peft datasets

制限事項

  • 小学校レベルの数学問題に特化
  • 高度な数学概念には限界あり
  • 適切なプロンプト形式が必要

学習詳細

このモデルはUnslothとHugging FaceのTRLライブラリを使用して2倍高速で学習されました。

ライセンス

Apache-2.0ライセンスの下で公開されています。

Uploaded model

  • Developed by: shimokawatoko
  • License: apache-2.0
  • Finetuned from model : unsloth/gemma-3-1b-it

This gemma3_text model was trained 2x faster with Unsloth and Huggingface's TRL library.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for shimokawatoko/gemma-3-grpo

Finetuned
(444)
this model