diff --git a/app/services/database.py b/app/services/database.py index 5d56827ce28fee34757d162ffec13244e0a17fe1..34575f7640b759505017bc51627afc1ed7236db9 100644 --- a/app/services/database.py +++ b/app/services/database.py @@ -25,29 +25,10 @@ async def ensure_custom_id_index_on_embedding(): pool = await PSQLDatabase.get_pool() async with pool.acquire() as conn: - # Check if the index exists - index_exists = await check_index_exists(conn, index_name) - - if not index_exists: - # If the index does not exist, create it - await conn.execute(f""" + await conn.execute(f""" CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} ({column_name}); """) - logger.debug(f"Created index '{index_name}' on '{table_name}({column_name})'") - else: - logger.debug(f"Index '{index_name}' already exists on '{table_name}({column_name})'") - - -async def check_index_exists(conn, index_name: str) -> bool: - # Adjust the SQL query if necessary - result = await conn.fetchval(""" - SELECT EXISTS ( - SELECT FROM pg_class c - JOIN pg_namespace n ON n.oid = c.relnamespace - WHERE c.relname = $1 AND n.nspname = 'public' -- Adjust schema if necessary - ); - """, index_name) - return result + logger.debug(f"Checking if index '{index_name}' on '{table_name}({column_name}) exists, if not found then the index is created.'") async def pg_health_check() -> bool: diff --git a/main.py b/main.py index 017ec79714df5284ef1153d35185e8c720907b81..0d9fb90c29c7d64c547ee5a0bfab94c0d9018be3 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ from contextlib import asynccontextmanager from starlette.responses import JSONResponse -from app.config import debug_mode, RAG_HOST, RAG_PORT, CHUNK_SIZE, CHUNK_OVERLAP, PDF_EXTRACT_IMAGES, VECTOR_DB_TYPE, \ +from app.config import VectorDBType, debug_mode, RAG_HOST, RAG_PORT, CHUNK_SIZE, CHUNK_OVERLAP, PDF_EXTRACT_IMAGES, VECTOR_DB_TYPE, \ LogMiddleware, logger from app.middleware import security_middleware from app.routes import document_routes, pgvector_routes @@ -16,7 +16,7 @@ from app.services.database import PSQLDatabase, ensure_custom_id_index_on_embedd @asynccontextmanager async def lifespan(app: FastAPI): # Startup logic goes here - if VECTOR_DB_TYPE == "pgvector": + if VECTOR_DB_TYPE == VectorDBType.PGVECTOR: await PSQLDatabase.get_pool() # Initialize the pool await ensure_custom_id_index_on_embedding() diff --git a/tests/services/test_database.py b/tests/services/test_database.py index 4240d4047edff3be2e756d476e0c72a3bbe2e630..33ea9bbfc4f5afe9e4432a9acff25a1b2c424764 100644 --- a/tests/services/test_database.py +++ b/tests/services/test_database.py @@ -34,9 +34,6 @@ def dummy_pool(monkeypatch): import asyncio @pytest.mark.asyncio async def test_ensure_custom_id_index_on_embedding(monkeypatch, dummy_pool): - async def dummy_check_index_exists(conn, index_name: str) -> bool: - return False - monkeypatch.setattr("app.services.database.check_index_exists", dummy_check_index_exists) result = await ensure_custom_id_index_on_embedding() # If no exceptions are raised, the function worked as expected. assert result is None \ No newline at end of file