Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/dify_plugin/core/entities/plugin/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,15 +205,10 @@ class ModelStartPollingRequest(ModelInvokeLLMRequest):
action: ModelActions = ModelActions.StartPolling
stream: Literal[False] = False

workflow_run_id: str
node_id: str


class ModelCheckPollingRequest(PluginAccessModelRequest):
action: ModelActions = ModelActions.CheckPolling

workflow_run_id: str
node_id: str
plugin_state: dict[str, JsonValue] = Field(min_length=1)


Expand Down
4 changes: 0 additions & 4 deletions src/dify_plugin/core/plugin_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,6 @@ def start_llm_polling(
stream=data.stream,
user=data.user_id,
json_schema=data.json_schema,
workflow_run_id=data.workflow_run_id,
node_id=data.node_id,
)

def check_llm_polling(
Expand Down Expand Up @@ -328,8 +326,6 @@ def check_llm_polling(
credentials=data.credentials,
plugin_state=data.plugin_state,
user=data.user_id,
workflow_run_id=data.workflow_run_id,
node_id=data.node_id,
)

def get_llm_num_tokens(
Expand Down
19 changes: 1 addition & 18 deletions src/dify_plugin/interfaces/model/large_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ def _start_polling(
stream: Literal[False] = False,
user: str | None = None,
*,
workflow_run_id: str,
node_id: str,
json_schema: dict[str, JsonValue] | None = None,
) -> LLMPollingResult:
"""Start a polling-based large language model invocation."""
Expand All @@ -103,8 +101,6 @@ def _start_polling(
stop,
stream,
user,
workflow_run_id,
node_id,
json_schema,
)
raise NotImplementedError
Expand All @@ -115,12 +111,9 @@ def _check_polling(
credentials: dict,
plugin_state: dict[str, JsonValue],
user: str | None = None,
*,
workflow_run_id: str,
node_id: str,
) -> LLMPollingResult:
"""Check a polling-based large language model invocation."""
del model, credentials, plugin_state, user, workflow_run_id, node_id
del model, credentials, plugin_state, user
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -769,9 +762,6 @@ def start_polling(
stream: Literal[False] = False,
user: str | None = None,
json_schema: dict[str, JsonValue] | None = None,
*,
workflow_run_id: str,
node_id: str,
) -> LLMPollingResult:
Comment thread
QuantumGhost marked this conversation as resolved.
"""Start a polling-based large language model invocation."""
if not self.supports_polling(model, credentials):
Expand All @@ -798,8 +788,6 @@ def start_polling(
stop=stop,
stream=stream,
user=user,
workflow_run_id=workflow_run_id,
node_id=node_id,
json_schema=json_schema,
)
except Exception as e:
Expand All @@ -811,9 +799,6 @@ def check_polling(
credentials: dict,
plugin_state: dict[str, JsonValue],
user: str | None = None,
*,
workflow_run_id: str,
node_id: str,
) -> LLMPollingResult:
"""Check a polling-based large language model invocation."""
if not self.supports_polling(model, credentials):
Expand All @@ -827,8 +812,6 @@ def check_polling(
credentials=credentials,
plugin_state=plugin_state,
user=user,
workflow_run_id=workflow_run_id,
node_id=node_id,
)
except Exception as e:
raise self._transform_invoke_error(e) from e
Expand Down
29 changes: 8 additions & 21 deletions tests/test_model_polling.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ class PollingScenario:
provider: str = "provider"
model: str = "llm"
api_key: str = "key"
workflow_run_id: str = "wr-1"
node_id: str = "node-1"
job_id: str = "job-1"
prompt_content: str = "hello"
result_content: str = "done"
Expand Down Expand Up @@ -96,8 +94,6 @@ def start_request(
"model_parameters": model_parameters or {},
"stop": [],
"tools": [],
"workflow_run_id": self.workflow_run_id,
"node_id": self.node_id,
}
if json_schema is not None:
data["json_schema"] = json_schema
Expand All @@ -117,8 +113,6 @@ def check_request(
"model_type": ModelType.LLM,
"model": self.model,
"credentials": self.credentials,
"workflow_run_id": self.workflow_run_id,
"node_id": self.node_id,
"plugin_state": plugin_state or self.plugin_state,
}
return ModelCheckPollingRequest(**data)
Expand Down Expand Up @@ -199,8 +193,6 @@ def _start_polling(
stream: Literal[False] = False,
user: str | None = None,
*,
workflow_run_id: str,
node_id: str,
json_schema: dict[str, JsonValue] | None = None,
) -> LLMPollingResult:
self.start_call = {
Expand All @@ -212,8 +204,6 @@ def _start_polling(
"stop": stop,
"stream": stream,
"user": user,
"workflow_run_id": workflow_run_id,
"node_id": node_id,
"json_schema": json_schema,
}
return LLMPollingResult(
Expand All @@ -230,17 +220,12 @@ def _check_polling(
credentials: dict,
plugin_state: dict[str, JsonValue],
user: str | None = None,
*,
workflow_run_id: str,
node_id: str,
) -> LLMPollingResult:
self.check_call = {
"model": model,
"credentials": credentials,
"plugin_state": plugin_state,
"user": user,
"workflow_run_id": workflow_run_id,
"node_id": node_id,
}
return LLMPollingResult(
status=LLMPollingStatus.SUCCEEDED,
Expand Down Expand Up @@ -275,10 +260,14 @@ def test_polling_requests_parse_daemon_payloads() -> None:
assert start_request.stream is False
assert isinstance(start_request.prompt_messages[0], UserPromptMessage)
assert start_request.json_schema == scenario.json_schema
assert "workflow_run_id" not in start_request.model_dump()
assert "node_id" not in start_request.model_dump()

check_request = scenario.check_request()
assert check_request.action == ModelActions.CheckPolling
assert check_request.plugin_state == scenario.plugin_state
assert "workflow_run_id" not in check_request.model_dump()
assert "node_id" not in check_request.model_dump()


def test_start_polling_request_rejects_streaming() -> None:
Expand All @@ -299,8 +288,6 @@ def test_check_polling_request_rejects_empty_plugin_state() -> None:
"model_type": ModelType.LLM,
"model": scenario.model,
"credentials": scenario.credentials,
"workflow_run_id": scenario.workflow_run_id,
"node_id": scenario.node_id,
"plugin_state": {},
}

Expand Down Expand Up @@ -329,10 +316,10 @@ def test_executor_starts_llm_polling() -> None:
assert response.max_attempts == scenario.max_attempts
assert model.start_call is not None
assert model.supports_polling(scenario.model, scenario.credentials)
assert model.start_call["workflow_run_id"] == scenario.workflow_run_id
assert model.start_call["node_id"] == scenario.node_id
assert model.start_call["json_schema"] == scenario.json_schema
assert model.start_call["model_parameters"] == {}
assert "workflow_run_id" not in model.start_call
assert "node_id" not in model.start_call


def test_executor_checks_llm_polling() -> None:
Expand All @@ -351,8 +338,8 @@ def test_executor_checks_llm_polling() -> None:
assert response.result.message.content == scenario.result_content
assert model.check_call is not None
assert model.check_call["plugin_state"] == scenario.plugin_state
assert model.check_call["workflow_run_id"] == scenario.workflow_run_id
assert model.check_call["node_id"] == scenario.node_id
assert "workflow_run_id" not in model.check_call
assert "node_id" not in model.check_call


def test_executor_rejects_llm_without_polling_feature() -> None:
Expand Down