diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 7b854fec34a39e..941ae6b330b156 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,6 +5,8 @@ from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union +from sqlalchemy.orm import sessionmaker + from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.exc import GenerateTaskStoppedError @@ -423,7 +425,9 @@ def _handle_multimodal_image_content( _logger.exception("Failed to save image file") return - # Create MessageFile record + # Create MessageFile record. + # Use an independent session so this side-effect write does not + # commit or close the caller's request-scoped session. message_file = MessageFile( message_id=message_id, type=FileType.IMAGE, @@ -437,9 +441,8 @@ def _handle_multimodal_image_content( created_by=user_id, ) - db.session.add(message_file) - db.session.commit() - db.session.refresh(message_file) + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: + session.add(message_file) # Publish QueueMessageFileEvent queue_manager.publish( diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 5494769082e384..26dc1a12a2c5ee 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from sqlalchemy import select, update -from sqlalchemy.orm import scoped_session +from sqlalchemy.orm import Session, scoped_session, sessionmaker from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom @@ -10,6 +10,7 @@ from core.rag.entities import RetrievalSourceMetadata from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document +from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument from models.enums import CreatorUserRole, DatasetQuerySource @@ -46,47 +47,52 @@ def on_query(self, query: str, dataset_id: str, session: scoped_session): created_by=self._user_id, ) - session.add(dataset_query) - session.commit() + # Use an independent session so this audit-log side effect does + # not commit or close the caller's request-scoped session. + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as independent_session: + independent_session.add(dataset_query) def on_tool_end(self, documents: list[Document], session: scoped_session): """Handle tool end.""" - for document in documents: - if document.metadata is not None: - document_id = document.metadata["document_id"] - dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id) - dataset_document = session.scalar(dataset_document_stmt) - if not dataset_document: - _logger.warning( - "Expected DatasetDocument record to exist, but none was found, document_id=%s", - document_id, - ) - continue - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - child_chunk_stmt = select(ChildChunk).where( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ) - child_chunk = session.scalar(child_chunk_stmt) - if child_chunk: - session.execute( - update(DocumentSegment) - .where(DocumentSegment.id == child_chunk.segment_id) - .values(hit_count=DocumentSegment.hit_count + 1) + # Use an independent session so hit-count updates do not + # interfere with the caller's request-scoped session. + with Session(db.engine, expire_on_commit=False) as independent_session: + for document in documents: + if document.metadata is not None: + document_id = document.metadata["document_id"] + dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id) + dataset_document = independent_session.scalar(dataset_document_stmt) + if not dataset_document: + _logger.warning( + "Expected DatasetDocument record to exist, but none was found, document_id=%s", + document_id, ) - else: - conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]] + continue + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + child_chunk_stmt = select(ChildChunk).where( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, + ) + child_chunk = independent_session.scalar(child_chunk_stmt) + if child_chunk: + independent_session.execute( + update(DocumentSegment) + .where(DocumentSegment.id == child_chunk.segment_id) + .values(hit_count=DocumentSegment.hit_count + 1) + ) + else: + conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]] - if "dataset_id" in document.metadata: - conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + if "dataset_id" in document.metadata: + conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - session.execute( - update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1) - ) + # add hit count to document segment + independent_session.execute( + update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1) + ) - session.commit() + independent_session.commit() # TODO(-LAN-): Improve type check def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]): diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index b3ea1a464f8dfa..130264972a3423 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -5,7 +5,6 @@ import pytest -from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent @@ -81,59 +80,55 @@ def test_handle_multimodal_image_content_with_url( # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - mock_session.add = MagicMock() - mock_session.commit = MagicMock() - mock_session.refresh = MagicMock() - - # Act - # Create a mock runner with the method bound - runner = MagicMock() - - method = AppRunner._handle_multimodal_image_content - runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) - - runner._handle_multimodal_image_content( - content=content, - message_id=mock_message_id, - user_id=mock_user_id, - tenant_id=mock_tenant_id, - queue_manager=mock_queue_manager, - ) - - # Assert - # Verify tool file was created from URL - mock_mgr.create_file_by_url.assert_called_once_with( - user_id=mock_user_id, - tenant_id=mock_tenant_id, - file_url=image_url, - conversation_id=None, - ) - - # Verify message file was created with correct parameters - mock_msg_file_class.assert_called_once() - call_kwargs = mock_msg_file_class.call_args[1] - assert call_kwargs["message_id"] == mock_message_id - assert call_kwargs["type"] == FileType.IMAGE - assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE - assert call_kwargs["belongs_to"] == "assistant" - assert call_kwargs["created_by"] == mock_user_id - - # Verify database operations - mock_session.add.assert_called_once_with(mock_message_file) - mock_session.commit.assert_called_once() - mock_session.refresh.assert_called_once_with(mock_message_file) - - # Verify event was published - mock_queue_manager.publish.assert_called_once() - publish_call = mock_queue_manager.publish.call_args - assert isinstance(publish_call[0][0], QueueMessageFileEvent) - assert publish_call[0][0].message_file_id == mock_message_file.id - # publish_from might be passed as positional or keyword argument - assert ( - publish_call[0][1] == PublishFrom.APPLICATION_MANAGER - or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER - ) + file_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.begin.return_value.__enter__ = MagicMock(return_value=file_session) + mock_session_factory.begin.return_value.__exit__ = MagicMock(return_value=False) + + with patch("core.app.apps.base_app_runner.sessionmaker", return_value=mock_session_factory) as mock_sm: + with patch("core.app.apps.base_app_runner.db") as mock_db: + # Act + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method( + runner, *args, **kwargs + ) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert + mock_mgr.create_file_by_url.assert_called_once_with( + user_id=mock_user_id, + tenant_id=mock_tenant_id, + file_url=image_url, + conversation_id=None, + ) + + mock_msg_file_class.assert_called_once() + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["message_id"] == mock_message_id + assert call_kwargs["type"] == FileType.IMAGE + assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE + assert call_kwargs["belongs_to"] == "assistant" + assert call_kwargs["created_by"] == mock_user_id + + # Verify independent session was used (not db.session) + mock_sm.assert_called_once_with(bind=mock_db.engine, expire_on_commit=False) + file_session.add.assert_called_once_with(mock_message_file) + mock_db.session.commit.assert_not_called() + mock_db.session.close.assert_not_called() + + # Verify event was published + mock_queue_manager.publish.assert_called_once() + publish_call = mock_queue_manager.publish.call_args + assert isinstance(publish_call[0][0], QueueMessageFileEvent) + assert publish_call[0][0].message_file_id == mock_message_file.id def test_handle_multimodal_image_content_with_base64( self, @@ -165,50 +160,44 @@ def test_handle_multimodal_image_content_with_base64( mock_mgr_class.return_value = mock_mgr with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: - # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - mock_session.add = MagicMock() - mock_session.commit = MagicMock() - mock_session.refresh = MagicMock() - - # Act - # Create a mock runner with the method bound - runner = MagicMock() - method = AppRunner._handle_multimodal_image_content - runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) - - runner._handle_multimodal_image_content( - content=content, - message_id=mock_message_id, - user_id=mock_user_id, - tenant_id=mock_tenant_id, - queue_manager=mock_queue_manager, - ) - - # Assert - # Verify tool file was created from base64 - mock_mgr.create_file_by_raw.assert_called_once() - call_kwargs = mock_mgr.create_file_by_raw.call_args[1] - assert call_kwargs["user_id"] == mock_user_id - assert call_kwargs["tenant_id"] == mock_tenant_id - assert call_kwargs["conversation_id"] is None - assert "file_binary" in call_kwargs - assert call_kwargs["mimetype"] == "image/png" - assert call_kwargs["filename"].startswith("generated_image") - assert call_kwargs["filename"].endswith(".png") - - # Verify message file was created - mock_msg_file_class.assert_called_once() - - # Verify database operations - mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() - mock_session.refresh.assert_called_once() - - # Verify event was published - mock_queue_manager.publish.assert_called_once() + file_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.begin.return_value.__enter__ = MagicMock(return_value=file_session) + mock_session_factory.begin.return_value.__exit__ = MagicMock(return_value=False) + + with patch("core.app.apps.base_app_runner.sessionmaker", return_value=mock_session_factory): + with patch("core.app.apps.base_app_runner.db") as mock_db: + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method( + runner, *args, **kwargs + ) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + mock_mgr.create_file_by_raw.assert_called_once() + call_kwargs = mock_mgr.create_file_by_raw.call_args[1] + assert call_kwargs["user_id"] == mock_user_id + assert call_kwargs["tenant_id"] == mock_tenant_id + assert call_kwargs["conversation_id"] is None + assert "file_binary" in call_kwargs + assert call_kwargs["mimetype"] == "image/png" + assert call_kwargs["filename"].startswith("generated_image") + assert call_kwargs["filename"].endswith(".png") + + mock_msg_file_class.assert_called_once() + file_session.add.assert_called_once() + mock_db.session.commit.assert_not_called() + + mock_queue_manager.publish.assert_called_once() def test_handle_multimodal_image_content_with_base64_data_uri( self, @@ -238,33 +227,32 @@ def test_handle_multimodal_image_content_with_base64_data_uri( mock_mgr_class.return_value = mock_mgr with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: - # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - mock_session.add = MagicMock() - mock_session.commit = MagicMock() - mock_session.refresh = MagicMock() - - # Act - # Create a mock runner with the method bound - runner = MagicMock() - method = AppRunner._handle_multimodal_image_content - runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) - - runner._handle_multimodal_image_content( - content=content, - message_id=mock_message_id, - user_id=mock_user_id, - tenant_id=mock_tenant_id, - queue_manager=mock_queue_manager, - ) - - # Assert - verify that base64 data was extracted correctly (without prefix) - mock_mgr.create_file_by_raw.assert_called_once() - call_kwargs = mock_mgr.create_file_by_raw.call_args[1] - # The base64 data should be decoded, so we check the binary was passed - assert "file_binary" in call_kwargs + file_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.begin.return_value.__enter__ = MagicMock(return_value=file_session) + mock_session_factory.begin.return_value.__exit__ = MagicMock(return_value=False) + + with patch("core.app.apps.base_app_runner.sessionmaker", return_value=mock_session_factory): + with patch("core.app.apps.base_app_runner.db"): + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method( + runner, *args, **kwargs + ) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + mock_mgr.create_file_by_raw.assert_called_once() + call_kwargs = mock_mgr.create_file_by_raw.call_args[1] + assert "file_binary" in call_kwargs def test_handle_multimodal_image_content_without_url_or_base64( self, @@ -284,9 +272,7 @@ def test_handle_multimodal_image_content_without_url_or_base64( with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - # Act - # Create a mock runner with the method bound + with patch("core.app.apps.base_app_runner.db"): runner = MagicMock() method = AppRunner._handle_multimodal_image_content runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) @@ -299,10 +285,8 @@ def test_handle_multimodal_image_content_without_url_or_base64( queue_manager=mock_queue_manager, ) - # Assert - should not create any files or publish events mock_mgr_class.assert_not_called() mock_msg_file_class.assert_not_called() - mock_session.add.assert_not_called() mock_queue_manager.publish.assert_not_called() def test_handle_multimodal_image_content_with_error( @@ -322,20 +306,16 @@ def test_handle_multimodal_image_content_with_error( ) with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: - # Setup mock to raise exception mock_mgr = MagicMock() mock_mgr.create_file_by_url.side_effect = Exception("Network error") mock_mgr_class.return_value = mock_mgr with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - # Act - # Create a mock runner with the method bound + with patch("core.app.apps.base_app_runner.db"): runner = MagicMock() method = AppRunner._handle_multimodal_image_content runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) - # Should not raise exception, just log it runner._handle_multimodal_image_content( content=content, message_id=mock_message_id, @@ -344,9 +324,7 @@ def test_handle_multimodal_image_content_with_error( queue_manager=mock_queue_manager, ) - # Assert - should not create message file or publish event on error mock_msg_file_class.assert_not_called() - mock_session.add.assert_not_called() mock_queue_manager.publish.assert_not_called() def test_handle_multimodal_image_content_debugger_mode( @@ -369,37 +347,36 @@ def test_handle_multimodal_image_content_debugger_mode( mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: - # Setup mock tool file manager mock_mgr = MagicMock() mock_mgr.create_file_by_url.return_value = mock_tool_file mock_mgr_class.return_value = mock_mgr with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: - # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - mock_session.add = MagicMock() - mock_session.commit = MagicMock() - mock_session.refresh = MagicMock() - - # Act - # Create a mock runner with the method bound - runner = MagicMock() - method = AppRunner._handle_multimodal_image_content - runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) - - runner._handle_multimodal_image_content( - content=content, - message_id=mock_message_id, - user_id=mock_user_id, - tenant_id=mock_tenant_id, - queue_manager=mock_queue_manager, - ) - - # Assert - verify created_by_role is ACCOUNT for debugger mode - call_kwargs = mock_msg_file_class.call_args[1] - assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT + file_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.begin.return_value.__enter__ = MagicMock(return_value=file_session) + mock_session_factory.begin.return_value.__exit__ = MagicMock(return_value=False) + + with patch("core.app.apps.base_app_runner.sessionmaker", return_value=mock_session_factory): + with patch("core.app.apps.base_app_runner.db"): + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method( + runner, *args, **kwargs + ) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT def test_handle_multimodal_image_content_service_api_mode( self, @@ -421,34 +398,33 @@ def test_handle_multimodal_image_content_service_api_mode( mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: - # Setup mock tool file manager mock_mgr = MagicMock() mock_mgr.create_file_by_url.return_value = mock_tool_file mock_mgr_class.return_value = mock_mgr with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: - # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - mock_session.add = MagicMock() - mock_session.commit = MagicMock() - mock_session.refresh = MagicMock() - - # Act - # Create a mock runner with the method bound - runner = MagicMock() - method = AppRunner._handle_multimodal_image_content - runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) - - runner._handle_multimodal_image_content( - content=content, - message_id=mock_message_id, - user_id=mock_user_id, - tenant_id=mock_tenant_id, - queue_manager=mock_queue_manager, - ) - - # Assert - verify created_by_role is END_USER for service API - call_kwargs = mock_msg_file_class.call_args[1] - assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER + file_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.begin.return_value.__enter__ = MagicMock(return_value=file_session) + mock_session_factory.begin.return_value.__exit__ = MagicMock(return_value=False) + + with patch("core.app.apps.base_app_runner.sessionmaker", return_value=mock_session_factory): + with patch("core.app.apps.base_app_runner.db"): + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method( + runner, *args, **kwargs + ) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER diff --git a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py index 4912badfc553e9..62c4ae9d411de0 100644 --- a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py +++ b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py @@ -14,6 +14,9 @@ def mock_queue_manager(mocker: MockerFixture): @pytest.fixture def handler(mock_queue_manager, mocker: MockerFixture): + mocker.patch( + "core.callback_handler.index_tool_callback_handler.db", + ) return DatasetIndexToolCallbackHandler( queue_manager=mock_queue_manager, app_id="app-1", @@ -33,8 +36,18 @@ class TestOnQuery: ], ) def test_on_query_success_roles(self, mocker: MockerFixture, mock_queue_manager, invoke_from, expected_role): - # Arrange - mock_session = mocker.Mock() + # Arrange — the caller passes a session, but our fix uses an independent one + caller_session = mocker.Mock() + + independent_session = mocker.MagicMock() + mock_session_factory = mocker.MagicMock() + mock_session_factory.begin.return_value.__enter__ = mocker.MagicMock(return_value=independent_session) + mock_session_factory.begin.return_value.__exit__ = mocker.MagicMock(return_value=False) + mocker.patch( + "core.callback_handler.index_tool_callback_handler.sessionmaker", + return_value=mock_session_factory, + ) + mocker.patch("core.callback_handler.index_tool_callback_handler.db") handler = DatasetIndexToolCallbackHandler( queue_manager=mock_queue_manager, @@ -46,17 +59,28 @@ def test_on_query_success_roles(self, mocker: MockerFixture, mock_queue_manager, handler._invoke_from = invoke_from - # Act - handler.on_query("test query", "dataset-1", mock_session) + # Act — pass caller_session as required by signature + handler.on_query("test query", "dataset-1", caller_session) - # Assert - mock_session.add.assert_called_once() - dataset_query = mock_session.add.call_args.args[0] + # Assert — independent session used, not the caller's session + independent_session.add.assert_called_once() + dataset_query = independent_session.add.call_args.args[0] assert dataset_query.created_by_role == expected_role - mock_session.commit.assert_called_once() + caller_session.add.assert_not_called() + caller_session.commit.assert_not_called() def test_on_query_none_values(self, mocker: MockerFixture, mock_queue_manager): - mock_session = mocker.Mock() + caller_session = mocker.Mock() + + independent_session = mocker.MagicMock() + mock_session_factory = mocker.MagicMock() + mock_session_factory.begin.return_value.__enter__ = mocker.MagicMock(return_value=independent_session) + mock_session_factory.begin.return_value.__exit__ = mocker.MagicMock(return_value=False) + mocker.patch( + "core.callback_handler.index_tool_callback_handler.sessionmaker", + return_value=mock_session_factory, + ) + mocker.patch("core.callback_handler.index_tool_callback_handler.db") handler = DatasetIndexToolCallbackHandler( queue_manager=mock_queue_manager, @@ -66,40 +90,67 @@ def test_on_query_none_values(self, mocker: MockerFixture, mock_queue_manager): invoke_from=None, ) - handler.on_query(None, None, mock_session) + handler.on_query(None, None, caller_session) - mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() + independent_session.add.assert_called_once() + caller_session.add.assert_not_called() class TestOnToolEnd: def test_on_tool_end_no_metadata(self, handler: DatasetIndexToolCallbackHandler, mocker: MockerFixture): - mock_session = mocker.Mock() + caller_session = mocker.Mock() + + independent_session = mocker.MagicMock() + mocker.patch( + "core.callback_handler.index_tool_callback_handler.Session", + return_value=independent_session, + ) + independent_session.__enter__ = mocker.MagicMock(return_value=independent_session) + independent_session.__exit__ = mocker.MagicMock(return_value=False) document = mocker.Mock() document.metadata = None - handler.on_tool_end([document], mock_session) + handler.on_tool_end([document], caller_session) - mock_session.commit.assert_not_called() + independent_session.commit.assert_called_once() + independent_session.execute.assert_not_called() + caller_session.commit.assert_not_called() def test_on_tool_end_dataset_document_not_found( self, handler: DatasetIndexToolCallbackHandler, mocker: MockerFixture ): - mock_session = mocker.Mock() - mock_session.scalar.return_value = None + caller_session = mocker.Mock() + + independent_session = mocker.MagicMock() + mocker.patch( + "core.callback_handler.index_tool_callback_handler.Session", + return_value=independent_session, + ) + independent_session.__enter__ = mocker.MagicMock(return_value=independent_session) + independent_session.__exit__ = mocker.MagicMock(return_value=False) + independent_session.scalar.return_value = None document = mocker.Mock() document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} - handler.on_tool_end([document], mock_session) + handler.on_tool_end([document], caller_session) - mock_session.scalar.assert_called_once() + independent_session.scalar.assert_called_once() + caller_session.scalar.assert_not_called() def test_on_tool_end_parent_child_index_with_child( self, handler: DatasetIndexToolCallbackHandler, mocker: MockerFixture ): - mock_session = mocker.Mock() + caller_session = mocker.Mock() + + independent_session = mocker.MagicMock() + mocker.patch( + "core.callback_handler.index_tool_callback_handler.Session", + return_value=independent_session, + ) + independent_session.__enter__ = mocker.MagicMock(return_value=independent_session) + independent_session.__exit__ = mocker.MagicMock(return_value=False) mock_dataset_doc = mocker.Mock() from core.callback_handler.index_tool_callback_handler import IndexStructureType @@ -111,23 +162,32 @@ def test_on_tool_end_parent_child_index_with_child( mock_child_chunk = mocker.Mock() mock_child_chunk.segment_id = "segment-1" - mock_session.scalar.side_effect = [mock_dataset_doc, mock_child_chunk] + independent_session.scalar.side_effect = [mock_dataset_doc, mock_child_chunk] document = mocker.Mock() document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} - handler.on_tool_end([document], mock_session) + handler.on_tool_end([document], caller_session) - mock_session.execute.assert_called_once() - mock_session.commit.assert_called_once() + independent_session.execute.assert_called_once() + independent_session.commit.assert_called_once() + caller_session.execute.assert_not_called() def test_on_tool_end_non_parent_child_index(self, handler: DatasetIndexToolCallbackHandler, mocker: MockerFixture): - mock_session = mocker.Mock() + caller_session = mocker.Mock() + + independent_session = mocker.MagicMock() + mocker.patch( + "core.callback_handler.index_tool_callback_handler.Session", + return_value=independent_session, + ) + independent_session.__enter__ = mocker.MagicMock(return_value=independent_session) + independent_session.__exit__ = mocker.MagicMock(return_value=False) mock_dataset_doc = mocker.Mock() mock_dataset_doc.doc_form = "OTHER" - mock_session.scalar.return_value = mock_dataset_doc + independent_session.scalar.return_value = mock_dataset_doc document = mocker.Mock() document.metadata = { @@ -136,14 +196,24 @@ def test_on_tool_end_non_parent_child_index(self, handler: DatasetIndexToolCallb "dataset_id": "dataset-1", } - handler.on_tool_end([document], mock_session) + handler.on_tool_end([document], caller_session) - mock_session.execute.assert_called_once() - mock_session.commit.assert_called_once() + independent_session.execute.assert_called_once() + independent_session.commit.assert_called_once() + caller_session.execute.assert_not_called() def test_on_tool_end_empty_documents(self, handler: DatasetIndexToolCallbackHandler, mocker: MockerFixture): - mock_session = mocker.Mock() - handler.on_tool_end([], mock_session) + caller_session = mocker.Mock() + + independent_session = mocker.MagicMock() + mocker.patch( + "core.callback_handler.index_tool_callback_handler.Session", + return_value=independent_session, + ) + independent_session.__enter__ = mocker.MagicMock(return_value=independent_session) + independent_session.__exit__ = mocker.MagicMock(return_value=False) + + handler.on_tool_end([], caller_session) class TestReturnRetrieverResourceInfo: