diff --git a/skyrl/tinker/engine.py b/skyrl/tinker/engine.py index 6c449eb9f7..da73de70ae 100644 --- a/skyrl/tinker/engine.py +++ b/skyrl/tinker/engine.py @@ -296,6 +296,23 @@ def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoi ) session.commit() + def _find_destructive_barriers(self, session: Session) -> dict[str, int]: + """Find the earliest pending destructive operation (optim_step/load_weights) per model. + + These act as scheduling barriers: model passes before them can be batched + safely, and single requests after a blocked pass must wait. + """ + query = ( + select(FutureDB.model_id, func.min(FutureDB.request_id).label("barrier_id")) + .where( + (FutureDB.request_type == types.RequestType.OPTIM_STEP) + | (FutureDB.request_type == types.RequestType.LOAD_WEIGHTS) + ) + .where(FutureDB.status == RequestStatus.PENDING) + .group_by(FutureDB.model_id) + ) + return dict(session.exec(query).all()) + def find_batchable_model_passes( self, session: Session, request_type: types.RequestType ) -> dict[str, tuple[str, types.ForwardBackwardInput]]: @@ -311,17 +328,7 @@ def find_batchable_model_passes( Returns: Dict mapping request_id to (model_id, request_data) tuples """ - # Find the earliest pending optim_step or load_weights per model (these act as barriers) - barriers_query = ( - select(FutureDB.model_id, func.min(FutureDB.request_id).label("barrier_id")) - .where( - (FutureDB.request_type == types.RequestType.OPTIM_STEP) - | (FutureDB.request_type == types.RequestType.LOAD_WEIGHTS) - ) - .where(FutureDB.status == RequestStatus.PENDING) - .group_by(FutureDB.model_id) - ) - barriers = dict(session.exec(barriers_query).all()) + barriers = self._find_destructive_barriers(session) # Get all pending operations of the requested type ordered by request_id query = ( @@ -389,6 +396,26 @@ def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.R Returns: Dict mapping request_id to (model_id, request_type, request_data) tuples """ + # Find the first blocked forward pass per model: a pending FORWARD/FORWARD_BACKWARD + # that sits behind a destructive barrier and won't be batched this iteration. + # Single requests must not jump ahead of these. + destructive_barriers = self._find_destructive_barriers(session) + blocked_pass_barriers: dict[str, int] = {} + if destructive_barriers: + pending_passes = session.exec( + select(FutureDB.model_id, FutureDB.request_id) + .where( + (FutureDB.request_type == types.RequestType.FORWARD_BACKWARD) + | (FutureDB.request_type == types.RequestType.FORWARD) + ) + .where(FutureDB.status == RequestStatus.PENDING) + .where(FutureDB.model_id.in_(destructive_barriers.keys())) + .order_by(FutureDB.request_id) + ).all() + for model_id, req_id in pending_passes: + if req_id >= destructive_barriers[model_id]: + blocked_pass_barriers.setdefault(model_id, req_id) + statement = ( select(FutureDB) .where(FutureDB.status == RequestStatus.PENDING) @@ -400,6 +427,13 @@ def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.R ) other_futures = session.exec(statement).all() + # Filter: only include ops that come before the first blocked pass for their model + other_futures = [ + op + for op in other_futures + if op.model_id not in blocked_pass_barriers or op.request_id < blocked_pass_barriers[op.model_id] + ] + return {str(f.request_id): (f.model_id, f.request_type, f.request_data) for f in other_futures} def process_create_model(self, model_id: str, request_data: types.CreateModelInput) -> types.CreateModelOutput: diff --git a/tests/tinker/test_engine.py b/tests/tinker/test_engine.py index 2d3db81946..4fc60f6b3c 100644 --- a/tests/tinker/test_engine.py +++ b/tests/tinker/test_engine.py @@ -6,7 +6,7 @@ from skyrl.tinker import types from skyrl.tinker.config import EngineConfig -from skyrl.tinker.db_models import ModelDB, SessionDB +from skyrl.tinker.db_models import FutureDB, ModelDB, RequestStatus, SessionDB from skyrl.tinker.engine import TinkerEngine, prepare_model_pass_batch BASE_MODEL = "trl-internal-testing/tiny-Qwen3ForCausalLM" @@ -134,3 +134,112 @@ def test_prepare_model_pass_batch_loss_fn_and_config( assert batch.all_loss_fns == [loss_fn] assert batch.all_loss_fn_configs == [loss_fn_config] assert batch.all_model_inputs == [datum.model_input] + + +@pytest.fixture() +def scheduling_engine(): + """Create a TinkerEngine with only the DB initialized (no backend) for scheduling tests.""" + from sqlalchemy import create_engine + + from skyrl.tinker.db_models import enable_sqlite_wal + + engine = object.__new__(TinkerEngine) + engine.db_engine = create_engine("sqlite:///:memory:", echo=False) + enable_sqlite_wal(engine.db_engine) + SQLModel.metadata.create_all(engine.db_engine) + return engine + + +def test_find_single_requests_respects_forward_backward_barriers(scheduling_engine): + """Regression: optim_step must not run before a preceding forward_backward for the same model. + + Given pending requests [fwdbwd1, optim1, fwdbwd2, optim2] for the same model, + find_single_requests should only return optim1 (not optim2), because fwdbwd2 + acts as a barrier — optim2 depends on fwdbwd2's gradients. + """ + engine = scheduling_engine + model_id = "test_model" + + with Session(engine.db_engine) as session: + # Insert requests in order: fwdbwd1, optim1, fwdbwd2, optim2 + for req_type in [ + types.RequestType.FORWARD_BACKWARD, + types.RequestType.OPTIM_STEP, + types.RequestType.FORWARD_BACKWARD, + types.RequestType.OPTIM_STEP, + ]: + session.add( + FutureDB( + request_type=req_type, + model_id=model_id, + request_data={}, + status=RequestStatus.PENDING, + ) + ) + session.commit() + + with Session(engine.db_engine) as session: + # find_single_requests should return only optim1 (request_id=2), NOT optim2 (request_id=4) + singles = engine.find_single_requests(session) + assert list(singles.keys()) == ["2"] + + +def test_find_single_requests_no_barrier_when_no_pending_passes(scheduling_engine): + """When there are no pending forward/forward_backward requests, all single requests are returned.""" + engine = scheduling_engine + + with Session(engine.db_engine) as session: + for model_id in ["model_a", "model_b"]: + session.add( + FutureDB( + request_type=types.RequestType.OPTIM_STEP, + model_id=model_id, + request_data={}, + status=RequestStatus.PENDING, + ) + ) + session.commit() + + with Session(engine.db_engine) as session: + singles = engine.find_single_requests(session) + assert len(singles) == 2 + + +def test_find_single_requests_barrier_is_per_model(scheduling_engine): + """A blocked forward_backward on model_a should not block an optim_step on model_b.""" + engine = scheduling_engine + + with Session(engine.db_engine) as session: + # model_a: fwdbwd(1), optim(2), fwdbwd(3), optim(4) + # model_b: optim(5) + for req_type in [ + types.RequestType.FORWARD_BACKWARD, + types.RequestType.OPTIM_STEP, + types.RequestType.FORWARD_BACKWARD, + types.RequestType.OPTIM_STEP, + ]: + session.add( + FutureDB( + request_type=req_type, + model_id="model_a", + request_data={}, + status=RequestStatus.PENDING, + ) + ) + session.add( + FutureDB( + request_type=types.RequestType.OPTIM_STEP, + model_id="model_b", + request_data={}, + status=RequestStatus.PENDING, + ) + ) + session.commit() + + with Session(engine.db_engine) as session: + singles = engine.find_single_requests(session) + # model_a: optim(2) returned, optim(4) blocked by fwdbwd(3) + # model_b: optim(5) returned (not affected by model_a's barrier) + assert list(singles.keys()) == ["2", "5"] + assert singles["2"][0] == "model_a" + assert singles["5"][0] == "model_b"