Cua-BenchGuideAdvanced

RL Dataloader

API reference for ReplayBuffer and MultiTurnDataloader

RL Dataloader

The MultiTurnDataloader handles parallel environment rollout, tokenization, and batching for RL training.

ReplayBuffer

Stores experience tuples with episode-aware reward discounting.

from cua_bench.workers import ReplayBuffer

buffer = ReplayBuffer(
    capacity=10000,           # Max experiences to store
    gamma=0.9,                # Discount factor
    only_keep_outcome=False,  # Keep all steps or just final
    balance_thres=0.5,        # Threshold for balance stats
)

# Add experiences (as tuple)
buffer.add((
    0,  # worker_id
    {"obs": "...", "reward": 0.0, "done": False},  # env_ret
    {"uid": "episode-123"}  # meta_info
))

# Sample for training
samples = buffer.sample(batch_size=32)

# Get balance statistics
below, above = buffer.get_balance_stats()

Reward Discounting

When an episode completes, rewards propagate backwards:

Episode: [step0, step1, step2 (done, reward=1.0)]
With gamma=0.9:
  step2.reward = 1.0
  step1.reward = 0.9
  step0.reward = 0.81

MultiTurnDataloader

Manages parallel environments and provides tokenized batches.

from cua_bench.workers import (
    MultiTurnDataloader,
    CBEnvWorkerClient,
    create_workers,
)

workers = await create_workers(n_workers=4, allowed_ips=["127.0.0.1"])
task_configs = [{"env_path": "./task", "task_index": 0, "split": "train"}]
env_configs = [{
    "server_url": w.api_url,
    "task_configs": task_configs,
    "max_step": 50,
    "max_hist": 10,
    "timeout": 300,
} for w in workers]

dataloader = MultiTurnDataloader(
    env_class=CBEnvWorkerClient,
    env_configs=env_configs,
    tokenizer=tokenizer,
    batch_size=4,
    replay_capacity=10000,
    replay_reward_discount=0.9,
    max_prompt_length=1024,
    max_response_length=256,
)

for batch in dataloader:
    # batch: input_ids, attention_mask, position_ids, worker_id, meta_info
    responses = model.generate(batch['input_ids'])

    dataloader.async_step({
        'prompts': batch['input_ids'],
        'responses': responses,
        'attention_mask': combined_mask,
        'worker_id': batch['worker_id'],
        'meta_info': batch['meta_info'],
    })

Constructor Parameters

ParameterTypeDefaultDescription
env_classTyperequiredEnvironment client class
env_configsList[Dict]requiredWorker configs
task_configsList[Dict]requiredTask configurations
tokenizerAnyrequiredHuggingFace tokenizer
processorAnyNoneHuggingFace processor for multimodal
is_multi_modalboolFalseEnable image processing
batch_sizeint8Must be less than or equal to num_envs
replay_capacityint10000Replay buffer size
replay_reward_discountfloat0.9Gamma for discounting
max_prompt_lengthint1024Max prompt tokens
max_response_lengthint1024Max response tokens
only_keep_outcome_in_replayboolFalseOnly keep final steps

Batch Format

From next(dataloader):

{
    'input_ids': torch.Tensor,       # (batch, seq_len)
    'attention_mask': torch.Tensor,  # (batch, seq_len)
    'position_ids': torch.Tensor,    # (batch, seq_len)
    'worker_id': np.ndarray,
    'meta_info': np.ndarray,
}

batch_return Format

For async_step():

{
    'prompts': torch.Tensor,         # (batch, prompt_len)
    'responses': torch.Tensor,       # (batch, response_len)
    'attention_mask': torch.Tensor,  # (batch, total_len)
    'worker_id': np.ndarray,
    'meta_info': np.ndarray,
}

Action Parsing

The dataloader parses action strings from responses:

<|action_start|>click(0.5, 0.5)<|action_end|>
<|action_start|>type("hello")<|action_end|>
<|action_start|>done()<|action_end|>

Methods

# Get running reward (EMA)
reward = dataloader.running_outcome_reward()

# Get balance stats
below, above = dataloader.get_balance_stats()

# Sample from replay buffer
batch = dataloader.sample_from_buffer(batch_size=32)

# Print stats
dataloader.print_stats_in_replay_buffer()

# Clear replay buffer
dataloader.clear_replay_buffer()

# Cleanup workers
dataloader.close()

Example

See the Train an Agent with GRPO tutorial for a complete working example.

Was this page helpful?