yamlgraph 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of yamlgraph might be problematic. Click here for more details.
- examples/__init__.py +1 -0
- examples/storyboard/__init__.py +1 -0
- examples/storyboard/generate_videos.py +335 -0
- examples/storyboard/nodes/__init__.py +10 -0
- examples/storyboard/nodes/animated_character_node.py +248 -0
- examples/storyboard/nodes/animated_image_node.py +138 -0
- examples/storyboard/nodes/character_node.py +162 -0
- examples/storyboard/nodes/image_node.py +118 -0
- examples/storyboard/nodes/replicate_tool.py +238 -0
- examples/storyboard/retry_images.py +118 -0
- tests/__init__.py +1 -0
- tests/conftest.py +178 -0
- tests/integration/__init__.py +1 -0
- tests/integration/test_animated_storyboard.py +63 -0
- tests/integration/test_cli_commands.py +242 -0
- tests/integration/test_map_demo.py +50 -0
- tests/integration/test_memory_demo.py +281 -0
- tests/integration/test_pipeline_flow.py +105 -0
- tests/integration/test_providers.py +163 -0
- tests/integration/test_resume.py +75 -0
- tests/unit/__init__.py +1 -0
- tests/unit/test_agent_nodes.py +200 -0
- tests/unit/test_checkpointer.py +212 -0
- tests/unit/test_cli.py +121 -0
- tests/unit/test_cli_package.py +81 -0
- tests/unit/test_compile_graph_map.py +132 -0
- tests/unit/test_conditions_routing.py +253 -0
- tests/unit/test_config.py +93 -0
- tests/unit/test_conversation_memory.py +270 -0
- tests/unit/test_database.py +145 -0
- tests/unit/test_deprecation.py +104 -0
- tests/unit/test_executor.py +60 -0
- tests/unit/test_executor_async.py +179 -0
- tests/unit/test_export.py +150 -0
- tests/unit/test_expressions.py +178 -0
- tests/unit/test_format_prompt.py +145 -0
- tests/unit/test_generic_report.py +200 -0
- tests/unit/test_graph_commands.py +327 -0
- tests/unit/test_graph_loader.py +299 -0
- tests/unit/test_graph_schema.py +193 -0
- tests/unit/test_inline_schema.py +151 -0
- tests/unit/test_issues.py +164 -0
- tests/unit/test_jinja2_prompts.py +85 -0
- tests/unit/test_langsmith.py +319 -0
- tests/unit/test_llm_factory.py +109 -0
- tests/unit/test_llm_factory_async.py +118 -0
- tests/unit/test_loops.py +403 -0
- tests/unit/test_map_node.py +144 -0
- tests/unit/test_no_backward_compat.py +56 -0
- tests/unit/test_node_factory.py +225 -0
- tests/unit/test_prompts.py +166 -0
- tests/unit/test_python_nodes.py +198 -0
- tests/unit/test_reliability.py +298 -0
- tests/unit/test_result_export.py +234 -0
- tests/unit/test_router.py +296 -0
- tests/unit/test_sanitize.py +99 -0
- tests/unit/test_schema_loader.py +295 -0
- tests/unit/test_shell_tools.py +229 -0
- tests/unit/test_state_builder.py +331 -0
- tests/unit/test_state_builder_map.py +104 -0
- tests/unit/test_state_config.py +197 -0
- tests/unit/test_template.py +190 -0
- tests/unit/test_tool_nodes.py +129 -0
- yamlgraph/__init__.py +35 -0
- yamlgraph/builder.py +110 -0
- yamlgraph/cli/__init__.py +139 -0
- yamlgraph/cli/__main__.py +6 -0
- yamlgraph/cli/commands.py +232 -0
- yamlgraph/cli/deprecation.py +92 -0
- yamlgraph/cli/graph_commands.py +382 -0
- yamlgraph/cli/validators.py +37 -0
- yamlgraph/config.py +67 -0
- yamlgraph/constants.py +66 -0
- yamlgraph/error_handlers.py +226 -0
- yamlgraph/executor.py +275 -0
- yamlgraph/executor_async.py +122 -0
- yamlgraph/graph_loader.py +337 -0
- yamlgraph/map_compiler.py +138 -0
- yamlgraph/models/__init__.py +36 -0
- yamlgraph/models/graph_schema.py +141 -0
- yamlgraph/models/schemas.py +124 -0
- yamlgraph/models/state_builder.py +236 -0
- yamlgraph/node_factory.py +240 -0
- yamlgraph/routing.py +87 -0
- yamlgraph/schema_loader.py +160 -0
- yamlgraph/storage/__init__.py +17 -0
- yamlgraph/storage/checkpointer.py +72 -0
- yamlgraph/storage/database.py +320 -0
- yamlgraph/storage/export.py +269 -0
- yamlgraph/tools/__init__.py +1 -0
- yamlgraph/tools/agent.py +235 -0
- yamlgraph/tools/nodes.py +124 -0
- yamlgraph/tools/python_tool.py +178 -0
- yamlgraph/tools/shell.py +205 -0
- yamlgraph/utils/__init__.py +47 -0
- yamlgraph/utils/conditions.py +157 -0
- yamlgraph/utils/expressions.py +111 -0
- yamlgraph/utils/langsmith.py +308 -0
- yamlgraph/utils/llm_factory.py +118 -0
- yamlgraph/utils/llm_factory_async.py +105 -0
- yamlgraph/utils/logging.py +127 -0
- yamlgraph/utils/prompts.py +116 -0
- yamlgraph/utils/sanitize.py +98 -0
- yamlgraph/utils/template.py +102 -0
- yamlgraph/utils/validators.py +181 -0
- yamlgraph-0.1.1.dist-info/METADATA +854 -0
- yamlgraph-0.1.1.dist-info/RECORD +111 -0
- yamlgraph-0.1.1.dist-info/WHEEL +5 -0
- yamlgraph-0.1.1.dist-info/entry_points.txt +2 -0
- yamlgraph-0.1.1.dist-info/licenses/LICENSE +21 -0
- yamlgraph-0.1.1.dist-info/top_level.txt +3 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Tests for issues that were identified and fixed.
|
|
2
|
+
|
|
3
|
+
These tests verify the fixes for issues documented in docs/open-issues.md.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from unittest.mock import patch
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
|
|
10
|
+
from tests.conftest import FixtureAnalysis, FixtureGeneratedContent
|
|
11
|
+
from yamlgraph.builder import build_resume_graph
|
|
12
|
+
from yamlgraph.graph_loader import load_graph_config
|
|
13
|
+
from yamlgraph.models import create_initial_state
|
|
14
|
+
|
|
15
|
+
# =============================================================================
|
|
16
|
+
# Issue 1: Resume Logic - FIXED: skip_if_exists behavior
|
|
17
|
+
# =============================================================================
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TestResumeStartFromParameter:
|
|
21
|
+
"""Issue 1: Resume should skip nodes whose output already exists."""
|
|
22
|
+
|
|
23
|
+
@patch("yamlgraph.node_factory.execute_prompt")
|
|
24
|
+
def test_resume_from_analyze_skips_generate(self, mock_execute):
|
|
25
|
+
"""When state has 'generated', generate node should be skipped.
|
|
26
|
+
|
|
27
|
+
Resume works via skip_if_exists: if output already in state, skip LLM call.
|
|
28
|
+
"""
|
|
29
|
+
# State with generated content already present
|
|
30
|
+
state = create_initial_state(topic="test", thread_id="issue1")
|
|
31
|
+
state["generated"] = FixtureGeneratedContent(
|
|
32
|
+
title="Already Generated",
|
|
33
|
+
content="This was generated in a previous run",
|
|
34
|
+
word_count=10,
|
|
35
|
+
tags=[],
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# Only mock analyze and summarize - generate should be skipped
|
|
39
|
+
mock_analysis = FixtureAnalysis(
|
|
40
|
+
summary="Analysis",
|
|
41
|
+
key_points=["Point"],
|
|
42
|
+
sentiment="neutral",
|
|
43
|
+
confidence=0.8,
|
|
44
|
+
)
|
|
45
|
+
mock_execute.side_effect = [mock_analysis, "Final summary"]
|
|
46
|
+
|
|
47
|
+
graph = build_resume_graph().compile()
|
|
48
|
+
result = graph.invoke(state)
|
|
49
|
+
|
|
50
|
+
# Expected: 2 calls (analyze, summarize) - generate skipped
|
|
51
|
+
assert mock_execute.call_count == 2, (
|
|
52
|
+
f"Expected 2 LLM calls (analyze, summarize), "
|
|
53
|
+
f"but got {mock_execute.call_count}. "
|
|
54
|
+
f"Generate should be skipped when 'generated' exists!"
|
|
55
|
+
)
|
|
56
|
+
# Original generated content should be preserved
|
|
57
|
+
assert result["generated"].title == "Already Generated"
|
|
58
|
+
|
|
59
|
+
@patch("yamlgraph.node_factory.execute_prompt")
|
|
60
|
+
def test_resume_from_summarize_skips_generate_and_analyze(self, mock_execute):
|
|
61
|
+
"""When state has 'generated' and 'analysis', only summarize runs."""
|
|
62
|
+
state = create_initial_state(topic="test", thread_id="issue1b")
|
|
63
|
+
state["generated"] = FixtureGeneratedContent(
|
|
64
|
+
title="Done",
|
|
65
|
+
content="Content",
|
|
66
|
+
word_count=5,
|
|
67
|
+
tags=[],
|
|
68
|
+
)
|
|
69
|
+
state["analysis"] = FixtureAnalysis(
|
|
70
|
+
summary="Done",
|
|
71
|
+
key_points=["Point"],
|
|
72
|
+
sentiment="positive",
|
|
73
|
+
confidence=0.9,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
mock_execute.return_value = "Final summary"
|
|
77
|
+
|
|
78
|
+
graph = build_resume_graph().compile()
|
|
79
|
+
result = graph.invoke(state)
|
|
80
|
+
|
|
81
|
+
# Expected: 1 call (summarize only)
|
|
82
|
+
assert mock_execute.call_count == 1, (
|
|
83
|
+
f"Expected 1 LLM call (summarize only), "
|
|
84
|
+
f"but got {mock_execute.call_count}. "
|
|
85
|
+
f"Generate and analyze should be skipped!"
|
|
86
|
+
)
|
|
87
|
+
# Original content should be preserved
|
|
88
|
+
assert result["generated"].title == "Done"
|
|
89
|
+
assert result["analysis"].summary == "Done"
|
|
90
|
+
|
|
91
|
+
def test_resume_preserves_existing_generated_content(self):
|
|
92
|
+
"""Resuming should NOT overwrite already-generated content."""
|
|
93
|
+
# Covered by test_resume_from_analyze_skips_generate
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# =============================================================================
|
|
98
|
+
# Issue 2: Conditions Block is Dead Config
|
|
99
|
+
# =============================================================================
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class TestConditionsFromYAML:
|
|
103
|
+
"""Issue 2: Conditions block was dead config - now uses expression routing."""
|
|
104
|
+
|
|
105
|
+
def test_conditions_block_not_in_schema(self):
|
|
106
|
+
"""GraphConfig no longer parses conditions block."""
|
|
107
|
+
from yamlgraph.config import DEFAULT_GRAPH
|
|
108
|
+
|
|
109
|
+
config = load_graph_config(DEFAULT_GRAPH)
|
|
110
|
+
|
|
111
|
+
# conditions attribute should not exist
|
|
112
|
+
assert not hasattr(
|
|
113
|
+
config, "conditions"
|
|
114
|
+
), "GraphConfig should not have 'conditions' attribute - it's dead config"
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# =============================================================================
|
|
118
|
+
# Issue 5: _entry_point hack
|
|
119
|
+
# =============================================================================
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class TestEntryPointHack:
|
|
123
|
+
"""Issue 5: Using private _entry_point is fragile."""
|
|
124
|
+
|
|
125
|
+
@pytest.fixture
|
|
126
|
+
def simple_yaml(self, tmp_path):
|
|
127
|
+
"""Minimal YAML for testing."""
|
|
128
|
+
yaml_content = """
|
|
129
|
+
version: "1.0"
|
|
130
|
+
name: test
|
|
131
|
+
nodes:
|
|
132
|
+
first:
|
|
133
|
+
type: llm
|
|
134
|
+
prompt: generate
|
|
135
|
+
output_model: yamlgraph.models.GenericReport
|
|
136
|
+
state_key: generated
|
|
137
|
+
edges:
|
|
138
|
+
- from: START
|
|
139
|
+
to: first
|
|
140
|
+
- from: first
|
|
141
|
+
to: END
|
|
142
|
+
"""
|
|
143
|
+
yaml_file = tmp_path / "test.yaml"
|
|
144
|
+
yaml_file.write_text(yaml_content)
|
|
145
|
+
return yaml_file
|
|
146
|
+
|
|
147
|
+
def test_entry_point_accessible_via_behavior(self, simple_yaml):
|
|
148
|
+
"""Entry point should be testable via graph behavior, not private attrs.
|
|
149
|
+
|
|
150
|
+
Currently graph_loader.py sets graph._entry_point for testing.
|
|
151
|
+
This test shows how to test entry point via behavior instead.
|
|
152
|
+
"""
|
|
153
|
+
from yamlgraph.graph_loader import load_and_compile
|
|
154
|
+
|
|
155
|
+
graph = load_and_compile(simple_yaml)
|
|
156
|
+
_ = graph.compile() # Verify it compiles
|
|
157
|
+
|
|
158
|
+
# Get the graph structure - this is the proper way
|
|
159
|
+
# The first node after START should be 'first'
|
|
160
|
+
nodes = list(graph.nodes.keys())
|
|
161
|
+
assert "first" in nodes
|
|
162
|
+
|
|
163
|
+
# We can also check by looking at edges from __start__
|
|
164
|
+
# But testing via invocation is more robust
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
"""Integration test for Jinja2 prompt templates."""
|
|
2
|
+
|
|
3
|
+
from yamlgraph.executor import format_prompt, load_prompt
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_jinja2_analyze_list_prompt():
|
|
7
|
+
"""Test the analyze_list prompt with Jinja2 features."""
|
|
8
|
+
prompt = load_prompt("analyze_list")
|
|
9
|
+
|
|
10
|
+
# Test data
|
|
11
|
+
variables = {
|
|
12
|
+
"items": [
|
|
13
|
+
{
|
|
14
|
+
"title": "Introduction to AI",
|
|
15
|
+
"topic": "Artificial Intelligence",
|
|
16
|
+
"word_count": 500,
|
|
17
|
+
"tags": ["AI", "machine learning", "technology"],
|
|
18
|
+
"content": "Artificial intelligence is transforming how we interact with technology...",
|
|
19
|
+
},
|
|
20
|
+
{
|
|
21
|
+
"title": "Machine Learning Basics",
|
|
22
|
+
"topic": "ML Fundamentals",
|
|
23
|
+
"word_count": 750,
|
|
24
|
+
"tags": ["ML", "algorithms", "data"],
|
|
25
|
+
"content": "Machine learning involves training models on data to make predictions...",
|
|
26
|
+
},
|
|
27
|
+
],
|
|
28
|
+
"min_confidence": 0.8,
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
# Format the template field
|
|
32
|
+
result = format_prompt(prompt["template"], variables)
|
|
33
|
+
|
|
34
|
+
# Verify Jinja2 features are working
|
|
35
|
+
assert "2 items" in result # {{ items|length }} filter
|
|
36
|
+
assert "1. Introduction to AI" in result # {{ loop.index }}
|
|
37
|
+
assert "2. Machine Learning Basics" in result
|
|
38
|
+
assert "**Tags**: AI, machine learning, technology" in result # join filter
|
|
39
|
+
assert "**Tags**: ML, algorithms, data" in result
|
|
40
|
+
assert "confidence >= 0.8" in result # conditional rendering
|
|
41
|
+
assert "**Content**:" in result # if/else conditional
|
|
42
|
+
|
|
43
|
+
# Verify loop counter
|
|
44
|
+
assert "### 1." in result
|
|
45
|
+
assert "### 2." in result
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_jinja2_prompt_with_empty_list():
|
|
49
|
+
"""Test analyze_list prompt with empty items."""
|
|
50
|
+
prompt = load_prompt("analyze_list")
|
|
51
|
+
|
|
52
|
+
variables = {"items": [], "min_confidence": None}
|
|
53
|
+
|
|
54
|
+
result = format_prompt(prompt["template"], variables)
|
|
55
|
+
|
|
56
|
+
# Should handle empty list gracefully
|
|
57
|
+
assert "0 items" in result
|
|
58
|
+
assert "### 1." not in result # No items to iterate
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_jinja2_prompt_without_optional_fields():
|
|
62
|
+
"""Test analyze_list prompt without optional fields."""
|
|
63
|
+
prompt = load_prompt("analyze_list")
|
|
64
|
+
|
|
65
|
+
variables = {
|
|
66
|
+
"items": [
|
|
67
|
+
{
|
|
68
|
+
"title": "Short Content",
|
|
69
|
+
"topic": "Brief",
|
|
70
|
+
"word_count": 100,
|
|
71
|
+
"tags": [], # Empty tags
|
|
72
|
+
"content": "Short content without tags",
|
|
73
|
+
},
|
|
74
|
+
],
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
result = format_prompt(prompt["template"], variables)
|
|
78
|
+
|
|
79
|
+
# Should handle missing/empty optional fields
|
|
80
|
+
assert "1 items" in result
|
|
81
|
+
assert "Short Content" in result
|
|
82
|
+
# Should not show tags section if empty
|
|
83
|
+
assert "**Tags**:" not in result or "**Tags**: \n" in result
|
|
84
|
+
# Should not show min_confidence note if not provided
|
|
85
|
+
assert "confidence >=" not in result
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
"""Unit tests for LangSmith utilities.
|
|
2
|
+
|
|
3
|
+
Tests for:
|
|
4
|
+
- share_run() - Create public share links
|
|
5
|
+
- read_run_shared_link() - Get existing share links
|
|
6
|
+
- get_client() - Client creation with env var handling
|
|
7
|
+
- is_tracing_enabled() - Tracing detection
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
from unittest.mock import MagicMock, patch
|
|
12
|
+
|
|
13
|
+
from yamlgraph.utils.langsmith import (
|
|
14
|
+
get_client,
|
|
15
|
+
get_latest_run_id,
|
|
16
|
+
get_project_name,
|
|
17
|
+
is_tracing_enabled,
|
|
18
|
+
read_run_shared_link,
|
|
19
|
+
share_run,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# =============================================================================
|
|
23
|
+
# is_tracing_enabled() tests
|
|
24
|
+
# =============================================================================
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TestIsTracingEnabled:
|
|
28
|
+
"""Tests for is_tracing_enabled()."""
|
|
29
|
+
|
|
30
|
+
def test_enabled_with_langchain_tracing_v2_true(self):
|
|
31
|
+
"""LANGCHAIN_TRACING_V2=true enables tracing."""
|
|
32
|
+
with patch.dict(os.environ, {"LANGCHAIN_TRACING_V2": "true"}, clear=False):
|
|
33
|
+
# Need to remove LANGSMITH_TRACING if set
|
|
34
|
+
env = dict(os.environ)
|
|
35
|
+
env.pop("LANGSMITH_TRACING", None)
|
|
36
|
+
with patch.dict(os.environ, env, clear=True):
|
|
37
|
+
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
|
38
|
+
assert is_tracing_enabled() is True
|
|
39
|
+
|
|
40
|
+
def test_enabled_with_langsmith_tracing_true(self):
|
|
41
|
+
"""LANGSMITH_TRACING=true enables tracing."""
|
|
42
|
+
with patch.dict(os.environ, {"LANGSMITH_TRACING": "true"}, clear=True):
|
|
43
|
+
assert is_tracing_enabled() is True
|
|
44
|
+
|
|
45
|
+
def test_disabled_when_no_env_vars(self):
|
|
46
|
+
"""No tracing vars means disabled."""
|
|
47
|
+
with patch.dict(os.environ, {}, clear=True):
|
|
48
|
+
assert is_tracing_enabled() is False
|
|
49
|
+
|
|
50
|
+
def test_disabled_with_false_value(self):
|
|
51
|
+
"""Explicit false value disables tracing."""
|
|
52
|
+
with patch.dict(os.environ, {"LANGCHAIN_TRACING_V2": "false"}, clear=True):
|
|
53
|
+
assert is_tracing_enabled() is False
|
|
54
|
+
|
|
55
|
+
def test_case_insensitive(self):
|
|
56
|
+
"""TRUE, True, true all work."""
|
|
57
|
+
with patch.dict(os.environ, {"LANGSMITH_TRACING": "TRUE"}, clear=True):
|
|
58
|
+
assert is_tracing_enabled() is True
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# =============================================================================
|
|
62
|
+
# get_project_name() tests
|
|
63
|
+
# =============================================================================
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class TestGetProjectName:
|
|
67
|
+
"""Tests for get_project_name()."""
|
|
68
|
+
|
|
69
|
+
def test_langchain_project(self):
|
|
70
|
+
"""Returns LANGCHAIN_PROJECT when set."""
|
|
71
|
+
with patch.dict(os.environ, {"LANGCHAIN_PROJECT": "my-project"}, clear=True):
|
|
72
|
+
assert get_project_name() == "my-project"
|
|
73
|
+
|
|
74
|
+
def test_langsmith_project(self):
|
|
75
|
+
"""Returns LANGSMITH_PROJECT when set."""
|
|
76
|
+
with patch.dict(os.environ, {"LANGSMITH_PROJECT": "other-project"}, clear=True):
|
|
77
|
+
assert get_project_name() == "other-project"
|
|
78
|
+
|
|
79
|
+
def test_langchain_takes_precedence(self):
|
|
80
|
+
"""LANGCHAIN_PROJECT takes precedence over LANGSMITH_PROJECT."""
|
|
81
|
+
with patch.dict(
|
|
82
|
+
os.environ,
|
|
83
|
+
{"LANGCHAIN_PROJECT": "first", "LANGSMITH_PROJECT": "second"},
|
|
84
|
+
clear=True,
|
|
85
|
+
):
|
|
86
|
+
assert get_project_name() == "first"
|
|
87
|
+
|
|
88
|
+
def test_default_value(self):
|
|
89
|
+
"""Returns default when no env vars."""
|
|
90
|
+
with patch.dict(os.environ, {}, clear=True):
|
|
91
|
+
assert get_project_name() == "yamlgraph"
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# =============================================================================
|
|
95
|
+
# get_client() tests
|
|
96
|
+
# =============================================================================
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class TestGetClient:
|
|
100
|
+
"""Tests for get_client()."""
|
|
101
|
+
|
|
102
|
+
def test_returns_none_without_api_key(self):
|
|
103
|
+
"""No API key means no client."""
|
|
104
|
+
with patch.dict(os.environ, {}, clear=True):
|
|
105
|
+
assert get_client() is None
|
|
106
|
+
|
|
107
|
+
def test_creates_client_with_langchain_key(self):
|
|
108
|
+
"""Creates client with LANGCHAIN_API_KEY."""
|
|
109
|
+
with patch.dict(
|
|
110
|
+
os.environ,
|
|
111
|
+
{"LANGCHAIN_API_KEY": "lsv2_test_key"},
|
|
112
|
+
clear=True,
|
|
113
|
+
):
|
|
114
|
+
with patch("langsmith.Client") as mock_client:
|
|
115
|
+
result = get_client()
|
|
116
|
+
mock_client.assert_called_once()
|
|
117
|
+
assert result is not None
|
|
118
|
+
|
|
119
|
+
def test_creates_client_with_langsmith_key(self):
|
|
120
|
+
"""Creates client with LANGSMITH_API_KEY."""
|
|
121
|
+
with patch.dict(
|
|
122
|
+
os.environ,
|
|
123
|
+
{"LANGSMITH_API_KEY": "lsv2_test_key"},
|
|
124
|
+
clear=True,
|
|
125
|
+
):
|
|
126
|
+
with patch("langsmith.Client") as mock_client:
|
|
127
|
+
result = get_client()
|
|
128
|
+
mock_client.assert_called_once()
|
|
129
|
+
assert result is not None
|
|
130
|
+
|
|
131
|
+
def test_uses_custom_endpoint(self):
|
|
132
|
+
"""Uses LANGSMITH_ENDPOINT if set."""
|
|
133
|
+
with patch.dict(
|
|
134
|
+
os.environ,
|
|
135
|
+
{
|
|
136
|
+
"LANGSMITH_API_KEY": "key",
|
|
137
|
+
"LANGSMITH_ENDPOINT": "https://eu.smith.langchain.com",
|
|
138
|
+
},
|
|
139
|
+
clear=True,
|
|
140
|
+
):
|
|
141
|
+
with patch("langsmith.Client") as mock_client:
|
|
142
|
+
get_client()
|
|
143
|
+
mock_client.assert_called_with(
|
|
144
|
+
api_url="https://eu.smith.langchain.com",
|
|
145
|
+
api_key="key",
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def test_returns_none_on_import_error(self):
|
|
149
|
+
"""Returns None if langsmith not installed."""
|
|
150
|
+
# Verify graceful handling when Client constructor fails
|
|
151
|
+
with patch.dict(os.environ, {"LANGSMITH_API_KEY": "key"}, clear=True):
|
|
152
|
+
with patch("langsmith.Client", side_effect=ImportError("No module")):
|
|
153
|
+
# Should catch ImportError and return None
|
|
154
|
+
result = get_client()
|
|
155
|
+
assert result is None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# =============================================================================
|
|
159
|
+
# share_run() tests
|
|
160
|
+
# =============================================================================
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class TestShareRun:
|
|
164
|
+
"""Tests for share_run()."""
|
|
165
|
+
|
|
166
|
+
def test_returns_none_when_no_client(self):
|
|
167
|
+
"""Returns None when client unavailable."""
|
|
168
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=None):
|
|
169
|
+
result = share_run("test-run-id")
|
|
170
|
+
assert result is None
|
|
171
|
+
|
|
172
|
+
def test_shares_provided_run_id(self):
|
|
173
|
+
"""Shares the provided run ID."""
|
|
174
|
+
mock_client = MagicMock()
|
|
175
|
+
mock_client.share_run.return_value = "https://smith.langchain.com/public/abc123"
|
|
176
|
+
|
|
177
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=mock_client):
|
|
178
|
+
result = share_run("my-run-id")
|
|
179
|
+
|
|
180
|
+
mock_client.share_run.assert_called_once_with("my-run-id")
|
|
181
|
+
assert result == "https://smith.langchain.com/public/abc123"
|
|
182
|
+
|
|
183
|
+
def test_uses_latest_run_when_no_id(self):
|
|
184
|
+
"""Gets latest run ID when not provided."""
|
|
185
|
+
mock_client = MagicMock()
|
|
186
|
+
mock_client.share_run.return_value = "https://share.url"
|
|
187
|
+
|
|
188
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=mock_client):
|
|
189
|
+
with patch(
|
|
190
|
+
"yamlgraph.utils.langsmith.get_latest_run_id",
|
|
191
|
+
return_value="latest-id",
|
|
192
|
+
):
|
|
193
|
+
result = share_run()
|
|
194
|
+
|
|
195
|
+
mock_client.share_run.assert_called_once_with("latest-id")
|
|
196
|
+
assert result == "https://share.url"
|
|
197
|
+
|
|
198
|
+
def test_returns_none_when_no_latest_run(self):
|
|
199
|
+
"""Returns None when no latest run found."""
|
|
200
|
+
mock_client = MagicMock()
|
|
201
|
+
|
|
202
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=mock_client):
|
|
203
|
+
with patch(
|
|
204
|
+
"yamlgraph.utils.langsmith.get_latest_run_id",
|
|
205
|
+
return_value=None,
|
|
206
|
+
):
|
|
207
|
+
result = share_run()
|
|
208
|
+
assert result is None
|
|
209
|
+
|
|
210
|
+
def test_handles_exception_gracefully(self):
|
|
211
|
+
"""Returns None on error (logs warning to stderr)."""
|
|
212
|
+
mock_client = MagicMock()
|
|
213
|
+
mock_client.share_run.side_effect = Exception("API error")
|
|
214
|
+
|
|
215
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=mock_client):
|
|
216
|
+
result = share_run("test-id")
|
|
217
|
+
assert result is None
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
# =============================================================================
|
|
221
|
+
# read_run_shared_link() tests
|
|
222
|
+
# =============================================================================
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class TestReadRunSharedLink:
|
|
226
|
+
"""Tests for read_run_shared_link()."""
|
|
227
|
+
|
|
228
|
+
def test_returns_none_when_no_client(self):
|
|
229
|
+
"""Returns None when client unavailable."""
|
|
230
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=None):
|
|
231
|
+
result = read_run_shared_link("test-run-id")
|
|
232
|
+
assert result is None
|
|
233
|
+
|
|
234
|
+
def test_returns_existing_link(self):
|
|
235
|
+
"""Returns existing share link."""
|
|
236
|
+
mock_client = MagicMock()
|
|
237
|
+
mock_client.read_run_shared_link.return_value = "https://existing.url"
|
|
238
|
+
|
|
239
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=mock_client):
|
|
240
|
+
result = read_run_shared_link("my-run-id")
|
|
241
|
+
|
|
242
|
+
mock_client.read_run_shared_link.assert_called_once_with("my-run-id")
|
|
243
|
+
assert result == "https://existing.url"
|
|
244
|
+
|
|
245
|
+
def test_returns_none_when_not_shared(self):
|
|
246
|
+
"""Returns None when run not shared (exception)."""
|
|
247
|
+
mock_client = MagicMock()
|
|
248
|
+
mock_client.read_run_shared_link.side_effect = Exception("Not found")
|
|
249
|
+
|
|
250
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=mock_client):
|
|
251
|
+
result = read_run_shared_link("test-id")
|
|
252
|
+
assert result is None
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
# =============================================================================
|
|
256
|
+
# get_latest_run_id() tests
|
|
257
|
+
# =============================================================================
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class TestGetLatestRunId:
|
|
261
|
+
"""Tests for get_latest_run_id()."""
|
|
262
|
+
|
|
263
|
+
def test_returns_none_when_no_client(self):
|
|
264
|
+
"""Returns None when client unavailable."""
|
|
265
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=None):
|
|
266
|
+
result = get_latest_run_id()
|
|
267
|
+
assert result is None
|
|
268
|
+
|
|
269
|
+
def test_returns_latest_run_id(self):
|
|
270
|
+
"""Returns ID of most recent run."""
|
|
271
|
+
mock_run = MagicMock()
|
|
272
|
+
mock_run.id = "abc-123"
|
|
273
|
+
|
|
274
|
+
mock_client = MagicMock()
|
|
275
|
+
mock_client.list_runs.return_value = [mock_run]
|
|
276
|
+
|
|
277
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=mock_client):
|
|
278
|
+
with patch(
|
|
279
|
+
"yamlgraph.utils.langsmith.get_project_name",
|
|
280
|
+
return_value="test-project",
|
|
281
|
+
):
|
|
282
|
+
result = get_latest_run_id()
|
|
283
|
+
|
|
284
|
+
mock_client.list_runs.assert_called_once_with(
|
|
285
|
+
project_name="test-project", limit=1
|
|
286
|
+
)
|
|
287
|
+
assert result == "abc-123"
|
|
288
|
+
|
|
289
|
+
def test_returns_none_when_no_runs(self):
|
|
290
|
+
"""Returns None when no runs found."""
|
|
291
|
+
mock_client = MagicMock()
|
|
292
|
+
mock_client.list_runs.return_value = []
|
|
293
|
+
|
|
294
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=mock_client):
|
|
295
|
+
result = get_latest_run_id()
|
|
296
|
+
assert result is None
|
|
297
|
+
|
|
298
|
+
def test_uses_provided_project_name(self):
|
|
299
|
+
"""Uses provided project name."""
|
|
300
|
+
mock_run = MagicMock()
|
|
301
|
+
mock_run.id = "run-id"
|
|
302
|
+
mock_client = MagicMock()
|
|
303
|
+
mock_client.list_runs.return_value = [mock_run]
|
|
304
|
+
|
|
305
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=mock_client):
|
|
306
|
+
get_latest_run_id(project_name="custom-project")
|
|
307
|
+
|
|
308
|
+
mock_client.list_runs.assert_called_once_with(
|
|
309
|
+
project_name="custom-project", limit=1
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
def test_handles_exception_gracefully(self):
|
|
313
|
+
"""Returns None on error (logs warning to stderr)."""
|
|
314
|
+
mock_client = MagicMock()
|
|
315
|
+
mock_client.list_runs.side_effect = Exception("API error")
|
|
316
|
+
|
|
317
|
+
with patch("yamlgraph.utils.langsmith.get_client", return_value=mock_client):
|
|
318
|
+
result = get_latest_run_id()
|
|
319
|
+
assert result is None
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Unit tests for LLM factory module."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from unittest.mock import patch
|
|
5
|
+
|
|
6
|
+
import pytest
|
|
7
|
+
from langchain_anthropic import ChatAnthropic
|
|
8
|
+
|
|
9
|
+
from yamlgraph.utils.llm_factory import clear_cache, create_llm
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TestCreateLLM:
|
|
13
|
+
"""Test the create_llm factory function."""
|
|
14
|
+
|
|
15
|
+
def setup_method(self):
|
|
16
|
+
"""Clear cache and environment before each test."""
|
|
17
|
+
clear_cache()
|
|
18
|
+
|
|
19
|
+
def test_default_provider_is_anthropic(self):
|
|
20
|
+
"""Should use Anthropic by default."""
|
|
21
|
+
# Clear PROVIDER from environment to ensure default behavior
|
|
22
|
+
with patch.dict(os.environ, {"PROVIDER": ""}, clear=False):
|
|
23
|
+
llm = create_llm(temperature=0.7)
|
|
24
|
+
assert isinstance(llm, ChatAnthropic)
|
|
25
|
+
assert llm.temperature == 0.7
|
|
26
|
+
|
|
27
|
+
def test_explicit_anthropic_provider(self):
|
|
28
|
+
"""Should create Anthropic LLM when provider='anthropic'."""
|
|
29
|
+
llm = create_llm(provider="anthropic", temperature=0.5)
|
|
30
|
+
assert isinstance(llm, ChatAnthropic)
|
|
31
|
+
assert llm.temperature == 0.5
|
|
32
|
+
|
|
33
|
+
def test_mistral_provider(self):
|
|
34
|
+
"""Should create Mistral LLM when provider='mistral'."""
|
|
35
|
+
with patch.dict(os.environ, {"MISTRAL_API_KEY": "test-key"}):
|
|
36
|
+
llm = create_llm(provider="mistral", temperature=0.8)
|
|
37
|
+
# Check it's the right class (will import on first call)
|
|
38
|
+
assert llm.__class__.__name__ == "ChatMistralAI"
|
|
39
|
+
assert llm.temperature == 0.8
|
|
40
|
+
|
|
41
|
+
def test_openai_provider(self):
|
|
42
|
+
"""Should create OpenAI LLM when provider='openai'."""
|
|
43
|
+
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
44
|
+
llm = create_llm(provider="openai", temperature=0.6)
|
|
45
|
+
assert llm.__class__.__name__ == "ChatOpenAI"
|
|
46
|
+
assert llm.temperature == 0.6
|
|
47
|
+
|
|
48
|
+
def test_provider_from_environment(self):
|
|
49
|
+
"""Should use PROVIDER env var when no provider specified."""
|
|
50
|
+
with patch.dict(
|
|
51
|
+
os.environ, {"PROVIDER": "mistral", "MISTRAL_API_KEY": "test-key"}
|
|
52
|
+
):
|
|
53
|
+
llm = create_llm(temperature=0.7)
|
|
54
|
+
assert llm.__class__.__name__ == "ChatMistralAI"
|
|
55
|
+
|
|
56
|
+
def test_custom_model(self):
|
|
57
|
+
"""Should use custom model when specified."""
|
|
58
|
+
with patch.dict(os.environ, {"PROVIDER": ""}, clear=False):
|
|
59
|
+
llm = create_llm(model="claude-opus-4", temperature=0.5)
|
|
60
|
+
assert isinstance(llm, ChatAnthropic)
|
|
61
|
+
assert llm.model == "claude-opus-4"
|
|
62
|
+
|
|
63
|
+
def test_model_override_parameter(self):
|
|
64
|
+
"""Should prefer model parameter over default."""
|
|
65
|
+
llm = create_llm(provider="anthropic", model="claude-sonnet-4", temperature=0.7)
|
|
66
|
+
assert llm.model == "claude-sonnet-4"
|
|
67
|
+
|
|
68
|
+
def test_default_models(self):
|
|
69
|
+
"""Should use correct default models for each provider."""
|
|
70
|
+
# Anthropic default
|
|
71
|
+
llm_anthropic = create_llm(provider="anthropic", temperature=0.7)
|
|
72
|
+
assert llm_anthropic.model == "claude-haiku-4-5"
|
|
73
|
+
|
|
74
|
+
# Mistral default
|
|
75
|
+
with patch.dict(os.environ, {"MISTRAL_API_KEY": "test-key"}):
|
|
76
|
+
llm_mistral = create_llm(provider="mistral", temperature=0.7)
|
|
77
|
+
assert llm_mistral.model == "mistral-large-latest"
|
|
78
|
+
|
|
79
|
+
# OpenAI default (uses model_name attribute)
|
|
80
|
+
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
|
81
|
+
llm_openai = create_llm(provider="openai", temperature=0.7)
|
|
82
|
+
assert llm_openai.model_name == "gpt-4o"
|
|
83
|
+
|
|
84
|
+
def test_invalid_provider(self):
|
|
85
|
+
"""Should raise error for invalid provider."""
|
|
86
|
+
with pytest.raises((ValueError, KeyError)):
|
|
87
|
+
create_llm(provider="invalid-provider", temperature=0.7)
|
|
88
|
+
|
|
89
|
+
def test_caching(self):
|
|
90
|
+
"""Should cache LLM instances for same parameters."""
|
|
91
|
+
llm1 = create_llm(provider="anthropic", temperature=0.7)
|
|
92
|
+
llm2 = create_llm(provider="anthropic", temperature=0.7)
|
|
93
|
+
assert llm1 is llm2
|
|
94
|
+
|
|
95
|
+
# Different temperature = different instance
|
|
96
|
+
llm3 = create_llm(provider="anthropic", temperature=0.5)
|
|
97
|
+
assert llm1 is not llm3
|
|
98
|
+
|
|
99
|
+
def test_cache_key_includes_all_params(self):
|
|
100
|
+
"""Cache should differentiate on provider, model, temperature."""
|
|
101
|
+
llm1 = create_llm(
|
|
102
|
+
provider="anthropic", model="claude-haiku-4-5", temperature=0.7
|
|
103
|
+
)
|
|
104
|
+
llm2 = create_llm(provider="anthropic", model="claude-opus-4", temperature=0.7)
|
|
105
|
+
assert llm1 is not llm2
|
|
106
|
+
|
|
107
|
+
with patch.dict(os.environ, {"MISTRAL_API_KEY": "test-key"}):
|
|
108
|
+
llm3 = create_llm(provider="mistral", temperature=0.7)
|
|
109
|
+
assert llm1 is not llm3
|