From 35bb94936e731660390175e70badb4a50ccdf6f5 Mon Sep 17 00:00:00 2001 From: JACOB STANLEY Date: Sat, 27 Jun 2026 02:34:26 +0100 Subject: [PATCH 1/3] https://github.com/jaynomyaro/astroml.git --- astroml/db/schema.py | 80 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/astroml/db/schema.py b/astroml/db/schema.py index 1de0971..f81ced8 100644 --- a/astroml/db/schema.py +++ b/astroml/db/schema.py @@ -535,3 +535,83 @@ class NormalizedTransaction(Base): postgresql_where=(receiver.isnot(None)), ), ) + + +# --------------------------------------------------------------------------- +# Model Registry +# --------------------------------------------------------------------------- + +class Model(Base): + """Machine learning model metadata for the model registry.""" + + __tablename__ = "models" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(String(128), nullable=False, unique=True) + description: Mapped[Optional[str]] = mapped_column(Text) + framework: Mapped[str] = mapped_column(String(32), nullable=False) + task_type: Mapped[str] = mapped_column(String(32), nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, server_default="true") + created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + nullable=False, server_default=func.now(), onupdate=func.now() + ) + + # Relationships + versions: Mapped[list[ModelVersion]] = relationship( + back_populates="model", + cascade="all, delete-orphan", + ) + + __table_args__ = ( + Index("ix_models_framework", "framework"), + Index("ix_models_task_type", "task_type"), + Index("ix_models_is_active", "is_active"), + CheckConstraint( + "framework IN ('pytorch', 'tensorflow', 'sklearn', 'xgboost', 'lightgbm', 'custom')", + name="ck_models_framework", + ), + CheckConstraint( + "task_type IN ('classification', 'regression', 'anomaly_detection', 'clustering', 'custom')", + name="ck_models_task_type", + ), + ) + + +class ModelVersion(Base): + """Specific version of a machine learning model with artifacts and metrics.""" + + __tablename__ = "model_versions" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + model_id: Mapped[int] = mapped_column( + Integer, ForeignKey("models.id"), nullable=False + ) + version: Mapped[str] = mapped_column(String(32), nullable=False) + artifact_path: Mapped[str] = mapped_column(String(512), nullable=False) + hyperparameters: Mapped[Optional[dict]] = mapped_column( + JSON().with_variant(JSONB(), "postgresql") + ) + metrics: Mapped[Optional[dict]] = mapped_column( + JSON().with_variant(JSONB(), "postgresql") + ) + status: Mapped[str] = mapped_column(String(32), nullable=False, server_default="training") + created_at: Mapped[datetime] = mapped_column(nullable=False, server_default=func.now()) + updated_at: Mapped[datetime] = mapped_column( + nullable=False, server_default=func.now(), onupdate=func.now() + ) + deployed_at: Mapped[Optional[datetime]] = mapped_column() + + # Relationships + model: Mapped[Model] = relationship(back_populates="versions") + + __table_args__ = ( + UniqueConstraint("model_id", "version", name="uq_model_versions_model_version"), + Index("ix_model_versions_model_id", "model_id"), + Index("ix_model_versions_status", "status"), + Index("ix_model_versions_created_at", "created_at"), + CheckConstraint( + "status IN ('training', 'trained', 'deployed', 'archived', 'failed')", + name="ck_model_versions_status", + ), + ) From 5cab9d262fb7b7dbe83d776aeca7985a35a4c355 Mon Sep 17 00:00:00 2001 From: JACOB STANLEY Date: Sat, 27 Jun 2026 02:51:18 +0100 Subject: [PATCH 2/3] https://github.com/jaynomyaro/astroml.git --- astroml/tracking/__init__.py | 3 +- astroml/tracking/model_registry.py | 396 +++++++++++++++++++++++++++++ 2 files changed, 398 insertions(+), 1 deletion(-) create mode 100644 astroml/tracking/model_registry.py diff --git a/astroml/tracking/__init__.py b/astroml/tracking/__init__.py index e2b03d5..cc8a2f8 100644 --- a/astroml/tracking/__init__.py +++ b/astroml/tracking/__init__.py @@ -1,3 +1,4 @@ from .mlflow_tracker import MLflowTracker +from .model_registry import ModelRegistry -__all__ = ["MLflowTracker"] +__all__ = ["MLflowTracker", "ModelRegistry"] diff --git a/astroml/tracking/model_registry.py b/astroml/tracking/model_registry.py new file mode 100644 index 0000000..020a99f --- /dev/null +++ b/astroml/tracking/model_registry.py @@ -0,0 +1,396 @@ +"""Model registry for managing ML models and their versions.""" +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any, Dict, List, Optional + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from astroml.db.schema import Model, ModelVersion +from astroml.db.session import get_session + +logger = logging.getLogger(__name__) + + +class ModelRegistry: + """Core class for managing ML models and their versions in the database. + + Provides CRUD operations for Model and ModelVersion entities, + with helper methods for common registry operations. + """ + + def __init__(self, session: Optional[Session] = None): + """Initialize the registry. + + Args: + session: Optional SQLAlchemy session. If not provided, creates a new session. + """ + self._session = session + self._owns_session = session is None + + @property + def session(self) -> Session: + """Get the SQLAlchemy session, creating one if needed.""" + if self._session is None: + self._session = get_session() + return self._session + + def close(self) -> None: + """Close the session if we own it.""" + if self._owns_session and self._session is not None: + self._session.close() + self._session = None + + def __enter__(self) -> "ModelRegistry": + return self + + def __exit__(self, *_: Any) -> None: + self.close() + + # ------------------------------------------------------------------ + # Model CRUD operations + # ------------------------------------------------------------------ + + def create_model( + self, + name: str, + framework: str, + task_type: str, + description: Optional[str] = None, + is_active: bool = True, + ) -> Model: + """Create a new model. + + Args: + name: Unique model name + framework: ML framework (pytorch, tensorflow, sklearn, etc.) + task_type: Task type (classification, regression, etc.) + description: Optional model description + is_active: Whether the model is active + + Returns: + Created Model instance + + Raises: + ValueError: If a model with the same name already exists + """ + existing = self.get_model_by_name(name) + if existing: + raise ValueError(f"Model with name '{name}' already exists") + + model = Model( + name=name, + description=description, + framework=framework, + task_type=task_type, + is_active=is_active, + ) + self.session.add(model) + self.session.commit() + self.session.refresh(model) + logger.info("Created model: %s (id=%d)", name, model.id) + return model + + def get_model(self, model_id: int) -> Optional[Model]: + """Get a model by ID. + + Args: + model_id: Model ID + + Returns: + Model instance or None if not found + """ + return self.session.get(Model, model_id) + + def get_model_by_name(self, name: str) -> Optional[Model]: + """Get a model by name. + + Args: + name: Model name + + Returns: + Model instance or None if not found + """ + stmt = select(Model).where(Model.name == name) + return self.session.execute(stmt).scalar_one_or_none() + + def list_models( + self, + framework: Optional[str] = None, + task_type: Optional[str] = None, + is_active: Optional[bool] = None, + ) -> List[Model]: + """List models with optional filters. + + Args: + framework: Filter by framework + task_type: Filter by task type + is_active: Filter by active status + + Returns: + List of Model instances + """ + stmt = select(Model) + if framework: + stmt = stmt.where(Model.framework == framework) + if task_type: + stmt = stmt.where(Model.task_type == task_type) + if is_active is not None: + stmt = stmt.where(Model.is_active == is_active) + stmt = stmt.order_by(Model.created_at.desc()) + return list(self.session.execute(stmt).scalars().all()) + + def update_model( + self, + model_id: int, + description: Optional[str] = None, + is_active: Optional[bool] = None, + ) -> Optional[Model]: + """Update a model. + + Args: + model_id: Model ID + description: New description + is_active: New active status + + Returns: + Updated Model instance or None if not found + """ + model = self.get_model(model_id) + if not model: + return None + + if description is not None: + model.description = description + if is_active is not None: + model.is_active = is_active + + self.session.commit() + self.session.refresh(model) + logger.info("Updated model: %s (id=%d)", model.name, model.id) + return model + + def delete_model(self, model_id: int) -> bool: + """Delete a model and all its versions. + + Args: + model_id: Model ID + + Returns: + True if deleted, False if not found + """ + model = self.get_model(model_id) + if not model: + return False + + self.session.delete(model) + self.session.commit() + logger.info("Deleted model: %s (id=%d)", model.name, model_id) + return True + + # ------------------------------------------------------------------ + # ModelVersion CRUD operations + # ------------------------------------------------------------------ + + def create_model_version( + self, + model_id: int, + version: str, + artifact_path: str, + hyperparameters: Optional[Dict[str, Any]] = None, + metrics: Optional[Dict[str, Any]] = None, + status: str = "training", + ) -> ModelVersion: + """Create a new model version. + + Args: + model_id: Parent model ID + version: Version string (e.g., "1.0.0") + artifact_path: Path to model artifacts + hyperparameters: Optional hyperparameters dict + metrics: Optional metrics dict + status: Version status (training, trained, deployed, etc.) + + Returns: + Created ModelVersion instance + + Raises: + ValueError: If model not found or version already exists for this model + """ + model = self.get_model(model_id) + if not model: + raise ValueError(f"Model with id {model_id} not found") + + existing = self.get_model_version(model_id, version) + if existing: + raise ValueError(f"Version '{version}' already exists for model {model_id}") + + model_version = ModelVersion( + model_id=model_id, + version=version, + artifact_path=artifact_path, + hyperparameters=hyperparameters, + metrics=metrics, + status=status, + ) + self.session.add(model_version) + self.session.commit() + self.session.refresh(model_version) + logger.info( + "Created model version: %s (id=%d, model_id=%d)", + version, + model_version.id, + model_id, + ) + return model_version + + def get_model_version(self, model_id: int, version: str) -> Optional[ModelVersion]: + """Get a specific model version. + + Args: + model_id: Model ID + version: Version string + + Returns: + ModelVersion instance or None if not found + """ + stmt = select(ModelVersion).where( + ModelVersion.model_id == model_id, ModelVersion.version == version + ) + return self.session.execute(stmt).scalar_one_or_none() + + def get_model_version_by_id(self, version_id: int) -> Optional[ModelVersion]: + """Get a model version by ID. + + Args: + version_id: ModelVersion ID + + Returns: + ModelVersion instance or None if not found + """ + return self.session.get(ModelVersion, version_id) + + def list_model_versions( + self, + model_id: Optional[int] = None, + status: Optional[str] = None, + ) -> List[ModelVersion]: + """List model versions with optional filters. + + Args: + model_id: Filter by model ID + status: Filter by status + + Returns: + List of ModelVersion instances + """ + stmt = select(ModelVersion) + if model_id: + stmt = stmt.where(ModelVersion.model_id == model_id) + if status: + stmt = stmt.where(ModelVersion.status == status) + stmt = stmt.order_by(ModelVersion.created_at.desc()) + return list(self.session.execute(stmt).scalars().all()) + + def update_model_version( + self, + version_id: int, + status: Optional[str] = None, + metrics: Optional[Dict[str, Any]] = None, + deployed_at: Optional[datetime] = None, + ) -> Optional[ModelVersion]: + """Update a model version. + + Args: + version_id: ModelVersion ID + status: New status + metrics: New or updated metrics + deployed_at: Deployment timestamp + + Returns: + Updated ModelVersion instance or None if not found + """ + version = self.get_model_version_by_id(version_id) + if not version: + return None + + if status is not None: + version.status = status + if metrics is not None: + version.metrics = metrics + if deployed_at is not None: + version.deployed_at = deployed_at + + self.session.commit() + self.session.refresh(version) + logger.info("Updated model version: %s (id=%d)", version.version, version_id) + return version + + def delete_model_version(self, version_id: int) -> bool: + """Delete a model version. + + Args: + version_id: ModelVersion ID + + Returns: + True if deleted, False if not found + """ + version = self.get_model_version_by_id(version_id) + if not version: + return False + + self.session.delete(version) + self.session.commit() + logger.info("Deleted model version: %s (id=%d)", version.version, version_id) + return True + + # ------------------------------------------------------------------ + # Helper methods + # ------------------------------------------------------------------ + + def get_latest_version(self, model_id: int) -> Optional[ModelVersion]: + """Get the latest version of a model by creation time. + + Args: + model_id: Model ID + + Returns: + Latest ModelVersion or None if no versions exist + """ + stmt = ( + select(ModelVersion) + .where(ModelVersion.model_id == model_id) + .order_by(ModelVersion.created_at.desc()) + .limit(1) + ) + return self.session.execute(stmt).scalar_one_or_none() + + def get_deployed_version(self, model_id: int) -> Optional[ModelVersion]: + """Get the deployed version of a model. + + Args: + model_id: Model ID + + Returns: + Deployed ModelVersion or None if no deployed version exists + """ + stmt = ( + select(ModelVersion) + .where(ModelVersion.model_id == model_id, ModelVersion.status == "deployed") + .order_by(ModelVersion.deployed_at.desc()) + .limit(1) + ) + return self.session.execute(stmt).scalar_one_or_none() + + def mark_deployed(self, version_id: int) -> Optional[ModelVersion]: + """Mark a model version as deployed. + + Args: + version_id: ModelVersion ID + + Returns: + Updated ModelVersion or None if not found + """ + return self.update_model_version(version_id, status="deployed", deployed_at=datetime.now(datetime.UTC)) From 05ebd6141965b8cb073ce24450e429f73ac83347 Mon Sep 17 00:00:00 2001 From: JACOB STANLEY Date: Sat, 27 Jun 2026 03:02:41 +0100 Subject: [PATCH 3/3] https://github.com/jaynomyaro/astroml.git --- astroml/tracking/model_registry.py | 154 +++++++++++++++++++++++++++++ tests/test_schema.py | 89 ++++++++++++++++- 2 files changed, 242 insertions(+), 1 deletion(-) diff --git a/astroml/tracking/model_registry.py b/astroml/tracking/model_registry.py index 020a99f..5486470 100644 --- a/astroml/tracking/model_registry.py +++ b/astroml/tracking/model_registry.py @@ -14,6 +14,30 @@ logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# ModelVersion State Machine +# --------------------------------------------------------------------------- + +# Valid status transitions for ModelVersion +# Format: {from_status: [to_status, ...]} +VALID_STATUS_TRANSITIONS = { + "training": ["trained", "failed"], + "trained": ["deployed", "archived"], + "deployed": ["archived"], + "archived": [], # Terminal state + "failed": ["training"], # Can retry training +} + +# All valid statuses +VALID_STATUSES = set(VALID_STATUS_TRANSITIONS.keys()) + + +class InvalidStatusTransitionError(ValueError): + """Raised when an invalid status transition is attempted.""" + + pass + + class ModelRegistry: """Core class for managing ML models and their versions in the database. @@ -300,6 +324,7 @@ def update_model_version( status: Optional[str] = None, metrics: Optional[Dict[str, Any]] = None, deployed_at: Optional[datetime] = None, + validate_transition: bool = True, ) -> Optional[ModelVersion]: """Update a model version. @@ -308,15 +333,21 @@ def update_model_version( status: New status metrics: New or updated metrics deployed_at: Deployment timestamp + validate_transition: Whether to validate status transitions (default: True) Returns: Updated ModelVersion instance or None if not found + + Raises: + InvalidStatusTransitionError: If status transition is invalid """ version = self.get_model_version_by_id(version_id) if not version: return None if status is not None: + if validate_transition: + self._validate_status_transition(version.status, status) version.status = status if metrics is not None: version.metrics = metrics @@ -392,5 +423,128 @@ def mark_deployed(self, version_id: int) -> Optional[ModelVersion]: Returns: Updated ModelVersion or None if not found + + Raises: + InvalidStatusTransitionError: If version cannot be deployed """ return self.update_model_version(version_id, status="deployed", deployed_at=datetime.now(datetime.UTC)) + + # ------------------------------------------------------------------ + # State machine methods + # ------------------------------------------------------------------ + + @staticmethod + def _validate_status_transition(from_status: str, to_status: str) -> None: + """Validate that a status transition is allowed. + + Args: + from_status: Current status + to_status: Target status + + Raises: + InvalidStatusTransitionError: If transition is not allowed + """ + if to_status not in VALID_STATUSES: + raise InvalidStatusTransitionError(f"Invalid target status: '{to_status}'") + + if from_status == to_status: + return # No-op transition is allowed + + allowed_transitions = VALID_STATUS_TRANSITIONS.get(from_status, []) + if to_status not in allowed_transitions: + raise InvalidStatusTransitionError( + f"Cannot transition from '{from_status}' to '{to_status}'. " + f"Allowed transitions from '{from_status}': {allowed_transitions}" + ) + + def transition_status(self, version_id: int, to_status: str) -> Optional[ModelVersion]: + """Transition a model version to a new status with validation. + + Args: + version_id: ModelVersion ID + to_status: Target status + + Returns: + Updated ModelVersion or None if not found + + Raises: + InvalidStatusTransitionError: If transition is not allowed + """ + version = self.get_model_version_by_id(version_id) + if not version: + return None + + self._validate_status_transition(version.status, to_status) + return self.update_model_version(version_id, status=to_status) + + def mark_trained(self, version_id: int, metrics: Optional[Dict[str, Any]] = None) -> Optional[ModelVersion]: + """Mark a model version as trained. + + Args: + version_id: ModelVersion ID + metrics: Optional training metrics + + Returns: + Updated ModelVersion or None if not found + + Raises: + InvalidStatusTransitionError: If version cannot be marked as trained + """ + return self.update_model_version(version_id, status="trained", metrics=metrics) + + def mark_failed(self, version_id: int) -> Optional[ModelVersion]: + """Mark a model version as failed. + + Args: + version_id: ModelVersion ID + + Returns: + Updated ModelVersion or None if not found + + Raises: + InvalidStatusTransitionError: If version cannot be marked as failed + """ + return self.update_model_version(version_id, status="failed") + + def mark_archived(self, version_id: int) -> Optional[ModelVersion]: + """Mark a model version as archived. + + Args: + version_id: ModelVersion ID + + Returns: + Updated ModelVersion or None if not found + + Raises: + InvalidStatusTransitionError: If version cannot be archived + """ + return self.update_model_version(version_id, status="archived") + + def retry_training(self, version_id: int) -> Optional[ModelVersion]: + """Retry training for a failed model version. + + Args: + version_id: ModelVersion ID + + Returns: + Updated ModelVersion or None if not found + + Raises: + InvalidStatusTransitionError: If version cannot be retried + """ + return self.update_model_version(version_id, status="training") + + def get_valid_transitions(self, version_id: int) -> List[str]: + """Get valid status transitions for a model version. + + Args: + version_id: ModelVersion ID + + Returns: + List of valid target statuses, or empty list if version not found + """ + version = self.get_model_version_by_id(version_id) + if not version: + return [] + + return VALID_STATUS_TRANSITIONS.get(version.status, []).copy() diff --git a/tests/test_schema.py b/tests/test_schema.py index c51d33e..9578351 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -21,6 +21,8 @@ GraphPaymentDetail, GraphTransactionDetail, Ledger, + Model, + ModelVersion, Operation, Transaction, ) @@ -50,7 +52,7 @@ def session(engine): # --------------------------------------------------------------------------- def test_models_importable(): - """All five model classes import cleanly.""" + """All model classes import cleanly.""" for cls in ( Ledger, Transaction, @@ -62,6 +64,8 @@ def test_models_importable(): GraphTransactionDetail, GraphClaimDetail, GraphPaymentDetail, + Model, + ModelVersion, ): assert hasattr(cls, "__tablename__") @@ -80,6 +84,8 @@ def test_create_all_tables(engine): "graph_payment_details", "graph_transaction_details", "ledgers", + "model_versions", + "models", "normalized_transactions", "operations", "transactions", @@ -95,6 +101,8 @@ def test_table_names(): assert Asset.__tablename__ == "assets" assert GraphAccount.__tablename__ == "graph_accounts" assert GraphEdge.__tablename__ == "graph_edges" + assert Model.__tablename__ == "models" + assert ModelVersion.__tablename__ == "model_versions" # --------------------------------------------------------------------------- @@ -226,6 +234,48 @@ def test_graph_detail_columns(engine): assert {"edge_id", "edge_type", "payment_reference", "payment_status", "fee_amount", "settled_at", "details"} <= payment_cols +def test_model_columns(engine): + inspector = inspect(engine) + cols = {c["name"] for c in inspector.get_columns("models")} + expected = { + "id", + "name", + "description", + "framework", + "task_type", + "is_active", + "created_at", + "updated_at", + } + assert expected <= cols + + +def test_model_version_columns(engine): + inspector = inspect(engine) + cols = {c["name"] for c in inspector.get_columns("model_versions")} + expected = { + "id", + "model_id", + "version", + "artifact_path", + "hyperparameters", + "metrics", + "status", + "created_at", + "updated_at", + "deployed_at", + } + assert expected <= cols + + # FK to models + fks = inspector.get_foreign_keys("model_versions") + assert any( + fk["referred_table"] == "models" + and fk["referred_columns"] == ["id"] + for fk in fks + ) + + # --------------------------------------------------------------------------- # Relationships # --------------------------------------------------------------------------- @@ -319,6 +369,43 @@ def test_graph_relationships(session): assert detail.edge is edge +def test_model_registry_relationships(session): + """Model.versions cascade deletes ModelVersion rows.""" + now = datetime.now(timezone.utc) + + model = Model( + name="test-model", + framework="pytorch", + task_type="classification", + description="Test model", + ) + session.add(model) + session.flush() + + version1 = ModelVersion( + model_id=model.id, + version="1.0.0", + artifact_path="/models/v1", + status="trained", + ) + version2 = ModelVersion( + model_id=model.id, + version="2.0.0", + artifact_path="/models/v2", + status="training", + ) + session.add_all([version1, version2]) + session.flush() + + session.refresh(model) + + assert len(model.versions) == 2 + assert version1 in model.versions + assert version2 in model.versions + assert version1.model is model + assert version2.model is model + + # --------------------------------------------------------------------------- # Round-trip insert & query # ---------------------------------------------------------------------------