agstack 1.8.3__tar.gz → 1.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. {agstack-1.8.3 → agstack-1.9.0}/PKG-INFO +2 -3
  2. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/detect_node.py +6 -4
  3. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/llm_chat_node.py +10 -6
  4. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/llm_embed_node.py +1 -1
  5. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/llm_rerank_node.py +3 -2
  6. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/tool_node.py +5 -2
  7. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/tool.py +4 -0
  8. {agstack-1.8.3 → agstack-1.9.0}/agstack.egg-info/PKG-INFO +2 -3
  9. {agstack-1.8.3 → agstack-1.9.0}/agstack.egg-info/requires.txt +1 -2
  10. {agstack-1.8.3 → agstack-1.9.0}/pyproject.toml +2 -3
  11. {agstack-1.8.3 → agstack-1.9.0}/tests/test_flow_io.py +225 -0
  12. {agstack-1.8.3 → agstack-1.9.0}/LICENSE +0 -0
  13. {agstack-1.8.3 → agstack-1.9.0}/README.md +0 -0
  14. {agstack-1.8.3 → agstack-1.9.0}/agstack/__init__.py +0 -0
  15. {agstack-1.8.3 → agstack-1.9.0}/agstack/config/__init__.py +0 -0
  16. {agstack-1.8.3 → agstack-1.9.0}/agstack/config/logger.py +0 -0
  17. {agstack-1.8.3 → agstack-1.9.0}/agstack/config/manager.py +0 -0
  18. {agstack-1.8.3 → agstack-1.9.0}/agstack/config/types.py +0 -0
  19. {agstack-1.8.3 → agstack-1.9.0}/agstack/contexts.py +0 -0
  20. {agstack-1.8.3 → agstack-1.9.0}/agstack/decorators.py +0 -0
  21. {agstack-1.8.3 → agstack-1.9.0}/agstack/events.py +0 -0
  22. {agstack-1.8.3 → agstack-1.9.0}/agstack/exceptions.py +0 -0
  23. {agstack-1.8.3 → agstack-1.9.0}/agstack/fastapi/__init__.py +0 -0
  24. {agstack-1.8.3 → agstack-1.9.0}/agstack/fastapi/exception.py +0 -0
  25. {agstack-1.8.3 → agstack-1.9.0}/agstack/fastapi/middleware.py +0 -0
  26. {agstack-1.8.3 → agstack-1.9.0}/agstack/fastapi/offline.py +0 -0
  27. {agstack-1.8.3 → agstack-1.9.0}/agstack/fastapi/sse.py +0 -0
  28. {agstack-1.8.3 → agstack-1.9.0}/agstack/infra/db/__init__.py +0 -0
  29. {agstack-1.8.3 → agstack-1.9.0}/agstack/infra/es/__init__.py +0 -0
  30. {agstack-1.8.3 → agstack-1.9.0}/agstack/infra/kg/__init__.py +0 -0
  31. {agstack-1.8.3 → agstack-1.9.0}/agstack/infra/mq/__init__.py +0 -0
  32. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/__init__.py +0 -0
  33. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/client.py +0 -0
  34. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/__init__.py +0 -0
  35. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/agent.py +0 -0
  36. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/context.py +0 -0
  37. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/event.py +0 -0
  38. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/exceptions.py +0 -0
  39. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/factory.py +0 -0
  40. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/flow.py +0 -0
  41. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/loader.py +0 -0
  42. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/__init__.py +0 -0
  43. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/agent_node.py +0 -0
  44. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/base.py +0 -0
  45. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/nodes/python_node.py +0 -0
  46. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/records.py +0 -0
  47. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/registry.py +0 -0
  48. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/sandbox.py +0 -0
  49. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/flow/state.py +0 -0
  50. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/prompts.py +0 -0
  51. {agstack-1.8.3 → agstack-1.9.0}/agstack/llm/token.py +0 -0
  52. {agstack-1.8.3 → agstack-1.9.0}/agstack/schema.py +0 -0
  53. {agstack-1.8.3 → agstack-1.9.0}/agstack/security/__init__.py +0 -0
  54. {agstack-1.8.3 → agstack-1.9.0}/agstack/security/casbin.py +0 -0
  55. {agstack-1.8.3 → agstack-1.9.0}/agstack/security/crypt.py +0 -0
  56. {agstack-1.8.3 → agstack-1.9.0}/agstack/status.py +0 -0
  57. {agstack-1.8.3 → agstack-1.9.0}/agstack.egg-info/SOURCES.txt +0 -0
  58. {agstack-1.8.3 → agstack-1.9.0}/agstack.egg-info/dependency_links.txt +0 -0
  59. {agstack-1.8.3 → agstack-1.9.0}/agstack.egg-info/top_level.txt +0 -0
  60. {agstack-1.8.3 → agstack-1.9.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: agstack
3
- Version: 1.8.3
3
+ Version: 1.9.0
4
4
  Summary: Production-ready toolkit for building FastAPI and LLM applications
5
5
  Author-email: XtraVisions <gitadmin@xtravisions.com>, Chen Hao <chenhao@xtravisions.com>
6
6
  Maintainer-email: XtraVisions <gitadmin@xtravisions.com>, Chen Hao <chenhao@xtravisions.com>
@@ -34,8 +34,7 @@ Requires-Dist: pydantic>=2.12.4
34
34
  Requires-Dist: python-multipart>=0.0.20
35
35
  Requires-Dist: requests>=2.32.5
36
36
  Requires-Dist: RestrictedPython>=7.0
37
- Requires-Dist: sqlalchemy[asyncio]>=2.0.48
38
- Requires-Dist: sqlobjects>=1.6.0
37
+ Requires-Dist: sqlobjects>=1.9.0
39
38
  Requires-Dist: tiktoken>=0.12.0
40
39
  Requires-Dist: uvicorn>=0.41.0
41
40
  Dynamic: license-file
@@ -47,10 +47,12 @@ class DetectNodeHandler(NodeHandler):
47
47
  resolved_inputs = self.resolve_inputs(config, context)
48
48
 
49
49
  query = resolved_inputs.get("query", "")
50
- instruction = config.get("instruction", "Classify the input")
51
- options = config.get("options", [])
52
- model = config.get("model", "gpt-4o-mini")
53
- temperature = config.get("temperature", 0.0)
50
+ instruction = resolved_inputs.get("instruction") or config.get("instruction", "Classify the input")
51
+ _options = resolved_inputs.get("options")
52
+ options = _options if _options is not None else config.get("options", [])
53
+ model = resolved_inputs.get("model") or config.get("model", "gpt-4o-mini")
54
+ _raw_temp = resolved_inputs.get("temperature")
55
+ temperature: float = float(_raw_temp) if _raw_temp is not None else float(config.get("temperature", 0.0))
54
56
 
55
57
  messages = self._build_classification_prompt(instruction, options, query)
56
58
 
@@ -47,9 +47,11 @@ class LLMChatNodeHandler(NodeHandler):
47
47
  resolved_inputs = self.resolve_inputs(config, context)
48
48
  prompt_text = self._build_prompt(config.get("prompt", ""), resolved_inputs)
49
49
 
50
- model = config.get("model", "gpt-4o")
51
- temperature = config.get("temperature", 0.7)
52
- max_tokens = config.get("max_tokens")
50
+ model = resolved_inputs.get("model") or config.get("model", "gpt-4o")
51
+ _temp = resolved_inputs.get("temperature")
52
+ temperature: float = float(_temp) if _temp is not None else float(config.get("temperature", 0.7))
53
+ _max = resolved_inputs.get("max_tokens")
54
+ max_tokens = _max if _max is not None else config.get("max_tokens")
53
55
 
54
56
  client = get_llm_client()
55
57
  messages: list[ChatCompletionMessageParam] = [{"role": "user", "content": prompt_text}]
@@ -103,9 +105,11 @@ class LLMChatNodeHandler(NodeHandler):
103
105
  resolved_inputs = self.resolve_inputs(config, context)
104
106
  prompt_text = self._build_prompt(config.get("prompt", ""), resolved_inputs)
105
107
 
106
- model = config.get("model", "gpt-4o")
107
- temperature = config.get("temperature", 0.7)
108
- max_tokens = config.get("max_tokens")
108
+ model = resolved_inputs.get("model") or config.get("model", "gpt-4o")
109
+ _temp = resolved_inputs.get("temperature")
110
+ temperature: float = float(_temp) if _temp is not None else float(config.get("temperature", 0.7))
111
+ _max = resolved_inputs.get("max_tokens")
112
+ max_tokens = _max if _max is not None else config.get("max_tokens")
109
113
 
110
114
  client = get_llm_client()
111
115
  messages: list[ChatCompletionMessageParam] = [{"role": "user", "content": prompt_text}]
@@ -29,7 +29,7 @@ class LLMEmbedNodeHandler(NodeHandler):
29
29
  if isinstance(texts, str):
30
30
  texts = [texts]
31
31
 
32
- model = config.get("model", "bge-m3")
32
+ model = resolved_inputs.get("model") or config.get("model", "bge-m3")
33
33
 
34
34
  client = get_llm_client()
35
35
  embeddings = await client.embed(texts=texts, model=model)
@@ -30,8 +30,9 @@ class LLMRerankNodeHandler(NodeHandler):
30
30
  if isinstance(documents, str):
31
31
  documents = [documents]
32
32
 
33
- model = config.get("model", "bge-reranker-v2-m3")
34
- top_n = config.get("top_n", 10)
33
+ model = resolved_inputs.get("model") or config.get("model", "bge-reranker-v2-m3")
34
+ _top_n = resolved_inputs.get("top_n")
35
+ top_n = _top_n if _top_n is not None else config.get("top_n", 10)
35
36
 
36
37
  client = get_llm_client()
37
38
  raw_results = await client.rerank(
@@ -4,7 +4,7 @@
4
4
 
5
5
  from typing import TYPE_CHECKING, Any
6
6
 
7
- from ..exceptions import FlowError
7
+ from ..exceptions import FlowError, ToolExecutionError
8
8
  from ..registry import registry
9
9
  from .base import NodeHandler
10
10
 
@@ -31,4 +31,7 @@ class ToolNodeHandler(NodeHandler):
31
31
  config = node.get("config", {})
32
32
  resolved = self.resolve_inputs(config, context)
33
33
  tool = self._create_tool(config)
34
- return await tool.run(context, inputs=resolved)
34
+ result = await tool.execute_async(context, inputs=resolved)
35
+ if not result.success:
36
+ raise ToolExecutionError("TOOL_EXECUTION_FAILED", args={"tool_name": tool.name, "error": result.error})
37
+ return result.result
@@ -2,6 +2,7 @@
2
2
 
3
3
  """工具定义和执行"""
4
4
 
5
+ import logging
5
6
  from dataclasses import dataclass
6
7
  from typing import TYPE_CHECKING, Any, Callable
7
8
 
@@ -9,6 +10,8 @@ from typing import TYPE_CHECKING, Any, Callable
9
10
  if TYPE_CHECKING:
10
11
  from .context import FlowContext
11
12
 
13
+ logger = logging.getLogger(__name__)
14
+
12
15
 
13
16
  @dataclass
14
17
  class ToolResult:
@@ -52,6 +55,7 @@ class Tool:
52
55
 
53
56
  return ToolResult(name=self.name, arguments=inputs or {}, result=result, success=True)
54
57
  except Exception as e:
58
+ logger.warning("Tool %s failed: %s", self.name, e, exc_info=True)
55
59
  return ToolResult(
56
60
  name=self.name,
57
61
  arguments=inputs or {},
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: agstack
3
- Version: 1.8.3
3
+ Version: 1.9.0
4
4
  Summary: Production-ready toolkit for building FastAPI and LLM applications
5
5
  Author-email: XtraVisions <gitadmin@xtravisions.com>, Chen Hao <chenhao@xtravisions.com>
6
6
  Maintainer-email: XtraVisions <gitadmin@xtravisions.com>, Chen Hao <chenhao@xtravisions.com>
@@ -34,8 +34,7 @@ Requires-Dist: pydantic>=2.12.4
34
34
  Requires-Dist: python-multipart>=0.0.20
35
35
  Requires-Dist: requests>=2.32.5
36
36
  Requires-Dist: RestrictedPython>=7.0
37
- Requires-Dist: sqlalchemy[asyncio]>=2.0.48
38
- Requires-Dist: sqlobjects>=1.6.0
37
+ Requires-Dist: sqlobjects>=1.9.0
39
38
  Requires-Dist: tiktoken>=0.12.0
40
39
  Requires-Dist: uvicorn>=0.41.0
41
40
  Dynamic: license-file
@@ -12,7 +12,6 @@ pydantic>=2.12.4
12
12
  python-multipart>=0.0.20
13
13
  requests>=2.32.5
14
14
  RestrictedPython>=7.0
15
- sqlalchemy[asyncio]>=2.0.48
16
- sqlobjects>=1.6.0
15
+ sqlobjects>=1.9.0
17
16
  tiktoken>=0.12.0
18
17
  uvicorn>=0.41.0
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "agstack"
3
- version = "1.8.3"
3
+ version = "1.9.0"
4
4
  description = "Production-ready toolkit for building FastAPI and LLM applications"
5
5
  readme = "README.md"
6
6
  license = "MIT"
@@ -53,8 +53,7 @@ dependencies = [
53
53
  "python-multipart>=0.0.20",
54
54
  "requests>=2.32.5",
55
55
  "RestrictedPython>=7.0",
56
- "sqlalchemy[asyncio]>=2.0.48",
57
- "sqlobjects>=1.6.0",
56
+ "sqlobjects>=1.9.0",
58
57
  "tiktoken>=0.12.0",
59
58
  "uvicorn>=0.41.0",
60
59
  ]
@@ -257,6 +257,76 @@ class TestDetectNodeHandler:
257
257
  assert isinstance(result, dict)
258
258
  assert result == {"choice": "qa"}
259
259
 
260
+ @patch("agstack.llm.flow.nodes.detect_node.get_llm_client")
261
+ def test_dynamic_instruction_and_options(self, mock_get_client):
262
+ from agstack.llm.flow.nodes.detect_node import DetectNodeHandler
263
+
264
+ mock_response = MagicMock()
265
+ mock_choice = MagicMock()
266
+ mock_choice.message.content = '{"result": "billing"}'
267
+ mock_response.choices = [mock_choice]
268
+ mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15)
269
+
270
+ mock_client = AsyncMock()
271
+ mock_client.chat = AsyncMock(return_value=mock_response)
272
+ mock_get_client.return_value = mock_client
273
+
274
+ handler = DetectNodeHandler()
275
+ ctx = FlowContext(
276
+ variables={
277
+ "my_instruction": "classify ticket type",
278
+ "my_options": ["billing", "technical", "general"],
279
+ }
280
+ )
281
+ node = {
282
+ "id": "detect1",
283
+ "type": "detect",
284
+ "config": {
285
+ "inputs": {
286
+ "query": "$v.user_query",
287
+ "instruction": "$v.my_instruction",
288
+ "options": "$v.my_options",
289
+ },
290
+ },
291
+ }
292
+ ctx.variables["user_query"] = "I was charged twice"
293
+ result = asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
294
+ assert result == {"choice": "billing"}
295
+
296
+ @patch("agstack.llm.flow.nodes.detect_node.get_llm_client")
297
+ def test_dynamic_model_and_temperature(self, mock_get_client):
298
+ from agstack.llm.flow.nodes.detect_node import DetectNodeHandler
299
+
300
+ mock_response = MagicMock()
301
+ mock_choice = MagicMock()
302
+ mock_choice.message.content = '{"result": "qa"}'
303
+ mock_response.choices = [mock_choice]
304
+ mock_response.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15)
305
+
306
+ mock_client = AsyncMock()
307
+ mock_client.chat = AsyncMock(return_value=mock_response)
308
+ mock_get_client.return_value = mock_client
309
+
310
+ handler = DetectNodeHandler()
311
+ ctx = FlowContext(variables={"chosen_model": "qwen2.5-72b", "temp": 0.1})
312
+ node = {
313
+ "id": "detect2",
314
+ "type": "detect",
315
+ "config": {
316
+ "options": ["qa", "chitchat"],
317
+ "inputs": {
318
+ "query": "hello",
319
+ "model": "$v.chosen_model",
320
+ "temperature": "$v.temp",
321
+ },
322
+ },
323
+ }
324
+ result = asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
325
+ call_args = mock_client.chat.call_args
326
+ assert call_args.kwargs["model"] == "qwen2.5-72b"
327
+ assert call_args.kwargs["temperature"] == 0.1
328
+ assert result == {"choice": "qa"}
329
+
260
330
 
261
331
  # ── LLMChatNodeHandler ──
262
332
 
@@ -344,6 +414,47 @@ class TestLLMChatNodeHandler:
344
414
  assert len(system_msg) == 1
345
415
  assert system_msg[0]["content"] == "You speak Chinese"
346
416
 
417
+ @patch("agstack.llm.flow.nodes.llm_chat_node.get_llm_client")
418
+ def test_dynamic_model_temperature_max_tokens(self, mock_get_client):
419
+ from agstack.llm.flow.nodes.llm_chat_node import LLMChatNodeHandler
420
+
421
+ mock_response = MagicMock()
422
+ mock_choice = MagicMock()
423
+ mock_choice.message.content = "response"
424
+ mock_response.choices = [mock_choice]
425
+ mock_response.usage = MagicMock(prompt_tokens=5, completion_tokens=3, total_tokens=8)
426
+
427
+ mock_client = AsyncMock()
428
+ mock_client.chat = AsyncMock(return_value=mock_response)
429
+ mock_get_client.return_value = mock_client
430
+
431
+ handler = LLMChatNodeHandler()
432
+ ctx = FlowContext(
433
+ variables={
434
+ "chosen_model": "qwen2.5-72b",
435
+ "temp": 0.2,
436
+ "max_tok": 512,
437
+ }
438
+ )
439
+ node = {
440
+ "id": "chat1",
441
+ "type": "llm_chat",
442
+ "config": {
443
+ "prompt": "Hello",
444
+ "model": "gpt-4o",
445
+ "inputs": {
446
+ "model": "$v.chosen_model",
447
+ "temperature": "$v.temp",
448
+ "max_tokens": "$v.max_tok",
449
+ },
450
+ },
451
+ }
452
+ asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
453
+ call_args = mock_client.chat.call_args
454
+ assert call_args.kwargs["model"] == "qwen2.5-72b"
455
+ assert call_args.kwargs["temperature"] == 0.2
456
+ assert call_args.kwargs["max_tokens"] == 512
457
+
347
458
 
348
459
  # ── Flow routing ──
349
460
 
@@ -433,3 +544,117 @@ class TestDataFlowIntegration:
433
544
  }
434
545
  result = asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
435
546
  assert result == {"result": 30}
547
+
548
+
549
+ class TestLLMRerankNodeHandler:
550
+ """LLM Rerank 节点动态参数测试"""
551
+
552
+ @patch("agstack.llm.flow.nodes.llm_rerank_node.get_llm_client")
553
+ def test_dynamic_model_and_top_n(self, mock_get_client):
554
+ from agstack.llm.flow.nodes.llm_rerank_node import LLMRerankNodeHandler
555
+
556
+ mock_client = AsyncMock()
557
+ mock_client.rerank = AsyncMock(return_value=[(0, 0.95, "doc A"), (1, 0.80, "doc B")])
558
+ mock_get_client.return_value = mock_client
559
+
560
+ handler = LLMRerankNodeHandler()
561
+ ctx = FlowContext(variables={"rerank_model": "bge-reranker-large", "topk": 2})
562
+ node = {
563
+ "id": "rerank1",
564
+ "type": "llm_rerank",
565
+ "config": {
566
+ "model": "bge-reranker-v2-m3",
567
+ "inputs": {
568
+ "query": "best python book",
569
+ "documents": ["doc A", "doc B", "doc C"],
570
+ "model": "$v.rerank_model",
571
+ "top_n": "$v.topk",
572
+ },
573
+ },
574
+ }
575
+ result = asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
576
+ call_args = mock_client.rerank.call_args
577
+ assert call_args.kwargs["model"] == "bge-reranker-large"
578
+ assert call_args.kwargs["top_n"] == 2
579
+ assert result == {
580
+ "results": [{"index": 0, "score": 0.95, "text": "doc A"}, {"index": 1, "score": 0.80, "text": "doc B"}]
581
+ }
582
+
583
+ @patch("agstack.llm.flow.nodes.llm_rerank_node.get_llm_client")
584
+ def test_static_fallback_still_works(self, mock_get_client):
585
+ from agstack.llm.flow.nodes.llm_rerank_node import LLMRerankNodeHandler
586
+
587
+ mock_client = AsyncMock()
588
+ mock_client.rerank = AsyncMock(return_value=[(0, 0.9, "doc A")])
589
+ mock_get_client.return_value = mock_client
590
+
591
+ handler = LLMRerankNodeHandler()
592
+ ctx = FlowContext()
593
+ node = {
594
+ "id": "rerank2",
595
+ "type": "llm_rerank",
596
+ "config": {
597
+ "model": "bge-reranker-v2-m3",
598
+ "top_n": 5,
599
+ "inputs": {
600
+ "query": "test",
601
+ "documents": ["doc A"],
602
+ },
603
+ },
604
+ }
605
+ asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
606
+ call_args = mock_client.rerank.call_args
607
+ assert call_args.kwargs["model"] == "bge-reranker-v2-m3"
608
+ assert call_args.kwargs["top_n"] == 5
609
+
610
+
611
+ class TestLLMEmbedNodeHandler:
612
+ """LLM Embed 节点动态参数测试"""
613
+
614
+ @patch("agstack.llm.flow.nodes.llm_embed_node.get_llm_client")
615
+ def test_dynamic_model(self, mock_get_client):
616
+ from agstack.llm.flow.nodes.llm_embed_node import LLMEmbedNodeHandler
617
+
618
+ mock_client = AsyncMock()
619
+ mock_client.embed = AsyncMock(return_value=[[0.1, 0.2, 0.3]])
620
+ mock_get_client.return_value = mock_client
621
+
622
+ handler = LLMEmbedNodeHandler()
623
+ ctx = FlowContext(variables={"embed_model": "text-embedding-3-large"})
624
+ node = {
625
+ "id": "embed1",
626
+ "type": "llm_embed",
627
+ "config": {
628
+ "model": "bge-m3",
629
+ "inputs": {
630
+ "texts": ["hello world"],
631
+ "model": "$v.embed_model",
632
+ },
633
+ },
634
+ }
635
+ result = asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
636
+ call_args = mock_client.embed.call_args
637
+ assert call_args.kwargs["model"] == "text-embedding-3-large"
638
+ assert result == {"embeddings": [[0.1, 0.2, 0.3]]}
639
+
640
+ @patch("agstack.llm.flow.nodes.llm_embed_node.get_llm_client")
641
+ def test_static_model_fallback(self, mock_get_client):
642
+ from agstack.llm.flow.nodes.llm_embed_node import LLMEmbedNodeHandler
643
+
644
+ mock_client = AsyncMock()
645
+ mock_client.embed = AsyncMock(return_value=[[0.1, 0.2]])
646
+ mock_get_client.return_value = mock_client
647
+
648
+ handler = LLMEmbedNodeHandler()
649
+ ctx = FlowContext()
650
+ node = {
651
+ "id": "embed2",
652
+ "type": "llm_embed",
653
+ "config": {
654
+ "model": "bge-m3",
655
+ "inputs": {"texts": ["hello"]},
656
+ },
657
+ }
658
+ asyncio.get_event_loop().run_until_complete(handler.execute(node, ctx))
659
+ call_args = mock_client.embed.call_args
660
+ assert call_args.kwargs["model"] == "bge-m3"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes