Cua-BenchExamples
RL Training with GRPO
Train a GUI agent using TRL's GRPOTrainer with CUA-Bench environments on Modal
Train a multimodal GUI agent using GRPO (Group Relative Policy Optimization) with TRL and CUA-Bench on Modal.
How It Works
The training follows TRL's OpenEnv pattern:
┌─────────────────────────────────────────────────────────────┐
│ GRPO Training Loop │
├─────────────────────────────────────────────────────────────┤
│ 1. rollout_func() called with prompts │
│ ├─ Reset CUA-Bench environment │
│ ├─ Generate actions via vLLM │
│ ├─ Step environment, collect rewards │
│ └─ Return {prompt_ids, completion_ids, logprobs, ...} │
│ │
│ 2. Reward function receives env_reward from rollout │
│ │
│ 3. GRPO updates policy using rewards │
└─────────────────────────────────────────────────────────────┘Step 1: Install Modal
pip install modal
modal setupOptionally, create a wandb secret for logging:
modal secret create wandb-secret WANDB_API_KEY=your_keyStep 2: Create the Training Script
Create modal_grpo_training.py:
#!/usr/bin/env python3
"""GRPO Training with CUA-Bench on Modal.
This script runs GRPO (Group Relative Policy Optimization) training using
TRL's GRPOTrainer with CUA-Bench environments on Modal's cloud infrastructure.
Usage:
# Setup Modal
pip install modal
modal setup
# Run training
modal run modal_grpo_training.py
# With custom settings
modal run modal_grpo_training.py --num-workers 4 --max-steps 1000
"""
from __future__ import annotations
from dataclasses import dataclass
import modal
# =============================================================================
# Training Configuration
# =============================================================================
@dataclass
class TrainerConfig:
"""Configuration for GRPO training hyperparameters."""
# Model
model_id: str = "Qwen/Qwen3-VL-2B-Instruct"
task_prompt: str = "Complete the computer task successfully."
# Environment
num_workers: int = 2
max_steps: int = 10
max_history: int = 2
# Generation
num_generations: int = 4
max_completion_length: int = 256
temperature: float = 0.7
# Training
learning_rate: float = 5e-6
gradient_accumulation_steps: int = 4
per_device_train_batch_size: int = 1
num_train_epochs: int = 1
warmup_steps: int = 10
# vLLM
use_vllm: bool = True
vllm_mode: str = "colocate"
vllm_gpu_memory_utilization: float = 0.4
vllm_max_model_length: int = 32768
# Checkpointing
output_dir: str = "/checkpoints/grpo-output"
save_strategy: str = "steps"
save_steps: int = 100
# Logging
logging_steps: int = 1
use_wandb: bool = True
wandb_project: str = "cua-bench-grpo"
# Dataset
dataset_size: int = 1000
# Misc
bf16: bool = True
debug: bool = False
DEFAULT_CONFIG = TrainerConfig()
# =============================================================================
# Modal App Configuration
# =============================================================================
app = modal.App("cua-bench-grpo-training")
image = (
modal.Image.debian_slim(python_version="3.11")
.apt_install("git", "chromium", "chromium-driver")
.uv_pip_install(
"trl[vllm]==0.27.1",
"datasets==3.5.1",
"verl",
"pillow",
"requests",
"fastapi",
"uvicorn",
"playwright",
"cua-bench",
)
.run_commands("playwright install chromium")
# .add_local_dir("cua_bench", remote_path="/root/cua_bench")
)
checkpoint_volume = modal.Volume.from_name("cua-bench-grpo-checkpoints", create_if_missing=True)
# Optional: wandb secret for logging
# Create with: modal secret create wandb-secret WANDB_API_KEY=your_key
# =============================================================================
# Constants
# =============================================================================
SYSTEM_PROMPT = """You are a GUI agent. Complete tasks by interacting with the screen.
Available actions:
- click(x, y): Click at coordinates (0-1000 range)
- type(text): Type text
- hotkey(key1, key2, ...): Press keyboard shortcut
- scroll(x, y, direction): Scroll up/down
- wait(): Wait 1 second
- done(): Mark task complete
Reply with thinking in <|think_start|>...<|think_end|> tags,
then action in <|action_start|>...<|action_end|> tags.
Once you complete the task, output a done() action.
Example:
<|think_start|>I see a blue button. I'll click it.<|think_end|>
<|action_start|>click(500, 300)<|action_end|>"""
CLICK_TARGET_HTML = """<!DOCTYPE html>
<html>
<head>
<style>
body {{ margin: 0; display: flex; justify-content: center; align-items: center; height: 100vh; background: #f0f0f0; }}
.target {{ position: absolute; left: {x}px; top: {y}px; padding: 12px 24px; background: #3b82f6; color: white; border: none; border-radius: 8px; cursor: pointer; }}
</style>
</head>
<body>
<h1>Click the Button</h1>
<button class="target" onclick="window.__clicked=true;">Click Me</button>
<script>window.__clicked = false;</script>
</body>
</html>"""
CLICK_TARGET_TASK_PY = '''
import cua_bench as cb
import random
HTML = """{html}"""
pid = None
@cb.tasks_config(split="train")
def get_tasks():
return [cb.Task(
description="Click the blue 'Click Me' button",
computer={{"provider": "simulated", "setup_config": {{"width": 1024, "height": 720}}}},
metadata={{"x": random.randint(100, 350), "y": random.randint(100, 250)}}
) for _ in range(50)]
@cb.setup_task(split="train")
async def setup(task, session):
global pid
html = HTML.format(x=task.metadata["x"], y=task.metadata["y"])
pid = await session.launch_window(html=html, title="Click Target", width=512, height=384)
@cb.evaluate_task(split="train")
async def evaluate(task, session):
return [1.0 if await session.execute_javascript(pid, "window.__clicked") else 0.0]
'''
# =============================================================================
# Training Function
# =============================================================================
@app.function(
image=image,
gpu="H100",
timeout=60 * 60 * 4, # 4 hours
volumes={"/checkpoints": checkpoint_volume},
secrets=[modal.Secret.from_name("wandb-secret")],
)
def train_grpo(
model_id: str = DEFAULT_CONFIG.model_id,
task_prompt: str = DEFAULT_CONFIG.task_prompt,
num_workers: int = DEFAULT_CONFIG.num_workers,
max_steps: int = DEFAULT_CONFIG.max_steps,
max_history: int = DEFAULT_CONFIG.max_history,
num_generations: int = DEFAULT_CONFIG.num_generations,
dataset_size: int = DEFAULT_CONFIG.dataset_size,
max_completion_length: int = DEFAULT_CONFIG.max_completion_length,
temperature: float = DEFAULT_CONFIG.temperature,
learning_rate: float = DEFAULT_CONFIG.learning_rate,
gradient_accumulation_steps: int = DEFAULT_CONFIG.gradient_accumulation_steps,
save_steps: int = DEFAULT_CONFIG.save_steps,
use_wandb: bool = DEFAULT_CONFIG.use_wandb,
wandb_project: str = DEFAULT_CONFIG.wandb_project,
debug: bool = DEFAULT_CONFIG.debug,
):
"""Run GRPO training with CUA-Bench environments."""
import asyncio
import base64
import io
import re
import tempfile
from pathlib import Path
from datasets import Dataset
from PIL import Image
from transformers import AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
from vllm import SamplingParams
from cua_bench.workers import CBEnvWorkerClient, cleanup_workers, create_workers
# -------------------------------------------------------------------------
# Multimodal rollout generation (based on TRL's generate_rollout_completions)
# -------------------------------------------------------------------------
def generate_rollout_completions_multimodal(
trainer,
prompts: list[str],
images: list[list[Image.Image]] | None = None,
) -> list[dict]:
"""
Generate completions for multimodal prompts using vLLM in colocate mode.
Args:
trainer: GRPOTrainer instance with vLLM configured
prompts: List of text prompts
images: Optional list of image lists, one per prompt
Returns:
List of dicts with prompt_ids, completion_ids, logprobs, and text
"""
if not prompts:
return []
if not trainer.use_vllm or trainer.vllm_mode != "colocate":
raise RuntimeError("Multimodal rollouts require vLLM in colocate mode.")
# Build sampling params
sampling_params = SamplingParams(
n=1,
temperature=trainer.temperature,
top_k=trainer.top_k,
min_p=0.0 if trainer.min_p is None else trainer.min_p,
max_tokens=trainer.max_completion_length,
logprobs=0,
)
if trainer.repetition_penalty is not None:
sampling_params.repetition_penalty = trainer.repetition_penalty
if trainer.top_p is not None:
sampling_params.top_p = trainer.top_p
# Wake up vLLM if sleep mode is enabled
if trainer.args.vllm_enable_sleep_mode:
trainer.llm.wake_up(tags=["kv_cache"])
trainer.llm.collective_rpc("reload_weights")
# Build inputs with multimodal data if images provided
if images:
inputs = []
for i, prompt in enumerate(prompts):
prompt_images = images[i] if i < len(images) else []
if prompt_images:
inputs.append({
"prompt": prompt,
"multi_modal_data": {"image": prompt_images},
})
else:
inputs.append({"prompt": prompt})
vllm_outputs = trainer.llm.generate(inputs, sampling_params=sampling_params, use_tqdm=False)
else:
vllm_outputs = trainer.llm.generate(prompts, sampling_params=sampling_params, use_tqdm=False)
# Process outputs
results = []
for request in vllm_outputs:
if not request.outputs:
results.append({
"prompt_ids": request.prompt_token_ids,
"completion_ids": [],
"logprobs": [],
"text": "",
})
continue
sequence = request.outputs[0]
logprobs = [
next(iter(token_logprob.values())).logprob
for token_logprob in sequence.logprobs
] if sequence.logprobs else []
results.append({
"prompt_ids": request.prompt_token_ids,
"completion_ids": list(sequence.token_ids),
"logprobs": logprobs,
"text": sequence.text,
})
# Sleep vLLM if sleep mode is enabled
if trainer.args.vllm_enable_sleep_mode:
trainer.llm.sleep(level=2)
return results
# -------------------------------------------------------------------------
# Helper functions
# -------------------------------------------------------------------------
def decode_image(b64_str: str) -> Image.Image:
img = Image.open(io.BytesIO(base64.b64decode(b64_str)))
return img.convert("RGB") if img.mode != "RGB" else img
def extract_images(obs: str) -> tuple[str, list[Image.Image]]:
"""Extract base64 images from obs and replace with vLLM-compatible placeholder."""
images = []
pattern = r"<\|vision_start\|>(.*?)<\|vision_end\|>"
def repl(m):
images.append(decode_image(m.group(1)))
# Use Qwen2-VL's placeholder format that vLLM expects
return "<|vision_start|><|image_pad|><|vision_end|>"
cleaned = re.sub(pattern, repl, obs, flags=re.DOTALL)
return cleaned, images
def make_prompt(
tok,
instruction: str,
obs: str,
step: int,
history: list[tuple[str, str]] | None = None,
) -> tuple[str, list[Image.Image]]:
"""Build prompt with optional history of previous observations and actions.
Args:
tok: Tokenizer
instruction: Task instruction
obs: Current observation (may contain base64 images)
step: Current step number
history: List of (prev_obs, prev_action) tuples
"""
all_images = []
# Build history section
history_parts = []
if history:
for i, (prev_obs, prev_action) in enumerate(history):
cleaned_prev_obs, prev_images = extract_images(prev_obs)
all_images.extend(prev_images)
history_parts.append(f"Step {step - len(history) + i + 1}:\n{cleaned_prev_obs}\nAction: {prev_action}")
# Current observation
cleaned_obs, curr_images = extract_images(obs)
all_images.extend(curr_images)
# Build user content
if history_parts:
history_text = "\n\n".join(history_parts)
user_content = f"Task: {instruction}\n\nHistory:\n{history_text}\n\nStep {step + 1} (current):\n{cleaned_obs}\n\nWhat action?"
else:
user_content = f"Task: {instruction}\n\nStep {step + 1}:\n{cleaned_obs}\n\nWhat action?"
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
]
prompt = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
return prompt, all_images
def parse_action(response: str) -> str:
match = re.search(r"<\|action_start\|>(.*?)<\|action_end\|>", response, re.DOTALL)
return match.group(1).strip() if match else "wait()"
# -------------------------------------------------------------------------
# Episode rollout
# -------------------------------------------------------------------------
def rollout_episode(trainer, env, tok, prompt_text: str, max_hist: int = 2) -> dict:
env_ret, meta = env.reset()
# Get instruction from env.prompt (set during reset)
instruction = env.prompt["instruction"] if env.prompt else prompt_text
all_prompt_ids, all_completion_ids, all_logprobs = [], [], []
all_images = []
step_rewards = []
history = [] # List of (screenshot, action) tuples
for step in range(max_steps):
if env_ret.get("done", False):
break
# Get current screenshot from env.prompt["steps"][-1]
# steps contains: [screenshot0, action0, screenshot1, action1, ...]
current_screenshot = env.prompt["steps"][-1] if env.prompt else ""
# Get recent history (last max_hist steps)
recent_history = history[-max_hist:] if history else None
prompt, images = make_prompt(tok, instruction, current_screenshot, step, history=recent_history)
if images:
all_images.extend(images)
rollout_output = generate_rollout_completions_multimodal(
trainer, [prompt], images=[images] if images else None
)[0]
all_prompt_ids.extend(rollout_output["prompt_ids"])
all_completion_ids.extend(rollout_output["completion_ids"])
all_logprobs.extend(rollout_output["logprobs"])
completion = rollout_output.get("text") or tok.decode(
rollout_output["completion_ids"], skip_special_tokens=True
)
# Add current screenshot and action to history for next step
action = parse_action(completion)
history.append((current_screenshot, action))
if debug:
print(f" Step {step + 1}: {action}")
env_ret, _ = env.step(completion)
step_rewards.append(float(env_ret.get("reward", 0.0)))
result = {
"prompt_ids": all_prompt_ids,
"completion_ids": all_completion_ids,
"logprobs": all_logprobs,
"env_reward": step_rewards[-1] if step_rewards else 0.0,
}
if all_images:
result["images"] = all_images
return result
# -------------------------------------------------------------------------
# Reward function
# -------------------------------------------------------------------------
def reward_evaluator_func(completions: list[str], env_reward=None, **_) -> list[float]:
if env_reward is not None:
return [float(r) for r in env_reward]
return [0.0] * len(completions)
# -------------------------------------------------------------------------
# Create task
# -------------------------------------------------------------------------
def create_task(tmp_dir: Path) -> Path:
task_dir = tmp_dir / "click-target"
task_dir.mkdir(exist_ok=True)
task_code = CLICK_TARGET_TASK_PY.format(html=CLICK_TARGET_HTML)
(task_dir / "main.py").write_text(task_code)
return task_dir
# -------------------------------------------------------------------------
# Main training logic
# -------------------------------------------------------------------------
import os
# Initialize wandb if available
wandb_enabled = False
if use_wandb and os.environ.get("WANDB_API_KEY"):
try:
import wandb
wandb.init(
project=wandb_project,
config={
"model_id": model_id,
"num_workers": num_workers,
"max_steps": max_steps,
"num_generations": num_generations,
"learning_rate": learning_rate,
"temperature": temperature,
"dataset_size": dataset_size,
},
)
wandb_enabled = True
print("Wandb logging enabled")
except ImportError:
print("Wandb not installed, skipping logging")
print("=" * 60)
print("GRPO Training with CUA-Bench on Modal")
print("=" * 60)
print(f"Model: {model_id}")
print(f"Workers: {num_workers}")
print(f"Max steps: {max_steps}")
print(f"Generations: {num_generations}")
print(f"Wandb: {wandb_enabled}")
print("=" * 60)
# Setup tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Create temp task
tmp_dir = Path(tempfile.mkdtemp(prefix="grpo-modal-"))
task_path = str(create_task(tmp_dir))
print(f"Created task: {task_path}")
# Start workers
print(f"Starting {num_workers} worker(s)...")
workers = asyncio.run(create_workers(
n_workers=num_workers,
allowed_ips=["127.0.0.1"],
startup_timeout=120.0,
))
worker_urls = [w.api_url for w in workers]
print(f"Workers ready: {worker_urls}")
# Create environment clients
task_configs = [{"env_path": task_path, "task_index": i % 10, "split": "train"} for i in range(num_workers)]
envs = [
CBEnvWorkerClient({
"server_url": url,
"task_configs": task_configs,
"max_step": max_steps,
"max_hist": 2,
"timeout": 300,
})
for url in worker_urls
]
# Dataset
dataset = Dataset.from_dict({"prompt": [task_prompt] * dataset_size})
# GRPO config
config = GRPOConfig(
output_dir=DEFAULT_CONFIG.output_dir,
use_vllm=DEFAULT_CONFIG.use_vllm,
vllm_mode=DEFAULT_CONFIG.vllm_mode,
vllm_gpu_memory_utilization=DEFAULT_CONFIG.vllm_gpu_memory_utilization,
vllm_max_model_length=DEFAULT_CONFIG.vllm_max_model_length,
num_train_epochs=DEFAULT_CONFIG.num_train_epochs,
per_device_train_batch_size=DEFAULT_CONFIG.per_device_train_batch_size,
warmup_steps=DEFAULT_CONFIG.warmup_steps,
logging_steps=DEFAULT_CONFIG.logging_steps,
save_strategy=DEFAULT_CONFIG.save_strategy,
bf16=DEFAULT_CONFIG.bf16,
# From function args (overridable)
learning_rate=learning_rate,
gradient_accumulation_steps=gradient_accumulation_steps,
num_generations=num_generations,
max_completion_length=max_completion_length,
temperature=temperature,
save_steps=save_steps,
report_to="wandb" if wandb_enabled else "none",
)
# Rollout function
env_idx = [0]
def rollout_func(prompts: list[str], trainer: GRPOTrainer) -> dict[str, list]:
all_prompt_ids, all_completion_ids, all_logprobs = [], [], []
all_rewards, all_images = [], []
if debug:
print(f"\n[Rollout] Processing {len(prompts)} prompts")
for i, prompt in enumerate(prompts):
env = envs[env_idx[0] % len(envs)]
env_idx[0] += 1
if debug:
print(f"[Rollout] Episode {i + 1}/{len(prompts)}")
episode = rollout_episode(trainer, env, tokenizer, prompt, max_hist=max_history)
all_prompt_ids.append(episode["prompt_ids"])
all_completion_ids.append(episode["completion_ids"])
all_logprobs.append(episode["logprobs"])
all_rewards.append(episode["env_reward"])
if "images" in episode:
all_images.append(episode["images"])
result = {
"prompt_ids": all_prompt_ids,
"completion_ids": all_completion_ids,
"logprobs": all_logprobs,
"env_reward": all_rewards,
}
if all_images:
result["images"] = all_images
return result
# Create trainer
trainer = GRPOTrainer(
model=model_id,
processing_class=tokenizer,
reward_funcs=[reward_evaluator_func],
train_dataset=dataset,
args=config,
rollout_func=rollout_func,
)
try:
print("\nStarting training...")
trainer.train()
# Save final checkpoint
trainer.save_model("/checkpoints/final")
checkpoint_volume.commit()
print("\nTraining complete!")
return {"status": "complete", "checkpoints": "/checkpoints/final"}
finally:
print("Cleaning up workers...")
asyncio.run(cleanup_workers(workers))
import shutil
shutil.rmtree(tmp_dir, ignore_errors=True)
# =============================================================================
# Entrypoint
# =============================================================================
@app.local_entrypoint()
def main(
model_id: str | None = None,
num_workers: int | None = None,
max_steps: int | None = None,
num_generations: int | None = None,
dataset_size: int | None = None,
learning_rate: float | None = None,
save_steps: int | None = None,
debug: bool = False,
):
"""Run GRPO training on Modal."""
kwargs = {
k: v for k, v in {
"model_id": model_id,
"num_workers": num_workers,
"max_steps": max_steps,
"num_generations": num_generations,
"dataset_size": dataset_size,
"learning_rate": learning_rate,
"save_steps": save_steps,
"debug": debug,
}.items() if v is not None
}
result = train_grpo.remote(**kwargs)
print(f"Training result: {result}")Step 3: Run Training
modal run modal_grpo_training.pyWith custom settings:
modal run modal_grpo_training.py \
--model-id Qwen/Qwen3-VL-2B-Instruct \
--num-workers 4 \
--max-steps 10 \
--num-generations 4 \
--learning-rate 5e-6 \
--debugCLI Options
| Option | Default | Description |
|---|---|---|
--model-id | Qwen/Qwen3-VL-2B-Instruct | Model to train |
--num-workers | 2 | Number of parallel environment workers |
--max-steps | 10 | Max steps per episode |
--num-generations | 4 | Rollouts per prompt |
--dataset-size | 1000 | Number of training samples |
--learning-rate | 5e-6 | Learning rate |
--save-steps | 100 | Checkpoint save interval |
--debug | False | Enable verbose output |
Creating Custom Tasks
To train on your own task, create a task directory with main.py:
import cua_bench as cb
@cb.tasks_config(split="train")
def get_tasks():
return [cb.Task(
description="Your task description",
computer={"provider": "simulated", "setup_config": {"width": 1024, "height": 720}},
)]
@cb.setup_task(split="train")
async def setup(task, session):
await session.launch_window(url="https://example.com")
@cb.evaluate_task(split="train")
async def evaluate(task, session):
success = await session.execute_javascript(None, "checkSuccess()")
return [1.0 if success else 0.0]Then modify the script to use your task path instead of the built-in click target task.
Reference
Was this page helpful?