Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 38 additions & 55 deletions ddcDatabases/db_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager, contextmanager
from datetime import datetime
from typing import Any, AsyncGenerator, Generator, Sequence, TypeVar
Expand All @@ -13,16 +14,16 @@
DBInsertSingleException,
)
from sqlalchemy import RowMapping
from sqlalchemy.engine import create_engine, Engine, URL
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.engine import Engine, URL
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncEngine, AsyncSession
from sqlalchemy.orm import Session, sessionmaker


# Type variable for generic model types
T = TypeVar('T')


class BaseConnection:
class BaseConnection(ABC):
__slots__ = (
'connection_url',
'engine_args',
Expand Down Expand Up @@ -94,33 +95,15 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self._temp_engine.dispose()
self.is_connected = False

@abstractmethod
@contextmanager
def _get_engine(self) -> Generator[Engine, None, None]:
_connection_url = URL.create(
drivername=self.sync_driver,
**self.connection_url,
)
_engine_args = {
"url": _connection_url,
**self.engine_args,
}
_engine = create_engine(**_engine_args)
yield _engine
_engine.dispose()
pass

@abstractmethod
@asynccontextmanager
async def _get_async_engine(self) -> AsyncGenerator[AsyncEngine, None]:
_connection_url = URL.create(
drivername=self.async_driver,
**self.connection_url,
)
_engine_args = {
"url": _connection_url,
**self.engine_args,
}
_engine = create_async_engine(**_engine_args)
yield _engine
await _engine.dispose()
pass

def _test_connection_sync(self, session: Session) -> None:
_connection_url_copy = self.connection_url.copy()
Expand Down Expand Up @@ -197,14 +180,14 @@ def __init__(self, session: Session) -> None:
def fetchall(self, stmt: Any, as_dict: bool = False) -> list[RowMapping] | list[dict]:
"""
Execute a SELECT statement and fetch all results.

Args:
stmt: SQLAlchemy statement or raw SQL string to execute
as_dict: If True, returns list of dicts; if False, returns list of RowMapping objects

Returns:
List of query results as either RowMapping objects or dictionaries

Raises:
DBFetchAllException: If query execution fails
"""
Expand All @@ -225,13 +208,13 @@ def fetchall(self, stmt: Any, as_dict: bool = False) -> list[RowMapping] | list[
def fetchvalue(self, stmt: Any) -> str | None:
"""
Execute a SELECT statement and fetch a single scalar value.

Args:
stmt: SQLAlchemy statement or raw SQL string to execute

Returns:
String representation of the first column of the first row, or None if no results

Raises:
DBFetchValueException: If query execution fails
"""
Expand All @@ -247,13 +230,13 @@ def fetchvalue(self, stmt: Any) -> str | None:
def insert(self, stmt: Any) -> Any:
"""
Insert a single record and return the inserted instance with updated fields.

Args:
stmt: SQLAlchemy model instance to insert

Returns:
The inserted model instance with refreshed data (including auto-generated IDs)

Raises:
DBInsertSingleException: If insert operation fails
"""
Expand All @@ -277,7 +260,7 @@ def insertbulk(self, model: type[T], list_data: Sequence[dict[str, Any]], batch_
model: The SQLAlchemy model class
list_data: List of dictionaries containing the data to insert
batch_size: Number of records to insert per batch (default: 1000)

Raises:
DBInsertBulkException: If bulk insert operation fails
"""
Expand All @@ -297,12 +280,12 @@ def insertbulk(self, model: type[T], list_data: Sequence[dict[str, Any]], batch_
def deleteall(self, model: type[T]) -> None:
"""
Delete all records from a table.

WARNING: This operation removes ALL data from the specified table.

Args:
model: The SQLAlchemy model class representing the table to clear

Raises:
DBDeleteAllDataException: If delete operation fails
"""
Expand All @@ -316,10 +299,10 @@ def deleteall(self, model: type[T]) -> None:
def execute(self, stmt: Any) -> None:
"""
Execute a statement that doesn't return results (INSERT, UPDATE, DELETE).

Args:
stmt: SQLAlchemy statement or raw SQL string to execute

Raises:
DBExecuteException: If statement execution fails
"""
Expand All @@ -340,14 +323,14 @@ def __init__(self, session: AsyncSession):
async def fetchall(self, stmt: Any, as_dict: bool = False) -> list[RowMapping] | list[dict]:
"""
Execute a SELECT statement asynchronously and fetch all results.

Args:
stmt: SQLAlchemy statement or raw SQL string to execute
as_dict: If True, returns list of dicts; if False, returns list of RowMapping objects

Returns:
List of query results as either RowMapping objects or dictionaries

Raises:
DBFetchAllException: If query execution fails
"""
Expand All @@ -368,13 +351,13 @@ async def fetchall(self, stmt: Any, as_dict: bool = False) -> list[RowMapping] |
async def fetchvalue(self, stmt) -> str | None:
"""
Execute a SELECT statement asynchronously and fetch a single scalar value.

Args:
stmt: SQLAlchemy statement or raw SQL string to execute

Returns:
String representation of the first column of the first row, or None if no results

Raises:
DBFetchValueException: If query execution fails
"""
Expand All @@ -390,13 +373,13 @@ async def fetchvalue(self, stmt) -> str | None:
async def insert(self, stmt: Any) -> Any:
"""
Insert a single record asynchronously and return the inserted instance with updated fields.

Args:
stmt: SQLAlchemy model instance to insert

Returns:
The inserted model instance with refreshed data (including auto-generated IDs)

Raises:
DBInsertSingleException: If insert operation fails
"""
Expand All @@ -420,7 +403,7 @@ async def insertbulk(self, model: type[T], list_data: Sequence[dict[str, Any]],
model: The SQLAlchemy model class
list_data: List of dictionaries containing the data to insert
batch_size: Number of records to insert per batch (default: 1000)

Raises:
DBInsertBulkException: If bulk insert operation fails
"""
Expand All @@ -442,12 +425,12 @@ async def insertbulk(self, model: type[T], list_data: Sequence[dict[str, Any]],
async def deleteall(self, model: type[T]) -> None:
"""
Delete all records from a table asynchronously.

WARNING: This operation removes ALL data from the specified table.

Args:
model: The SQLAlchemy model class representing the table to clear

Raises:
DBDeleteAllDataException: If delete operation fails
"""
Expand All @@ -462,10 +445,10 @@ async def deleteall(self, model: type[T]) -> None:
async def execute(self, stmt: Any) -> None:
"""
Execute a statement asynchronously that doesn't return results (INSERT, UPDATE, DELETE).

Args:
stmt: SQLAlchemy statement or raw SQL string to execute

Raises:
DBExecuteException: If statement execution fails
"""
Expand Down
104 changes: 92 additions & 12 deletions ddcDatabases/mongodb.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,62 @@
import logging
import sys
from dataclasses import dataclass
from typing import Optional, Type
from pymongo import ASCENDING, DESCENDING, MongoClient
from pymongo.cursor import Cursor
from pymongo.errors import PyMongoError
from .settings import get_mongodb_settings


@dataclass(slots=True, frozen=True)
class MongoConnectionConfig:
host: str | None = None
port: int | None = None
user: str | None = None
password: str | None = None
database: str | None = None
collection: str | None = None


@dataclass(slots=True, frozen=True)
class MongoQueryConfig:
query: dict | None = None
sort_column: str | None = None
sort_order: str | None = None
batch_size: int | None = None
limit: int | None = None


logger = logging.getLogger(__name__)
# Add NullHandler to prevent "No handlers found" warnings in libraries
logger.addHandler(logging.NullHandler())


class MongoDB:
"""
Class to handle MongoDB connections
Class to handle MongoDB connections.
"""

__slots__ = (
'host',
'port',
'user',
'password',
'database',
'collection',
'query',
'sort_column',
'sort_order',
'batch_size',
'limit',
'sync_driver',
'is_connected',
'client',
'cursor_ref',
'_connection_config',
'_query_config',
)

def __init__(
self,
host: str | None = None,
Expand All @@ -33,17 +73,36 @@ def __init__(
):
_settings = get_mongodb_settings()

self.host = host or _settings.host
self.port = port or _settings.port
self.user = user or _settings.user
self.password = password or _settings.password
self.database = database or _settings.database
self.collection = collection
self.query = query or {}
self.sort_column = sort_column
self.sort_order = sort_order
self.batch_size = batch_size or _settings.batch_size
self.limit = limit or _settings.limit
# Create configuration objects using dataclasses
self._connection_config = MongoConnectionConfig(
host=host or _settings.host,
port=port or _settings.port,
user=user or _settings.user,
password=password or _settings.password,
database=database or _settings.database,
collection=collection,
)

self._query_config = MongoQueryConfig(
query=query or {},
sort_column=sort_column,
sort_order=sort_order,
batch_size=batch_size or _settings.batch_size,
limit=limit or _settings.limit,
)

# Set instance attributes for backward compatibility
self.host = self._connection_config.host
self.port = self._connection_config.port
self.user = self._connection_config.user
self.password = self._connection_config.password
self.database = self._connection_config.database
self.collection = self._connection_config.collection
self.query = self._query_config.query
self.sort_column = self._query_config.sort_column
self.sort_order = self._query_config.sort_order
self.batch_size = self._query_config.batch_size
self.limit = self._query_config.limit
self.sync_driver = _settings.sync_driver
self.is_connected = False
self.client = None
Expand All @@ -52,6 +111,27 @@ def __init__(
if not self.collection:
raise ValueError("MongoDB collection name is required")

def __repr__(self) -> str:
"""String representation using configuration objects."""
return (
"MongoDB("
f"host={self._connection_config.host!r}, "
f"port={self._connection_config.port}, "
f"database={self._connection_config.database!r}, "
f"collection={self._connection_config.collection!r}, "
f"batch_size={self._query_config.batch_size}, "
f"limit={self._query_config.limit}"
")"
)

def get_connection_info(self) -> MongoConnectionConfig:
"""Get immutable connection configuration."""
return self._connection_config

def get_query_info(self) -> MongoQueryConfig:
"""Get immutable query configuration."""
return self._query_config

def __enter__(self) -> Cursor:
try:
_connection_url = f"{self.sync_driver}://{self.user}:{self.password}@{self.host}/{self.database}"
Expand Down
Loading
Loading