Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def convert_list_to_set(cls, v: Iterable[int] | None) -> Set[int] | None:
class GPT2ModelTPConfig(BaseModel):
model: PydanticPytorchModuleOrListType # TODO set proper type
device_mesh: PydanticDeviceMeshIFType
context_parallel_load_balancer: Literal["headtail", "ptrr"] | None = "headtail"

@model_validator(mode="after")
def validate_tp_mesh_existence(self) -> "GPT2ModelTPConfig":
Expand All @@ -335,12 +336,36 @@ def validate_tp_mesh_existence(self) -> "GPT2ModelTPConfig":
raise ValueError(f"Device mesh {self.device_mesh=} has no defined mesh_dim_names.")
if ParallelismDegrees.TP.value not in mesh_dim_names:
raise ValueError(f"Tensor parallelism key '{ParallelismDegrees.TP.value}' not in {self.device_mesh=}")
if (
"context_parallel_load_balancer" in self.model_fields_set
and self.context_parallel_load_balancer is not None
and ParallelismDegrees.CP.value not in mesh_dim_names
):
raise ValueError(
"context_parallel_load_balancer can only be set when context parallelism is configured in the mesh. "
f"Expected key '{ParallelismDegrees.CP.value}' in {self.device_mesh=}."
)
if ParallelismDegrees.DP_REPLICATE.value in mesh_dim_names:
# TorchTitan uses replicate (i.e, plain DP) to combine DP with TP.
raise ValueError("data_parallel_replicate_degree > 1 cannot be used with Tensor Parallelism.")
return self


class GPT2ModelCPConfig(BaseModel):
model: PydanticPytorchModuleOrListType
device_mesh: PydanticDeviceMeshIFType
context_parallel_load_balancer: Literal["headtail", "ptrr"] | None = "headtail"

@model_validator(mode="after")
def validate_cp_mesh_existence(self) -> "GPT2ModelCPConfig":
mesh_dim_names = self.device_mesh.mesh_dim_names
if mesh_dim_names is None:
raise ValueError(f"Device mesh {self.device_mesh=} has no defined mesh_dim_names.")
if ParallelismDegrees.CP.value not in mesh_dim_names:
raise ValueError(f"Context parallelism key '{ParallelismDegrees.CP.value}' not in {self.device_mesh=}")
return self


class CompiledModelConfig(BaseModel):
model: PydanticPytorchModuleOrListType
block_names: list[str]
Expand Down
73 changes: 61 additions & 12 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,26 @@ def apply_rotary_pos_emb(self, x, cos, sin):
# the rotation below work
return (x * cos) + (self.rotate_half(x) * sin)

def _compute_cos_sin_from_positions(
self, position_ids: torch.Tensor, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# position_ids: (B, T) or (T,) — explicit global token positions
# Returns cos, sin of shape (B or 1, 1, T, dim_model) matching x: (B, nh, T, hd)
pos = position_ids.float()
if pos.dim() == 1:
pos = pos.unsqueeze(0) # (1, T)
freqs = torch.einsum("bt,d->btd", pos, self.inv_freq.to(x.dtype)) # (B or 1, T, dim/2)
emb = torch.cat((freqs, freqs), dim=-1) # (B or 1, T, dim)
cos = emb.cos().to(x.dtype).unsqueeze(1) # (B or 1, 1, T, dim)
sin = emb.sin().to(x.dtype).unsqueeze(1)
return cos, sin

def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
position_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass of the RotaryTransform module.
Expand All @@ -217,12 +235,20 @@ def forward(
q (torch.Tensor): Query tensor.
k (torch.Tensor): Key tensor.
v (torch.Tensor): Value tensor.
position_ids (torch.Tensor | None): Optional explicit global position indices of shape
(B, T) or (T,). When provided (e.g. for context-parallel ranks that hold
non-contiguous token ranges), the correct global RoPE frequencies are computed
from these positions instead of assuming a local 0-based range.

Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Tuple containing the modified query tensor, key tensor, and value tensor.
"""
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k)
if position_ids is not None:
cos, sin = self._compute_cos_sin_from_positions(position_ids, k)
self._cos_cached, self._sin_cached = cos, sin
else:
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k)
q = self.apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached)
k = self.apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached)

Expand Down Expand Up @@ -514,7 +540,12 @@ def projection(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch

@staticmethod
def execute_qkv_transforms(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qkv_transforms: nn.ModuleList, n_head_q: int
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
qkv_transforms: nn.ModuleList,
n_head_q: int,
position_ids: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Applies a series of transformations to the query, key, and value tensors.
Expand All @@ -525,6 +556,8 @@ def execute_qkv_transforms(
v (torch.Tensor): The value tensors.
qkv_transforms (nn.ModuleList): A list of transformation modules to be applied to q, k, and v.
n_head_q (int): The number of heads for the query tensors.
position_ids (torch.Tensor | None): Optional explicit global position indices forwarded
to RotaryTransform so CP ranks use correct global RoPE frequencies.

Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand All @@ -541,7 +574,10 @@ def execute_qkv_transforms(
v = v.view(batch_size, sequence_length, -1, n_head_dim).transpose(1, 2).contiguous() # (B, nh_kv, T, hd)

for transform in qkv_transforms:
q, k, v = transform(q, k, v)
if isinstance(transform, RotaryTransform) and position_ids is not None:
q, k, v = transform(q, k, v, position_ids=position_ids)
else:
q, k, v = transform(q, k, v)

return q, k, v

Expand Down Expand Up @@ -655,12 +691,14 @@ def execute_attention(
raise NotImplementedError(f"Attention implementation {attention_impl} not supported")
return y # (B, T, nh_q, hd)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, position_ids: torch.Tensor | None = None) -> torch.Tensor:
"""
Forward pass of the CausalSelfAttention module.

Args:
x (torch.Tensor): Input tensor of shape (B, T, n_embd)
position_ids (torch.Tensor | None): Optional global position indices forwarded to
RotaryTransform for correct CP-rank-aware RoPE.

Returns:
torch.Tensor: Output tensor of shape (B, T, n_embd), representing the output projection.
Expand All @@ -669,7 +707,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
q, k, v = self.projection(x) # q: (B, T, n_embd), k: (B, T, n_embd // n_rep), v: (B, T, n_embd // n_rep)

# q: (B, nh_q, T, hd), k: (B, nh_kv, T, hd), v: (B, nh_kv, T, hd)
q, k, v = CausalSelfAttention.execute_qkv_transforms(q, k, v, self.qkv_transforms, self.n_head_q)
q, k, v = CausalSelfAttention.execute_qkv_transforms(
q, k, v, self.qkv_transforms, self.n_head_q, position_ids=position_ids
)
if self.q_norm is not None and self.k_norm is not None:
q = self.q_norm(q)
k = self.k_norm(k)
Expand Down Expand Up @@ -796,17 +836,19 @@ def _check_ffn_hidden_dim(self, n_embd: int, ffn_hidden: int) -> None:
f"but got `n_embd = {n_embd}` and `ffn_hidden = {ffn_hidden}`."
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, position_ids: torch.Tensor | None = None) -> torch.Tensor:
"""
Forward pass of the GPT2Block.

Args:
x (torch.Tensor): Input tensor.
position_ids (torch.Tensor | None): Optional global position indices forwarded to
the attention layer for CP-aware RoPE.

Returns:
torch.Tensor: Output tensor.
"""
x = x + self.attn(self.attention_norm(x))
x = x + self.attn(self.attention_norm(x), position_ids=position_ids)
x = x + self.mlp(self.ffn_norm(x))
return x

Expand Down Expand Up @@ -971,22 +1013,29 @@ def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, t
Forward pass of the GPT2LLM module.

Args:
inputs (dict[str, torch.Tensor] | torch.Tensor): Input data.
inputs (dict[str, torch.Tensor] | torch.Tensor): Input data. When a dict, an optional
``"position_ids"`` key (shape ``(B, T)`` or ``(1, T)``) may be present to supply
explicit global token positions for CP-aware RoPE.

Returns:
dict[str, torch.Tensor] | torch.Tensor: Model output.
"""
if isinstance(inputs, dict):
return {self.prediction_key: self.forward_impl(inputs[self.sample_key])}
position_ids = inputs.get("position_ids", None)
return {self.prediction_key: self.forward_impl(inputs[self.sample_key], position_ids=position_ids)}
else:
return self.forward_impl(inputs)

def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor:
def forward_impl(self, inputs: torch.Tensor, position_ids: torch.Tensor | None = None) -> torch.Tensor:
"""
Forward pass implementation of the GPT2LLM module.

Args:
inputs (torch.Tensor): A tensor containing input token ids.
position_ids (torch.Tensor | None): Optional explicit global position indices
of shape ``(B, T)`` or ``(1, T)``. When provided, RoPE uses these positions
instead of a local 0-based arange, enabling correct behaviour for CP ranks
that hold non-contiguous token ranges.

Returns:
torch.Tensor: A tensor containing output logits.
Expand All @@ -1010,7 +1059,7 @@ def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor:
h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h

for layer_idx in self.transformer.h:
h = self.transformer.h[layer_idx](h)
h = self.transformer.h[layer_idx](h, position_ids=position_ids)
h = self.transformer.lm_head_norm(h) if hasattr(self.transformer, "lm_head_norm") else h
h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h
return h
Expand Down
110 changes: 108 additions & 2 deletions src/modalities/models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,21 @@
GPT2LLM,
AttentionConfig,
AttentionImplementation,
CausalSelfAttention,
LayerNormWrapperConfig,
PositionTypes,
SwiGLU,
TransformerMLP,
)
from modalities.models.model import ActivationType
from modalities.models.parallelism.context_parallel import apply_cp_to_sdpa_attention_forward
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.running_env.env_utils import FSDP2MixedPrecisionSettings, MixedPrecisionSettings
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees
from modalities.running_env.fsdp.device_mesh import (
ParallelismDegrees,
get_mesh_for_parallelism_method,
has_parallelism_method,
)
from modalities.running_env.fsdp.fsdp_auto_wrapper import FSDPTransformerAutoWrapPolicyFactory
from modalities.training.activation_checkpointing.activation_checkpointing import (
ActivationCheckpointing,
Expand Down Expand Up @@ -593,6 +599,73 @@ def register_hooks_recursively(module: nn.Module, prefix: str = ""):


class GPT2ModelFactory:
@staticmethod
def _get_cp_mesh_if_enabled(device_mesh: DeviceMesh) -> DeviceMesh | None:
if not has_parallelism_method(device_mesh, ParallelismDegrees.CP):
return None
cp_mesh = get_mesh_for_parallelism_method(device_mesh=device_mesh, parallelism_method=ParallelismDegrees.CP)
return cp_mesh if cp_mesh.size() > 1 else None

@staticmethod
def _validate_context_parallel_seq_len(
model: GPT2LLM,
cp_degree: int,
tp_degree: int = 1,
load_balancer_type: str | None = "headtail",
) -> None:
# The "headtail" balancer splits each rank's chunk into a head and tail piece,
# requiring an extra factor of 2. Other load balancers don't impose this constraint.
headtail_factor = 2 if load_balancer_type == "headtail" else 1
seq_len_divisor = tp_degree * cp_degree * headtail_factor
if model.sequence_length % seq_len_divisor != 0:
headtail_note = " * 2 (headtail)" if load_balancer_type == "headtail" else ""
raise ValueError(
f"For GPT2 CP runs, sequence_length must be divisible by tp_degree * cp_degree{headtail_note}. "
f"Got sequence_length={model.sequence_length}, tp_degree={tp_degree}, cp_degree={cp_degree}, "
f"load_balancer_type={load_balancer_type!r}."
)

@staticmethod
def _apply_context_parallel_to_gpt2_attention(
model: GPT2LLM,
cp_mesh: DeviceMesh | None,
context_parallel_load_balancer: str | None,
) -> None:
if cp_mesh is None:
return

if context_parallel_load_balancer not in ("headtail", "ptrr", None):
raise ValueError(
"context_parallel_load_balancer must be one of: 'headtail', 'ptrr', or None. "
f"Got {context_parallel_load_balancer}."
)

GPT2ModelFactory._validate_context_parallel_seq_len(
model=model, cp_degree=cp_mesh.size(), load_balancer_type=context_parallel_load_balancer
)

attention_modules: list[nn.Module] = []
transformer_layers = getattr(model.transformer, "h", None)
if not isinstance(transformer_layers, nn.ModuleDict):
raise TypeError(
"Context parallelism requires model.transformer.h to be an nn.ModuleDict of GPT2 blocks. "
f"Got type {type(transformer_layers).__name__}."
)

for _, transformer_block in transformer_layers.named_children():
attn_module = getattr(transformer_block, "attn", None)
if not isinstance(attn_module, CausalSelfAttention):
continue
if attn_module.attention_impl != AttentionImplementation.PYTORCH_FLASH:
raise NotImplementedError(
"Context parallelism currently supports only attention_implementation='pytorch_flash' "
"for GPT2 in this codebase."
)
attention_modules.append(attn_module)

apply_cp_to_sdpa_attention_forward(attention_modules=attention_modules, cp_mesh=cp_mesh)
setattr(model, "_context_parallel_load_balancer", context_parallel_load_balancer)

@staticmethod
def get_gpt2_model(
sample_key: str,
Expand Down Expand Up @@ -653,8 +726,41 @@ def get_gpt2_model(
return model

@staticmethod
def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) -> nn.Module:
def get_gpt2_context_parallelized_model(
model: GPT2LLM,
device_mesh: DeviceMesh,
context_parallel_load_balancer: str | None = "headtail",
) -> nn.Module:
cp_mesh = GPT2ModelFactory._get_cp_mesh_if_enabled(device_mesh=device_mesh)
GPT2ModelFactory._apply_context_parallel_to_gpt2_attention(
model=model,
cp_mesh=cp_mesh,
context_parallel_load_balancer=context_parallel_load_balancer,
)
return model

@staticmethod
def get_gpt2_tensor_parallelized_model(
model: GPT2LLM,
device_mesh: DeviceMesh,
context_parallel_load_balancer: str | None = "headtail",
) -> nn.Module:
tp_mesh = device_mesh[ParallelismDegrees.TP.value]
cp_mesh = GPT2ModelFactory._get_cp_mesh_if_enabled(device_mesh=device_mesh)

if cp_mesh is not None:
GPT2ModelFactory._validate_context_parallel_seq_len(
model=model,
cp_degree=cp_mesh.size(),
tp_degree=tp_mesh.size(),
load_balancer_type=context_parallel_load_balancer,
)
GPT2ModelFactory._apply_context_parallel_to_gpt2_attention(
model=model,
cp_mesh=cp_mesh,
context_parallel_load_balancer=context_parallel_load_balancer,
)

model_tp_plan = {
# Row-wise parallelism might seem counterintuitive here,
# but the embedding layer has weight shape (vocab_size, n_embd).
Expand Down
Loading
Loading