From 85b58d9c70b77d0c2e2717474cefb2b9ec6945a8 Mon Sep 17 00:00:00 2001 From: Pranav Agarwal Date: Tue, 30 Jun 2026 14:57:52 +0530 Subject: [PATCH 1/2] fix(api): isolate side-effect session writes in multimodal and RAG handlers Extends the session isolation fix from #37895 to two files that were missed: base_app_runner.py and index_tool_callback_handler.py. Both files used db.session.commit() on the shared Flask request-scoped session to persist side-effect writes (a MessageFile during multimodal LLM streaming, and DatasetQuery audit logs / hit_count updates during RAG retrieval). This flushed all pending ORM changes in the request transaction, not just the intended record, risking premature commits and DetachedInstanceError under concurrency. Both now use independent sessionmaker sessions matching the established pattern from #37895. Fixes #38176 --- api/core/app/apps/base_app_runner.py | 11 +- .../index_tool_callback_handler.py | 74 ++-- .../chat/test_base_app_runner_multimodal.py | 338 ++++++++---------- .../test_index_tool_callback_handler.py | 92 +++-- 4 files changed, 275 insertions(+), 240 deletions(-) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 7b854fec34a39e..941ae6b330b156 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,6 +5,8 @@ from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union +from sqlalchemy.orm import sessionmaker + from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.exc import GenerateTaskStoppedError @@ -423,7 +425,9 @@ def _handle_multimodal_image_content( _logger.exception("Failed to save image file") return - # Create MessageFile record + # Create MessageFile record. + # Use an independent session so this side-effect write does not + # commit or close the caller's request-scoped session. message_file = MessageFile( message_id=message_id, type=FileType.IMAGE, @@ -437,9 +441,8 @@ def _handle_multimodal_image_content( created_by=user_id, ) - db.session.add(message_file) - db.session.commit() - db.session.refresh(message_file) + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: + session.add(message_file) # Publish QueueMessageFileEvent queue_manager.publish( diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 205e004290138e..c8d241976355a2 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from sqlalchemy import select, update +from sqlalchemy.orm import Session, sessionmaker from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom @@ -46,47 +47,52 @@ def on_query(self, query: str, dataset_id: str): created_by=self._user_id, ) - db.session.add(dataset_query) - db.session.commit() + # Use an independent session so this audit-log side effect does + # not commit or close the caller's request-scoped session. + with sessionmaker(bind=db.engine, expire_on_commit=False).begin() as session: + session.add(dataset_query) def on_tool_end(self, documents: list[Document]): """Handle tool end.""" - for document in documents: - if document.metadata is not None: - document_id = document.metadata["document_id"] - dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id) - dataset_document = db.session.scalar(dataset_document_stmt) - if not dataset_document: - _logger.warning( - "Expected DatasetDocument record to exist, but none was found, document_id=%s", - document_id, - ) - continue - if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: - child_chunk_stmt = select(ChildChunk).where( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ) - child_chunk = db.session.scalar(child_chunk_stmt) - if child_chunk: - db.session.execute( - update(DocumentSegment) - .where(DocumentSegment.id == child_chunk.segment_id) - .values(hit_count=DocumentSegment.hit_count + 1) + # Use an independent session so hit-count updates do not + # interfere with the caller's request-scoped session. + with Session(db.engine, expire_on_commit=False) as session: + for document in documents: + if document.metadata is not None: + document_id = document.metadata["document_id"] + dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id) + dataset_document = session.scalar(dataset_document_stmt) + if not dataset_document: + _logger.warning( + "Expected DatasetDocument record to exist, but none was found, document_id=%s", + document_id, ) - else: - conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]] + continue + if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: + child_chunk_stmt = select(ChildChunk).where( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, + ) + child_chunk = session.scalar(child_chunk_stmt) + if child_chunk: + session.execute( + update(DocumentSegment) + .where(DocumentSegment.id == child_chunk.segment_id) + .values(hit_count=DocumentSegment.hit_count + 1) + ) + else: + conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]] - if "dataset_id" in document.metadata: - conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + if "dataset_id" in document.metadata: + conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"]) - # add hit count to document segment - db.session.execute( - update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1) - ) + # add hit count to document segment + session.execute( + update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1) + ) - db.session.commit() + session.commit() # TODO(-LAN-): Improve type check def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]): diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index b3ea1a464f8dfa..130264972a3423 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -5,7 +5,6 @@ import pytest -from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent @@ -81,59 +80,55 @@ def test_handle_multimodal_image_content_with_url( # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - mock_session.add = MagicMock() - mock_session.commit = MagicMock() - mock_session.refresh = MagicMock() - - # Act - # Create a mock runner with the method bound - runner = MagicMock() - - method = AppRunner._handle_multimodal_image_content - runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) - - runner._handle_multimodal_image_content( - content=content, - message_id=mock_message_id, - user_id=mock_user_id, - tenant_id=mock_tenant_id, - queue_manager=mock_queue_manager, - ) - - # Assert - # Verify tool file was created from URL - mock_mgr.create_file_by_url.assert_called_once_with( - user_id=mock_user_id, - tenant_id=mock_tenant_id, - file_url=image_url, - conversation_id=None, - ) - - # Verify message file was created with correct parameters - mock_msg_file_class.assert_called_once() - call_kwargs = mock_msg_file_class.call_args[1] - assert call_kwargs["message_id"] == mock_message_id - assert call_kwargs["type"] == FileType.IMAGE - assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE - assert call_kwargs["belongs_to"] == "assistant" - assert call_kwargs["created_by"] == mock_user_id - - # Verify database operations - mock_session.add.assert_called_once_with(mock_message_file) - mock_session.commit.assert_called_once() - mock_session.refresh.assert_called_once_with(mock_message_file) - - # Verify event was published - mock_queue_manager.publish.assert_called_once() - publish_call = mock_queue_manager.publish.call_args - assert isinstance(publish_call[0][0], QueueMessageFileEvent) - assert publish_call[0][0].message_file_id == mock_message_file.id - # publish_from might be passed as positional or keyword argument - assert ( - publish_call[0][1] == PublishFrom.APPLICATION_MANAGER - or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER - ) + file_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.begin.return_value.__enter__ = MagicMock(return_value=file_session) + mock_session_factory.begin.return_value.__exit__ = MagicMock(return_value=False) + + with patch("core.app.apps.base_app_runner.sessionmaker", return_value=mock_session_factory) as mock_sm: + with patch("core.app.apps.base_app_runner.db") as mock_db: + # Act + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method( + runner, *args, **kwargs + ) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + # Assert + mock_mgr.create_file_by_url.assert_called_once_with( + user_id=mock_user_id, + tenant_id=mock_tenant_id, + file_url=image_url, + conversation_id=None, + ) + + mock_msg_file_class.assert_called_once() + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["message_id"] == mock_message_id + assert call_kwargs["type"] == FileType.IMAGE + assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE + assert call_kwargs["belongs_to"] == "assistant" + assert call_kwargs["created_by"] == mock_user_id + + # Verify independent session was used (not db.session) + mock_sm.assert_called_once_with(bind=mock_db.engine, expire_on_commit=False) + file_session.add.assert_called_once_with(mock_message_file) + mock_db.session.commit.assert_not_called() + mock_db.session.close.assert_not_called() + + # Verify event was published + mock_queue_manager.publish.assert_called_once() + publish_call = mock_queue_manager.publish.call_args + assert isinstance(publish_call[0][0], QueueMessageFileEvent) + assert publish_call[0][0].message_file_id == mock_message_file.id def test_handle_multimodal_image_content_with_base64( self, @@ -165,50 +160,44 @@ def test_handle_multimodal_image_content_with_base64( mock_mgr_class.return_value = mock_mgr with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: - # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - mock_session.add = MagicMock() - mock_session.commit = MagicMock() - mock_session.refresh = MagicMock() - - # Act - # Create a mock runner with the method bound - runner = MagicMock() - method = AppRunner._handle_multimodal_image_content - runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) - - runner._handle_multimodal_image_content( - content=content, - message_id=mock_message_id, - user_id=mock_user_id, - tenant_id=mock_tenant_id, - queue_manager=mock_queue_manager, - ) - - # Assert - # Verify tool file was created from base64 - mock_mgr.create_file_by_raw.assert_called_once() - call_kwargs = mock_mgr.create_file_by_raw.call_args[1] - assert call_kwargs["user_id"] == mock_user_id - assert call_kwargs["tenant_id"] == mock_tenant_id - assert call_kwargs["conversation_id"] is None - assert "file_binary" in call_kwargs - assert call_kwargs["mimetype"] == "image/png" - assert call_kwargs["filename"].startswith("generated_image") - assert call_kwargs["filename"].endswith(".png") - - # Verify message file was created - mock_msg_file_class.assert_called_once() - - # Verify database operations - mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() - mock_session.refresh.assert_called_once() - - # Verify event was published - mock_queue_manager.publish.assert_called_once() + file_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.begin.return_value.__enter__ = MagicMock(return_value=file_session) + mock_session_factory.begin.return_value.__exit__ = MagicMock(return_value=False) + + with patch("core.app.apps.base_app_runner.sessionmaker", return_value=mock_session_factory): + with patch("core.app.apps.base_app_runner.db") as mock_db: + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method( + runner, *args, **kwargs + ) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + mock_mgr.create_file_by_raw.assert_called_once() + call_kwargs = mock_mgr.create_file_by_raw.call_args[1] + assert call_kwargs["user_id"] == mock_user_id + assert call_kwargs["tenant_id"] == mock_tenant_id + assert call_kwargs["conversation_id"] is None + assert "file_binary" in call_kwargs + assert call_kwargs["mimetype"] == "image/png" + assert call_kwargs["filename"].startswith("generated_image") + assert call_kwargs["filename"].endswith(".png") + + mock_msg_file_class.assert_called_once() + file_session.add.assert_called_once() + mock_db.session.commit.assert_not_called() + + mock_queue_manager.publish.assert_called_once() def test_handle_multimodal_image_content_with_base64_data_uri( self, @@ -238,33 +227,32 @@ def test_handle_multimodal_image_content_with_base64_data_uri( mock_mgr_class.return_value = mock_mgr with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: - # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - mock_session.add = MagicMock() - mock_session.commit = MagicMock() - mock_session.refresh = MagicMock() - - # Act - # Create a mock runner with the method bound - runner = MagicMock() - method = AppRunner._handle_multimodal_image_content - runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) - - runner._handle_multimodal_image_content( - content=content, - message_id=mock_message_id, - user_id=mock_user_id, - tenant_id=mock_tenant_id, - queue_manager=mock_queue_manager, - ) - - # Assert - verify that base64 data was extracted correctly (without prefix) - mock_mgr.create_file_by_raw.assert_called_once() - call_kwargs = mock_mgr.create_file_by_raw.call_args[1] - # The base64 data should be decoded, so we check the binary was passed - assert "file_binary" in call_kwargs + file_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.begin.return_value.__enter__ = MagicMock(return_value=file_session) + mock_session_factory.begin.return_value.__exit__ = MagicMock(return_value=False) + + with patch("core.app.apps.base_app_runner.sessionmaker", return_value=mock_session_factory): + with patch("core.app.apps.base_app_runner.db"): + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method( + runner, *args, **kwargs + ) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + mock_mgr.create_file_by_raw.assert_called_once() + call_kwargs = mock_mgr.create_file_by_raw.call_args[1] + assert "file_binary" in call_kwargs def test_handle_multimodal_image_content_without_url_or_base64( self, @@ -284,9 +272,7 @@ def test_handle_multimodal_image_content_without_url_or_base64( with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - # Act - # Create a mock runner with the method bound + with patch("core.app.apps.base_app_runner.db"): runner = MagicMock() method = AppRunner._handle_multimodal_image_content runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) @@ -299,10 +285,8 @@ def test_handle_multimodal_image_content_without_url_or_base64( queue_manager=mock_queue_manager, ) - # Assert - should not create any files or publish events mock_mgr_class.assert_not_called() mock_msg_file_class.assert_not_called() - mock_session.add.assert_not_called() mock_queue_manager.publish.assert_not_called() def test_handle_multimodal_image_content_with_error( @@ -322,20 +306,16 @@ def test_handle_multimodal_image_content_with_error( ) with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: - # Setup mock to raise exception mock_mgr = MagicMock() mock_mgr.create_file_by_url.side_effect = Exception("Network error") mock_mgr_class.return_value = mock_mgr with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - # Act - # Create a mock runner with the method bound + with patch("core.app.apps.base_app_runner.db"): runner = MagicMock() method = AppRunner._handle_multimodal_image_content runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) - # Should not raise exception, just log it runner._handle_multimodal_image_content( content=content, message_id=mock_message_id, @@ -344,9 +324,7 @@ def test_handle_multimodal_image_content_with_error( queue_manager=mock_queue_manager, ) - # Assert - should not create message file or publish event on error mock_msg_file_class.assert_not_called() - mock_session.add.assert_not_called() mock_queue_manager.publish.assert_not_called() def test_handle_multimodal_image_content_debugger_mode( @@ -369,37 +347,36 @@ def test_handle_multimodal_image_content_debugger_mode( mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: - # Setup mock tool file manager mock_mgr = MagicMock() mock_mgr.create_file_by_url.return_value = mock_tool_file mock_mgr_class.return_value = mock_mgr with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: - # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - mock_session.add = MagicMock() - mock_session.commit = MagicMock() - mock_session.refresh = MagicMock() - - # Act - # Create a mock runner with the method bound - runner = MagicMock() - method = AppRunner._handle_multimodal_image_content - runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) - - runner._handle_multimodal_image_content( - content=content, - message_id=mock_message_id, - user_id=mock_user_id, - tenant_id=mock_tenant_id, - queue_manager=mock_queue_manager, - ) - - # Assert - verify created_by_role is ACCOUNT for debugger mode - call_kwargs = mock_msg_file_class.call_args[1] - assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT + file_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.begin.return_value.__enter__ = MagicMock(return_value=file_session) + mock_session_factory.begin.return_value.__exit__ = MagicMock(return_value=False) + + with patch("core.app.apps.base_app_runner.sessionmaker", return_value=mock_session_factory): + with patch("core.app.apps.base_app_runner.db"): + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method( + runner, *args, **kwargs + ) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT def test_handle_multimodal_image_content_service_api_mode( self, @@ -421,34 +398,33 @@ def test_handle_multimodal_image_content_service_api_mode( mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class: - # Setup mock tool file manager mock_mgr = MagicMock() mock_mgr.create_file_by_url.return_value = mock_tool_file mock_mgr_class.return_value = mock_mgr with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class: - # Setup mock message file mock_msg_file_class.return_value = mock_message_file - with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session: - mock_session.add = MagicMock() - mock_session.commit = MagicMock() - mock_session.refresh = MagicMock() - - # Act - # Create a mock runner with the method bound - runner = MagicMock() - method = AppRunner._handle_multimodal_image_content - runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs) - - runner._handle_multimodal_image_content( - content=content, - message_id=mock_message_id, - user_id=mock_user_id, - tenant_id=mock_tenant_id, - queue_manager=mock_queue_manager, - ) - - # Assert - verify created_by_role is END_USER for service API - call_kwargs = mock_msg_file_class.call_args[1] - assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER + file_session = MagicMock() + mock_session_factory = MagicMock() + mock_session_factory.begin.return_value.__enter__ = MagicMock(return_value=file_session) + mock_session_factory.begin.return_value.__exit__ = MagicMock(return_value=False) + + with patch("core.app.apps.base_app_runner.sessionmaker", return_value=mock_session_factory): + with patch("core.app.apps.base_app_runner.db"): + runner = MagicMock() + method = AppRunner._handle_multimodal_image_content + runner._handle_multimodal_image_content = lambda *args, **kwargs: method( + runner, *args, **kwargs + ) + + runner._handle_multimodal_image_content( + content=content, + message_id=mock_message_id, + user_id=mock_user_id, + tenant_id=mock_tenant_id, + queue_manager=mock_queue_manager, + ) + + call_kwargs = mock_msg_file_class.call_args[1] + assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER diff --git a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py index f23669c3c7b30f..ae78dcd7f82c01 100644 --- a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py +++ b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py @@ -39,6 +39,15 @@ def test_on_query_success_roles(self, mocker: MockerFixture, mock_queue_manager, # Arrange mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + file_session = mocker.MagicMock() + mock_session_factory = mocker.MagicMock() + mock_session_factory.begin.return_value.__enter__ = mocker.MagicMock(return_value=file_session) + mock_session_factory.begin.return_value.__exit__ = mocker.MagicMock(return_value=False) + mocker.patch( + "core.callback_handler.index_tool_callback_handler.sessionmaker", + return_value=mock_session_factory, + ) + handler = DatasetIndexToolCallbackHandler( queue_manager=mock_queue_manager, app_id="app-1", @@ -52,14 +61,23 @@ def test_on_query_success_roles(self, mocker: MockerFixture, mock_queue_manager, # Act handler.on_query("test query", "dataset-1") - # Assert - mock_db.session.add.assert_called_once() - dataset_query = mock_db.session.add.call_args.args[0] + # Assert — independent session used, not db.session + file_session.add.assert_called_once() + dataset_query = file_session.add.call_args.args[0] assert dataset_query.created_by_role == expected_role - mock_db.session.commit.assert_called_once() + mock_db.session.commit.assert_not_called() def test_on_query_none_values(self, mocker: MockerFixture, mock_queue_manager): - mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + file_session = mocker.MagicMock() + mock_session_factory = mocker.MagicMock() + mock_session_factory.begin.return_value.__enter__ = mocker.MagicMock(return_value=file_session) + mock_session_factory.begin.return_value.__exit__ = mocker.MagicMock(return_value=False) + mocker.patch( + "core.callback_handler.index_tool_callback_handler.sessionmaker", + return_value=mock_session_factory, + ) handler = DatasetIndexToolCallbackHandler( queue_manager=mock_queue_manager, @@ -71,38 +89,56 @@ def test_on_query_none_values(self, mocker: MockerFixture, mock_queue_manager): handler.on_query(None, None) - mock_db.session.add.assert_called_once() - mock_db.session.commit.assert_called_once() + file_session.add.assert_called_once() class TestOnToolEnd: def test_on_tool_end_no_metadata(self, handler: DatasetIndexToolCallbackHandler, mocker: MockerFixture): - mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + mock_session = mocker.MagicMock() + mocker.patch( + "core.callback_handler.index_tool_callback_handler.Session", + return_value=mock_session, + ) + mock_session.__enter__ = mocker.MagicMock(return_value=mock_session) + mock_session.__exit__ = mocker.MagicMock(return_value=False) document = mocker.Mock() document.metadata = None handler.on_tool_end([document]) - mock_db.session.commit.assert_not_called() + mock_session.commit.assert_called_once() + mock_session.execute.assert_not_called() def test_on_tool_end_dataset_document_not_found( self, handler: DatasetIndexToolCallbackHandler, mocker: MockerFixture ): - mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") - mock_db.session.scalar.return_value = None + mock_session = mocker.MagicMock() + mocker.patch( + "core.callback_handler.index_tool_callback_handler.Session", + return_value=mock_session, + ) + mock_session.__enter__ = mocker.MagicMock(return_value=mock_session) + mock_session.__exit__ = mocker.MagicMock(return_value=False) + mock_session.scalar.return_value = None document = mocker.Mock() document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} handler.on_tool_end([document]) - mock_db.session.scalar.assert_called_once() + mock_session.scalar.assert_called_once() def test_on_tool_end_parent_child_index_with_child( self, handler: DatasetIndexToolCallbackHandler, mocker: MockerFixture ): - mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + mock_session = mocker.MagicMock() + mocker.patch( + "core.callback_handler.index_tool_callback_handler.Session", + return_value=mock_session, + ) + mock_session.__enter__ = mocker.MagicMock(return_value=mock_session) + mock_session.__exit__ = mocker.MagicMock(return_value=False) mock_dataset_doc = mocker.Mock() from core.callback_handler.index_tool_callback_handler import IndexStructureType @@ -114,23 +150,29 @@ def test_on_tool_end_parent_child_index_with_child( mock_child_chunk = mocker.Mock() mock_child_chunk.segment_id = "segment-1" - mock_db.session.scalar.side_effect = [mock_dataset_doc, mock_child_chunk] + mock_session.scalar.side_effect = [mock_dataset_doc, mock_child_chunk] document = mocker.Mock() document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} handler.on_tool_end([document]) - mock_db.session.execute.assert_called_once() - mock_db.session.commit.assert_called_once() + mock_session.execute.assert_called_once() + mock_session.commit.assert_called_once() def test_on_tool_end_non_parent_child_index(self, handler: DatasetIndexToolCallbackHandler, mocker: MockerFixture): - mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + mock_session = mocker.MagicMock() + mocker.patch( + "core.callback_handler.index_tool_callback_handler.Session", + return_value=mock_session, + ) + mock_session.__enter__ = mocker.MagicMock(return_value=mock_session) + mock_session.__exit__ = mocker.MagicMock(return_value=False) mock_dataset_doc = mocker.Mock() mock_dataset_doc.doc_form = "OTHER" - mock_db.session.scalar.return_value = mock_dataset_doc + mock_session.scalar.return_value = mock_dataset_doc document = mocker.Mock() document.metadata = { @@ -141,10 +183,18 @@ def test_on_tool_end_non_parent_child_index(self, handler: DatasetIndexToolCallb handler.on_tool_end([document]) - mock_db.session.execute.assert_called_once() - mock_db.session.commit.assert_called_once() + mock_session.execute.assert_called_once() + mock_session.commit.assert_called_once() + + def test_on_tool_end_empty_documents(self, handler: DatasetIndexToolCallbackHandler, mocker: MockerFixture): + mock_session = mocker.MagicMock() + mocker.patch( + "core.callback_handler.index_tool_callback_handler.Session", + return_value=mock_session, + ) + mock_session.__enter__ = mocker.MagicMock(return_value=mock_session) + mock_session.__exit__ = mocker.MagicMock(return_value=False) - def test_on_tool_end_empty_documents(self, handler: DatasetIndexToolCallbackHandler): handler.on_tool_end([]) From 21d301e486a13bba439a5891fa9e01b533d903e5 Mon Sep 17 00:00:00 2001 From: Pranav Agarwal Date: Fri, 3 Jul 2026 00:08:45 +0530 Subject: [PATCH 2/2] Merge latest origin/main into fix/isolate-side-effect-session-writes --- .../core/callback_handler/test_index_tool_callback_handler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py index c71f4aa6691348..62c4ae9d411de0 100644 --- a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py +++ b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py @@ -173,9 +173,7 @@ def test_on_tool_end_parent_child_index_with_child( independent_session.commit.assert_called_once() caller_session.execute.assert_not_called() - def test_on_tool_end_non_parent_child_index( - self, handler: DatasetIndexToolCallbackHandler, mocker: MockerFixture - ): + def test_on_tool_end_non_parent_child_index(self, handler: DatasetIndexToolCallbackHandler, mocker: MockerFixture): caller_session = mocker.Mock() independent_session = mocker.MagicMock()