yamlgraph 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. examples/__init__.py +1 -0
  2. examples/codegen/__init__.py +5 -0
  3. examples/codegen/models/__init__.py +13 -0
  4. examples/codegen/models/schemas.py +76 -0
  5. examples/codegen/tests/__init__.py +1 -0
  6. examples/codegen/tests/test_ai_helpers.py +235 -0
  7. examples/codegen/tests/test_ast_analysis.py +174 -0
  8. examples/codegen/tests/test_code_analysis.py +134 -0
  9. examples/codegen/tests/test_code_context.py +301 -0
  10. examples/codegen/tests/test_code_nav.py +89 -0
  11. examples/codegen/tests/test_dependency_tools.py +119 -0
  12. examples/codegen/tests/test_example_tools.py +185 -0
  13. examples/codegen/tests/test_git_tools.py +112 -0
  14. examples/codegen/tests/test_impl_agent_schemas.py +193 -0
  15. examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
  16. examples/codegen/tests/test_jedi_analysis.py +226 -0
  17. examples/codegen/tests/test_meta_tools.py +250 -0
  18. examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
  19. examples/codegen/tests/test_syntax_tools.py +85 -0
  20. examples/codegen/tests/test_synthesize_prompt.py +94 -0
  21. examples/codegen/tests/test_template_tools.py +244 -0
  22. examples/codegen/tools/__init__.py +80 -0
  23. examples/codegen/tools/ai_helpers.py +420 -0
  24. examples/codegen/tools/ast_analysis.py +92 -0
  25. examples/codegen/tools/code_context.py +180 -0
  26. examples/codegen/tools/code_nav.py +52 -0
  27. examples/codegen/tools/dependency_tools.py +120 -0
  28. examples/codegen/tools/example_tools.py +188 -0
  29. examples/codegen/tools/git_tools.py +151 -0
  30. examples/codegen/tools/impl_executor.py +614 -0
  31. examples/codegen/tools/jedi_analysis.py +311 -0
  32. examples/codegen/tools/meta_tools.py +202 -0
  33. examples/codegen/tools/syntax_tools.py +26 -0
  34. examples/codegen/tools/template_tools.py +356 -0
  35. examples/fastapi_interview.py +167 -0
  36. examples/npc/api/__init__.py +1 -0
  37. examples/npc/api/app.py +100 -0
  38. examples/npc/api/routes/__init__.py +5 -0
  39. examples/npc/api/routes/encounter.py +182 -0
  40. examples/npc/api/session.py +330 -0
  41. examples/npc/demo.py +387 -0
  42. examples/npc/nodes/__init__.py +5 -0
  43. examples/npc/nodes/image_node.py +92 -0
  44. examples/npc/run_encounter.py +230 -0
  45. examples/shared/__init__.py +0 -0
  46. examples/shared/replicate_tool.py +238 -0
  47. examples/storyboard/__init__.py +1 -0
  48. examples/storyboard/generate_videos.py +335 -0
  49. examples/storyboard/nodes/__init__.py +12 -0
  50. examples/storyboard/nodes/animated_character_node.py +248 -0
  51. examples/storyboard/nodes/animated_image_node.py +138 -0
  52. examples/storyboard/nodes/character_node.py +162 -0
  53. examples/storyboard/nodes/image_node.py +118 -0
  54. examples/storyboard/nodes/replicate_tool.py +49 -0
  55. examples/storyboard/retry_images.py +118 -0
  56. scripts/demo_async_executor.py +212 -0
  57. scripts/demo_interview_e2e.py +200 -0
  58. scripts/demo_streaming.py +140 -0
  59. scripts/run_interview_demo.py +94 -0
  60. scripts/test_interrupt_fix.py +26 -0
  61. tests/__init__.py +1 -0
  62. tests/conftest.py +178 -0
  63. tests/integration/__init__.py +1 -0
  64. tests/integration/test_animated_storyboard.py +63 -0
  65. tests/integration/test_cli_commands.py +242 -0
  66. tests/integration/test_colocated_prompts.py +139 -0
  67. tests/integration/test_map_demo.py +50 -0
  68. tests/integration/test_memory_demo.py +283 -0
  69. tests/integration/test_npc_api/__init__.py +1 -0
  70. tests/integration/test_npc_api/test_routes.py +357 -0
  71. tests/integration/test_npc_api/test_session.py +216 -0
  72. tests/integration/test_pipeline_flow.py +105 -0
  73. tests/integration/test_providers.py +163 -0
  74. tests/integration/test_resume.py +75 -0
  75. tests/integration/test_subgraph_integration.py +295 -0
  76. tests/integration/test_subgraph_interrupt.py +106 -0
  77. tests/unit/__init__.py +1 -0
  78. tests/unit/test_agent_nodes.py +355 -0
  79. tests/unit/test_async_executor.py +346 -0
  80. tests/unit/test_checkpointer.py +212 -0
  81. tests/unit/test_checkpointer_factory.py +212 -0
  82. tests/unit/test_cli.py +121 -0
  83. tests/unit/test_cli_package.py +81 -0
  84. tests/unit/test_compile_graph_map.py +132 -0
  85. tests/unit/test_conditions_routing.py +253 -0
  86. tests/unit/test_config.py +93 -0
  87. tests/unit/test_conversation_memory.py +276 -0
  88. tests/unit/test_database.py +145 -0
  89. tests/unit/test_deprecation.py +104 -0
  90. tests/unit/test_executor.py +172 -0
  91. tests/unit/test_executor_async.py +179 -0
  92. tests/unit/test_export.py +149 -0
  93. tests/unit/test_expressions.py +178 -0
  94. tests/unit/test_feature_brainstorm.py +194 -0
  95. tests/unit/test_format_prompt.py +145 -0
  96. tests/unit/test_generic_report.py +200 -0
  97. tests/unit/test_graph_commands.py +327 -0
  98. tests/unit/test_graph_linter.py +627 -0
  99. tests/unit/test_graph_loader.py +357 -0
  100. tests/unit/test_graph_schema.py +193 -0
  101. tests/unit/test_inline_schema.py +151 -0
  102. tests/unit/test_interrupt_node.py +182 -0
  103. tests/unit/test_issues.py +164 -0
  104. tests/unit/test_jinja2_prompts.py +85 -0
  105. tests/unit/test_json_extract.py +134 -0
  106. tests/unit/test_langsmith.py +600 -0
  107. tests/unit/test_langsmith_tools.py +204 -0
  108. tests/unit/test_llm_factory.py +109 -0
  109. tests/unit/test_llm_factory_async.py +118 -0
  110. tests/unit/test_loops.py +403 -0
  111. tests/unit/test_map_node.py +144 -0
  112. tests/unit/test_no_backward_compat.py +56 -0
  113. tests/unit/test_node_factory.py +348 -0
  114. tests/unit/test_passthrough_node.py +126 -0
  115. tests/unit/test_prompts.py +324 -0
  116. tests/unit/test_python_nodes.py +198 -0
  117. tests/unit/test_reliability.py +298 -0
  118. tests/unit/test_result_export.py +234 -0
  119. tests/unit/test_router.py +296 -0
  120. tests/unit/test_sanitize.py +99 -0
  121. tests/unit/test_schema_loader.py +295 -0
  122. tests/unit/test_shell_tools.py +229 -0
  123. tests/unit/test_state_builder.py +331 -0
  124. tests/unit/test_state_builder_map.py +104 -0
  125. tests/unit/test_state_config.py +197 -0
  126. tests/unit/test_streaming.py +307 -0
  127. tests/unit/test_subgraph.py +596 -0
  128. tests/unit/test_template.py +190 -0
  129. tests/unit/test_tool_call_integration.py +164 -0
  130. tests/unit/test_tool_call_node.py +178 -0
  131. tests/unit/test_tool_nodes.py +129 -0
  132. tests/unit/test_websearch.py +234 -0
  133. yamlgraph/__init__.py +35 -0
  134. yamlgraph/builder.py +110 -0
  135. yamlgraph/cli/__init__.py +159 -0
  136. yamlgraph/cli/__main__.py +6 -0
  137. yamlgraph/cli/commands.py +231 -0
  138. yamlgraph/cli/deprecation.py +92 -0
  139. yamlgraph/cli/graph_commands.py +541 -0
  140. yamlgraph/cli/validators.py +37 -0
  141. yamlgraph/config.py +67 -0
  142. yamlgraph/constants.py +70 -0
  143. yamlgraph/error_handlers.py +227 -0
  144. yamlgraph/executor.py +290 -0
  145. yamlgraph/executor_async.py +288 -0
  146. yamlgraph/graph_loader.py +451 -0
  147. yamlgraph/map_compiler.py +150 -0
  148. yamlgraph/models/__init__.py +36 -0
  149. yamlgraph/models/graph_schema.py +181 -0
  150. yamlgraph/models/schemas.py +124 -0
  151. yamlgraph/models/state_builder.py +236 -0
  152. yamlgraph/node_factory.py +768 -0
  153. yamlgraph/routing.py +87 -0
  154. yamlgraph/schema_loader.py +240 -0
  155. yamlgraph/storage/__init__.py +20 -0
  156. yamlgraph/storage/checkpointer.py +72 -0
  157. yamlgraph/storage/checkpointer_factory.py +123 -0
  158. yamlgraph/storage/database.py +320 -0
  159. yamlgraph/storage/export.py +269 -0
  160. yamlgraph/tools/__init__.py +1 -0
  161. yamlgraph/tools/agent.py +320 -0
  162. yamlgraph/tools/graph_linter.py +388 -0
  163. yamlgraph/tools/langsmith_tools.py +125 -0
  164. yamlgraph/tools/nodes.py +126 -0
  165. yamlgraph/tools/python_tool.py +179 -0
  166. yamlgraph/tools/shell.py +205 -0
  167. yamlgraph/tools/websearch.py +242 -0
  168. yamlgraph/utils/__init__.py +48 -0
  169. yamlgraph/utils/conditions.py +157 -0
  170. yamlgraph/utils/expressions.py +245 -0
  171. yamlgraph/utils/json_extract.py +104 -0
  172. yamlgraph/utils/langsmith.py +416 -0
  173. yamlgraph/utils/llm_factory.py +118 -0
  174. yamlgraph/utils/llm_factory_async.py +105 -0
  175. yamlgraph/utils/logging.py +104 -0
  176. yamlgraph/utils/prompts.py +171 -0
  177. yamlgraph/utils/sanitize.py +98 -0
  178. yamlgraph/utils/template.py +102 -0
  179. yamlgraph/utils/validators.py +181 -0
  180. yamlgraph-0.3.9.dist-info/METADATA +1105 -0
  181. yamlgraph-0.3.9.dist-info/RECORD +185 -0
  182. yamlgraph-0.3.9.dist-info/WHEEL +5 -0
  183. yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
  184. yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
  185. yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
@@ -0,0 +1,204 @@
1
+ """Tests for LangSmith tool wrappers.
2
+
3
+ Tests for the agent-facing tool wrappers in yamlgraph.tools.langsmith_tools.
4
+ """
5
+
6
+ from unittest.mock import patch
7
+
8
+ # =============================================================================
9
+ # get_run_details_tool tests
10
+ # =============================================================================
11
+
12
+
13
+ class TestGetRunDetailsTool:
14
+ """Tests for get_run_details_tool()."""
15
+
16
+ def test_returns_success_with_run_details(self):
17
+ """Returns run details with success flag."""
18
+ from yamlgraph.tools.langsmith_tools import get_run_details_tool
19
+
20
+ mock_details = {
21
+ "id": "run-123",
22
+ "name": "test_pipeline",
23
+ "status": "success",
24
+ "error": None,
25
+ "start_time": "2026-01-18T10:00:00",
26
+ "end_time": "2026-01-18T10:01:00",
27
+ "inputs": {"topic": "AI"},
28
+ "outputs": {"result": "done"},
29
+ "run_type": "chain",
30
+ }
31
+
32
+ with patch(
33
+ "yamlgraph.tools.langsmith_tools.get_run_details",
34
+ return_value=mock_details,
35
+ ):
36
+ result = get_run_details_tool("run-123")
37
+
38
+ assert result["success"] is True
39
+ assert result["id"] == "run-123"
40
+ assert result["status"] == "success"
41
+
42
+ def test_returns_error_when_no_details(self):
43
+ """Returns error dict when details not available."""
44
+ from yamlgraph.tools.langsmith_tools import get_run_details_tool
45
+
46
+ with patch(
47
+ "yamlgraph.tools.langsmith_tools.get_run_details",
48
+ return_value=None,
49
+ ):
50
+ result = get_run_details_tool("run-123")
51
+
52
+ assert result["success"] is False
53
+ assert "error" in result
54
+
55
+ def test_passes_run_id_to_underlying_function(self):
56
+ """Passes run_id parameter correctly."""
57
+ from yamlgraph.tools.langsmith_tools import get_run_details_tool
58
+
59
+ with patch(
60
+ "yamlgraph.tools.langsmith_tools.get_run_details",
61
+ return_value={"id": "test"},
62
+ ) as mock_get:
63
+ get_run_details_tool("specific-run-id")
64
+ mock_get.assert_called_once_with("specific-run-id")
65
+
66
+ def test_uses_latest_when_no_run_id(self):
67
+ """Uses latest run when no ID provided."""
68
+ from yamlgraph.tools.langsmith_tools import get_run_details_tool
69
+
70
+ with patch(
71
+ "yamlgraph.tools.langsmith_tools.get_run_details",
72
+ return_value={"id": "latest"},
73
+ ) as mock_get:
74
+ get_run_details_tool()
75
+ mock_get.assert_called_once_with(None)
76
+
77
+
78
+ # =============================================================================
79
+ # get_run_errors_tool tests
80
+ # =============================================================================
81
+
82
+
83
+ class TestGetRunErrorsTool:
84
+ """Tests for get_run_errors_tool()."""
85
+
86
+ def test_returns_errors_with_count(self):
87
+ """Returns errors with count and has_errors flag."""
88
+ from yamlgraph.tools.langsmith_tools import get_run_errors_tool
89
+
90
+ mock_errors = [
91
+ {"node": "generate", "error": "API failed", "run_type": "llm"},
92
+ {"node": "analyze", "error": "Timeout", "run_type": "llm"},
93
+ ]
94
+
95
+ with patch(
96
+ "yamlgraph.tools.langsmith_tools.get_run_errors",
97
+ return_value=mock_errors,
98
+ ):
99
+ result = get_run_errors_tool("run-123")
100
+
101
+ assert result["success"] is True
102
+ assert result["error_count"] == 2
103
+ assert result["has_errors"] is True
104
+ assert len(result["errors"]) == 2
105
+
106
+ def test_returns_empty_when_no_errors(self):
107
+ """Returns empty list when no errors."""
108
+ from yamlgraph.tools.langsmith_tools import get_run_errors_tool
109
+
110
+ with patch(
111
+ "yamlgraph.tools.langsmith_tools.get_run_errors",
112
+ return_value=[],
113
+ ):
114
+ result = get_run_errors_tool("run-123")
115
+
116
+ assert result["success"] is True
117
+ assert result["error_count"] == 0
118
+ assert result["has_errors"] is False
119
+ assert result["errors"] == []
120
+
121
+ def test_passes_run_id_to_underlying_function(self):
122
+ """Passes run_id parameter correctly."""
123
+ from yamlgraph.tools.langsmith_tools import get_run_errors_tool
124
+
125
+ with patch(
126
+ "yamlgraph.tools.langsmith_tools.get_run_errors",
127
+ return_value=[],
128
+ ) as mock_get:
129
+ get_run_errors_tool("specific-run-id")
130
+ mock_get.assert_called_once_with("specific-run-id")
131
+
132
+
133
+ # =============================================================================
134
+ # get_failed_runs_tool tests
135
+ # =============================================================================
136
+
137
+
138
+ class TestGetFailedRunsTool:
139
+ """Tests for get_failed_runs_tool()."""
140
+
141
+ def test_returns_failed_runs_with_count(self):
142
+ """Returns failed runs with count."""
143
+ from yamlgraph.tools.langsmith_tools import get_failed_runs_tool
144
+
145
+ mock_runs = [
146
+ {
147
+ "id": "run-1",
148
+ "name": "pipe1",
149
+ "error": "Err1",
150
+ "start_time": "2026-01-18T10:00:00",
151
+ },
152
+ {
153
+ "id": "run-2",
154
+ "name": "pipe2",
155
+ "error": "Err2",
156
+ "start_time": "2026-01-18T11:00:00",
157
+ },
158
+ ]
159
+
160
+ with patch(
161
+ "yamlgraph.tools.langsmith_tools.get_failed_runs",
162
+ return_value=mock_runs,
163
+ ):
164
+ result = get_failed_runs_tool(limit=5)
165
+
166
+ assert result["success"] is True
167
+ assert result["failed_count"] == 2
168
+ assert len(result["runs"]) == 2
169
+
170
+ def test_returns_empty_when_no_failures(self):
171
+ """Returns empty list when no failures."""
172
+ from yamlgraph.tools.langsmith_tools import get_failed_runs_tool
173
+
174
+ with patch(
175
+ "yamlgraph.tools.langsmith_tools.get_failed_runs",
176
+ return_value=[],
177
+ ):
178
+ result = get_failed_runs_tool()
179
+
180
+ assert result["success"] is True
181
+ assert result["failed_count"] == 0
182
+ assert result["runs"] == []
183
+
184
+ def test_passes_parameters_correctly(self):
185
+ """Passes limit and project_name to underlying function."""
186
+ from yamlgraph.tools.langsmith_tools import get_failed_runs_tool
187
+
188
+ with patch(
189
+ "yamlgraph.tools.langsmith_tools.get_failed_runs",
190
+ return_value=[],
191
+ ) as mock_get:
192
+ get_failed_runs_tool(limit=5, project_name="custom-project")
193
+ mock_get.assert_called_once_with(project_name="custom-project", limit=5)
194
+
195
+ def test_uses_defaults_when_no_params(self):
196
+ """Uses default limit when not provided."""
197
+ from yamlgraph.tools.langsmith_tools import get_failed_runs_tool
198
+
199
+ with patch(
200
+ "yamlgraph.tools.langsmith_tools.get_failed_runs",
201
+ return_value=[],
202
+ ) as mock_get:
203
+ get_failed_runs_tool()
204
+ mock_get.assert_called_once_with(project_name=None, limit=10)
@@ -0,0 +1,109 @@
1
+ """Unit tests for LLM factory module."""
2
+
3
+ import os
4
+ from unittest.mock import patch
5
+
6
+ import pytest
7
+ from langchain_anthropic import ChatAnthropic
8
+
9
+ from yamlgraph.utils.llm_factory import clear_cache, create_llm
10
+
11
+
12
+ class TestCreateLLM:
13
+ """Test the create_llm factory function."""
14
+
15
+ def setup_method(self):
16
+ """Clear cache and environment before each test."""
17
+ clear_cache()
18
+
19
+ def test_default_provider_is_anthropic(self):
20
+ """Should use Anthropic by default."""
21
+ # Clear PROVIDER from environment to ensure default behavior
22
+ with patch.dict(os.environ, {"PROVIDER": ""}, clear=False):
23
+ llm = create_llm(temperature=0.7)
24
+ assert isinstance(llm, ChatAnthropic)
25
+ assert llm.temperature == 0.7
26
+
27
+ def test_explicit_anthropic_provider(self):
28
+ """Should create Anthropic LLM when provider='anthropic'."""
29
+ llm = create_llm(provider="anthropic", temperature=0.5)
30
+ assert isinstance(llm, ChatAnthropic)
31
+ assert llm.temperature == 0.5
32
+
33
+ def test_mistral_provider(self):
34
+ """Should create Mistral LLM when provider='mistral'."""
35
+ with patch.dict(os.environ, {"MISTRAL_API_KEY": "test-key"}):
36
+ llm = create_llm(provider="mistral", temperature=0.8)
37
+ # Check it's the right class (will import on first call)
38
+ assert llm.__class__.__name__ == "ChatMistralAI"
39
+ assert llm.temperature == 0.8
40
+
41
+ def test_openai_provider(self):
42
+ """Should create OpenAI LLM when provider='openai'."""
43
+ with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
44
+ llm = create_llm(provider="openai", temperature=0.6)
45
+ assert llm.__class__.__name__ == "ChatOpenAI"
46
+ assert llm.temperature == 0.6
47
+
48
+ def test_provider_from_environment(self):
49
+ """Should use PROVIDER env var when no provider specified."""
50
+ with patch.dict(
51
+ os.environ, {"PROVIDER": "mistral", "MISTRAL_API_KEY": "test-key"}
52
+ ):
53
+ llm = create_llm(temperature=0.7)
54
+ assert llm.__class__.__name__ == "ChatMistralAI"
55
+
56
+ def test_custom_model(self):
57
+ """Should use custom model when specified."""
58
+ with patch.dict(os.environ, {"PROVIDER": ""}, clear=False):
59
+ llm = create_llm(model="claude-opus-4", temperature=0.5)
60
+ assert isinstance(llm, ChatAnthropic)
61
+ assert llm.model == "claude-opus-4"
62
+
63
+ def test_model_override_parameter(self):
64
+ """Should prefer model parameter over default."""
65
+ llm = create_llm(provider="anthropic", model="claude-sonnet-4", temperature=0.7)
66
+ assert llm.model == "claude-sonnet-4"
67
+
68
+ def test_default_models(self):
69
+ """Should use correct default models for each provider."""
70
+ # Anthropic default
71
+ llm_anthropic = create_llm(provider="anthropic", temperature=0.7)
72
+ assert llm_anthropic.model == "claude-haiku-4-5"
73
+
74
+ # Mistral default
75
+ with patch.dict(os.environ, {"MISTRAL_API_KEY": "test-key"}):
76
+ llm_mistral = create_llm(provider="mistral", temperature=0.7)
77
+ assert llm_mistral.model == "mistral-large-latest"
78
+
79
+ # OpenAI default (uses model_name attribute)
80
+ with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
81
+ llm_openai = create_llm(provider="openai", temperature=0.7)
82
+ assert llm_openai.model_name == "gpt-4o"
83
+
84
+ def test_invalid_provider(self):
85
+ """Should raise error for invalid provider."""
86
+ with pytest.raises((ValueError, KeyError)):
87
+ create_llm(provider="invalid-provider", temperature=0.7)
88
+
89
+ def test_caching(self):
90
+ """Should cache LLM instances for same parameters."""
91
+ llm1 = create_llm(provider="anthropic", temperature=0.7)
92
+ llm2 = create_llm(provider="anthropic", temperature=0.7)
93
+ assert llm1 is llm2
94
+
95
+ # Different temperature = different instance
96
+ llm3 = create_llm(provider="anthropic", temperature=0.5)
97
+ assert llm1 is not llm3
98
+
99
+ def test_cache_key_includes_all_params(self):
100
+ """Cache should differentiate on provider, model, temperature."""
101
+ llm1 = create_llm(
102
+ provider="anthropic", model="claude-haiku-4-5", temperature=0.7
103
+ )
104
+ llm2 = create_llm(provider="anthropic", model="claude-opus-4", temperature=0.7)
105
+ assert llm1 is not llm2
106
+
107
+ with patch.dict(os.environ, {"MISTRAL_API_KEY": "test-key"}):
108
+ llm3 = create_llm(provider="mistral", temperature=0.7)
109
+ assert llm1 is not llm3
@@ -0,0 +1,118 @@
1
+ """Unit tests for async LLM factory module."""
2
+
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import pytest
6
+
7
+ from yamlgraph.utils.llm_factory_async import (
8
+ create_llm_async,
9
+ get_executor,
10
+ invoke_async,
11
+ shutdown_executor,
12
+ )
13
+
14
+
15
+ class TestGetExecutor:
16
+ """Tests for get_executor function."""
17
+
18
+ def teardown_method(self):
19
+ """Clean up executor after each test."""
20
+ shutdown_executor()
21
+
22
+ def test_creates_executor(self):
23
+ """Should create a ThreadPoolExecutor."""
24
+ executor = get_executor()
25
+ assert executor is not None
26
+
27
+ def test_returns_same_executor(self):
28
+ """Should return the same executor on subsequent calls."""
29
+ executor1 = get_executor()
30
+ executor2 = get_executor()
31
+ assert executor1 is executor2
32
+
33
+
34
+ class TestShutdownExecutor:
35
+ """Tests for shutdown_executor function."""
36
+
37
+ def test_shutdown_cleans_up(self):
38
+ """Shutdown should clean up executor."""
39
+ # Create an executor
40
+ executor1 = get_executor()
41
+ assert executor1 is not None
42
+
43
+ # Shutdown
44
+ shutdown_executor()
45
+
46
+ # Next call should create a new executor
47
+ executor2 = get_executor()
48
+ assert executor2 is not executor1
49
+
50
+ def test_shutdown_when_none(self):
51
+ """Shutdown when no executor should not raise."""
52
+ shutdown_executor() # Ensure clean state
53
+ shutdown_executor() # Should not raise
54
+
55
+
56
+ class TestCreateLLMAsync:
57
+ """Tests for create_llm_async function."""
58
+
59
+ def teardown_method(self):
60
+ """Clean up executor after each test."""
61
+ shutdown_executor()
62
+
63
+ @pytest.mark.asyncio
64
+ async def test_creates_llm(self):
65
+ """Should create an LLM instance."""
66
+ llm = await create_llm_async(provider="anthropic", temperature=0.5)
67
+ assert llm is not None
68
+ assert llm.temperature == 0.5
69
+
70
+ @pytest.mark.asyncio
71
+ async def test_uses_default_provider(self):
72
+ """Should use default provider when not specified."""
73
+ with patch.dict("os.environ", {"PROVIDER": ""}, clear=False):
74
+ llm = await create_llm_async(temperature=0.7)
75
+ # Default is anthropic
76
+ assert "anthropic" in llm.__class__.__name__.lower()
77
+
78
+
79
+ class TestInvokeAsync:
80
+ """Tests for invoke_async function."""
81
+
82
+ def teardown_method(self):
83
+ """Clean up executor after each test."""
84
+ shutdown_executor()
85
+
86
+ @pytest.mark.asyncio
87
+ async def test_invoke_returns_string(self):
88
+ """Should return string content when no output model."""
89
+ mock_llm = MagicMock()
90
+ mock_response = MagicMock()
91
+ mock_response.content = "Hello, world!"
92
+ mock_llm.invoke.return_value = mock_response
93
+
94
+ messages = [MagicMock()]
95
+ result = await invoke_async(mock_llm, messages)
96
+
97
+ assert result == "Hello, world!"
98
+ mock_llm.invoke.assert_called_once_with(messages)
99
+
100
+ @pytest.mark.asyncio
101
+ async def test_invoke_with_output_model(self):
102
+ """Should use structured output when model provided."""
103
+ from pydantic import BaseModel
104
+
105
+ class TestOutput(BaseModel):
106
+ value: str
107
+
108
+ mock_llm = MagicMock()
109
+ mock_structured_llm = MagicMock()
110
+ mock_llm.with_structured_output.return_value = mock_structured_llm
111
+ mock_structured_llm.invoke.return_value = TestOutput(value="test")
112
+
113
+ messages = [MagicMock()]
114
+ result = await invoke_async(mock_llm, messages, output_model=TestOutput)
115
+
116
+ assert isinstance(result, TestOutput)
117
+ assert result.value == "test"
118
+ mock_llm.with_structured_output.assert_called_once_with(TestOutput)