agstack 1.8.3__tar.gz → 1.9.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {agstack-1.8.3 → agstack-1.9.0}/PKG-INFO +2 -3
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/detect_node.py +6 -4
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/llm_chat_node.py +10 -6
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/llm_embed_node.py +1 -1
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/llm_rerank_node.py +3 -2
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/tool_node.py +5 -2
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/tool.py +4 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack.egg-info/PKG-INFO +2 -3
- {agstack-1.8.3 → agstack-1.9.0}/agstack.egg-info/requires.txt +1 -2
- {agstack-1.8.3 → agstack-1.9.0}/pyproject.toml +2 -3
- {agstack-1.8.3 → agstack-1.9.0}/tests/test_flow_io.py +225 -0
- {agstack-1.8.3 → agstack-1.9.0}/LICENSE +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/README.md +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/__init__.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/config/__init__.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/config/logger.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/config/manager.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/config/types.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/contexts.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/decorators.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/events.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/exceptions.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/fastapi/__init__.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/fastapi/exception.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/fastapi/middleware.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/fastapi/offline.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/fastapi/sse.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/infra/db/__init__.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/infra/es/__init__.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/infra/kg/__init__.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/infra/mq/__init__.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/__init__.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/client.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/__init__.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/agent.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/context.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/event.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/exceptions.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/factory.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/flow.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/loader.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/__init__.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/agent_node.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/base.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/python_node.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/records.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/registry.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/sandbox.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/state.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/prompts.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/token.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/schema.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/security/__init__.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/security/casbin.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/security/crypt.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack/status.py +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack.egg-info/SOURCES.txt +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack.egg-info/dependency_links.txt +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/agstack.egg-info/top_level.txt +0 -0
- {agstack-1.8.3 → agstack-1.9.0}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: agstack
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.9.0
|
|
4
4
|
Summary: Production-ready toolkit for building FastAPI and LLM applications
|
|
5
5
|
Author-email: XtraVisions <gitadmin@xtravisions.com>, Chen Hao <chenhao@xtravisions.com>
|
|
6
6
|
Maintainer-email: XtraVisions <gitadmin@xtravisions.com>, Chen Hao <chenhao@xtravisions.com>
|
|
@@ -34,8 +34,7 @@ Requires-Dist: pydantic>=2.12.4
|
|
|
34
34
|
Requires-Dist: python-multipart>=0.0.20
|
|
35
35
|
Requires-Dist: requests>=2.32.5
|
|
36
36
|
Requires-Dist: RestrictedPython>=7.0
|
|
37
|
-
Requires-Dist:
|
|
38
|
-
Requires-Dist: sqlobjects>=1.6.0
|
|
37
|
+
Requires-Dist: sqlobjects>=1.9.0
|
|
39
38
|
Requires-Dist: tiktoken>=0.12.0
|
|
40
39
|
Requires-Dist: uvicorn>=0.41.0
|
|
41
40
|
Dynamic: license-file
|
|
@@ -47,10 +47,12 @@ class DetectNodeHandler(NodeHandler):
|
|
|
47
47
|
resolved_inputs = self.resolve_inputs(config, context)
|
|
48
48
|
|
|
49
49
|
query = resolved_inputs.get("query", "")
|
|
50
|
-
instruction = config.get("instruction", "Classify the input")
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
50
|
+
instruction = resolved_inputs.get("instruction") or config.get("instruction", "Classify the input")
|
|
51
|
+
_options = resolved_inputs.get("options")
|
|
52
|
+
options = _options if _options is not None else config.get("options", [])
|
|
53
|
+
model = resolved_inputs.get("model") or config.get("model", "gpt-4o-mini")
|
|
54
|
+
_raw_temp = resolved_inputs.get("temperature")
|
|
55
|
+
temperature: float = float(_raw_temp) if _raw_temp is not None else float(config.get("temperature", 0.0))
|
|
54
56
|
|
|
55
57
|
messages = self._build_classification_prompt(instruction, options, query)
|
|
56
58
|
|
|
@@ -47,9 +47,11 @@ class LLMChatNodeHandler(NodeHandler):
|
|
|
47
47
|
resolved_inputs = self.resolve_inputs(config, context)
|
|
48
48
|
prompt_text = self._build_prompt(config.get("prompt", ""), resolved_inputs)
|
|
49
49
|
|
|
50
|
-
model = config.get("model", "gpt-4o")
|
|
51
|
-
|
|
52
|
-
|
|
50
|
+
model = resolved_inputs.get("model") or config.get("model", "gpt-4o")
|
|
51
|
+
_temp = resolved_inputs.get("temperature")
|
|
52
|
+
temperature: float = float(_temp) if _temp is not None else float(config.get("temperature", 0.7))
|
|
53
|
+
_max = resolved_inputs.get("max_tokens")
|
|
54
|
+
max_tokens = _max if _max is not None else config.get("max_tokens")
|
|
53
55
|
|
|
54
56
|
client = get_llm_client()
|
|
55
57
|
messages: list[ChatCompletionMessageParam] = [{"role": "user", "content": prompt_text}]
|
|
@@ -103,9 +105,11 @@ class LLMChatNodeHandler(NodeHandler):
|
|
|
103
105
|
resolved_inputs = self.resolve_inputs(config, context)
|
|
104
106
|
prompt_text = self._build_prompt(config.get("prompt", ""), resolved_inputs)
|
|
105
107
|
|
|
106
|
-
model = config.get("model", "gpt-4o")
|
|
107
|
-
|
|
108
|
-
|
|
108
|
+
model = resolved_inputs.get("model") or config.get("model", "gpt-4o")
|
|
109
|
+
_temp = resolved_inputs.get("temperature")
|
|
110
|
+
temperature: float = float(_temp) if _temp is not None else float(config.get("temperature", 0.7))
|
|
111
|
+
_max = resolved_inputs.get("max_tokens")
|
|
112
|
+
max_tokens = _max if _max is not None else config.get("max_tokens")
|
|
109
113
|
|
|
110
114
|
client = get_llm_client()
|
|
111
115
|
messages: list[ChatCompletionMessageParam] = [{"role": "user", "content": prompt_text}]
|
|
@@ -29,7 +29,7 @@ class LLMEmbedNodeHandler(NodeHandler):
|
|
|
29
29
|
if isinstance(texts, str):
|
|
30
30
|
texts = [texts]
|
|
31
31
|
|
|
32
|
-
model = config.get("model", "bge-m3")
|
|
32
|
+
model = resolved_inputs.get("model") or config.get("model", "bge-m3")
|
|
33
33
|
|
|
34
34
|
client = get_llm_client()
|
|
35
35
|
embeddings = await client.embed(texts=texts, model=model)
|
|
@@ -30,8 +30,9 @@ class LLMRerankNodeHandler(NodeHandler):
|
|
|
30
30
|
if isinstance(documents, str):
|
|
31
31
|
documents = [documents]
|
|
32
32
|
|
|
33
|
-
model = config.get("model", "bge-reranker-v2-m3")
|
|
34
|
-
|
|
33
|
+
model = resolved_inputs.get("model") or config.get("model", "bge-reranker-v2-m3")
|
|
34
|
+
_top_n = resolved_inputs.get("top_n")
|
|
35
|
+
top_n = _top_n if _top_n is not None else config.get("top_n", 10)
|
|
35
36
|
|
|
36
37
|
client = get_llm_client()
|
|
37
38
|
raw_results = await client.rerank(
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
|
|
5
5
|
from typing import TYPE_CHECKING, Any
|
|
6
6
|
|
|
7
|
-
from ..exceptions import FlowError
|
|
7
|
+
from ..exceptions import FlowError, ToolExecutionError
|
|
8
8
|
from ..registry import registry
|
|
9
9
|
from .base import NodeHandler
|
|
10
10
|
|
|
@@ -31,4 +31,7 @@ class ToolNodeHandler(NodeHandler):
|
|
|
31
31
|
config = node.get("config", {})
|
|
32
32
|
resolved = self.resolve_inputs(config, context)
|
|
33
33
|
tool = self._create_tool(config)
|
|
34
|
-
|
|
34
|
+
result = await tool.execute_async(context, inputs=resolved)
|
|
35
|
+
if not result.success:
|
|
36
|
+
raise ToolExecutionError("TOOL_EXECUTION_FAILED", args={"tool_name": tool.name, "error": result.error})
|
|
37
|
+
return result.result
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
"""工具定义和执行"""
|
|
4
4
|
|
|
5
|
+
import logging
|
|
5
6
|
from dataclasses import dataclass
|
|
6
7
|
from typing import TYPE_CHECKING, Any, Callable
|
|
7
8
|
|
|
@@ -9,6 +10,8 @@ from typing import TYPE_CHECKING, Any, Callable
|
|
|
9
10
|
if TYPE_CHECKING:
|
|
10
11
|
from .context import FlowContext
|
|
11
12
|
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
12
15
|
|
|
13
16
|
@dataclass
|
|
14
17
|
class ToolResult:
|
|
@@ -52,6 +55,7 @@ class Tool:
|
|
|
52
55
|
|
|
53
56
|
return ToolResult(name=self.name, arguments=inputs or {}, result=result, success=True)
|
|
54
57
|
except Exception as e:
|
|
58
|
+
logger.warning("Tool %s failed: %s", self.name, e, exc_info=True)
|
|
55
59
|
return ToolResult(
|
|
56
60
|
name=self.name,
|
|
57
61
|
arguments=inputs or {},
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: agstack
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.9.0
|
|
4
4
|
Summary: Production-ready toolkit for building FastAPI and LLM applications
|
|
5
5
|
Author-email: XtraVisions <gitadmin@xtravisions.com>, Chen Hao <chenhao@xtravisions.com>
|
|
6
6
|
Maintainer-email: XtraVisions <gitadmin@xtravisions.com>, Chen Hao <chenhao@xtravisions.com>
|
|
@@ -34,8 +34,7 @@ Requires-Dist: pydantic>=2.12.4
|
|
|
34
34
|
Requires-Dist: python-multipart>=0.0.20
|
|
35
35
|
Requires-Dist: requests>=2.32.5
|
|
36
36
|
Requires-Dist: RestrictedPython>=7.0
|
|
37
|
-
Requires-Dist:
|
|
38
|
-
Requires-Dist: sqlobjects>=1.6.0
|
|
37
|
+
Requires-Dist: sqlobjects>=1.9.0
|
|
39
38
|
Requires-Dist: tiktoken>=0.12.0
|
|
40
39
|
Requires-Dist: uvicorn>=0.41.0
|
|
41
40
|
Dynamic: license-file
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "agstack"
|
|
3
|
-
version = "1.
|
|
3
|
+
version = "1.9.0"
|
|
4
4
|
description = "Production-ready toolkit for building FastAPI and LLM applications"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
license = "MIT"
|
|
@@ -53,8 +53,7 @@ dependencies = [
|
|
|
53
53
|
"python-multipart>=0.0.20",
|
|
54
54
|
"requests>=2.32.5",
|
|
55
55
|
"RestrictedPython>=7.0",
|
|
56
|
-
"
|
|
57
|
-
"sqlobjects>=1.6.0",
|
|
56
|
+
"sqlobjects>=1.9.0",
|
|
58
57
|
"tiktoken>=0.12.0",
|
|
59
58
|
"uvicorn>=0.41.0",
|
|
60
59
|
]
|
|
@@ -257,6 +257,76 @@ class TestDetectNodeHandler:
|
|
|
257
257
|
assert isinstance(result, dict)
|
|
258
258
|
assert result == {"choice": "qa"}
|
|
259
259
|
|
|
260
|
+
@patch("agstack.llm.flow.nodes.detect_node.get_llm_client")
|
|
261
|
+
def test_dynamic_instruction_and_options(self, mock_get_client):
|
|
262
|
+
from agstack.llm.flow.nodes.detect_node import DetectNodeHandler
|
|
263
|
+
|
|
264
|
+
mock_response = MagicMock()
|
|
265
|
+
mock_choice = MagicMock()
|
|
266
|
+
mock_choice.message.content = '{"result": "billing"}'
|
|
267
|
+
mock_response.choices = [mock_choice]
|
|
268
|
+
mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
|
269
|
+
|
|
270
|
+
mock_client = AsyncMock()
|
|
271
|
+
mock_client.chat = AsyncMock(return_value=mock_response)
|
|
272
|
+
mock_get_client.return_value = mock_client
|
|
273
|
+
|
|
274
|
+
handler = DetectNodeHandler()
|
|
275
|
+
ctx = FlowContext(
|
|
276
|
+
variables={
|
|
277
|
+
"my_instruction": "classify ticket type",
|
|
278
|
+
"my_options": ["billing", "technical", "general"],
|
|
279
|
+
}
|
|
280
|
+
)
|
|
281
|
+
node = {
|
|
282
|
+
"id": "detect1",
|
|
283
|
+
"type": "detect",
|
|
284
|
+
"config": {
|
|
285
|
+
"inputs": {
|
|
286
|
+
"query": "$v.user_query",
|
|
287
|
+
"instruction": "$v.my_instruction",
|
|
288
|
+
"options": "$v.my_options",
|
|
289
|
+
},
|
|
290
|
+
},
|
|
291
|
+
}
|
|
292
|
+
ctx.variables["user_query"] = "I was charged twice"
|
|
293
|
+
result = asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
|
|
294
|
+
assert result == {"choice": "billing"}
|
|
295
|
+
|
|
296
|
+
@patch("agstack.llm.flow.nodes.detect_node.get_llm_client")
|
|
297
|
+
def test_dynamic_model_and_temperature(self, mock_get_client):
|
|
298
|
+
from agstack.llm.flow.nodes.detect_node import DetectNodeHandler
|
|
299
|
+
|
|
300
|
+
mock_response = MagicMock()
|
|
301
|
+
mock_choice = MagicMock()
|
|
302
|
+
mock_choice.message.content = '{"result": "qa"}'
|
|
303
|
+
mock_response.choices = [mock_choice]
|
|
304
|
+
mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
|
305
|
+
|
|
306
|
+
mock_client = AsyncMock()
|
|
307
|
+
mock_client.chat = AsyncMock(return_value=mock_response)
|
|
308
|
+
mock_get_client.return_value = mock_client
|
|
309
|
+
|
|
310
|
+
handler = DetectNodeHandler()
|
|
311
|
+
ctx = FlowContext(variables={"chosen_model": "qwen2.5-72b", "temp": 0.1})
|
|
312
|
+
node = {
|
|
313
|
+
"id": "detect2",
|
|
314
|
+
"type": "detect",
|
|
315
|
+
"config": {
|
|
316
|
+
"options": ["qa", "chitchat"],
|
|
317
|
+
"inputs": {
|
|
318
|
+
"query": "hello",
|
|
319
|
+
"model": "$v.chosen_model",
|
|
320
|
+
"temperature": "$v.temp",
|
|
321
|
+
},
|
|
322
|
+
},
|
|
323
|
+
}
|
|
324
|
+
result = asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
|
|
325
|
+
call_args = mock_client.chat.call_args
|
|
326
|
+
assert call_args.kwargs["model"] == "qwen2.5-72b"
|
|
327
|
+
assert call_args.kwargs["temperature"] == 0.1
|
|
328
|
+
assert result == {"choice": "qa"}
|
|
329
|
+
|
|
260
330
|
|
|
261
331
|
# ── LLMChatNodeHandler ──
|
|
262
332
|
|
|
@@ -344,6 +414,47 @@ class TestLLMChatNodeHandler:
|
|
|
344
414
|
assert len(system_msg) == 1
|
|
345
415
|
assert system_msg[0]["content"] == "You speak Chinese"
|
|
346
416
|
|
|
417
|
+
@patch("agstack.llm.flow.nodes.llm_chat_node.get_llm_client")
|
|
418
|
+
def test_dynamic_model_temperature_max_tokens(self, mock_get_client):
|
|
419
|
+
from agstack.llm.flow.nodes.llm_chat_node import LLMChatNodeHandler
|
|
420
|
+
|
|
421
|
+
mock_response = MagicMock()
|
|
422
|
+
mock_choice = MagicMock()
|
|
423
|
+
mock_choice.message.content = "response"
|
|
424
|
+
mock_response.choices = [mock_choice]
|
|
425
|
+
mock_response.usage = MagicMock(prompt_tokens=5, completion_tokens=3, total_tokens=8)
|
|
426
|
+
|
|
427
|
+
mock_client = AsyncMock()
|
|
428
|
+
mock_client.chat = AsyncMock(return_value=mock_response)
|
|
429
|
+
mock_get_client.return_value = mock_client
|
|
430
|
+
|
|
431
|
+
handler = LLMChatNodeHandler()
|
|
432
|
+
ctx = FlowContext(
|
|
433
|
+
variables={
|
|
434
|
+
"chosen_model": "qwen2.5-72b",
|
|
435
|
+
"temp": 0.2,
|
|
436
|
+
"max_tok": 512,
|
|
437
|
+
}
|
|
438
|
+
)
|
|
439
|
+
node = {
|
|
440
|
+
"id": "chat1",
|
|
441
|
+
"type": "llm_chat",
|
|
442
|
+
"config": {
|
|
443
|
+
"prompt": "Hello",
|
|
444
|
+
"model": "gpt-4o",
|
|
445
|
+
"inputs": {
|
|
446
|
+
"model": "$v.chosen_model",
|
|
447
|
+
"temperature": "$v.temp",
|
|
448
|
+
"max_tokens": "$v.max_tok",
|
|
449
|
+
},
|
|
450
|
+
},
|
|
451
|
+
}
|
|
452
|
+
asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
|
|
453
|
+
call_args = mock_client.chat.call_args
|
|
454
|
+
assert call_args.kwargs["model"] == "qwen2.5-72b"
|
|
455
|
+
assert call_args.kwargs["temperature"] == 0.2
|
|
456
|
+
assert call_args.kwargs["max_tokens"] == 512
|
|
457
|
+
|
|
347
458
|
|
|
348
459
|
# ── Flow routing ──
|
|
349
460
|
|
|
@@ -433,3 +544,117 @@ class TestDataFlowIntegration:
|
|
|
433
544
|
}
|
|
434
545
|
result = asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
|
|
435
546
|
assert result == {"result": 30}
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
class TestLLMRerankNodeHandler:
|
|
550
|
+
"""LLM Rerank 节点动态参数测试"""
|
|
551
|
+
|
|
552
|
+
@patch("agstack.llm.flow.nodes.llm_rerank_node.get_llm_client")
|
|
553
|
+
def test_dynamic_model_and_top_n(self, mock_get_client):
|
|
554
|
+
from agstack.llm.flow.nodes.llm_rerank_node import LLMRerankNodeHandler
|
|
555
|
+
|
|
556
|
+
mock_client = AsyncMock()
|
|
557
|
+
mock_client.rerank = AsyncMock(return_value=[(0, 0.95, "doc A"), (1, 0.80, "doc B")])
|
|
558
|
+
mock_get_client.return_value = mock_client
|
|
559
|
+
|
|
560
|
+
handler = LLMRerankNodeHandler()
|
|
561
|
+
ctx = FlowContext(variables={"rerank_model": "bge-reranker-large", "topk": 2})
|
|
562
|
+
node = {
|
|
563
|
+
"id": "rerank1",
|
|
564
|
+
"type": "llm_rerank",
|
|
565
|
+
"config": {
|
|
566
|
+
"model": "bge-reranker-v2-m3",
|
|
567
|
+
"inputs": {
|
|
568
|
+
"query": "best python book",
|
|
569
|
+
"documents": ["doc A", "doc B", "doc C"],
|
|
570
|
+
"model": "$v.rerank_model",
|
|
571
|
+
"top_n": "$v.topk",
|
|
572
|
+
},
|
|
573
|
+
},
|
|
574
|
+
}
|
|
575
|
+
result = asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
|
|
576
|
+
call_args = mock_client.rerank.call_args
|
|
577
|
+
assert call_args.kwargs["model"] == "bge-reranker-large"
|
|
578
|
+
assert call_args.kwargs["top_n"] == 2
|
|
579
|
+
assert result == {
|
|
580
|
+
"results": [{"index": 0, "score": 0.95, "text": "doc A"}, {"index": 1, "score": 0.80, "text": "doc B"}]
|
|
581
|
+
}
|
|
582
|
+
|
|
583
|
+
@patch("agstack.llm.flow.nodes.llm_rerank_node.get_llm_client")
|
|
584
|
+
def test_static_fallback_still_works(self, mock_get_client):
|
|
585
|
+
from agstack.llm.flow.nodes.llm_rerank_node import LLMRerankNodeHandler
|
|
586
|
+
|
|
587
|
+
mock_client = AsyncMock()
|
|
588
|
+
mock_client.rerank = AsyncMock(return_value=[(0, 0.9, "doc A")])
|
|
589
|
+
mock_get_client.return_value = mock_client
|
|
590
|
+
|
|
591
|
+
handler = LLMRerankNodeHandler()
|
|
592
|
+
ctx = FlowContext()
|
|
593
|
+
node = {
|
|
594
|
+
"id": "rerank2",
|
|
595
|
+
"type": "llm_rerank",
|
|
596
|
+
"config": {
|
|
597
|
+
"model": "bge-reranker-v2-m3",
|
|
598
|
+
"top_n": 5,
|
|
599
|
+
"inputs": {
|
|
600
|
+
"query": "test",
|
|
601
|
+
"documents": ["doc A"],
|
|
602
|
+
},
|
|
603
|
+
},
|
|
604
|
+
}
|
|
605
|
+
asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
|
|
606
|
+
call_args = mock_client.rerank.call_args
|
|
607
|
+
assert call_args.kwargs["model"] == "bge-reranker-v2-m3"
|
|
608
|
+
assert call_args.kwargs["top_n"] == 5
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
class TestLLMEmbedNodeHandler:
|
|
612
|
+
"""LLM Embed 节点动态参数测试"""
|
|
613
|
+
|
|
614
|
+
@patch("agstack.llm.flow.nodes.llm_embed_node.get_llm_client")
|
|
615
|
+
def test_dynamic_model(self, mock_get_client):
|
|
616
|
+
from agstack.llm.flow.nodes.llm_embed_node import LLMEmbedNodeHandler
|
|
617
|
+
|
|
618
|
+
mock_client = AsyncMock()
|
|
619
|
+
mock_client.embed = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
|
|
620
|
+
mock_get_client.return_value = mock_client
|
|
621
|
+
|
|
622
|
+
handler = LLMEmbedNodeHandler()
|
|
623
|
+
ctx = FlowContext(variables={"embed_model": "text-embedding-3-large"})
|
|
624
|
+
node = {
|
|
625
|
+
"id": "embed1",
|
|
626
|
+
"type": "llm_embed",
|
|
627
|
+
"config": {
|
|
628
|
+
"model": "bge-m3",
|
|
629
|
+
"inputs": {
|
|
630
|
+
"texts": ["hello world"],
|
|
631
|
+
"model": "$v.embed_model",
|
|
632
|
+
},
|
|
633
|
+
},
|
|
634
|
+
}
|
|
635
|
+
result = asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
|
|
636
|
+
call_args = mock_client.embed.call_args
|
|
637
|
+
assert call_args.kwargs["model"] == "text-embedding-3-large"
|
|
638
|
+
assert result == {"embeddings": [[0.1, 0.2, 0.3]]}
|
|
639
|
+
|
|
640
|
+
@patch("agstack.llm.flow.nodes.llm_embed_node.get_llm_client")
|
|
641
|
+
def test_static_model_fallback(self, mock_get_client):
|
|
642
|
+
from agstack.llm.flow.nodes.llm_embed_node import LLMEmbedNodeHandler
|
|
643
|
+
|
|
644
|
+
mock_client = AsyncMock()
|
|
645
|
+
mock_client.embed = AsyncMock(return_value=[[0.1, 0.2]])
|
|
646
|
+
mock_get_client.return_value = mock_client
|
|
647
|
+
|
|
648
|
+
handler = LLMEmbedNodeHandler()
|
|
649
|
+
ctx = FlowContext()
|
|
650
|
+
node = {
|
|
651
|
+
"id": "embed2",
|
|
652
|
+
"type": "llm_embed",
|
|
653
|
+
"config": {
|
|
654
|
+
"model": "bge-m3",
|
|
655
|
+
"inputs": {"texts": ["hello"]},
|
|
656
|
+
},
|
|
657
|
+
}
|
|
658
|
+
asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
|
|
659
|
+
call_args = mock_client.embed.call_args
|
|
660
|
+
assert call_args.kwargs["model"] == "bge-m3"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|