Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 53 additions & 28 deletions hindsight-api-slim/hindsight_api/_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import logging
import os
import re
from collections.abc import Collection

from sqlalchemy import text
from sqlalchemy.engine import Connection
Expand All @@ -22,6 +24,7 @@
VALID_EXTENSIONS = CONFIGURABLE_EXTENSIONS

SCANN_MIN_ROWS_FOR_AUTO_INDEX = 10_000
PGVECTOR_HNSW_MAX_DIMENSIONS = 2_000


_EXTENSION_NAMES = {
Expand All @@ -31,12 +34,12 @@
"scann": "alloydb_scann",
}

_INDEX_USING_CLAUSES = {
"pgvector": "USING hnsw (embedding vector_cosine_ops)",
"pgvectorscale": "USING diskann (embedding vector_cosine_ops) WITH (num_neighbors = 50)",
"pg_diskann": "USING diskann (embedding vector_cosine_ops) WITH (max_neighbors = 50)",
"vchord": "USING vchordrq (embedding vector_cosine_ops)",
"scann": "USING scann (embedding cosine) WITH (mode = 'AUTO')",
_INDEX_USING_TEMPLATES = {
"pgvector": "USING hnsw ({column} vector_cosine_ops)",
"pgvectorscale": "USING diskann ({column} vector_cosine_ops) WITH (num_neighbors = 50)",
"pg_diskann": "USING diskann ({column} vector_cosine_ops) WITH (max_neighbors = 50)",
"vchord": "USING vchordrq ({column} vector_cosine_ops)",
"scann": "USING scann ({column} cosine) WITH (mode = 'AUTO')",
}

_INDEX_TYPE_KEYWORDS = {
Expand Down Expand Up @@ -130,9 +133,11 @@ def pg_extension_name(ext: str) -> str:
return _EXTENSION_NAMES[validate_extension(ext)]


def index_using_clause(ext: str) -> str:
def index_using_clause(ext: str, *, column: str = "embedding") -> str:
"""Return the CREATE INDEX USING clause for the vector backend."""
return _INDEX_USING_CLAUSES[_normalize_resolved(ext)]
if not re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", column):
raise ValueError(f"Invalid vector index column name: {column!r}")
return _INDEX_USING_TEMPLATES[_normalize_resolved(ext)].format(column=column)


def index_type_keyword(ext: str) -> str:
Expand All @@ -151,6 +156,18 @@ def should_defer_index_creation(ext: str, row_count: int) -> bool:
return minimum_rows > 0 and row_count < minimum_rows


def validate_vector_index_dimension(ext: str, dimension: int, *, table_name: str | None = None) -> None:
"""Raise if the backend cannot index vectors of the requested dimension."""
if _normalize_resolved(ext) == "pgvector" and dimension > PGVECTOR_HNSW_MAX_DIMENSIONS:
location = f" on {table_name}" if table_name else ""
raise RuntimeError(
f"Embedding dimension {dimension}{location} exceeds pgvector HNSW index limit of "
f"{PGVECTOR_HNSW_MAX_DIMENSIONS}. Use an embedding model with <= "
f"{PGVECTOR_HNSW_MAX_DIMENSIONS} dimensions, or switch to a vector extension that supports higher "
"dimensions (e.g., pgvectorscale/DiskANN or AlloyDB ScaNN)."
)


def ann_search_tuning_settings(ext: str, *, kind: str) -> tuple[tuple[str, str], ...]:
"""Return per-backend (guc_name, value) pairs for ANN search-time tuning.

Expand Down Expand Up @@ -182,44 +199,52 @@ def bootstrap_extension(conn: Connection, ext: str) -> None:
conn.execute(text(statement))


def detect_vector_extension(conn: Connection, vector_extension: str = "pgvector") -> str:
"""Validate the configured vector extension exists and return the index backend."""
def resolve_vector_extension_from_installed(vector_extension: str, installed_extensions: Collection[str]) -> str:
"""Return the resolved backend from the configured value and installed PG extensions."""
configured_ext = validate_extension(vector_extension)
installed = set(installed_extensions)

if configured_ext == "pgvectorscale":
pgvector_check = conn.execute(text("SELECT 1 FROM pg_extension WHERE extname = 'vector'")).scalar()
if not pgvector_check:
if "vector" not in installed:
raise RuntimeError(
"DiskANN (pgvectorscale/pg_diskann) requires pgvector to be installed. "
f"Install it with: {_INSTALL_HINTS['pgvectorscale']}"
)

vectorscale_check = conn.execute(text("SELECT 1 FROM pg_extension WHERE extname = 'vectorscale'")).scalar()
pg_diskann_check = conn.execute(text("SELECT 1 FROM pg_extension WHERE extname = 'pg_diskann'")).scalar()

if vectorscale_check:
logger.debug("Using vector extension: pgvectorscale (DiskANN)")
if "vectorscale" in installed:
return "pgvectorscale"
if pg_diskann_check:
logger.debug("Using vector extension: pg_diskann (Azure DiskANN)")
if "pg_diskann" in installed:
return "pg_diskann"

raise RuntimeError(
"Configured vector extension 'pgvectorscale' not found. Install either:\n"
" - pgvectorscale (open source): CREATE EXTENSION vectorscale CASCADE;\n"
" - pg_diskann (Azure): CREATE EXTENSION pg_diskann CASCADE;"
)

extension_name = pg_extension_name(configured_ext)
extension_check = conn.execute(
text("SELECT 1 FROM pg_extension WHERE extname = :extension_name"),
{"extension_name": extension_name},
).scalar()
if not extension_check:
if extension_name not in installed:
raise RuntimeError(
f"Configured vector extension '{configured_ext}' not found. "
f"Install it with: {_INSTALL_HINTS[configured_ext]}"
)

logger.debug("Using configured vector extension: %s", configured_ext)
return configured_ext


def detect_vector_extension(conn: Connection, vector_extension: str = "pgvector") -> str:
"""Validate the configured vector extension exists and return the index backend."""
configured_ext = validate_extension(vector_extension)

if configured_ext == "pgvectorscale":
extension_names = ("vector", "vectorscale", "pg_diskann")
else:
extension_names = (pg_extension_name(configured_ext),)
installed_extensions = {
name
for name in extension_names
if conn.execute(
text("SELECT 1 FROM pg_extension WHERE extname = :extension_name"),
{"extension_name": name},
).scalar()
}
resolved = resolve_vector_extension_from_installed(configured_ext, installed_extensions)
logger.debug("Using vector extension: %s", resolved)
return resolved
Loading