Skip to content

Context Parallelism#446

Open
rrutmann wants to merge 6 commits into
mainfrom
cp
Open

Context Parallelism#446
rrutmann wants to merge 6 commits into
mainfrom
cp

Conversation

@rrutmann
Copy link
Copy Markdown
Collaborator

@rrutmann rrutmann commented May 19, 2026

What does this PR do?

This PR adds end-to-end context parallel support for the GPT2 path. The current implementation requires the usage of pytorch's flash attention when enabling context parallelism

General Changes

  • Applies CP-aware SDPA dispatch in attention.
  • Adds trainer-side sequence sharding for inputs and targets on the CP mesh.
  • Propagates CP load-balancer configuration from model setup into trainer sharding.
  • Registers and validates the new GPT2 CP model variant.
  • Adds focused unit coverage and an e2e CP-vs-non-CP parity test.

Breaking Changes

  • ..

Checklist before submitting final PR

  • My PR is minimal and addresses one issue in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have checked that all tests run through (python tests/tests.py)
  • I have updated the internal changelog (CHANGELOG_DEV.md)

rrutmann and others added 3 commits May 18, 2026 13:43
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
@rrutmann rrutmann self-assigned this May 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant