TRL documentation

Asynchronous GRPO

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.29.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Asynchronous GRPO

This trainer requires vllm>=0.17.1 and transformers>=5.2.0. For distributed training, only FSDP2 is supported (DeepSpeed ZeRO is not).

Currently, vllm and transformers have conflicting dependency constraints. To work around this, install vLLM first and then force-install transformers:

pip install 'vllm>=0.17.1'
pip install 'transformers>=5.2.0' --no-deps

Overview

AsyncGRPOTrainer implements the same GRPO algorithm but decouples rollout generation from training. A background worker continuously streams completions from a vLLM server while the training loop consumes them, so generation and gradient updates overlap instead of alternating. The API mirrors GRPOTrainer — for full details on the GRPO method itself (advantage computation, KL estimation, loss formulation, reward functions, etc.), see the GRPO Trainer documentation. Not all features from GRPOTrainer are available; refer to AsyncGRPOConfig for the supported parameters.

This trainer was contributed by Quentin Gallouédec and Amine Dirhoussi.

How it differs from GRPOTrainer

In the standard GRPOTrainer, generation and training are sequential: generate a batch, compute the loss, update weights, repeat. Even in vLLM colocate mode, where generation runs on the same GPUs, one phase must finish before the other begins.

AsyncGRPOTrainer separates these two concerns:

  • Rollout worker (background thread) — sends prompts to a vLLM server, scores completions with reward functions, computes advantages, and pushes ready-to-train samples into a queue.
  • Training loop (main process) — pulls samples from the queue, computes the clipped surrogate loss, and updates the model weights.

After every weight_sync_steps training steps, the updated weights are transferred to the vLLM server via NCCL so that subsequent generations reflect the latest policy.

Because generation and training run concurrently, the training samples may have been generated by a slightly older version of the model. The max_staleness parameter controls how many weight updates a sample can lag behind before being discarded.

The number of concurrent requests sent to the vLLM server is controlled by max_inflight_tasks. By default it is set automatically to max_staleness × per_device_train_batch_size × gradient_accumulation_steps × num_processes — the maximum number of samples the trainer can consume before they become stale. Generating more than this is wasteful since the excess samples will be discarded.

Quick start

# train_async_grpo.py
from datasets import load_dataset
from trl.experimental.async_grpo import AsyncGRPOTrainer
from trl.rewards import accuracy_reward

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

trainer = AsyncGRPOTrainer(
    model="Qwen/Qwen3-4B",
    reward_funcs=accuracy_reward,
    train_dataset=dataset,
)
trainer.train()

The vLLM server and the trainer must run on separate GPUs. Use CUDA_VISIBLE_DEVICES to partition your GPUs. For example, with 2 GPUs, you can run the vLLM server on GPU 0 and the trainer on GPU 1 as follows:

# Terminal 1: vLLM server on GPU 0 (dev mode + NCCL weight transfer are required)
CUDA_VISIBLE_DEVICES=0 VLLM_SERVER_DEV_MODE=1 vllm serve Qwen/Qwen3-4B \
    --max-model-len 4096 \
    --logprobs-mode processed_logprobs \
    --weight-transfer-config '{"backend":"nccl"}'

Set --max-model-len to the maximum total sequence length (prompt + completion) you expect. A lower value reduces GPU memory usage on the server, freeing more memory for the KV cache and increasing throughput. A good starting point is the prompt length plus max_completion_length from your config.

# Terminal 2: training on GPU 1
CUDA_VISIBLE_DEVICES=1 accelerate launch train_async_grpo.py

Design philosophy

This trainer is intentionally kept minimal and is not meant to grow into a general-purpose solution. If you need a feature that is not supported, we recommend cloning the repository and adapting the trainer to your needs directly. New features will only be considered when there is significant community demand.

AsyncGRPOConfig

class trl.experimental.async_grpo.AsyncGRPOConfig

< >

( output_dir: str | None = None per_device_train_batch_size: int = 8 num_train_epochs: float = 3.0 max_steps: int = -1 learning_rate: float = 1e-06 lr_scheduler_type: transformers.trainer_utils.SchedulerType | str = 'linear' lr_scheduler_kwargs: dict | str | None = None warmup_steps: float = 0 optim: transformers.training_args.OptimizerNames | str = 'adamw_torch_fused' optim_args: str | None = None weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 optim_target_modules: None | str | list[str] = None gradient_accumulation_steps: int = 1 average_tokens_across_devices: bool = True max_grad_norm: float = 1.0 label_smoothing_factor: float = 0.0 bf16: bool | None = None fp16: bool = False bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: bool | None = None gradient_checkpointing: bool = True gradient_checkpointing_kwargs: dict[str, typing.Any] | str | None = None torch_compile: bool = False torch_compile_backend: str | None = None torch_compile_mode: str | None = None use_liger_kernel: bool = False liger_kernel_config: dict[str, bool] | None = None use_cache: bool = False neftune_noise_alpha: float | None = None torch_empty_cache_steps: int | None = None auto_find_batch_size: bool = False logging_strategy: transformers.trainer_utils.IntervalStrategy | str = 'steps' logging_steps: float = 1 logging_first_step: bool = False log_on_each_node: bool = True logging_nan_inf_filter: bool = True include_num_input_tokens_seen: str | bool = 'no' log_level: str = 'passive' log_level_replica: str = 'warning' disable_tqdm: bool | None = None report_to: None | str | list[str] = 'none' run_name: str | None = None project: str = 'huggingface' trackio_space_id: str | None = 'trackio' eval_strategy: transformers.trainer_utils.IntervalStrategy | str = 'no' eval_steps: float | None = None eval_delay: float = 0 per_device_eval_batch_size: int = 8 prediction_loss_only: bool = False eval_on_start: bool = False eval_do_concat_batches: bool = True eval_use_gather_object: bool = False eval_accumulation_steps: int | None = None include_for_metrics: list = <factory> batch_eval_metrics: bool = False save_only_model: bool = False save_strategy: transformers.trainer_utils.SaveStrategy | str = 'steps' save_steps: float = 500 save_on_each_node: bool = False save_total_limit: int | None = None enable_jit_checkpoint: bool = False push_to_hub: bool = False hub_token: str | None = None hub_private_repo: bool | None = None hub_model_id: str | None = None hub_strategy: transformers.trainer_utils.HubStrategy | str = 'every_save' hub_always_push: bool = False hub_revision: str | None = None load_best_model_at_end: bool = False metric_for_best_model: str | None = None greater_is_better: bool | None = None ignore_data_skip: bool = False restore_callback_states_from_checkpoint: bool = False full_determinism: bool = False seed: int = 42 data_seed: int | None = None use_cpu: bool = False accelerator_config: dict | str | None = None parallelism_config: accelerate.parallelism_config.ParallelismConfig | None = None dataloader_drop_last: bool = False dataloader_num_workers: int = 0 dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False dataloader_prefetch_factor: int | None = None remove_unused_columns: bool = True label_names: list[str] | None = None train_sampling_strategy: str = 'random' length_column_name: str = 'length' ddp_find_unused_parameters: bool | None = None ddp_bucket_cap_mb: int | None = None ddp_broadcast_buffers: bool | None = None ddp_backend: str | None = None ddp_timeout: int = 1800 fsdp: list[transformers.trainer_utils.FSDPOption] | str | None = None fsdp_config: dict[str, typing.Any] | str | None = None deepspeed: dict | str | None = None debug: str | list[transformers.debug_utils.DebugOption] = '' skip_memory_metrics: bool = True do_train: bool = False do_eval: bool = False do_predict: bool = False resume_from_checkpoint: str | None = None warmup_ratio: float | None = None logging_dir: str | None = None local_rank: int = -1 num_generations: int = 8 max_completion_length: int = 2048 temperature: float = 1.0 chat_template_kwargs: dict | None = None max_tool_calling_iterations: int | None = None vllm_server_base_url: str = 'http://localhost:8000' vllm_server_timeout: float = 240.0 request_timeout: int = 600 epsilon: float = 0.2 epsilon_high: float = 0.2 max_inflight_tasks: int = -1 max_staleness: int = 4 queue_maxsize: int = 1024 weight_sync_steps: int = 1 log_completions: bool = False num_completions_to_print: int = 3 )

Parameters that control generation

  • num_generations (int, optional, defaults to 8) — Number of generations per prompt to sample.
  • max_completion_length (int, optional, defaults to 2048) — Maximum number of tokens to generate per completion.
  • temperature (float, optional, defaults to 1.0) — Temperature for sampling. The higher the temperature, the more random the completions.
  • chat_template_kwargs (dict[str, Any], optional) — Additional keyword arguments to pass to the apply_chat_template function when generating completions.
  • max_tool_calling_iterations (int, optional) — Maximum number of tool-calling turns when training an agent. If None, there is no limit and generation stops when the model generates a response turn with no tool calls or when the total response length reaches max_completion_length.

Parameters that control the vLLM server

  • vllm_server_base_url (str, optional, defaults to "http --//localhost:8000"): Base URL of the vLLM server used for generation (e.g., "http://localhost:8000").
  • vllm_server_timeout (float, optional, defaults to 240.0) — Total timeout duration in seconds to wait for the vLLM server to be ready.
  • request_timeout (int, optional, defaults to 600) — Timeout in seconds for individual HTTP requests to the vLLM server.

Parameters that control the training

  • epsilon (float, optional, defaults to 0.2) — Lower-bound epsilon value for clipping.
  • epsilon_high (float, optional, defaults to 0.2) — Upper-bound epsilon value for clipping.

Parameters that control the async rollout pipeline

  • max_inflight_tasks (int, optional, defaults to -1) — Maximum number of concurrent generation tasks sent to the vLLM server. Defaults to -1 (auto), which sets it to max_staleness * per_device_train_batch_size * gradient_accumulation_steps * num_processes. If using tool-use environments, you may want to set this manually based on how many parallel environments you can run.
  • max_staleness (int, optional, defaults to 4) — Maximum number of weight update steps a rollout sample can lag behind the current model version before being discarded.
  • queue_maxsize (int, optional, defaults to 1024) — Maximum number of rollout samples to buffer in the rollout queue.
  • weight_sync_steps (int, optional, defaults to 1) — Number of training steps between weight synchronizations to the vLLM server.

Parameters that control the logging

  • log_completions (bool, optional, defaults to False) — Whether to log a sample of (prompt, completion) pairs every logging_steps steps.
  • num_completions_to_print (int, optional, defaults to 3) — Number of completions to print when log_completions=True.

Configuration class for the AsyncGRPOTrainer.

This class includes only the parameters that are specific to asynchronous GRPO training. For a full list of training arguments, please refer to the TrainingArguments documentation. Note that default values in this class may differ from those in TrainingArguments.

These parameters have default values different from TrainingArguments:

  • logging_steps: Defaults to 10 instead of 500.
  • gradient_checkpointing: Defaults to True instead of False.
  • bf16: Defaults to True if fp16 is not set, instead of False.
  • learning_rate: Defaults to 1e-6 instead of 5e-5.

AsyncGRPOTrainer

class trl.experimental.async_grpo.AsyncGRPOTrainer

< >

( model: str reward_funcs: collections.abc.Callable[..., list[float]] | list[collections.abc.Callable[..., list[float]]] args: trl.experimental.async_grpo.async_grpo_config.AsyncGRPOConfig | None = None train_dataset: datasets.arrow_dataset.Dataset | datasets.iterable_dataset.IterableDataset | None = None processing_class: transformers.tokenization_utils_base.PreTrainedTokenizerBase | None = None callbacks: list[transformers.trainer_callback.TrainerCallback] | None = None optimizers: tuple = (None, None) tools: list[collections.abc.Callable] | None = None environment_factory: collections.abc.Callable[[], trl.experimental.async_grpo.async_grpo_trainer._SupportsReset] | None = None rollout_worker: trl.experimental.async_grpo.async_grpo_trainer.RolloutWorkerProtocol | None = None )

Parameters

  • model (str) — Model to be trained. Must be a string, being the model id of a pretrained model hosted inside a model repo on huggingface.co, or a path to a directory containing model weights saved using save_pretrained, e.g., './my_model_directory/'. The model is loaded using from_pretrained. The model name is also used to identify the model on the vLLM server used for generation.
  • reward_funcs (RewardFunc | list[RewardFunc]) — Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward functions with the prompts and completions and sum the rewards. Can be either:

    • A single reward function: The function is provided with the prompts and the generated completions, plus any additional columns in the dataset. It should return a list of rewards. Reward functions can be either synchronous or asynchronous and can also return None when the reward is not applicable to those samples. This is useful for multi-task training where different reward functions apply to different types of samples. When a reward function returns None for a sample, that reward function is excluded from the reward calculation for that sample. For more details, see Using a custom reward function.
    • A list of reward functions, where each item is a reward function as described above. Rewards from all functions are summed.
  • args (AsyncGRPOConfig, optional) — Configuration for this trainer. If None, a default configuration is used.
  • train_dataset (Dataset or IterableDataset) — Dataset to use for training. It must include a column "prompt". Any additional columns in the dataset are ignored. The format of the samples can be either:

    • Standard: Each sample contains plain text.
    • Conversational: Each sample contains structured messages (e.g., role and content).
  • processing_class (PreTrainedTokenizerBase, optional) — Processing class used to process the data. The padding side must be set to "left". If None, the processing class is loaded from the model’s name with from_pretrained. A padding token, tokenizer.pad_token, must be set. If the processing class has not set a padding token, tokenizer.eos_token will be used as the default.
  • callbacks (list of TrainerCallback, optional) — List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed in here.

    If you want to remove one of the default callbacks used, use the remove_callback method.

  • optimizers (tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None], optional, defaults to (None, None)) — A tuple containing the optimizer and the scheduler to use. Will default to an instance of AdamW on your model and a scheduler given by get_linear_schedule_with_warmup controlled by args.
  • tools (list of Callable, optional) — A list of callable tool functions (sync or async) that the model can invoke during generation. Each tool should be a standard Python function with properly type-hinted arguments and return values, and a Google-style docstring describing its purpose, arguments, and return value. For more details, see: https://huggingface.co/docs/transformers/en/chat_extras#passing-tools. The model uses the function’s name, type hints, and docstring to determine how to call it. Ensure that the model’s chat template supports tool use and that it has been fine-tuned for tool calling.
  • environment_factory (EnvironmentFactory, optional) — A callable that creates and returns an environment instance. The environment class should define methods that can be invoked as tools during generation. Each method should comply with the same requirements as the tools described above. If environment_factory is provided, an instance of the environment is created for each generation in the batch, allowing for parallel and independent interactions. The environment must also implement a callable reset method that can be used to reset state between generations. The reset method should return either None or a string: when it returns a string, that string is appended to the last user message before generation. This feature is experimental and may change or be removed at any time without prior notice.

Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the paper DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models. This trainer is the asynchronous version of GRPO, where generation is offloaded to an external vLLM server that runs asynchronously alongside training, decoupling rollout from the gradient update loop.

Example:

from trl.experimental.async_grpo import AsyncGRPOTrainer
from trl.rewards import accuracy_reward
from datasets import load_dataset

dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

trainer = AsyncGRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    train_dataset=dataset,
)
trainer.train()
Update on GitHub