yamlgraph 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- examples/__init__.py +1 -0
- examples/codegen/__init__.py +5 -0
- examples/codegen/models/__init__.py +13 -0
- examples/codegen/models/schemas.py +76 -0
- examples/codegen/tests/__init__.py +1 -0
- examples/codegen/tests/test_ai_helpers.py +235 -0
- examples/codegen/tests/test_ast_analysis.py +174 -0
- examples/codegen/tests/test_code_analysis.py +134 -0
- examples/codegen/tests/test_code_context.py +301 -0
- examples/codegen/tests/test_code_nav.py +89 -0
- examples/codegen/tests/test_dependency_tools.py +119 -0
- examples/codegen/tests/test_example_tools.py +185 -0
- examples/codegen/tests/test_git_tools.py +112 -0
- examples/codegen/tests/test_impl_agent_schemas.py +193 -0
- examples/codegen/tests/test_impl_agent_v4_graph.py +94 -0
- examples/codegen/tests/test_jedi_analysis.py +226 -0
- examples/codegen/tests/test_meta_tools.py +250 -0
- examples/codegen/tests/test_plan_discovery_prompt.py +98 -0
- examples/codegen/tests/test_syntax_tools.py +85 -0
- examples/codegen/tests/test_synthesize_prompt.py +94 -0
- examples/codegen/tests/test_template_tools.py +244 -0
- examples/codegen/tools/__init__.py +80 -0
- examples/codegen/tools/ai_helpers.py +420 -0
- examples/codegen/tools/ast_analysis.py +92 -0
- examples/codegen/tools/code_context.py +180 -0
- examples/codegen/tools/code_nav.py +52 -0
- examples/codegen/tools/dependency_tools.py +120 -0
- examples/codegen/tools/example_tools.py +188 -0
- examples/codegen/tools/git_tools.py +151 -0
- examples/codegen/tools/impl_executor.py +614 -0
- examples/codegen/tools/jedi_analysis.py +311 -0
- examples/codegen/tools/meta_tools.py +202 -0
- examples/codegen/tools/syntax_tools.py +26 -0
- examples/codegen/tools/template_tools.py +356 -0
- examples/fastapi_interview.py +167 -0
- examples/npc/api/__init__.py +1 -0
- examples/npc/api/app.py +100 -0
- examples/npc/api/routes/__init__.py +5 -0
- examples/npc/api/routes/encounter.py +182 -0
- examples/npc/api/session.py +330 -0
- examples/npc/demo.py +387 -0
- examples/npc/nodes/__init__.py +5 -0
- examples/npc/nodes/image_node.py +92 -0
- examples/npc/run_encounter.py +230 -0
- examples/shared/__init__.py +0 -0
- examples/shared/replicate_tool.py +238 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +12 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +49 -0
- examples/storyboard/retry_images.py +118 -0
- scripts/demo_async_executor.py +212 -0
- scripts/demo_interview_e2e.py +200 -0
- scripts/demo_streaming.py +140 -0
- scripts/run_interview_demo.py +94 -0
- scripts/test_interrupt_fix.py +26 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_colocated_prompts.py +139 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +283 -0
- tests/integration/test_npc_api/__init__.py +1 -0
- tests/integration/test_npc_api/test_routes.py +357 -0
- tests/integration/test_npc_api/test_session.py +216 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/integration/test_subgraph_integration.py +295 -0
- tests/integration/test_subgraph_interrupt.py +106 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +355 -0
- tests/unit/test_async_executor.py +346 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_checkpointer_factory.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +276 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +172 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +149 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_feature_brainstorm.py +194 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_linter.py +627 -0
- tests/unit/test_graph_loader.py +357 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_interrupt_node.py +182 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_json_extract.py +134 -0
- tests/unit/test_langsmith.py +600 -0
- tests/unit/test_langsmith_tools.py +204 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +348 -0
- tests/unit/test_passthrough_node.py +126 -0
- tests/unit/test_prompts.py +324 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_streaming.py +307 -0
- tests/unit/test_subgraph.py +596 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_call_integration.py +164 -0
- tests/unit/test_tool_call_node.py +178 -0
- tests/unit/test_tool_nodes.py +129 -0
- tests/unit/test_websearch.py +234 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +159 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +231 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +541 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +70 -0
- yamlgraph/error_handlers.py +227 -0
- yamlgraph/executor.py +290 -0
- yamlgraph/executor_async.py +288 -0
- yamlgraph/graph_loader.py +451 -0
- yamlgraph/map_compiler.py +150 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +181 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +768 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +240 -0
- yamlgraph/storage/__init__.py +20 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/checkpointer_factory.py +123 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +320 -0
- yamlgraph/tools/graph_linter.py +388 -0
- yamlgraph/tools/langsmith_tools.py +125 -0
- yamlgraph/tools/nodes.py +126 -0
- yamlgraph/tools/python_tool.py +179 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/tools/websearch.py +242 -0
- yamlgraph/utils/__init__.py +48 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +245 -0
- yamlgraph/utils/json_extract.py +104 -0
- yamlgraph/utils/langsmith.py +416 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +104 -0
- yamlgraph/utils/prompts.py +171 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.3.9.dist-info/METADATA +1105 -0
- yamlgraph-0.3.9.dist-info/RECORD +185 -0
- yamlgraph-0.3.9.dist-info/WHEEL +5 -0
- yamlgraph-0.3.9.dist-info/entry_points.txt +2 -0
- yamlgraph-0.3.9.dist-info/licenses/LICENSE +33 -0
- yamlgraph-0.3.9.dist-info/top_level.txt +4 -0
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
"""Tests for LangSmith tool wrappers.
|
|
2
|
+
|
|
3
|
+
Tests for the agent-facing tool wrappers in yamlgraph.tools.langsmith_tools.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from unittest.mock import patch
|
|
7
|
+
|
|
8
|
+
# =============================================================================
|
|
9
|
+
# get_run_details_tool tests
|
|
10
|
+
# =============================================================================
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TestGetRunDetailsTool:
|
|
14
|
+
"""Tests for get_run_details_tool()."""
|
|
15
|
+
|
|
16
|
+
def test_returns_success_with_run_details(self):
|
|
17
|
+
"""Returns run details with success flag."""
|
|
18
|
+
from yamlgraph.tools.langsmith_tools import get_run_details_tool
|
|
19
|
+
|
|
20
|
+
mock_details = {
|
|
21
|
+
"id": "run-123",
|
|
22
|
+
"name": "test_pipeline",
|
|
23
|
+
"status": "success",
|
|
24
|
+
"error": None,
|
|
25
|
+
"start_time": "2026-01-18T10:00:00",
|
|
26
|
+
"end_time": "2026-01-18T10:01:00",
|
|
27
|
+
"inputs": {"topic": "AI"},
|
|
28
|
+
"outputs": {"result": "done"},
|
|
29
|
+
"run_type": "chain",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
with patch(
|
|
33
|
+
"yamlgraph.tools.langsmith_tools.get_run_details",
|
|
34
|
+
return_value=mock_details,
|
|
35
|
+
):
|
|
36
|
+
result = get_run_details_tool("run-123")
|
|
37
|
+
|
|
38
|
+
assert result["success"] is True
|
|
39
|
+
assert result["id"] == "run-123"
|
|
40
|
+
assert result["status"] == "success"
|
|
41
|
+
|
|
42
|
+
def test_returns_error_when_no_details(self):
|
|
43
|
+
"""Returns error dict when details not available."""
|
|
44
|
+
from yamlgraph.tools.langsmith_tools import get_run_details_tool
|
|
45
|
+
|
|
46
|
+
with patch(
|
|
47
|
+
"yamlgraph.tools.langsmith_tools.get_run_details",
|
|
48
|
+
return_value=None,
|
|
49
|
+
):
|
|
50
|
+
result = get_run_details_tool("run-123")
|
|
51
|
+
|
|
52
|
+
assert result["success"] is False
|
|
53
|
+
assert "error" in result
|
|
54
|
+
|
|
55
|
+
def test_passes_run_id_to_underlying_function(self):
|
|
56
|
+
"""Passes run_id parameter correctly."""
|
|
57
|
+
from yamlgraph.tools.langsmith_tools import get_run_details_tool
|
|
58
|
+
|
|
59
|
+
with patch(
|
|
60
|
+
"yamlgraph.tools.langsmith_tools.get_run_details",
|
|
61
|
+
return_value={"id": "test"},
|
|
62
|
+
) as mock_get:
|
|
63
|
+
get_run_details_tool("specific-run-id")
|
|
64
|
+
mock_get.assert_called_once_with("specific-run-id")
|
|
65
|
+
|
|
66
|
+
def test_uses_latest_when_no_run_id(self):
|
|
67
|
+
"""Uses latest run when no ID provided."""
|
|
68
|
+
from yamlgraph.tools.langsmith_tools import get_run_details_tool
|
|
69
|
+
|
|
70
|
+
with patch(
|
|
71
|
+
"yamlgraph.tools.langsmith_tools.get_run_details",
|
|
72
|
+
return_value={"id": "latest"},
|
|
73
|
+
) as mock_get:
|
|
74
|
+
get_run_details_tool()
|
|
75
|
+
mock_get.assert_called_once_with(None)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
# =============================================================================
|
|
79
|
+
# get_run_errors_tool tests
|
|
80
|
+
# =============================================================================
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class TestGetRunErrorsTool:
|
|
84
|
+
"""Tests for get_run_errors_tool()."""
|
|
85
|
+
|
|
86
|
+
def test_returns_errors_with_count(self):
|
|
87
|
+
"""Returns errors with count and has_errors flag."""
|
|
88
|
+
from yamlgraph.tools.langsmith_tools import get_run_errors_tool
|
|
89
|
+
|
|
90
|
+
mock_errors = [
|
|
91
|
+
{"node": "generate", "error": "API failed", "run_type": "llm"},
|
|
92
|
+
{"node": "analyze", "error": "Timeout", "run_type": "llm"},
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
with patch(
|
|
96
|
+
"yamlgraph.tools.langsmith_tools.get_run_errors",
|
|
97
|
+
return_value=mock_errors,
|
|
98
|
+
):
|
|
99
|
+
result = get_run_errors_tool("run-123")
|
|
100
|
+
|
|
101
|
+
assert result["success"] is True
|
|
102
|
+
assert result["error_count"] == 2
|
|
103
|
+
assert result["has_errors"] is True
|
|
104
|
+
assert len(result["errors"]) == 2
|
|
105
|
+
|
|
106
|
+
def test_returns_empty_when_no_errors(self):
|
|
107
|
+
"""Returns empty list when no errors."""
|
|
108
|
+
from yamlgraph.tools.langsmith_tools import get_run_errors_tool
|
|
109
|
+
|
|
110
|
+
with patch(
|
|
111
|
+
"yamlgraph.tools.langsmith_tools.get_run_errors",
|
|
112
|
+
return_value=[],
|
|
113
|
+
):
|
|
114
|
+
result = get_run_errors_tool("run-123")
|
|
115
|
+
|
|
116
|
+
assert result["success"] is True
|
|
117
|
+
assert result["error_count"] == 0
|
|
118
|
+
assert result["has_errors"] is False
|
|
119
|
+
assert result["errors"] == []
|
|
120
|
+
|
|
121
|
+
def test_passes_run_id_to_underlying_function(self):
|
|
122
|
+
"""Passes run_id parameter correctly."""
|
|
123
|
+
from yamlgraph.tools.langsmith_tools import get_run_errors_tool
|
|
124
|
+
|
|
125
|
+
with patch(
|
|
126
|
+
"yamlgraph.tools.langsmith_tools.get_run_errors",
|
|
127
|
+
return_value=[],
|
|
128
|
+
) as mock_get:
|
|
129
|
+
get_run_errors_tool("specific-run-id")
|
|
130
|
+
mock_get.assert_called_once_with("specific-run-id")
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# =============================================================================
|
|
134
|
+
# get_failed_runs_tool tests
|
|
135
|
+
# =============================================================================
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class TestGetFailedRunsTool:
|
|
139
|
+
"""Tests for get_failed_runs_tool()."""
|
|
140
|
+
|
|
141
|
+
def test_returns_failed_runs_with_count(self):
|
|
142
|
+
"""Returns failed runs with count."""
|
|
143
|
+
from yamlgraph.tools.langsmith_tools import get_failed_runs_tool
|
|
144
|
+
|
|
145
|
+
mock_runs = [
|
|
146
|
+
{
|
|
147
|
+
"id": "run-1",
|
|
148
|
+
"name": "pipe1",
|
|
149
|
+
"error": "Err1",
|
|
150
|
+
"start_time": "2026-01-18T10:00:00",
|
|
151
|
+
},
|
|
152
|
+
{
|
|
153
|
+
"id": "run-2",
|
|
154
|
+
"name": "pipe2",
|
|
155
|
+
"error": "Err2",
|
|
156
|
+
"start_time": "2026-01-18T11:00:00",
|
|
157
|
+
},
|
|
158
|
+
]
|
|
159
|
+
|
|
160
|
+
with patch(
|
|
161
|
+
"yamlgraph.tools.langsmith_tools.get_failed_runs",
|
|
162
|
+
return_value=mock_runs,
|
|
163
|
+
):
|
|
164
|
+
result = get_failed_runs_tool(limit=5)
|
|
165
|
+
|
|
166
|
+
assert result["success"] is True
|
|
167
|
+
assert result["failed_count"] == 2
|
|
168
|
+
assert len(result["runs"]) == 2
|
|
169
|
+
|
|
170
|
+
def test_returns_empty_when_no_failures(self):
|
|
171
|
+
"""Returns empty list when no failures."""
|
|
172
|
+
from yamlgraph.tools.langsmith_tools import get_failed_runs_tool
|
|
173
|
+
|
|
174
|
+
with patch(
|
|
175
|
+
"yamlgraph.tools.langsmith_tools.get_failed_runs",
|
|
176
|
+
return_value=[],
|
|
177
|
+
):
|
|
178
|
+
result = get_failed_runs_tool()
|
|
179
|
+
|
|
180
|
+
assert result["success"] is True
|
|
181
|
+
assert result["failed_count"] == 0
|
|
182
|
+
assert result["runs"] == []
|
|
183
|
+
|
|
184
|
+
def test_passes_parameters_correctly(self):
|
|
185
|
+
"""Passes limit and project_name to underlying function."""
|
|
186
|
+
from yamlgraph.tools.langsmith_tools import get_failed_runs_tool
|
|
187
|
+
|
|
188
|
+
with patch(
|
|
189
|
+
"yamlgraph.tools.langsmith_tools.get_failed_runs",
|
|
190
|
+
return_value=[],
|
|
191
|
+
) as mock_get:
|
|
192
|
+
get_failed_runs_tool(limit=5, project_name="custom-project")
|
|
193
|
+
mock_get.assert_called_once_with(project_name="custom-project", limit=5)
|
|
194
|
+
|
|
195
|
+
def test_uses_defaults_when_no_params(self):
|
|
196
|
+
"""Uses default limit when not provided."""
|
|
197
|
+
from yamlgraph.tools.langsmith_tools import get_failed_runs_tool
|
|
198
|
+
|
|
199
|
+
with patch(
|
|
200
|
+
"yamlgraph.tools.langsmith_tools.get_failed_runs",
|
|
201
|
+
return_value=[],
|
|
202
|
+
) as mock_get:
|
|
203
|
+
get_failed_runs_tool()
|
|
204
|
+
mock_get.assert_called_once_with(project_name=None, limit=10)
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Unit tests for LLM factory module."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from unittest.mock import patch
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
from langchain_anthropic import ChatAnthropic
|
|
8
|
+
|
|
9
|
+
from yamlgraph.utils.llm_factory import clear_cache, create_llm
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestCreateLLM:
|
|
13
|
+
"""Test the create_llm factory function."""
|
|
14
|
+
|
|
15
|
+
def setup_method(self):
|
|
16
|
+
"""Clear cache and environment before each test."""
|
|
17
|
+
clear_cache()
|
|
18
|
+
|
|
19
|
+
def test_default_provider_is_anthropic(self):
|
|
20
|
+
"""Should use Anthropic by default."""
|
|
21
|
+
# Clear PROVIDER from environment to ensure default behavior
|
|
22
|
+
with patch.dict(os.environ, {"PROVIDER": ""}, clear=False):
|
|
23
|
+
llm = create_llm(temperature=0.7)
|
|
24
|
+
assert isinstance(llm, ChatAnthropic)
|
|
25
|
+
assert llm.temperature == 0.7
|
|
26
|
+
|
|
27
|
+
def test_explicit_anthropic_provider(self):
|
|
28
|
+
"""Should create Anthropic LLM when provider='anthropic'."""
|
|
29
|
+
llm = create_llm(provider="anthropic", temperature=0.5)
|
|
30
|
+
assert isinstance(llm, ChatAnthropic)
|
|
31
|
+
assert llm.temperature == 0.5
|
|
32
|
+
|
|
33
|
+
def test_mistral_provider(self):
|
|
34
|
+
"""Should create Mistral LLM when provider='mistral'."""
|
|
35
|
+
with patch.dict(os.environ, {"MISTRAL_API_KEY": "test-key"}):
|
|
36
|
+
llm = create_llm(provider="mistral", temperature=0.8)
|
|
37
|
+
# Check it's the right class (will import on first call)
|
|
38
|
+
assert llm.__class__.__name__ == "ChatMistralAI"
|
|
39
|
+
assert llm.temperature == 0.8
|
|
40
|
+
|
|
41
|
+
def test_openai_provider(self):
|
|
42
|
+
"""Should create OpenAI LLM when provider='openai'."""
|
|
43
|
+
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
44
|
+
llm = create_llm(provider="openai", temperature=0.6)
|
|
45
|
+
assert llm.__class__.__name__ == "ChatOpenAI"
|
|
46
|
+
assert llm.temperature == 0.6
|
|
47
|
+
|
|
48
|
+
def test_provider_from_environment(self):
|
|
49
|
+
"""Should use PROVIDER env var when no provider specified."""
|
|
50
|
+
with patch.dict(
|
|
51
|
+
os.environ, {"PROVIDER": "mistral", "MISTRAL_API_KEY": "test-key"}
|
|
52
|
+
):
|
|
53
|
+
llm = create_llm(temperature=0.7)
|
|
54
|
+
assert llm.__class__.__name__ == "ChatMistralAI"
|
|
55
|
+
|
|
56
|
+
def test_custom_model(self):
|
|
57
|
+
"""Should use custom model when specified."""
|
|
58
|
+
with patch.dict(os.environ, {"PROVIDER": ""}, clear=False):
|
|
59
|
+
llm = create_llm(model="claude-opus-4", temperature=0.5)
|
|
60
|
+
assert isinstance(llm, ChatAnthropic)
|
|
61
|
+
assert llm.model == "claude-opus-4"
|
|
62
|
+
|
|
63
|
+
def test_model_override_parameter(self):
|
|
64
|
+
"""Should prefer model parameter over default."""
|
|
65
|
+
llm = create_llm(provider="anthropic", model="claude-sonnet-4", temperature=0.7)
|
|
66
|
+
assert llm.model == "claude-sonnet-4"
|
|
67
|
+
|
|
68
|
+
def test_default_models(self):
|
|
69
|
+
"""Should use correct default models for each provider."""
|
|
70
|
+
# Anthropic default
|
|
71
|
+
llm_anthropic = create_llm(provider="anthropic", temperature=0.7)
|
|
72
|
+
assert llm_anthropic.model == "claude-haiku-4-5"
|
|
73
|
+
|
|
74
|
+
# Mistral default
|
|
75
|
+
with patch.dict(os.environ, {"MISTRAL_API_KEY": "test-key"}):
|
|
76
|
+
llm_mistral = create_llm(provider="mistral", temperature=0.7)
|
|
77
|
+
assert llm_mistral.model == "mistral-large-latest"
|
|
78
|
+
|
|
79
|
+
# OpenAI default (uses model_name attribute)
|
|
80
|
+
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
81
|
+
llm_openai = create_llm(provider="openai", temperature=0.7)
|
|
82
|
+
assert llm_openai.model_name == "gpt-4o"
|
|
83
|
+
|
|
84
|
+
def test_invalid_provider(self):
|
|
85
|
+
"""Should raise error for invalid provider."""
|
|
86
|
+
with pytest.raises((ValueError, KeyError)):
|
|
87
|
+
create_llm(provider="invalid-provider", temperature=0.7)
|
|
88
|
+
|
|
89
|
+
def test_caching(self):
|
|
90
|
+
"""Should cache LLM instances for same parameters."""
|
|
91
|
+
llm1 = create_llm(provider="anthropic", temperature=0.7)
|
|
92
|
+
llm2 = create_llm(provider="anthropic", temperature=0.7)
|
|
93
|
+
assert llm1 is llm2
|
|
94
|
+
|
|
95
|
+
# Different temperature = different instance
|
|
96
|
+
llm3 = create_llm(provider="anthropic", temperature=0.5)
|
|
97
|
+
assert llm1 is not llm3
|
|
98
|
+
|
|
99
|
+
def test_cache_key_includes_all_params(self):
|
|
100
|
+
"""Cache should differentiate on provider, model, temperature."""
|
|
101
|
+
llm1 = create_llm(
|
|
102
|
+
provider="anthropic", model="claude-haiku-4-5", temperature=0.7
|
|
103
|
+
)
|
|
104
|
+
llm2 = create_llm(provider="anthropic", model="claude-opus-4", temperature=0.7)
|
|
105
|
+
assert llm1 is not llm2
|
|
106
|
+
|
|
107
|
+
with patch.dict(os.environ, {"MISTRAL_API_KEY": "test-key"}):
|
|
108
|
+
llm3 = create_llm(provider="mistral", temperature=0.7)
|
|
109
|
+
assert llm1 is not llm3
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
"""Unit tests for async LLM factory module."""
|
|
2
|
+
|
|
3
|
+
from unittest.mock import MagicMock, patch
|
|
4
|
+
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from yamlgraph.utils.llm_factory_async import (
|
|
8
|
+
create_llm_async,
|
|
9
|
+
get_executor,
|
|
10
|
+
invoke_async,
|
|
11
|
+
shutdown_executor,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TestGetExecutor:
|
|
16
|
+
"""Tests for get_executor function."""
|
|
17
|
+
|
|
18
|
+
def teardown_method(self):
|
|
19
|
+
"""Clean up executor after each test."""
|
|
20
|
+
shutdown_executor()
|
|
21
|
+
|
|
22
|
+
def test_creates_executor(self):
|
|
23
|
+
"""Should create a ThreadPoolExecutor."""
|
|
24
|
+
executor = get_executor()
|
|
25
|
+
assert executor is not None
|
|
26
|
+
|
|
27
|
+
def test_returns_same_executor(self):
|
|
28
|
+
"""Should return the same executor on subsequent calls."""
|
|
29
|
+
executor1 = get_executor()
|
|
30
|
+
executor2 = get_executor()
|
|
31
|
+
assert executor1 is executor2
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class TestShutdownExecutor:
|
|
35
|
+
"""Tests for shutdown_executor function."""
|
|
36
|
+
|
|
37
|
+
def test_shutdown_cleans_up(self):
|
|
38
|
+
"""Shutdown should clean up executor."""
|
|
39
|
+
# Create an executor
|
|
40
|
+
executor1 = get_executor()
|
|
41
|
+
assert executor1 is not None
|
|
42
|
+
|
|
43
|
+
# Shutdown
|
|
44
|
+
shutdown_executor()
|
|
45
|
+
|
|
46
|
+
# Next call should create a new executor
|
|
47
|
+
executor2 = get_executor()
|
|
48
|
+
assert executor2 is not executor1
|
|
49
|
+
|
|
50
|
+
def test_shutdown_when_none(self):
|
|
51
|
+
"""Shutdown when no executor should not raise."""
|
|
52
|
+
shutdown_executor() # Ensure clean state
|
|
53
|
+
shutdown_executor() # Should not raise
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class TestCreateLLMAsync:
|
|
57
|
+
"""Tests for create_llm_async function."""
|
|
58
|
+
|
|
59
|
+
def teardown_method(self):
|
|
60
|
+
"""Clean up executor after each test."""
|
|
61
|
+
shutdown_executor()
|
|
62
|
+
|
|
63
|
+
@pytest.mark.asyncio
|
|
64
|
+
async def test_creates_llm(self):
|
|
65
|
+
"""Should create an LLM instance."""
|
|
66
|
+
llm = await create_llm_async(provider="anthropic", temperature=0.5)
|
|
67
|
+
assert llm is not None
|
|
68
|
+
assert llm.temperature == 0.5
|
|
69
|
+
|
|
70
|
+
@pytest.mark.asyncio
|
|
71
|
+
async def test_uses_default_provider(self):
|
|
72
|
+
"""Should use default provider when not specified."""
|
|
73
|
+
with patch.dict("os.environ", {"PROVIDER": ""}, clear=False):
|
|
74
|
+
llm = await create_llm_async(temperature=0.7)
|
|
75
|
+
# Default is anthropic
|
|
76
|
+
assert "anthropic" in llm.__class__.__name__.lower()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class TestInvokeAsync:
|
|
80
|
+
"""Tests for invoke_async function."""
|
|
81
|
+
|
|
82
|
+
def teardown_method(self):
|
|
83
|
+
"""Clean up executor after each test."""
|
|
84
|
+
shutdown_executor()
|
|
85
|
+
|
|
86
|
+
@pytest.mark.asyncio
|
|
87
|
+
async def test_invoke_returns_string(self):
|
|
88
|
+
"""Should return string content when no output model."""
|
|
89
|
+
mock_llm = MagicMock()
|
|
90
|
+
mock_response = MagicMock()
|
|
91
|
+
mock_response.content = "Hello, world!"
|
|
92
|
+
mock_llm.invoke.return_value = mock_response
|
|
93
|
+
|
|
94
|
+
messages = [MagicMock()]
|
|
95
|
+
result = await invoke_async(mock_llm, messages)
|
|
96
|
+
|
|
97
|
+
assert result == "Hello, world!"
|
|
98
|
+
mock_llm.invoke.assert_called_once_with(messages)
|
|
99
|
+
|
|
100
|
+
@pytest.mark.asyncio
|
|
101
|
+
async def test_invoke_with_output_model(self):
|
|
102
|
+
"""Should use structured output when model provided."""
|
|
103
|
+
from pydantic import BaseModel
|
|
104
|
+
|
|
105
|
+
class TestOutput(BaseModel):
|
|
106
|
+
value: str
|
|
107
|
+
|
|
108
|
+
mock_llm = MagicMock()
|
|
109
|
+
mock_structured_llm = MagicMock()
|
|
110
|
+
mock_llm.with_structured_output.return_value = mock_structured_llm
|
|
111
|
+
mock_structured_llm.invoke.return_value = TestOutput(value="test")
|
|
112
|
+
|
|
113
|
+
messages = [MagicMock()]
|
|
114
|
+
result = await invoke_async(mock_llm, messages, output_model=TestOutput)
|
|
115
|
+
|
|
116
|
+
assert isinstance(result, TestOutput)
|
|
117
|
+
assert result.value == "test"
|
|
118
|
+
mock_llm.with_structured_output.assert_called_once_with(TestOutput)
|