hf gemma 3 pt generate bug

#15
by dglasscortex - opened
Google org

One can reproduce it by running the following code:

import torch
from transformers import AutoTokenizer, Gemma3ForCausalLM

ckpt = "google/gemma-3-1b-pt"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = Gemma3ForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

prompt = "Eiffel tower is located in"
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=50, do_sample=False)
    generation = generation[0][input_len:]

decoded = tokenizer.decode(generation, skip_special_tokens=True)
print(decoded)

Expected: text without unusual spacing around periods and without repetitions.
Actual:  " the heart of Paris, France.The Eiffel Tower is a symbol of Paris and France.The Eiffel Tower is a symbol of Paris and France.The Eiffel Tower is a symbol of Paris and France.The Eiffel Tower is a symbol"

Notice the unusual spacing between "France." and "The Eiffel" that occurs multiple times within 50 tokens. Also notice the repetitions of "The Eiffel Tower is a symbol of Paris and France".

Notes:

  1. This repros for both gemma-3-1b-pt and gemma-3-27b-pt
  2. I think it repros on CPU and on accelerators with slightly different text, but similar problems.
  3. gemma-2-9b (also a pt model) output for the same prompt and also using greedy decoding looks free of the above issues: " Paris, France. It is the most visited monument in the world. It is 324 meters tall and was built in 1889. It is made of iron and has 1,665 steps. It is a". Here's the snippet for gemma2-9b.

Sign up or log in to comment