Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions skyrl/tinker/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,23 @@ def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoi
)
session.commit()

def _find_destructive_barriers(self, session: Session) -> dict[str, int]:
"""Find the earliest pending destructive operation (optim_step/load_weights) per model.

These act as scheduling barriers: model passes before them can be batched
safely, and single requests after a blocked pass must wait.
"""
query = (
select(FutureDB.model_id, func.min(FutureDB.request_id).label("barrier_id"))
.where(
(FutureDB.request_type == types.RequestType.OPTIM_STEP)
| (FutureDB.request_type == types.RequestType.LOAD_WEIGHTS)
)
.where(FutureDB.status == RequestStatus.PENDING)
.group_by(FutureDB.model_id)
)
return dict(session.exec(query).all())

def find_batchable_model_passes(
self, session: Session, request_type: types.RequestType
) -> dict[str, tuple[str, types.ForwardBackwardInput]]:
Expand All @@ -311,17 +328,7 @@ def find_batchable_model_passes(
Returns:
Dict mapping request_id to (model_id, request_data) tuples
"""
# Find the earliest pending optim_step or load_weights per model (these act as barriers)
barriers_query = (
select(FutureDB.model_id, func.min(FutureDB.request_id).label("barrier_id"))
.where(
(FutureDB.request_type == types.RequestType.OPTIM_STEP)
| (FutureDB.request_type == types.RequestType.LOAD_WEIGHTS)
)
.where(FutureDB.status == RequestStatus.PENDING)
.group_by(FutureDB.model_id)
)
barriers = dict(session.exec(barriers_query).all())
barriers = self._find_destructive_barriers(session)

# Get all pending operations of the requested type ordered by request_id
query = (
Expand Down Expand Up @@ -389,6 +396,26 @@ def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.R
Returns:
Dict mapping request_id to (model_id, request_type, request_data) tuples
"""
# Find the first blocked forward pass per model: a pending FORWARD/FORWARD_BACKWARD
# that sits behind a destructive barrier and won't be batched this iteration.
# Single requests must not jump ahead of these.
destructive_barriers = self._find_destructive_barriers(session)
blocked_pass_barriers: dict[str, int] = {}
if destructive_barriers:
pending_passes = session.exec(
select(FutureDB.model_id, FutureDB.request_id)
.where(
(FutureDB.request_type == types.RequestType.FORWARD_BACKWARD)
| (FutureDB.request_type == types.RequestType.FORWARD)
)
.where(FutureDB.status == RequestStatus.PENDING)
.where(FutureDB.model_id.in_(destructive_barriers.keys()))
.order_by(FutureDB.request_id)
).all()
for model_id, req_id in pending_passes:
if req_id >= destructive_barriers[model_id]:
blocked_pass_barriers.setdefault(model_id, req_id)
Comment on lines +404 to +417
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for identifying blocked passes involves a nested loop over all pending passes for every call to find_single_requests. This can be optimized by using a dictionary lookup or a more efficient SQL query to avoid O(N*M) complexity where N is the number of pending passes and M is the number of models.


statement = (
select(FutureDB)
.where(FutureDB.status == RequestStatus.PENDING)
Expand All @@ -400,6 +427,13 @@ def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.R
)
other_futures = session.exec(statement).all()

# Filter: only include ops that come before the first blocked pass for their model
other_futures = [
op
for op in other_futures
if op.model_id not in blocked_pass_barriers or op.request_id < blocked_pass_barriers[op.model_id]
]

return {str(f.request_id): (f.model_id, f.request_type, f.request_data) for f in other_futures}

def process_create_model(self, model_id: str, request_data: types.CreateModelInput) -> types.CreateModelOutput:
Expand Down
111 changes: 110 additions & 1 deletion tests/tinker/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from skyrl.tinker import types
from skyrl.tinker.config import EngineConfig
from skyrl.tinker.db_models import ModelDB, SessionDB
from skyrl.tinker.db_models import FutureDB, ModelDB, RequestStatus, SessionDB
from skyrl.tinker.engine import TinkerEngine, prepare_model_pass_batch

BASE_MODEL = "trl-internal-testing/tiny-Qwen3ForCausalLM"
Expand Down Expand Up @@ -134,3 +134,112 @@ def test_prepare_model_pass_batch_loss_fn_and_config(
assert batch.all_loss_fns == [loss_fn]
assert batch.all_loss_fn_configs == [loss_fn_config]
assert batch.all_model_inputs == [datum.model_input]


@pytest.fixture()
def scheduling_engine():
"""Create a TinkerEngine with only the DB initialized (no backend) for scheduling tests."""
from sqlalchemy import create_engine

from skyrl.tinker.db_models import enable_sqlite_wal

engine = object.__new__(TinkerEngine)
engine.db_engine = create_engine("sqlite:///:memory:", echo=False)
enable_sqlite_wal(engine.db_engine)
SQLModel.metadata.create_all(engine.db_engine)
return engine


def test_find_single_requests_respects_forward_backward_barriers(scheduling_engine):
"""Regression: optim_step must not run before a preceding forward_backward for the same model.

Given pending requests [fwdbwd1, optim1, fwdbwd2, optim2] for the same model,
find_single_requests should only return optim1 (not optim2), because fwdbwd2
acts as a barrier — optim2 depends on fwdbwd2's gradients.
"""
engine = scheduling_engine
model_id = "test_model"

with Session(engine.db_engine) as session:
# Insert requests in order: fwdbwd1, optim1, fwdbwd2, optim2
for req_type in [
types.RequestType.FORWARD_BACKWARD,
types.RequestType.OPTIM_STEP,
types.RequestType.FORWARD_BACKWARD,
types.RequestType.OPTIM_STEP,
]:
session.add(
FutureDB(
request_type=req_type,
model_id=model_id,
request_data={},
status=RequestStatus.PENDING,
)
)
session.commit()

with Session(engine.db_engine) as session:
# find_single_requests should return only optim1 (request_id=2), NOT optim2 (request_id=4)
singles = engine.find_single_requests(session)
assert list(singles.keys()) == ["2"]


def test_find_single_requests_no_barrier_when_no_pending_passes(scheduling_engine):
"""When there are no pending forward/forward_backward requests, all single requests are returned."""
engine = scheduling_engine

with Session(engine.db_engine) as session:
for model_id in ["model_a", "model_b"]:
session.add(
FutureDB(
request_type=types.RequestType.OPTIM_STEP,
model_id=model_id,
request_data={},
status=RequestStatus.PENDING,
)
)
session.commit()

with Session(engine.db_engine) as session:
singles = engine.find_single_requests(session)
assert len(singles) == 2


def test_find_single_requests_barrier_is_per_model(scheduling_engine):
"""A blocked forward_backward on model_a should not block an optim_step on model_b."""
engine = scheduling_engine

with Session(engine.db_engine) as session:
# model_a: fwdbwd(1), optim(2), fwdbwd(3), optim(4)
# model_b: optim(5)
for req_type in [
types.RequestType.FORWARD_BACKWARD,
types.RequestType.OPTIM_STEP,
types.RequestType.FORWARD_BACKWARD,
types.RequestType.OPTIM_STEP,
]:
session.add(
FutureDB(
request_type=req_type,
model_id="model_a",
request_data={},
status=RequestStatus.PENDING,
)
)
session.add(
FutureDB(
request_type=types.RequestType.OPTIM_STEP,
model_id="model_b",
request_data={},
status=RequestStatus.PENDING,
)
)
session.commit()

with Session(engine.db_engine) as session:
singles = engine.find_single_requests(session)
# model_a: optim(2) returned, optim(4) blocked by fwdbwd(3)
# model_b: optim(5) returned (not affected by model_a's barrier)
assert list(singles.keys()) == ["2", "5"]
assert singles["2"][0] == "model_a"
assert singles["5"][0] == "model_b"
Loading