GRPO学習済み Gemma-3 1B 数学推論モデル
このモデルはunsloth/gemma-3-1b-itをGSM8Kデータセットで数学推論タスク用にGRPO学習したLoRAアダプタです。
モデル概要
- ベースモデル: unsloth/gemma-3-1b-it
- 学習手法: GRPO (Generalized Reward-based Policy Optimization)
- データセット: GSM8K (小学校レベルの数学問題)
- パラメータ数: 約6.5M (ベースモデルの0.65%)
- 開発者: shimokawatoko
- ライセンス: apache-2.0
特徴
数学問題に対して以下の形式で構造化された回答を生成します:
<start_working_out>
[段階的な思考プロセス]
<end_working_out>
<SOLUTION>[最終的な数値回答]</SOLUTION>
使用方法
Google Colabで実行(GPU設定をT4にしてください)
!pip install unsloth
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from datasets import load_dataset
import random
from unsloth import FastModel
# デバイス設定
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"使用デバイス: {device}")
# システムプロンプトと特殊トークンの定義
reasoning_start = "<start_working_out>"
reasoning_end = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"
system_prompt = f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""
def setup_model():
"""GRPO学習済みLoRAアダプタを読み込む"""
print("GRPO学習済みモデルを読み込み中...")
base_model_name = "unsloth/gemma-3-1b-it"
# Unslothを使用してベースモデルを読み込み
model, tokenizer = FastModel.from_pretrained(
model_name=base_model_name,
max_seq_length=1024,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
)
# LoRAアダプタを適用
adapter_path = "shimokawatoko/gemma-3-grpo"
model = PeftModel.from_pretrained(model, adapter_path)
return model, tokenizer
def generate_response(model, tokenizer, question, max_new_tokens=512):
"""質問に対してモデルの回答を生成"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
]
# チャットテンプレートを適用
text = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
# 入力をトークン化
inputs = tokenizer(text, return_tensors="pt").to(device)
# 推論実行
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=1.0,
top_p=0.95,
top_k=64,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
# 生成されたテキストから入力部分を除去
response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return response.strip()
def extract_answer_from_response(response):
"""レスポンスから最終回答を抽出"""
import re
# SOLUTIONタグ内の回答を抽出
solution_pattern = re.compile(
rf"{solution_start}(.+?){solution_end}",
flags=re.MULTILINE | re.DOTALL
)
match = solution_pattern.search(response)
if match:
return match.group(1).strip()
# SOLUTIONタグがない場合は数字を抽出
number_pattern = re.compile(r"(\d+(?:\.\d+)?)")
numbers = number_pattern.findall(response)
if numbers:
return numbers[-1] # 最後の数字を返す
return None
def solve_gsm8k_problem(model, tokenizer):
"""GSM8Kからランダムな問題を解く"""
# GSM8Kテストセットから問題を取得
dataset = load_dataset("openai/gsm8k", "main", split="test")
problem = random.choice(dataset)
question = problem["question"]
correct_answer = problem["answer"].split("####")[1].strip()
print(f"問題: {question}")
print(f"正解: {correct_answer}")
print("-" * 60)
# モデルで回答生成
response = generate_response(model, tokenizer, question)
predicted_answer = extract_answer_from_response(response)
print(f"モデル回答:\n{response}")
print("-" * 60)
print(f"抽出された答え: {predicted_answer}")
print(f"正解判定: {'✅ 正解' if predicted_answer == correct_answer else '❌ 不正解'}")
return {
"question": question,
"correct_answer": correct_answer,
"response": response,
"predicted_answer": predicted_answer,
"is_correct": predicted_answer == correct_answer
}
# モデルを読み込み
model, tokenizer = setup_model()
# ランダムな問題を解く
print("=== GSM8Kランダム問題を解いてみる ===")
result = solve_gsm8k_problem(model, tokenizer)
# 複数問題を連続で試す
print("\n=== 3つのランダム問題を解いてみる ===")
for i in range(3):
print(f"\n--- 問題 {i+1} ---")
solve_gsm8k_problem(model, tokenizer)
必要な環境
- GPU: 8GB以上推奨(Google Colab T4で動作確認済み)
- CPU: 16GB RAM以上(GPU使用時)
- Python: 3.8以上
インストール
pip install unsloth torch transformers peft datasets
制限事項
- 小学校レベルの数学問題に特化
- 高度な数学概念には限界あり
- 適切なプロンプト形式が必要
学習詳細
このモデルはUnslothとHugging FaceのTRLライブラリを使用して2倍高速で学習されました。
ライセンス
Apache-2.0ライセンスの下で公開されています。
Uploaded model
- Developed by: shimokawatoko
- License: apache-2.0
- Finetuned from model : unsloth/gemma-3-1b-it
This gemma3_text model was trained 2x faster with Unsloth and Huggingface's TRL library.
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support
