Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 260 additions & 52 deletions skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@
RefWorkerBase,
)
from skyrl.backends.skyrl_train.workers.worker_utils import (
BaseBatchIterator,
BatchIterator,
SampleBasedBatchIterator,
TokenBasedBatchIterator,
all_reduce_metrics,
get_microbatch_iterator,
reduce_metrics,
)
from skyrl.env_vars import SKYRL_WORKER_NCCL_TIMEOUT_IN_S
Expand Down Expand Up @@ -449,39 +453,80 @@ def lora_pre_wrap_hook(model):
def forward(self, data: TrainingInputBatch):
"""
Override `Worker.forward` to support passing the full mini batch to the MegatronModelWrapper.forward method.

Supports token-based micro-batching via max_tokens_per_microbatch config.
"""
from skyrl.backends.skyrl_train.utils.replay_utils import clear_router_replay

# Run in micro batches grouped into a single mini-batch
micro_bsz = self.cfg.micro_forward_batch_size_per_gpu
micro_batches = data.chunk(micro_bsz)
use_token_batching = self.cfg.max_tokens_per_microbatch > 0

if use_token_batching:
microbatch_iterator = get_microbatch_iterator(
data,
micro_batch_size=self.cfg.micro_forward_batch_size_per_gpu,
max_tokens_per_microbatch=self.cfg.max_tokens_per_microbatch,
)
else:
microbatch_iterator = None

# Build micro-batch dicts expected by policy.forward_mini_batch
micro_dicts = []
device = torch.cuda.current_device()
for micro in micro_batches:
micro.to(device)
sequences = micro["sequences"]
attention_mask = micro["attention_mask"]
num_actions = micro.metadata["response_length"]
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
rollout_expert_indices = micro.get("rollout_expert_indices")
if rollout_expert_indices is not None:
rollout_expert_indices = rollout_expert_indices.to(torch.int32)
micro_dicts.append(
{
"sequences": sequences,
"attention_mask": attention_mask,
"position_ids": position_ids,
"num_actions": num_actions,
"rollout_expert_indices": (rollout_expert_indices if self.enable_router_replay else None),
}
)

if microbatch_iterator is not None:
for microbatch in microbatch_iterator:
microbatch.to(device)
sequences = microbatch["sequences"]
attention_mask = microbatch["attention_mask"]
num_actions = microbatch.metadata["response_length"]
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
rollout_expert_indices = microbatch.get("rollout_expert_indices")
if rollout_expert_indices is not None:
rollout_expert_indices = rollout_expert_indices.to(torch.int32)
micro_dicts.append(
{
"sequences": sequences,
"attention_mask": attention_mask,
"position_ids": position_ids,
"num_actions": num_actions,
"rollout_expert_indices": (rollout_expert_indices if self.enable_router_replay else None),
}
)
else:
micro_bsz = self.cfg.micro_forward_batch_size_per_gpu
micro_batches = data.chunk(micro_bsz)
for micro in micro_batches:
micro.to(device)
sequences = micro["sequences"]
attention_mask = micro["attention_mask"]
num_actions = micro.metadata["response_length"]
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
rollout_expert_indices = micro.get("rollout_expert_indices")
if rollout_expert_indices is not None:
rollout_expert_indices = rollout_expert_indices.to(torch.int32)
micro_dicts.append(
{
"sequences": sequences,
"attention_mask": attention_mask,
"position_ids": position_ids,
"num_actions": num_actions,
"rollout_expert_indices": (rollout_expert_indices if self.enable_router_replay else None),
}
)

if use_token_batching:
# Pad microbatches to uniform batch size for Megatron compatibility
max_micro_bsz = max(m["sequences"].shape[0] for m in micro_dicts) if micro_dicts else 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of max(m["sequences"].shape[0] for m in micro_dicts) if micro_dicts else 1 is redundant if micro_dicts is guaranteed to be non-empty at this point. If it can be empty, consider handling it more explicitly to avoid potential issues with mbs being 1 when no data is present.

for i, m in enumerate(micro_dicts):
micro_dicts[i] = self._pad_forward_microbatch_to_size(m, max_micro_bsz)
mbs = max_micro_bsz
else:
mbs = micro_dicts[0]["sequences"].shape[0] if micro_dicts else 1

self.model.eval()
seq_len = micro_dicts[0]["sequences"].shape[1]
mbs = micro_dicts[0]["sequences"].shape[0]
with torch.no_grad():
log_probs = self.model.forward(
micro_batches=micro_dicts,
Expand All @@ -491,11 +536,81 @@ def forward(self, data: TrainingInputBatch):
)

log_probs = log_probs.to("cpu")
output = TrainingOutputBatch({"output": log_probs})
output.metadata = data.metadata

if use_token_batching and microbatch_iterator is not None:
# Need to strip padded samples and reorder back to original order
output = TrainingOutputBatch({"output": log_probs})
output.metadata = data.metadata
# The output from Megatron is concatenated across microbatches.
# We need to extract only the real (non-padded) samples and reorder.
output = self._reorder_megatron_forward_output(
output, microbatch_iterator, micro_dicts, mbs
)
else:
output = TrainingOutputBatch({"output": log_probs})
output.metadata = data.metadata

clear_router_replay()
return output

def _pad_forward_microbatch_to_size(self, micro_dict: dict, target_batch_size: int) -> dict:
"""Pad a forward micro-batch dict to target_batch_size."""
current_bsz = micro_dict["sequences"].shape[0]
if current_bsz >= target_batch_size:
return micro_dict

pad_count = target_batch_size - current_bsz
device = micro_dict["sequences"].device

padded = {}
for key, value in micro_dict.items():
if key in ("num_actions",):
padded[key] = value
continue
if value is None:
padded[key] = None
continue
if isinstance(value, torch.Tensor):
if key == "attention_mask":
pad_tensor = torch.ones((pad_count, *value.shape[1:]), dtype=value.dtype, device=device)
elif key == "position_ids":
seq_len = value.shape[1]
pad_tensor = torch.arange(seq_len, device=device).unsqueeze(0).expand(pad_count, -1)
else:
pad_tensor = torch.zeros((pad_count, *value.shape[1:]), dtype=value.dtype, device=device)
padded[key] = torch.cat([value, pad_tensor], dim=0)
else:
padded[key] = value

return padded

def _reorder_megatron_forward_output(
self, output: TrainingOutputBatch, microbatch_iterator, micro_dicts, padded_mbs
) -> TrainingOutputBatch:
"""Reorder forward output from token-based microbatching back to original sample order."""
if not isinstance(microbatch_iterator, TokenBasedBatchIterator):
return output

log_probs = output["output"] # shape: [total_padded_samples, num_actions]
num_microbatches = len(microbatch_iterator._microbatches) + microbatch_iterator._num_padding_microbatches

# Split by padded_mbs, take only real samples, reorder
all_log_probs = log_probs.split(padded_mbs, dim=0)

# Build original-order tensor
batch_size = microbatch_iterator.data.batch_size
num_actions = log_probs.shape[1]
reordered = torch.zeros((batch_size, num_actions), dtype=log_probs.dtype, device=log_probs.device)

for mb_idx, original_indices in enumerate(microbatch_iterator._microbatches):
mb_log_probs = all_log_probs[mb_idx]
for sample_idx, original_idx in enumerate(original_indices):
reordered[original_idx] = mb_log_probs[sample_idx]

result = TrainingOutputBatch({"output": reordered})
result.metadata = output.metadata
return result

def save_hf_model(self, export_dir: str, tokenizer):
# Save model in HuggingFace safetensors format
self.strategy.save_hf_model(
Expand Down Expand Up @@ -641,6 +756,50 @@ def init_model(self, model_path, num_training_steps: int = 1e9):

self.empty_cuda_cache = self.cfg.policy.megatron_config.empty_cuda_cache

def _pad_microbatch_to_size(self, micro_dict: dict, target_batch_size: int) -> dict:
"""Pad a micro-batch dict to target_batch_size with dummy samples.

Padded samples have loss_mask=0 so they don't contribute to the loss.
This is needed because Megatron's forward_backward_func requires uniform
micro_batch_size across all microbatches (especially with PP > 1).
"""
current_bsz = micro_dict["sequences"].shape[0]
if current_bsz >= target_batch_size:
return micro_dict

pad_count = target_batch_size - current_bsz
device = micro_dict["sequences"].device

padded = {}
for key, value in micro_dict.items():
if key in ("num_actions", "num_microbatches"):
padded[key] = value
continue
if value is None:
padded[key] = None
continue
if isinstance(value, torch.Tensor):
if key == "loss_mask":
# Pad with zeros so padded samples don't contribute to loss
pad_tensor = torch.zeros((pad_count, *value.shape[1:]), dtype=value.dtype, device=device)
elif key == "attention_mask":
# Pad attention with ones (minimal seq) to avoid NaN in position_ids
pad_tensor = torch.ones((pad_count, *value.shape[1:]), dtype=value.dtype, device=device)
elif key == "position_ids":
# position_ids for padded samples
seq_len = value.shape[1]
pad_tensor = torch.arange(seq_len, device=device).unsqueeze(0).expand(pad_count, -1)
elif key == "action_mask":
# action_mask should be zeros for padded samples
pad_tensor = torch.zeros((pad_count, *value.shape[1:]), dtype=value.dtype, device=device)
else:
pad_tensor = torch.zeros((pad_count, *value.shape[1:]), dtype=value.dtype, device=device)
padded[key] = torch.cat([value, pad_tensor], dim=0)
else:
padded[key] = value

return padded

def forward_backward(
self,
data: TrainingInputBatch,
Expand All @@ -650,7 +809,8 @@ def forward_backward(
"""
Perform forward and backward passes for a batch, handling micro-batching internally.

The batch is split into micro batches based on micro_train_batch_size_per_gpu.
The batch is split into micro batches based on micro_train_batch_size_per_gpu,
or by token count if max_tokens_per_microbatch is configured.
Megatron Core's forward_backward_func handles gradient accumulation internally.

Args:
Expand All @@ -669,38 +829,77 @@ def forward_backward(
# if use distributed optimizer, zero grad buffer will be handled by optimizer
chunk.zero_grad_buffer()

micro_batch_size = self.cfg.micro_train_batch_size_per_gpu
all_metrics = defaultdict(list)

# Move data to GPU
data.to(torch.cuda.current_device())

use_token_batching = self.cfg.max_tokens_per_microbatch > 0

if use_token_batching:
microbatch_iterator = get_microbatch_iterator(
data,
micro_batch_size=self.cfg.micro_train_batch_size_per_gpu,
max_tokens_per_microbatch=self.cfg.max_tokens_per_microbatch,
)
else:
microbatch_iterator = None

# Build micro-batch dicts expected by forward_backward_mini_batch
micro_buffer = []
for experience in BatchIterator(data, micro_batch_size, drop_last=False):
sequences = experience.sequences
attention_mask = experience.attention_mask
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
rollout_expert_indices = experience.rollout_expert_indices
if rollout_expert_indices is not None:
rollout_expert_indices = rollout_expert_indices.to(torch.int32)

micro_buffer.append(
{
"sequences": sequences,
"attention_mask": attention_mask,
"position_ids": position_ids,
"num_actions": experience.num_actions,
"old_action_log_probs": experience.action_log_probs,
"base_action_log_probs": experience.base_action_log_probs,
"advantages": experience.advantages,
"loss_mask": experience.loss_mask,
"rollout_action_logprobs": experience.rollout_logprobs,
"action_mask": experience.action_mask,
"rollout_expert_indices": rollout_expert_indices if self.enable_router_replay else None,
}
)

if microbatch_iterator is not None:
for microbatch in microbatch_iterator:
experience = BaseBatchIterator.batch_to_experience(microbatch)
sequences = experience.sequences
attention_mask = experience.attention_mask
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
rollout_expert_indices = experience.rollout_expert_indices
if rollout_expert_indices is not None:
rollout_expert_indices = rollout_expert_indices.to(torch.int32)

micro_buffer.append(
{
"sequences": sequences,
"attention_mask": attention_mask,
"position_ids": position_ids,
"num_actions": experience.num_actions,
"old_action_log_probs": experience.action_log_probs,
"base_action_log_probs": experience.base_action_log_probs,
"advantages": experience.advantages,
"loss_mask": experience.loss_mask,
"rollout_action_logprobs": experience.rollout_logprobs,
"action_mask": experience.action_mask,
"rollout_expert_indices": rollout_expert_indices if self.enable_router_replay else None,
}
)
else:
micro_batch_size = self.cfg.micro_train_batch_size_per_gpu
for experience in BatchIterator(data, micro_batch_size, drop_last=False):
sequences = experience.sequences
attention_mask = experience.attention_mask
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 0)
rollout_expert_indices = experience.rollout_expert_indices
if rollout_expert_indices is not None:
rollout_expert_indices = rollout_expert_indices.to(torch.int32)

micro_buffer.append(
{
"sequences": sequences,
"attention_mask": attention_mask,
"position_ids": position_ids,
"num_actions": experience.num_actions,
"old_action_log_probs": experience.action_log_probs,
"base_action_log_probs": experience.base_action_log_probs,
"advantages": experience.advantages,
"loss_mask": experience.loss_mask,
"rollout_action_logprobs": experience.rollout_logprobs,
"action_mask": experience.action_mask,
"rollout_expert_indices": rollout_expert_indices if self.enable_router_replay else None,
}
)

for m_batch in micro_buffer:
m_batch["num_microbatches"] = len(micro_buffer)
Expand All @@ -709,7 +908,16 @@ def forward_backward(
return {}

seq_len = micro_buffer[0]["sequences"].shape[1]
micro_bsz = micro_buffer[0]["sequences"].shape[0]

if use_token_batching:
# With token-based batching, microbatches may have different batch sizes.
# Megatron's forward_backward_func requires uniform micro_batch_size,
# so pad all microbatches to the max batch size across microbatches.
max_micro_bsz = max(m["sequences"].shape[0] for m in micro_buffer)
micro_buffer = [self._pad_microbatch_to_size(m, max_micro_bsz) for m in micro_buffer]
micro_bsz = max_micro_bsz
else:
micro_bsz = micro_buffer[0]["sequences"].shape[0]

metrics_list = self.model.forward_backward_mini_batch(
micro_batches=micro_buffer,
Expand Down
Loading
Loading