diff --git a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py index 6b1459647c..5adbe27988 100644 --- a/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py @@ -45,8 +45,12 @@ RefWorkerBase, ) from skyrl.backends.skyrl_train.workers.worker_utils import ( + BaseBatchIterator, BatchIterator, + SampleBasedBatchIterator, + TokenBasedBatchIterator, all_reduce_metrics, + get_microbatch_iterator, reduce_metrics, ) from skyrl.env_vars import SKYRL_WORKER_NCCL_TIMEOUT_IN_S @@ -449,39 +453,80 @@ def lora_pre_wrap_hook(model): def forward(self, data: TrainingInputBatch): """ Override `Worker.forward` to support passing the full mini batch to the MegatronModelWrapper.forward method. + + Supports token-based micro-batching via max_tokens_per_microbatch config. """ from skyrl.backends.skyrl_train.utils.replay_utils import clear_router_replay - # Run in micro batches grouped into a single mini-batch - micro_bsz = self.cfg.micro_forward_batch_size_per_gpu - micro_batches = data.chunk(micro_bsz) + use_token_batching = self.cfg.max_tokens_per_microbatch > 0 + + if use_token_batching: + microbatch_iterator = get_microbatch_iterator( + data, + micro_batch_size=self.cfg.micro_forward_batch_size_per_gpu, + max_tokens_per_microbatch=self.cfg.max_tokens_per_microbatch, + ) + else: + microbatch_iterator = None # Build micro-batch dicts expected by policy.forward_mini_batch micro_dicts = [] device = torch.cuda.current_device() - for micro in micro_batches: - micro.to(device) - sequences = micro["sequences"] - attention_mask = micro["attention_mask"] - num_actions = micro.metadata["response_length"] - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 0) - rollout_expert_indices = micro.get("rollout_expert_indices") - if rollout_expert_indices is not None: - rollout_expert_indices = rollout_expert_indices.to(torch.int32) - micro_dicts.append( - { - "sequences": sequences, - "attention_mask": attention_mask, - "position_ids": position_ids, - "num_actions": num_actions, - "rollout_expert_indices": (rollout_expert_indices if self.enable_router_replay else None), - } - ) + + if microbatch_iterator is not None: + for microbatch in microbatch_iterator: + microbatch.to(device) + sequences = microbatch["sequences"] + attention_mask = microbatch["attention_mask"] + num_actions = microbatch.metadata["response_length"] + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + rollout_expert_indices = microbatch.get("rollout_expert_indices") + if rollout_expert_indices is not None: + rollout_expert_indices = rollout_expert_indices.to(torch.int32) + micro_dicts.append( + { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": num_actions, + "rollout_expert_indices": (rollout_expert_indices if self.enable_router_replay else None), + } + ) + else: + micro_bsz = self.cfg.micro_forward_batch_size_per_gpu + micro_batches = data.chunk(micro_bsz) + for micro in micro_batches: + micro.to(device) + sequences = micro["sequences"] + attention_mask = micro["attention_mask"] + num_actions = micro.metadata["response_length"] + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + rollout_expert_indices = micro.get("rollout_expert_indices") + if rollout_expert_indices is not None: + rollout_expert_indices = rollout_expert_indices.to(torch.int32) + micro_dicts.append( + { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": num_actions, + "rollout_expert_indices": (rollout_expert_indices if self.enable_router_replay else None), + } + ) + + if use_token_batching: + # Pad microbatches to uniform batch size for Megatron compatibility + max_micro_bsz = max(m["sequences"].shape[0] for m in micro_dicts) if micro_dicts else 1 + for i, m in enumerate(micro_dicts): + micro_dicts[i] = self._pad_forward_microbatch_to_size(m, max_micro_bsz) + mbs = max_micro_bsz + else: + mbs = micro_dicts[0]["sequences"].shape[0] if micro_dicts else 1 self.model.eval() seq_len = micro_dicts[0]["sequences"].shape[1] - mbs = micro_dicts[0]["sequences"].shape[0] with torch.no_grad(): log_probs = self.model.forward( micro_batches=micro_dicts, @@ -491,11 +536,81 @@ def forward(self, data: TrainingInputBatch): ) log_probs = log_probs.to("cpu") - output = TrainingOutputBatch({"output": log_probs}) - output.metadata = data.metadata + + if use_token_batching and microbatch_iterator is not None: + # Need to strip padded samples and reorder back to original order + output = TrainingOutputBatch({"output": log_probs}) + output.metadata = data.metadata + # The output from Megatron is concatenated across microbatches. + # We need to extract only the real (non-padded) samples and reorder. + output = self._reorder_megatron_forward_output( + output, microbatch_iterator, micro_dicts, mbs + ) + else: + output = TrainingOutputBatch({"output": log_probs}) + output.metadata = data.metadata + clear_router_replay() return output + def _pad_forward_microbatch_to_size(self, micro_dict: dict, target_batch_size: int) -> dict: + """Pad a forward micro-batch dict to target_batch_size.""" + current_bsz = micro_dict["sequences"].shape[0] + if current_bsz >= target_batch_size: + return micro_dict + + pad_count = target_batch_size - current_bsz + device = micro_dict["sequences"].device + + padded = {} + for key, value in micro_dict.items(): + if key in ("num_actions",): + padded[key] = value + continue + if value is None: + padded[key] = None + continue + if isinstance(value, torch.Tensor): + if key == "attention_mask": + pad_tensor = torch.ones((pad_count, *value.shape[1:]), dtype=value.dtype, device=device) + elif key == "position_ids": + seq_len = value.shape[1] + pad_tensor = torch.arange(seq_len, device=device).unsqueeze(0).expand(pad_count, -1) + else: + pad_tensor = torch.zeros((pad_count, *value.shape[1:]), dtype=value.dtype, device=device) + padded[key] = torch.cat([value, pad_tensor], dim=0) + else: + padded[key] = value + + return padded + + def _reorder_megatron_forward_output( + self, output: TrainingOutputBatch, microbatch_iterator, micro_dicts, padded_mbs + ) -> TrainingOutputBatch: + """Reorder forward output from token-based microbatching back to original sample order.""" + if not isinstance(microbatch_iterator, TokenBasedBatchIterator): + return output + + log_probs = output["output"] # shape: [total_padded_samples, num_actions] + num_microbatches = len(microbatch_iterator._microbatches) + microbatch_iterator._num_padding_microbatches + + # Split by padded_mbs, take only real samples, reorder + all_log_probs = log_probs.split(padded_mbs, dim=0) + + # Build original-order tensor + batch_size = microbatch_iterator.data.batch_size + num_actions = log_probs.shape[1] + reordered = torch.zeros((batch_size, num_actions), dtype=log_probs.dtype, device=log_probs.device) + + for mb_idx, original_indices in enumerate(microbatch_iterator._microbatches): + mb_log_probs = all_log_probs[mb_idx] + for sample_idx, original_idx in enumerate(original_indices): + reordered[original_idx] = mb_log_probs[sample_idx] + + result = TrainingOutputBatch({"output": reordered}) + result.metadata = output.metadata + return result + def save_hf_model(self, export_dir: str, tokenizer): # Save model in HuggingFace safetensors format self.strategy.save_hf_model( @@ -641,6 +756,50 @@ def init_model(self, model_path, num_training_steps: int = 1e9): self.empty_cuda_cache = self.cfg.policy.megatron_config.empty_cuda_cache + def _pad_microbatch_to_size(self, micro_dict: dict, target_batch_size: int) -> dict: + """Pad a micro-batch dict to target_batch_size with dummy samples. + + Padded samples have loss_mask=0 so they don't contribute to the loss. + This is needed because Megatron's forward_backward_func requires uniform + micro_batch_size across all microbatches (especially with PP > 1). + """ + current_bsz = micro_dict["sequences"].shape[0] + if current_bsz >= target_batch_size: + return micro_dict + + pad_count = target_batch_size - current_bsz + device = micro_dict["sequences"].device + + padded = {} + for key, value in micro_dict.items(): + if key in ("num_actions", "num_microbatches"): + padded[key] = value + continue + if value is None: + padded[key] = None + continue + if isinstance(value, torch.Tensor): + if key == "loss_mask": + # Pad with zeros so padded samples don't contribute to loss + pad_tensor = torch.zeros((pad_count, *value.shape[1:]), dtype=value.dtype, device=device) + elif key == "attention_mask": + # Pad attention with ones (minimal seq) to avoid NaN in position_ids + pad_tensor = torch.ones((pad_count, *value.shape[1:]), dtype=value.dtype, device=device) + elif key == "position_ids": + # position_ids for padded samples + seq_len = value.shape[1] + pad_tensor = torch.arange(seq_len, device=device).unsqueeze(0).expand(pad_count, -1) + elif key == "action_mask": + # action_mask should be zeros for padded samples + pad_tensor = torch.zeros((pad_count, *value.shape[1:]), dtype=value.dtype, device=device) + else: + pad_tensor = torch.zeros((pad_count, *value.shape[1:]), dtype=value.dtype, device=device) + padded[key] = torch.cat([value, pad_tensor], dim=0) + else: + padded[key] = value + + return padded + def forward_backward( self, data: TrainingInputBatch, @@ -650,7 +809,8 @@ def forward_backward( """ Perform forward and backward passes for a batch, handling micro-batching internally. - The batch is split into micro batches based on micro_train_batch_size_per_gpu. + The batch is split into micro batches based on micro_train_batch_size_per_gpu, + or by token count if max_tokens_per_microbatch is configured. Megatron Core's forward_backward_func handles gradient accumulation internally. Args: @@ -669,38 +829,77 @@ def forward_backward( # if use distributed optimizer, zero grad buffer will be handled by optimizer chunk.zero_grad_buffer() - micro_batch_size = self.cfg.micro_train_batch_size_per_gpu all_metrics = defaultdict(list) # Move data to GPU data.to(torch.cuda.current_device()) + use_token_batching = self.cfg.max_tokens_per_microbatch > 0 + + if use_token_batching: + microbatch_iterator = get_microbatch_iterator( + data, + micro_batch_size=self.cfg.micro_train_batch_size_per_gpu, + max_tokens_per_microbatch=self.cfg.max_tokens_per_microbatch, + ) + else: + microbatch_iterator = None + # Build micro-batch dicts expected by forward_backward_mini_batch micro_buffer = [] - for experience in BatchIterator(data, micro_batch_size, drop_last=False): - sequences = experience.sequences - attention_mask = experience.attention_mask - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 0) - rollout_expert_indices = experience.rollout_expert_indices - if rollout_expert_indices is not None: - rollout_expert_indices = rollout_expert_indices.to(torch.int32) - - micro_buffer.append( - { - "sequences": sequences, - "attention_mask": attention_mask, - "position_ids": position_ids, - "num_actions": experience.num_actions, - "old_action_log_probs": experience.action_log_probs, - "base_action_log_probs": experience.base_action_log_probs, - "advantages": experience.advantages, - "loss_mask": experience.loss_mask, - "rollout_action_logprobs": experience.rollout_logprobs, - "action_mask": experience.action_mask, - "rollout_expert_indices": rollout_expert_indices if self.enable_router_replay else None, - } - ) + + if microbatch_iterator is not None: + for microbatch in microbatch_iterator: + experience = BaseBatchIterator.batch_to_experience(microbatch) + sequences = experience.sequences + attention_mask = experience.attention_mask + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + rollout_expert_indices = experience.rollout_expert_indices + if rollout_expert_indices is not None: + rollout_expert_indices = rollout_expert_indices.to(torch.int32) + + micro_buffer.append( + { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": experience.num_actions, + "old_action_log_probs": experience.action_log_probs, + "base_action_log_probs": experience.base_action_log_probs, + "advantages": experience.advantages, + "loss_mask": experience.loss_mask, + "rollout_action_logprobs": experience.rollout_logprobs, + "action_mask": experience.action_mask, + "rollout_expert_indices": rollout_expert_indices if self.enable_router_replay else None, + } + ) + else: + micro_batch_size = self.cfg.micro_train_batch_size_per_gpu + for experience in BatchIterator(data, micro_batch_size, drop_last=False): + sequences = experience.sequences + attention_mask = experience.attention_mask + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 0) + rollout_expert_indices = experience.rollout_expert_indices + if rollout_expert_indices is not None: + rollout_expert_indices = rollout_expert_indices.to(torch.int32) + + micro_buffer.append( + { + "sequences": sequences, + "attention_mask": attention_mask, + "position_ids": position_ids, + "num_actions": experience.num_actions, + "old_action_log_probs": experience.action_log_probs, + "base_action_log_probs": experience.base_action_log_probs, + "advantages": experience.advantages, + "loss_mask": experience.loss_mask, + "rollout_action_logprobs": experience.rollout_logprobs, + "action_mask": experience.action_mask, + "rollout_expert_indices": rollout_expert_indices if self.enable_router_replay else None, + } + ) for m_batch in micro_buffer: m_batch["num_microbatches"] = len(micro_buffer) @@ -709,7 +908,16 @@ def forward_backward( return {} seq_len = micro_buffer[0]["sequences"].shape[1] - micro_bsz = micro_buffer[0]["sequences"].shape[0] + + if use_token_batching: + # With token-based batching, microbatches may have different batch sizes. + # Megatron's forward_backward_func requires uniform micro_batch_size, + # so pad all microbatches to the max batch size across microbatches. + max_micro_bsz = max(m["sequences"].shape[0] for m in micro_buffer) + micro_buffer = [self._pad_microbatch_to_size(m, max_micro_bsz) for m in micro_buffer] + micro_bsz = max_micro_bsz + else: + micro_bsz = micro_buffer[0]["sequences"].shape[0] metrics_list = self.model.forward_backward_mini_batch( micro_batches=micro_buffer, diff --git a/skyrl/backends/skyrl_train/workers/worker.py b/skyrl/backends/skyrl_train/workers/worker.py index 181e3f428a..2221c99f57 100644 --- a/skyrl/backends/skyrl_train/workers/worker.py +++ b/skyrl/backends/skyrl_train/workers/worker.py @@ -50,8 +50,12 @@ ) from skyrl.backends.skyrl_train.utils.torch_utils import masked_mean from skyrl.backends.skyrl_train.workers.worker_utils import ( + BaseBatchIterator, BatchIterator, + SampleBasedBatchIterator, + TokenBasedBatchIterator, all_reduce_metrics, + get_microbatch_iterator, reduce_metrics, ) from skyrl.env_vars import ( @@ -376,16 +380,22 @@ def forward( ) -> TrainingOutputBatch: """Run forward pass on the input batch in inference mode. - This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.micro_forward_batch_size_per_gpu`. + This is a wrapper around `_forward_micro_batch` that runs in micro batches. + Uses token-based chunking if `max_tokens_per_microbatch` is configured, otherwise + falls back to sample-based chunking with `micro_forward_batch_size_per_gpu`. """ - # run in micro batches of cfg.micro_forward_batch_size_per_gpu # TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc. - micro_batches = data.chunk(self.cfg.micro_forward_batch_size_per_gpu) + microbatch_iterator = get_microbatch_iterator( + data, + micro_batch_size=self.cfg.micro_forward_batch_size_per_gpu, + max_tokens_per_microbatch=self.cfg.max_tokens_per_microbatch, + ) outputs = [] - for micro_batch in micro_batches: - outputs.append(self._forward_micro_batch(micro_batch)) - output = TrainingOutputBatch.cat(outputs) + for microbatch in microbatch_iterator: + outputs.append(self._forward_micro_batch(microbatch)) + output = microbatch_iterator.reorder_and_combine_batches(outputs) + if output.device is not None and output.device != torch.device("cpu"): output = output.to("cpu") return output @@ -698,14 +708,19 @@ def forward_backward( Returns: Aggregated metrics dict across all micro batches """ - micro_batch_size = self.cfg.micro_train_batch_size_per_gpu + microbatch_iterator = get_microbatch_iterator( + data, + micro_batch_size=self.cfg.micro_train_batch_size_per_gpu, + max_tokens_per_microbatch=self.cfg.max_tokens_per_microbatch, + ) all_metrics = defaultdict(list) all_loss_fn_outputs = [] # Handle separately from scalar metrics - for micro_batch in BatchIterator(data, micro_batch_size, drop_last=False): - microbatch_weight = micro_batch_size / len(data) + for microbatch in microbatch_iterator: + experience = BaseBatchIterator.batch_to_experience(microbatch) + microbatch_weight = len(microbatch) / len(data) metrics = self._forward_backward_micro( - micro_batch, microbatch_weight, loss_fn=loss_fn, loss_fn_config=loss_fn_config + experience, microbatch_weight, loss_fn=loss_fn, loss_fn_config=loss_fn_config ) # Extract loss_fn_outputs before reduce_metrics (it's not a scalar metric) @@ -1040,7 +1055,8 @@ def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: """ Perform forward and backward passes for a batch, handling micro-batching internally. - The batch is split into micro batches based on micro_train_batch_size_per_gpu. + The batch is split into micro batches based on micro_train_batch_size_per_gpu, + or by token count if max_tokens_per_microbatch is configured. Gradients accumulate across micro batches. Gradient scaling happens at optim_step. Args: @@ -1049,12 +1065,26 @@ def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: Returns: Aggregated metrics dict across all micro batches """ - micro_batch_size = self.cfg.micro_train_batch_size_per_gpu + use_token_batching = self.cfg.max_tokens_per_microbatch > 0 + microbatch_iterator = get_microbatch_iterator( + data, + micro_batch_size=self.cfg.micro_train_batch_size_per_gpu, + max_tokens_per_microbatch=self.cfg.max_tokens_per_microbatch, + ) all_metrics = defaultdict(list) - for micro_batch in BatchIterator(data, micro_batch_size, drop_last=False): - metrics = self._forward_backward_micro(micro_batch) - self._micro_batches_accumulated += 1 + for microbatch in microbatch_iterator: + experience = BaseBatchIterator.batch_to_experience(microbatch) + + if use_token_batching: + # With token-based batching, microbatches may have different sizes. + # Scale loss by microbatch_weight so gradients are correctly weighted. + microbatch_weight = len(microbatch) / len(data) + metrics = self._forward_backward_micro(experience, microbatch_weight=microbatch_weight) + else: + metrics = self._forward_backward_micro(experience) + self._micro_batches_accumulated += 1 + for k, v in metrics.items(): all_metrics[k].append(v) @@ -1066,14 +1096,17 @@ def forward_backward(self, data: TrainingInputBatch) -> Dict[str, float]: return result - def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: + def _forward_backward_micro( + self, experience: Experience, microbatch_weight: Optional[float] = None + ) -> Dict[str, float]: """ Perform forward and backward pass for one micro batch. - Loss is NOT scaled here - gradient scaling happens at optim_step time. - Args: experience: Experience object for one micro batch + microbatch_weight: If provided, scale loss by this weight before backward. + Used with token-based batching where microbatches have variable sizes. + If None, loss is unscaled (gradient scaling happens at optim_step time). Returns: All-reduced metrics dict for this micro batch @@ -1105,7 +1138,11 @@ def _forward_backward_micro(self, experience: Experience) -> Dict[str, float]: config=self.cfg.algorithm, loss_mask=loss_mask, ) - # NO loss scaling here - gradient scaling happens at optim_step + + if microbatch_weight is not None: + # Token-based batching: scale loss by weight so gradients are properly weighted + loss = loss * microbatch_weight + # else: NO loss scaling here - gradient scaling happens at optim_step self.strategy.backward(loss, self.model, self.optimizer) status = { @@ -1125,6 +1162,8 @@ def optim_step(self) -> float: The gradient norm (before scaling, after clipping) """ # Scale accumulated gradients by 1/N to get correct average + # NOTE: When using token-based batching, loss is pre-scaled by microbatch_weight + # in forward_backward, so _micro_batches_accumulated stays 0 and no scaling needed. if self._micro_batches_accumulated > 0: scale = 1.0 / self._micro_batches_accumulated for param in self.model.parameters(): diff --git a/skyrl/backends/skyrl_train/workers/worker_utils.py b/skyrl/backends/skyrl_train/workers/worker_utils.py index cecaccc17e..60de5647db 100644 --- a/skyrl/backends/skyrl_train/workers/worker_utils.py +++ b/skyrl/backends/skyrl_train/workers/worker_utils.py @@ -1,8 +1,12 @@ +import heapq import math -from typing import Dict, List +from typing import Dict, Iterator, List + +import torch +import torch.distributed as dist from skyrl.backends.skyrl_train.distributed.strategy import DistributedStrategy -from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +from skyrl.backends.skyrl_train.training_batch import TensorBatch, TrainingInputBatch from skyrl.train.dataset.replay_buffer import Experience @@ -70,35 +74,21 @@ def all_reduce_metrics( return status_mean -class BatchIterator: - """A simple iterator to yield micro batches of data from the training batch.""" +class BaseBatchIterator: + """Base class for batch iterators that chunk a TrainingInputBatch into microbatches.""" - def __init__(self, data: TrainingInputBatch, sample_batch_size: int, drop_last: bool = False): + def __init__(self, data: TrainingInputBatch): self.data = data - self.sample_batch_size = sample_batch_size - self.total_batch_size = data.batch_size - self.drop_last = drop_last - assert not drop_last, "drop_last is not supported yet" - num_micro_batches = self.total_batch_size / self.sample_batch_size - self.num_micro_batches = int(num_micro_batches) if drop_last else math.ceil(num_micro_batches) - # TODO: switch to tensordict.map_iter if possible - self._chunks = self.data.chunk(self.sample_batch_size) - self._iter = iter(self._chunks) def __len__(self): - return self.num_micro_batches + raise NotImplementedError - def __iter__(self): - return self + def __iter__(self) -> Iterator[TrainingInputBatch]: + raise NotImplementedError - def __next__(self) -> Experience: - try: - batch = next(self._iter) - exp = self.batch_to_experience(batch) - return exp - except StopIteration: - self._iter = iter(self._chunks) - raise StopIteration + def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch: + """Reorder and combine output batches to form a single output.""" + raise NotImplementedError @staticmethod def batch_to_experience(batch: TrainingInputBatch): @@ -127,3 +117,289 @@ def batch_to_experience(batch: TrainingInputBatch): image_grid_thw=batch.get("image_grid_thw"), ) return exp + + +# Keep BatchIterator as an alias for backward compatibility +class BatchIterator(BaseBatchIterator): + """A simple iterator to yield micro batches of data from the training batch. + + This is the original sample-based iterator. Kept as an alias for SampleBasedBatchIterator. + """ + + def __init__(self, data: TrainingInputBatch, sample_batch_size: int, drop_last: bool = False): + super().__init__(data) + self.sample_batch_size = sample_batch_size + self.total_batch_size = data.batch_size + self.drop_last = drop_last + assert not drop_last, "drop_last is not supported yet" + num_micro_batches = self.total_batch_size / self.sample_batch_size + self.num_micro_batches = int(num_micro_batches) if drop_last else math.ceil(num_micro_batches) + # TODO: switch to tensordict.map_iter if possible + self._chunks = self.data.chunk(self.sample_batch_size) + self._iter = iter(self._chunks) + + def __len__(self): + return self.num_micro_batches + + def __iter__(self): + return self + + def __next__(self) -> Experience: + try: + batch = next(self._iter) + exp = self.batch_to_experience(batch) + return exp + except StopIteration: + self._iter = iter(self._chunks) + raise StopIteration + + def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch: + """Concatenate output batches. No reordering needed for sample-based splitting.""" + return TensorBatch.cat(batches) + + +class SampleBasedBatchIterator(BaseBatchIterator): + """Iterator that yields fixed-size sample-based microbatches from the training input. + + Yields TrainingInputBatch objects (not Experience), unlike the legacy BatchIterator. + """ + + def __init__(self, data: TrainingInputBatch, sample_batch_size: int, drop_last: bool = False): + super().__init__(data) + self.sample_batch_size = sample_batch_size + self.total_batch_size = data.batch_size + self.drop_last = drop_last + assert not drop_last, "drop_last is not supported yet" + num_micro_batches = self.total_batch_size / self.sample_batch_size + self.num_micro_batches = int(num_micro_batches) if drop_last else math.ceil(num_micro_batches) + self._chunks = self.data.chunk(self.sample_batch_size) + + def __len__(self): + return self.num_micro_batches + + def __iter__(self) -> Iterator[TrainingInputBatch]: + return iter(self._chunks) + + def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch: + """Concatenate output batches. No reordering needed for sample-based splitting.""" + return TensorBatch.cat(batches) + + +def balanced_binpacking(token_counts: List[int], max_tokens_per_microbatch: int) -> List[List[int]]: + """Chunk a list of token counts into microbatches so that each + microbatch's total token count does not exceed `max_tokens_per_microbatch`, + and the microbatches are roughly balanced. + + Roughly balance by assigning sequences to the microbatch with + the least number of tokens so far. + + Args: + token_counts: List of token counts for each sample. + max_tokens_per_microbatch: Maximum total tokens allowed per microbatch. + + Returns: + A list of microbatches, where each microbatch is a list of indices (ints) + referring to entries in `token_counts`. + + >>> balanced_binpacking([10, 10, 5, 5], 15) + [[0, 2], [1, 3]] + >>> balanced_binpacking([10, 1, 1, 1, 1, 1], 10) + [[0], [1, 2, 3, 4, 5]] + >>> balanced_binpacking([8, 3, 5, 6, 2, 7], 11) + [[0, 4], [5, 1], [3, 2]] + """ + # Create list of (index, token_count) pairs and sort by token count descending + seq_lens = [(i, seq_len) for i, seq_len in enumerate(token_counts)] + seq_lens.sort(key=lambda x: x[1], reverse=True) + + # Track microbatch indices and their current token counts + microbatch_indices: List[List[int]] = [] + + # Heap to track the total number of tokens in each microbatch + microbatch_tokens_heap = [] # (current_total, bin_idx) + + for idx, seq_len in seq_lens: + placed = False + + # Look for an existing microbatch with the least number of tokens + # that can fit the sequence without exceeding the token limit. + if microbatch_tokens_heap: + microbatch_len, i = microbatch_tokens_heap[0] + new_microbatch_len = microbatch_len + seq_len + if new_microbatch_len <= max_tokens_per_microbatch: + microbatch_indices[i].append(idx) + heapq.heapreplace(microbatch_tokens_heap, (new_microbatch_len, i)) + placed = True + + # If no microbatch can fit the sequence, create a new microbatch. + if not placed: + microbatch_indices.append([idx]) + heapq.heappush(microbatch_tokens_heap, (seq_len, len(microbatch_indices) - 1)) + + return microbatch_indices + + +class TokenBasedBatchIterator(BaseBatchIterator): + """An iterator that chunks microbatches based on real token count. + + Packs samples into microbatches using bin-packing, ensuring each microbatch + doesn't exceed max_tokens_per_microbatch. All data parallel workers will have + the same number of microbatches (padding microbatches are added if needed). + """ + + def __init__( + self, + data: TrainingInputBatch, + max_tokens_per_microbatch: int, + ): + """ + Args: + data: The training input batch to chunk. + max_tokens_per_microbatch: Maximum number of tokens per microbatch. + """ + super().__init__(data) + self._max_tokens_per_microbatch = max_tokens_per_microbatch + + # Compute token counts per sample using attention_mask + attention_mask = data["attention_mask"] + self._token_counts = attention_mask.sum(dim=1).cpu().tolist() # [batch_size] + + # Create microbatches based on token count + self._microbatches = balanced_binpacking(self._token_counts, self._max_tokens_per_microbatch) + + # Synchronize the number of microbatches across all DP workers + max_num_microbatches = self._sync_num_microbatches() + self._num_padding_microbatches = max_num_microbatches - len(self._microbatches) + + def _create_microbatch_from_indices(self, indices: List[int]) -> TrainingInputBatch: + """Create a TrainingInputBatch from a list of sample indices.""" + indices_tensor = torch.tensor(indices, dtype=torch.long, device="cpu") + selected_data = {} + for key, value in self.data.items(): + if value is None: + selected_data[key] = None + else: + selected_data[key] = value[indices_tensor] + microbatch = TrainingInputBatch(selected_data) + microbatch.metadata = self.data.metadata + return microbatch + + def _create_padding_microbatch(self) -> TrainingInputBatch: + """Create a padding microbatch with loss_mask=0 so it doesn't affect the loss.""" + seq_len = 2 + num_actions = self.data.metadata["response_length"] + batch_size = 1 + + data = TrainingInputBatch( + { + "sequences": torch.randint(0, 100, (batch_size, seq_len), device="cpu"), + "attention_mask": torch.ones((batch_size, seq_len), dtype=int, device="cpu"), + "action_log_probs": 0.4 * torch.ones((batch_size, num_actions), device="cpu"), + "base_action_log_probs": 0.3 * torch.ones((batch_size, num_actions), device="cpu"), + "values": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "returns": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"), + # Loss mask is all zeros so padding samples don't contribute to the loss. + "loss_mask": torch.zeros((batch_size, num_actions), dtype=int, device="cpu"), + "response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + } + ) + data.metadata = self.data.metadata + return data + + def _sync_num_microbatches(self) -> int: + """Ensure all DP workers have the same number of micro batches.""" + local_num_microbatches = len(self._microbatches) + + if not dist.is_initialized(): + return local_num_microbatches + + # Get the maximum number of batches across all DP workers + if torch.cuda.is_available(): + device = torch.cuda.current_device() + else: + device = torch.device("cpu") + num_microbatches_tensor = torch.tensor(local_num_microbatches, dtype=torch.long, device=device) + dist.all_reduce(num_microbatches_tensor, op=dist.ReduceOp.MAX) + return num_microbatches_tensor.item() + + def __len__(self): + return len(self._microbatches) + self._num_padding_microbatches + + def __iter__(self) -> Iterator[TrainingInputBatch]: + for microbatch_indices in self._microbatches: + yield self._create_microbatch_from_indices(microbatch_indices) + + for _ in range(self._num_padding_microbatches): + yield self._create_padding_microbatch() + + def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch: + """Reorder and combine output batches into a single batch with + the same order as the original input data. + + Example: [[0, 2], [1, 3]] -> [0, 1, 2, 3] + + Args: + batches: List of microbatch outputs to reorder. + Returns: + A single reordered batch. + """ + non_padding_batches = batches[: len(batches) - self._num_padding_microbatches] + + if not non_padding_batches: + raise ValueError("Cannot reorder an empty list of microbatches.") + + # Create a reverse mapping of original idx -> (microbatch idx, sample idx) + original_idx_to_microbatch_idx = {} + for microbatch_idx, original_indices in enumerate(self._microbatches): + for sample_idx, original_idx in enumerate(original_indices): + original_idx_to_microbatch_idx[original_idx] = (microbatch_idx, sample_idx) + + # Get reference microbatch to know keys and tensor shapes + ref_microbatch = non_padding_batches[0] + reordered_data = {} + + for key, ref_value in ref_microbatch.items(): + if ref_value is None: + reordered_data[key] = None + continue + # Get shape of a single sample (remove batch dimension) + sample_shape = ref_value.shape[1:] + device = ref_value.device + dtype = ref_value.dtype + + # Pre-allocate output tensor: [batch_size, *sample_shape] + batch_size = len(self._token_counts) + output_tensor = torch.zeros((batch_size, *sample_shape), dtype=dtype, device=device) + + # Copy each sample directly into the correct position + for original_idx in range(batch_size): + microbatch_idx, sample_idx = original_idx_to_microbatch_idx[original_idx] + source_tensor = non_padding_batches[microbatch_idx][key] + output_tensor[original_idx] = source_tensor[sample_idx] + + reordered_data[key] = output_tensor + + # Create single TensorBatch with reordered data + reordered_batch = type(ref_microbatch)(reordered_data) + reordered_batch.metadata = ref_microbatch.metadata + return reordered_batch + + +def get_microbatch_iterator( + data: TrainingInputBatch, micro_batch_size: int, max_tokens_per_microbatch: int +) -> BaseBatchIterator: + """Factory function to get the appropriate microbatch iterator. + + Args: + data: The training input batch. + micro_batch_size: Number of samples per microbatch (used if max_tokens_per_microbatch <= 0). + max_tokens_per_microbatch: Maximum tokens per microbatch. If > 0, uses token-based batching. + + Returns: + A BaseBatchIterator instance. + """ + if max_tokens_per_microbatch > 0: + return TokenBasedBatchIterator(data, max_tokens_per_microbatch=max_tokens_per_microbatch) + else: + return SampleBasedBatchIterator(data, sample_batch_size=micro_batch_size, drop_last=False) diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 6f5f145418..b10e2bdf4f 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -594,6 +594,11 @@ class TrainerConfig(BaseConfig): critic_mini_batch_size: int = 256 micro_train_batch_size_per_gpu: int = 1 micro_forward_batch_size_per_gpu: int = 1 + max_tokens_per_microbatch: int = -1 + """Maximum number of tokens per microbatch. When > 0, microbatches are formed by bin-packing + samples based on their token counts (from attention_mask) instead of using a fixed sample count. + -1 means disabled (use sample-based micro_train_batch_size_per_gpu / micro_forward_batch_size_per_gpu). + Applies to both forward and training micro-batching.""" update_ref_every_epoch: bool = False use_sample_packing: bool = True eval_batch_size: int = 1024 diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index f2a52006e3..7af72a7ca3 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -237,6 +237,7 @@ trainer: critic_mini_batch_size: 256 micro_train_batch_size_per_gpu: 1 micro_forward_batch_size_per_gpu: 1 + max_tokens_per_microbatch: -1 # -1 means disabled; when > 0, uses token-based microbatching instead of fixed sample count update_ref_every_epoch: false use_sample_packing: true eval_batch_size: 1024 diff --git a/tests/backends/skyrl_train/gpu/gpu_ci/test_token_based_batching.py b/tests/backends/skyrl_train/gpu/gpu_ci/test_token_based_batching.py new file mode 100644 index 0000000000..a3bcbc0641 --- /dev/null +++ b/tests/backends/skyrl_train/gpu/gpu_ci/test_token_based_batching.py @@ -0,0 +1,546 @@ +""" +Tests for max_tokens_per_microbatch (token-based micro-batching). + +Tests verify: +1. Unit tests for balanced_binpacking and TokenBasedBatchIterator +2. FSDP forward_backward with token-based batching produces equivalent loss +3. Megatron forward with token-based batching produces equivalent results +4. Performance comparison (token-based vs sample-based) + +Run with: +uv run --isolated --extra dev --extra fsdp -- pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_token_based_batching.py +""" + +import time + +import pytest +import ray +import torch + +from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +from skyrl.backends.skyrl_train.workers.worker_utils import ( + TokenBasedBatchIterator, + balanced_binpacking, + get_microbatch_iterator, +) +from skyrl.train.config import SkyRLTrainConfig +from skyrl.train.utils.utils import validate_cfg +from tests.backends.skyrl_train.gpu.utils import ( + init_worker_with_type, +) + + +# ─── Unit Tests (CPU only, no Ray/GPU needed) ─────────────────────────────── + + +class TestBalancedBinpacking: + def test_basic_packing(self): + result = balanced_binpacking([10, 10, 5, 5], 15) + assert len(result) == 2 + # Each microbatch should have total <= 15 + for mb in result: + total = sum([10, 10, 5, 5][i] for i in mb) + assert total <= 15 + + def test_single_large_item(self): + result = balanced_binpacking([10, 1, 1, 1, 1, 1], 10) + assert len(result) == 2 + # The large item should be alone + for mb in result: + total = sum([10, 1, 1, 1, 1, 1][i] for i in mb) + assert total <= 10 + + def test_all_items_equal(self): + result = balanced_binpacking([5, 5, 5, 5], 10) + assert len(result) == 2 + for mb in result: + total = sum(5 for _ in mb) + assert total <= 10 + + def test_single_item(self): + result = balanced_binpacking([10], 15) + assert len(result) == 1 + assert result[0] == [0] + + def test_all_indices_covered(self): + token_counts = [8, 3, 5, 6, 2, 7] + result = balanced_binpacking(token_counts, 11) + all_indices = sorted(idx for mb in result for idx in mb) + assert all_indices == list(range(len(token_counts))) + + def test_no_overflow(self): + token_counts = [8, 3, 5, 6, 2, 7] + max_tokens = 11 + result = balanced_binpacking(token_counts, max_tokens) + for mb in result: + total = sum(token_counts[i] for i in mb) + assert total <= max_tokens + + +class TestTokenBasedBatchIterator: + def _make_batch(self, seq_lens, num_actions=4): + """Create a dummy TrainingInputBatch with variable sequence lengths.""" + batch_size = len(seq_lens) + max_seq_len = max(seq_lens) + + sequences = torch.zeros((batch_size, max_seq_len), dtype=int, device="cpu") + attention_mask = torch.zeros((batch_size, max_seq_len), dtype=int, device="cpu") + for i, seq_len in enumerate(seq_lens): + sequences[i, :seq_len] = torch.randint(0, 100, (seq_len,), dtype=int, device="cpu") + attention_mask[i, :seq_len] = 1 + + data = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "action_log_probs": 0.4 * torch.ones((batch_size, num_actions), device="cpu"), + "base_action_log_probs": 0.3 * torch.ones((batch_size, num_actions), device="cpu"), + "values": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "returns": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"), + "loss_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + "response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + } + ) + data.metadata = {"response_length": num_actions} + return data + + def test_iterator_yields_all_samples(self): + batch = self._make_batch([10, 10, 5, 5]) + iterator = TokenBasedBatchIterator(batch, max_tokens_per_microbatch=15) + + all_indices = [] + for mb_indices in iterator._microbatches: + all_indices.extend(mb_indices) + assert sorted(all_indices) == [0, 1, 2, 3] + + def test_iterator_respects_token_limit(self): + batch = self._make_batch([10, 10, 5, 5]) + iterator = TokenBasedBatchIterator(batch, max_tokens_per_microbatch=15) + + for microbatch in iterator: + token_count = microbatch["attention_mask"].sum().item() + # Allow some slack for padding microbatches + if microbatch["loss_mask"].sum() > 0: # not a padding batch + assert token_count <= 15 + + def test_len_matches_iteration(self): + batch = self._make_batch([10, 10, 5, 5]) + iterator = TokenBasedBatchIterator(batch, max_tokens_per_microbatch=15) + count = sum(1 for _ in iterator) + assert count == len(iterator) + + def test_reorder_and_combine(self): + """Verify that reorder_and_combine_batches restores original order.""" + batch = self._make_batch([10, 3, 8, 5]) + iterator = TokenBasedBatchIterator(batch, max_tokens_per_microbatch=12) + + # Simulate forward outputs (just use the microbatch itself as output) + outputs = [] + for microbatch in iterator: + outputs.append(microbatch) + + reordered = iterator.reorder_and_combine_batches(outputs) + # Check that the sequences match the original order + for i in range(batch.batch_size): + assert torch.equal(reordered["sequences"][i], batch["sequences"][i]) + + def test_get_microbatch_iterator_factory(self): + batch = self._make_batch([10, 10, 5, 5]) + + # Token-based + it = get_microbatch_iterator(batch, micro_batch_size=2, max_tokens_per_microbatch=15) + assert isinstance(it, TokenBasedBatchIterator) + + # Sample-based (disabled) + from skyrl.backends.skyrl_train.workers.worker_utils import SampleBasedBatchIterator + + it = get_microbatch_iterator(batch, micro_batch_size=2, max_tokens_per_microbatch=-1) + assert isinstance(it, SampleBasedBatchIterator) + + +# ─── GPU Tests: FSDP ───────────────────────────────────────────────────────── + +MODEL_NAME = "Qwen/Qwen3-0.6B" + + +def _make_variable_length_batch(seq_lens, num_actions=4): + """Create a TrainingInputBatch with variable-length sequences (right-padded to max).""" + torch.manual_seed(42) + batch_size = len(seq_lens) + max_seq_len = max(seq_lens) + + sequences = torch.zeros((batch_size, max_seq_len), dtype=int, device="cpu") + attention_mask = torch.zeros((batch_size, max_seq_len), dtype=int, device="cpu") + for i, seq_len in enumerate(seq_lens): + sequences[i, :seq_len] = torch.randint(1, 100, (seq_len,), device="cpu") + attention_mask[i, :seq_len] = 1 + + data = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "action_log_probs": 0.4 * torch.ones((batch_size, num_actions), device="cpu"), + "base_action_log_probs": 0.3 * torch.ones((batch_size, num_actions), device="cpu"), + "values": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "returns": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"), + "loss_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + "response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + "rollout_logprobs": 0.2 * torch.ones((batch_size, num_actions), device="cpu"), + } + ) + data.metadata = {"response_length": num_actions} + return data + + +def get_fsdp_test_config() -> SkyRLTrainConfig: + cfg = SkyRLTrainConfig() + cfg.trainer.placement.policy_num_gpus_per_node = 2 + cfg.trainer.logger = "console" + cfg.generator.inference_engine.tensor_parallel_size = 2 + return cfg + + +@pytest.mark.asyncio +@pytest.mark.parametrize("worker_type", ["policy", "critic"]) +async def test_fsdp_token_based_forward_backward(ray_init_fixture, worker_type): + """ + Test that forward_backward with max_tokens_per_microbatch works correctly for FSDP. + + Verifies: + 1. Token-based batching runs without errors for both policy and critic + 2. Returns valid metrics with expected keys + 3. For policy: loss is close to sample-based baseline (both use pre-scaled advantages) + """ + try: + # Create a batch with variable-length sequences + seq_lens = [30, 30, 15, 15] # 4 samples, 2 per DP rank + batch = _make_variable_length_batch(seq_lens, num_actions=4) + batch.metadata["global_step"] = 0 + + # Token-based batching + cfg_token = get_fsdp_test_config() + cfg_token.trainer.strategy = "fsdp2" + cfg_token.trainer.policy.model.path = MODEL_NAME + cfg_token.trainer.micro_train_batch_size_per_gpu = 1 + cfg_token.trainer.max_tokens_per_microbatch = 30 + validate_cfg(cfg_token) + + actor_group = init_worker_with_type( + worker_type, + shared_pg=None, + colocate_all=False, + num_gpus_per_node=cfg_token.trainer.placement.policy_num_gpus_per_node, + cfg=cfg_token, + ) + results_token = ray.get( + actor_group.async_run_ray_method("mesh", "forward_backward", data=batch) + ) + + # Verify results have expected structure + loss_key = "policy_loss" if worker_type == "policy" else "critic_loss" + for i, r in enumerate(results_token): + assert isinstance(r, dict), f"Result should be a dict, got {type(r)}" + assert loss_key in r, f"Missing {loss_key} in result" + print(f" Rank {i}: token-based {loss_key}={r[loss_key]:.6f}") + + if worker_type == "policy": + assert "loss_metrics/clip_ratio" in r + assert "policy_entropy" in r + assert "loss_fn_outputs" in r + + # Also verify optim_step works + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + print(f" {worker_type}: forward_backward + optim_step completed successfully") + + finally: + ray.shutdown() + + +@pytest.mark.asyncio +async def test_fsdp_token_based_loss_equivalence(ray_init_fixture): + """ + Test that policy loss with token-based batching matches sample-based + when each microbatch contains exactly 1 sample (i.e., when max_tokens + is set high enough that no packing occurs but low enough that each + sample gets its own microbatch). + """ + try: + # Uniform-length sequences to ensure identical batching behavior + seq_lens = [20, 20, 20, 20] + batch = _make_variable_length_batch(seq_lens, num_actions=4) + batch.metadata["global_step"] = 0 + + # Run 1: sample-based baseline (mbs=1) + cfg_baseline = get_fsdp_test_config() + cfg_baseline.trainer.strategy = "fsdp2" + cfg_baseline.trainer.policy.model.path = MODEL_NAME + cfg_baseline.trainer.micro_train_batch_size_per_gpu = 1 + cfg_baseline.trainer.max_tokens_per_microbatch = -1 + validate_cfg(cfg_baseline) + + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=cfg_baseline.trainer.placement.policy_num_gpus_per_node, + cfg=cfg_baseline, + ) + results_baseline = ray.get( + actor_group.async_run_ray_method("mesh", "forward_backward", data=batch) + ) + + ray.shutdown() + from tests.backends.skyrl_train.gpu.utils import ray_init_for_tests + + ray_init_for_tests() + + # Run 2: token-based with limit that gives 1 sample per microbatch + cfg_token = get_fsdp_test_config() + cfg_token.trainer.strategy = "fsdp2" + cfg_token.trainer.policy.model.path = MODEL_NAME + cfg_token.trainer.micro_train_batch_size_per_gpu = 1 + # max_tokens=20 means each 20-token sample goes in its own microbatch + cfg_token.trainer.max_tokens_per_microbatch = 20 + validate_cfg(cfg_token) + + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=cfg_token.trainer.placement.policy_num_gpus_per_node, + cfg=cfg_token, + ) + results_token = ray.get( + actor_group.async_run_ray_method("mesh", "forward_backward", data=batch) + ) + + # With uniform sequences and 1 sample per microbatch, losses should match closely + for i, (r_baseline, r_token) in enumerate(zip(results_baseline, results_token)): + bl = r_baseline["policy_loss"] + tl = r_token["policy_loss"] + print(f" Rank {i}: baseline={bl:.6f}, token-based={tl:.6f}, diff={abs(bl-tl):.6f}") + assert abs(bl - tl) < 1e-4, f"Loss mismatch on rank {i}: {bl} vs {tl}" + + finally: + ray.shutdown() + + +@pytest.mark.asyncio +async def test_fsdp_token_based_batching_performance(ray_init_fixture): + """ + Test that token-based batching shows better throughput than sample-based + when sequences have highly variable lengths. + + Creates a batch with a mix of short (15 tokens) and long (100 tokens) sequences. + With sample-based batching (mbs=1), each microbatch processes one sample with + the full max_seq_len padding. With token-based batching, short sequences can be + packed together, reducing the number of forward passes. + """ + try: + # Create batch with high variance in sequence lengths + # 8 samples total (4 per DP rank with 2 GPUs) + # Mix of short and long sequences + seq_lens = [100, 100, 15, 15, 100, 100, 15, 15] + batch = _make_variable_length_batch(seq_lens, num_actions=4) + batch.metadata["global_step"] = 0 + + # Run 1: sample-based with mbs=1 (no packing, wastes time on padding) + cfg_sample = get_fsdp_test_config() + cfg_sample.trainer.strategy = "fsdp2" + cfg_sample.trainer.policy.model.path = MODEL_NAME + cfg_sample.trainer.micro_train_batch_size_per_gpu = 1 + cfg_sample.trainer.max_tokens_per_microbatch = -1 + validate_cfg(cfg_sample) + + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=cfg_sample.trainer.placement.policy_num_gpus_per_node, + cfg=cfg_sample, + ) + + # Warmup + ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=batch)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + + start = time.time() + NUM_ITERS = 3 + for _ in range(NUM_ITERS): + ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=batch)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + sample_time = (time.time() - start) / NUM_ITERS + + ray.shutdown() + from tests.backends.skyrl_train.gpu.utils import ray_init_for_tests + + ray_init_for_tests() + + # Run 2: token-based batching + cfg_token = get_fsdp_test_config() + cfg_token.trainer.strategy = "fsdp2" + cfg_token.trainer.policy.model.path = MODEL_NAME + cfg_token.trainer.micro_train_batch_size_per_gpu = 1 + # Set max_tokens to pack 2 short sequences together: 15+15=30 < 120 + cfg_token.trainer.max_tokens_per_microbatch = 120 + validate_cfg(cfg_token) + + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=cfg_token.trainer.placement.policy_num_gpus_per_node, + cfg=cfg_token, + ) + + # Warmup + ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=batch)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + + start = time.time() + for _ in range(NUM_ITERS): + ray.get(actor_group.async_run_ray_method("mesh", "forward_backward", data=batch)) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + token_time = (time.time() - start) / NUM_ITERS + + print(f"\nPerformance comparison (avg over {NUM_ITERS} iterations):") + print(f" Sample-based (mbs=1): {sample_time:.3f}s") + print(f" Token-based (max_tokens=120): {token_time:.3f}s") + print(f" Speedup: {sample_time / token_time:.2f}x") + + # Token-based should be at least as fast (may not always be faster with small batches) + # The main point is correctness; performance benefit is more visible with larger batches + assert token_time < sample_time * 1.5, ( + f"Token-based batching should not be significantly slower: " + f"{token_time:.3f}s vs {sample_time:.3f}s" + ) + + finally: + ray.shutdown() + + +def _get_megatron_test_config(tp=2, pp=1, gpus=2) -> SkyRLTrainConfig: + """Create a Megatron test config that passes validate_cfg.""" + cfg = SkyRLTrainConfig() + cfg.trainer.policy.model.path = MODEL_NAME + cfg.trainer.micro_forward_batch_size_per_gpu = 2 + cfg.trainer.micro_train_batch_size_per_gpu = 2 + cfg.trainer.use_sample_packing = False + cfg.trainer.logger = "console" + cfg.trainer.strategy = "megatron" + cfg.trainer.placement.policy_num_gpus_per_node = gpus + cfg.trainer.placement.colocate_all = False + cfg.trainer.placement.colocate_policy_ref = False + cfg.trainer.policy.megatron_config.tensor_model_parallel_size = tp + cfg.trainer.policy.megatron_config.pipeline_model_parallel_size = pp + return cfg + + +@pytest.mark.asyncio +@pytest.mark.megatron +async def test_megatron_token_based_forward(ray_init_fixture): + """ + Test that forward pass with token-based batching works correctly for Megatron. + Compares forward output between sample-based and token-based batching. + """ + from skyrl.backends.skyrl_train.distributed.dispatch import ( + concatenate_outputs_after_mesh_dispatch, + ) + from tests.backends.skyrl_train.gpu.utils import ray_init_for_tests + + try: + seq_lens = [30, 30, 15, 15] + batch = _make_variable_length_batch(seq_lens, num_actions=4) + + # Run 1: sample-based baseline + cfg = _get_megatron_test_config(tp=2, pp=1, gpus=2) + cfg.trainer.max_tokens_per_microbatch = -1 + + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=cfg.trainer.placement.policy_num_gpus_per_node, + cfg=cfg, + ) + + results_refs = actor_group.async_run_ray_method("mesh", "forward", data=batch) + results_baseline = ray.get(results_refs) + output_baseline = concatenate_outputs_after_mesh_dispatch( + actor_group.actor_infos, results_baseline + )["output"] + + ray.shutdown() + ray_init_for_tests() + + # Run 2: token-based + cfg2 = _get_megatron_test_config(tp=2, pp=1, gpus=2) + cfg2.trainer.max_tokens_per_microbatch = 35 # Can fit 1 long or 2 short seqs + + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=cfg2.trainer.placement.policy_num_gpus_per_node, + cfg=cfg2, + ) + + results_refs = actor_group.async_run_ray_method("mesh", "forward", data=batch) + results_token = ray.get(results_refs) + output_token = concatenate_outputs_after_mesh_dispatch( + actor_group.actor_infos, results_token + )["output"] + + # Compare log probs + max_diff = torch.max(torch.abs(output_baseline - output_token)).item() + avg_diff = torch.mean(torch.abs(output_baseline - output_token)).item() + print(f"\nMegatron forward comparison:") + print(f" Max diff: {max_diff:.6f}") + print(f" Avg diff: {avg_diff:.6f}") + + assert max_diff < 1e-3, f"Max diff {max_diff} too large between sample-based and token-based" + + finally: + ray.shutdown() + + +@pytest.mark.asyncio +@pytest.mark.megatron +async def test_megatron_token_based_train(ray_init_fixture): + """ + Test that training with token-based batching works correctly for Megatron (TP=2, PP=1). + """ + try: + seq_lens = [30, 30, 15, 15, 30, 30, 15, 15] + batch = _make_variable_length_batch(seq_lens, num_actions=4) + batch.metadata["global_step"] = 0 + + cfg = _get_megatron_test_config(tp=2, pp=1, gpus=2) + cfg.trainer.max_tokens_per_microbatch = 35 + cfg.trainer.train_batch_size = len(seq_lens) + cfg.trainer.policy_mini_batch_size = 4 + cfg.trainer.algorithm.use_kl_loss = False + + actor_group = init_worker_with_type( + "policy", + shared_pg=None, + colocate_all=False, + num_gpus_per_node=cfg.trainer.placement.policy_num_gpus_per_node, + cfg=cfg, + ) + + results = ray.get( + actor_group.async_run_ray_method("mesh", "forward_backward", data=batch) + ) + ray.get(actor_group.async_run_ray_method("pass_through", "optim_step")) + + for result in results: + assert isinstance(result, dict), "Result should be a dictionary" + assert "policy_loss" in result + assert "policy_entropy" in result + print(f" policy_loss={result['policy_loss']:.6f}") + + finally: + ray.shutdown()