kolega-code 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kolega_code/__init__.py +151 -0
- kolega_code/agent/__init__.py +42 -0
- kolega_code/agent/baseagent.py +998 -0
- kolega_code/agent/browseragent.py +123 -0
- kolega_code/agent/coder.py +157 -0
- kolega_code/agent/common.py +41 -0
- kolega_code/agent/compression.py +81 -0
- kolega_code/agent/context.py +112 -0
- kolega_code/agent/conversation.py +408 -0
- kolega_code/agent/generalagent.py +146 -0
- kolega_code/agent/investigationagent.py +123 -0
- kolega_code/agent/planningagent.py +187 -0
- kolega_code/agent/prompt_provider.py +196 -0
- kolega_code/agent/prompt_templates/agents/browser.j2 +102 -0
- kolega_code/agent/prompt_templates/agents/coder_cli_mode.j2 +127 -0
- kolega_code/agent/prompt_templates/agents/general.j2 +68 -0
- kolega_code/agent/prompt_templates/agents/investigation.j2 +72 -0
- kolega_code/agent/prompt_templates/common/frontend_guidance.md +36 -0
- kolega_code/agent/prompt_templates/common/kolega_md_instructions.md +14 -0
- kolega_code/agent/prompt_templates/environment_variables/workspace_env_vars.md +11 -0
- kolega_code/agent/prompt_templates/template_guidance/expo-template.md +379 -0
- kolega_code/agent/prompt_templates/template_guidance/html-website-template.md +3 -0
- kolega_code/agent/prompt_templates/template_guidance/mern-stack-template.md +3 -0
- kolega_code/agent/prompt_templates/template_guidance/react-vite-shadcdn-template.md +182 -0
- kolega_code/agent/prompts.py +192 -0
- kolega_code/agent/tests/__init__.py +0 -0
- kolega_code/agent/tests/llm/__init__.py +0 -0
- kolega_code/agent/tests/llm/test_anthropic_token_counting.py +633 -0
- kolega_code/agent/tests/llm/test_billing_openai_cache.py +74 -0
- kolega_code/agent/tests/llm/test_client.py +773 -0
- kolega_code/agent/tests/llm/test_dashscope_mapping.py +32 -0
- kolega_code/agent/tests/llm/test_error_boundary.py +322 -0
- kolega_code/agent/tests/llm/test_exceptions.py +249 -0
- kolega_code/agent/tests/llm/test_instrumented_client.py +536 -0
- kolega_code/agent/tests/llm/test_instrumented_client_integration.py +547 -0
- kolega_code/agent/tests/llm/test_langfuse_normalization.py +39 -0
- kolega_code/agent/tests/llm/test_model_specs.py +17 -0
- kolega_code/agent/tests/llm/test_openai_cached_tokens.py +58 -0
- kolega_code/agent/tests/llm/test_openai_cached_tokens_stream.py +74 -0
- kolega_code/agent/tests/llm/test_openai_message_conversion.py +30 -0
- kolega_code/agent/tests/llm/test_openai_token_counting.py +687 -0
- kolega_code/agent/tests/llm/test_tool_execution_ids.py +193 -0
- kolega_code/agent/tests/services/__init__.py +1 -0
- kolega_code/agent/tests/services/test_browser.py +447 -0
- kolega_code/agent/tests/services/test_browser_parity.py +353 -0
- kolega_code/agent/tests/services/test_file_system.py +699 -0
- kolega_code/agent/tests/services/test_sandbox_terminal_input.py +98 -0
- kolega_code/agent/tests/services/test_terminal.py +154 -0
- kolega_code/agent/tests/services/test_terminal_command_tracking.py +385 -0
- kolega_code/agent/tests/services/test_terminal_state_serializer.py +262 -0
- kolega_code/agent/tests/test_agent_tools_inventory.py +267 -0
- kolega_code/agent/tests/test_base_agent.py +1942 -0
- kolega_code/agent/tests/test_coder_attachments.py +330 -0
- kolega_code/agent/tests/test_coder_prompt_extensions.py +61 -0
- kolega_code/agent/tests/test_commands.py +179 -0
- kolega_code/agent/tests/test_duplicate_tool_results.py +556 -0
- kolega_code/agent/tests/test_empty_message_handling.py +48 -0
- kolega_code/agent/tests/test_general_agent.py +242 -0
- kolega_code/agent/tests/test_html.py +320 -0
- kolega_code/agent/tests/test_parallel_tool_calls.py +291 -0
- kolega_code/agent/tests/test_planning_agent.py +227 -0
- kolega_code/agent/tests/test_prompt_provider.py +271 -0
- kolega_code/agent/tests/test_tool_registry.py +102 -0
- kolega_code/agent/tests/test_tools.py +549 -0
- kolega_code/agent/tests/tool_backend/__init__.py +0 -0
- kolega_code/agent/tests/tool_backend/test_agent_tool.py +356 -0
- kolega_code/agent/tests/tool_backend/test_base_tool.py +147 -0
- kolega_code/agent/tests/tool_backend/test_browser_tool.py +335 -0
- kolega_code/agent/tests/tool_backend/test_build_tool.py +93 -0
- kolega_code/agent/tests/tool_backend/test_create_file_tool.py +115 -0
- kolega_code/agent/tests/tool_backend/test_glob_tool.py +196 -0
- kolega_code/agent/tests/tool_backend/test_glob_tool_sandbox_parity.py +230 -0
- kolega_code/agent/tests/tool_backend/test_list_directory_tool.py +292 -0
- kolega_code/agent/tests/tool_backend/test_read_file_tool.py +173 -0
- kolega_code/agent/tests/tool_backend/test_replace_entire_file_tool.py +115 -0
- kolega_code/agent/tests/tool_backend/test_replace_lines_tool.py +141 -0
- kolega_code/agent/tests/tool_backend/test_search_and_replace_tool.py +174 -0
- kolega_code/agent/tests/tool_backend/test_search_codebase_tool.py +228 -0
- kolega_code/agent/tests/tool_backend/test_terminal_tool.py +482 -0
- kolega_code/agent/tests/tool_backend/test_think_hard_integration.py +189 -0
- kolega_code/agent/tests/tool_backend/test_think_hard_streaming.py +445 -0
- kolega_code/agent/tests/tool_backend/test_web_fetch_tool.py +194 -0
- kolega_code/agent/tool_backend/agent_tool.py +414 -0
- kolega_code/agent/tool_backend/apply_edit_tool.py +98 -0
- kolega_code/agent/tool_backend/apply_patch_tool.py +514 -0
- kolega_code/agent/tool_backend/base_tool.py +217 -0
- kolega_code/agent/tool_backend/browser_tool.py +271 -0
- kolega_code/agent/tool_backend/build_tool.py +93 -0
- kolega_code/agent/tool_backend/create_file_tool.py +52 -0
- kolega_code/agent/tool_backend/glob_tool.py +323 -0
- kolega_code/agent/tool_backend/list_directory_tool.py +300 -0
- kolega_code/agent/tool_backend/memory_tool.py +79 -0
- kolega_code/agent/tool_backend/read_file_tool.py +119 -0
- kolega_code/agent/tool_backend/replace_entire_file_tool.py +40 -0
- kolega_code/agent/tool_backend/replace_lines_tool.py +97 -0
- kolega_code/agent/tool_backend/search_and_replace_tool.py +146 -0
- kolega_code/agent/tool_backend/search_codebase_tool.py +377 -0
- kolega_code/agent/tool_backend/streaming_tool.py +47 -0
- kolega_code/agent/tool_backend/terminal_tool.py +643 -0
- kolega_code/agent/tool_backend/think_hard_tool.py +211 -0
- kolega_code/agent/tool_backend/web_fetch_tool.py +205 -0
- kolega_code/agent/tools.py +1704 -0
- kolega_code/agent/utils/commands.py +94 -0
- kolega_code/cli/__init__.py +1 -0
- kolega_code/cli/app.py +2756 -0
- kolega_code/cli/config.py +280 -0
- kolega_code/cli/connection.py +49 -0
- kolega_code/cli/file_index.py +147 -0
- kolega_code/cli/main.py +564 -0
- kolega_code/cli/mentions.py +155 -0
- kolega_code/cli/messages.py +89 -0
- kolega_code/cli/provider_registry.py +96 -0
- kolega_code/cli/session_store.py +207 -0
- kolega_code/cli/settings.py +87 -0
- kolega_code/cli/skills.py +409 -0
- kolega_code/cli/slash_commands.py +108 -0
- kolega_code/cli/tests/__init__.py +1 -0
- kolega_code/cli/tests/test_app.py +4251 -0
- kolega_code/cli/tests/test_cli_config.py +171 -0
- kolega_code/cli/tests/test_connection.py +26 -0
- kolega_code/cli/tests/test_file_index.py +103 -0
- kolega_code/cli/tests/test_main.py +455 -0
- kolega_code/cli/tests/test_mentions.py +108 -0
- kolega_code/cli/tests/test_session_store.py +67 -0
- kolega_code/cli/tests/test_settings.py +62 -0
- kolega_code/cli/tests/test_skills.py +157 -0
- kolega_code/cli/tests/test_slash_commands.py +88 -0
- kolega_code/cli/theme.py +180 -0
- kolega_code/config.py +154 -0
- kolega_code/events.py +202 -0
- kolega_code/llm/client.py +300 -0
- kolega_code/llm/exceptions.py +285 -0
- kolega_code/llm/instrumented_client.py +520 -0
- kolega_code/llm/models.py +1368 -0
- kolega_code/llm/providers/__init__.py +0 -0
- kolega_code/llm/providers/anthropic.py +387 -0
- kolega_code/llm/providers/base.py +71 -0
- kolega_code/llm/providers/google.py +157 -0
- kolega_code/llm/providers/models.py +37 -0
- kolega_code/llm/providers/openai.py +363 -0
- kolega_code/llm/ratelimit.py +40 -0
- kolega_code/llm/specs.py +67 -0
- kolega_code/llm/tool_execution_ids.py +18 -0
- kolega_code/models/__init__.py +9 -0
- kolega_code/models/sandbox_terminal_state.py +47 -0
- kolega_code/runtime.py +50 -0
- kolega_code/sandbox/README.md +200 -0
- kolega_code/sandbox/__init__.py +21 -0
- kolega_code/sandbox/async_filesystem.py +475 -0
- kolega_code/sandbox/base.py +297 -0
- kolega_code/sandbox/browser.py +25 -0
- kolega_code/sandbox/event_loop.py +43 -0
- kolega_code/sandbox/filesystem.py +341 -0
- kolega_code/sandbox/local.py +118 -0
- kolega_code/sandbox/serializer.py +175 -0
- kolega_code/sandbox/terminal.py +868 -0
- kolega_code/sandbox/utils.py +216 -0
- kolega_code/services/base.py +255 -0
- kolega_code/services/browser.py +444 -0
- kolega_code/services/file_system.py +749 -0
- kolega_code/services/html.py +221 -0
- kolega_code/services/terminal.py +903 -0
- kolega_code/tools/__init__.py +22 -0
- kolega_code/tools/core.py +33 -0
- kolega_code/tools/definitions.py +81 -0
- kolega_code/tools/registry.py +73 -0
- kolega_code-0.1.0.dist-info/METADATA +157 -0
- kolega_code-0.1.0.dist-info/RECORD +171 -0
- kolega_code-0.1.0.dist-info/WHEEL +4 -0
- kolega_code-0.1.0.dist-info/entry_points.txt +2 -0
- kolega_code-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Tests for the InstrumentedLLMClient class.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import pytest
|
|
7
|
+
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
|
8
|
+
|
|
9
|
+
from kolega_code.llm.models import Message, MessageHistory, TextBlock
|
|
10
|
+
from kolega_code.llm.instrumented_client import (
|
|
11
|
+
InstrumentedLLMClient,
|
|
12
|
+
MinimalLangfuseStreamWrapper,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
# Check if running in CI environment
|
|
16
|
+
SKIP_IN_CI = bool(os.getenv("CI")) or bool(os.getenv("GITLAB_CI"))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TestInstrumentedLLMClient:
|
|
20
|
+
"""Test the InstrumentedLLMClient class."""
|
|
21
|
+
|
|
22
|
+
@pytest.fixture
|
|
23
|
+
def mock_langfuse(self):
|
|
24
|
+
"""Create mock Langfuse client with generation tracking (v3 API)."""
|
|
25
|
+
langfuse = MagicMock()
|
|
26
|
+
|
|
27
|
+
# Create a mock generation that tracks calls
|
|
28
|
+
generation = MagicMock()
|
|
29
|
+
generation.update = MagicMock()
|
|
30
|
+
generation.end = MagicMock()
|
|
31
|
+
|
|
32
|
+
# Create a mock trace/span that returns the generation
|
|
33
|
+
trace = MagicMock()
|
|
34
|
+
trace.update_trace = MagicMock()
|
|
35
|
+
trace.update = MagicMock()
|
|
36
|
+
trace.end = MagicMock()
|
|
37
|
+
trace.start_generation = MagicMock(return_value=generation)
|
|
38
|
+
|
|
39
|
+
# Make langfuse.start_span() return the trace
|
|
40
|
+
langfuse.start_span = MagicMock(return_value=trace)
|
|
41
|
+
|
|
42
|
+
return langfuse, generation
|
|
43
|
+
|
|
44
|
+
@pytest.fixture
|
|
45
|
+
def instrumented_client(self, mock_langfuse):
|
|
46
|
+
"""Create an instrumented LLM client with mocked Langfuse."""
|
|
47
|
+
langfuse, _ = mock_langfuse
|
|
48
|
+
return InstrumentedLLMClient(
|
|
49
|
+
provider="anthropic",
|
|
50
|
+
api_key="test-key",
|
|
51
|
+
langfuse_client=langfuse,
|
|
52
|
+
workspace_id="workspace-123",
|
|
53
|
+
thread_id="thread-456",
|
|
54
|
+
agent_type="test-agent",
|
|
55
|
+
environment="test",
|
|
56
|
+
user_id="user-789",
|
|
57
|
+
user_email="test@example.com",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
def test_init_with_langfuse(self, mock_langfuse):
|
|
61
|
+
"""Test initialization with Langfuse client."""
|
|
62
|
+
langfuse, _ = mock_langfuse
|
|
63
|
+
client = InstrumentedLLMClient(
|
|
64
|
+
provider="anthropic",
|
|
65
|
+
api_key="test-key",
|
|
66
|
+
langfuse_client=langfuse,
|
|
67
|
+
workspace_id="workspace-123",
|
|
68
|
+
thread_id="thread-456",
|
|
69
|
+
agent_type="test-agent",
|
|
70
|
+
environment="production",
|
|
71
|
+
user_id="user-789",
|
|
72
|
+
user_email="test@example.com",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
assert client.langfuse == langfuse
|
|
76
|
+
assert client.workspace_id == "workspace-123"
|
|
77
|
+
assert client.thread_id == "thread-456"
|
|
78
|
+
assert client.agent_type == "test-agent"
|
|
79
|
+
assert client.environment == "production"
|
|
80
|
+
assert client.user_id == "user-789"
|
|
81
|
+
assert client.user_email == "test@example.com"
|
|
82
|
+
|
|
83
|
+
def test_init_without_langfuse(self):
|
|
84
|
+
"""Test initialization without Langfuse client."""
|
|
85
|
+
client = InstrumentedLLMClient(
|
|
86
|
+
provider="anthropic",
|
|
87
|
+
api_key="test-key",
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
assert client.langfuse is None
|
|
91
|
+
assert client.workspace_id is None
|
|
92
|
+
assert client.thread_id is None
|
|
93
|
+
assert client.agent_type is None
|
|
94
|
+
assert client.environment == "development" # default from os.environ
|
|
95
|
+
|
|
96
|
+
def test_create_generation_metadata(self, instrumented_client):
|
|
97
|
+
"""Test metadata creation for Langfuse generation."""
|
|
98
|
+
metadata = instrumented_client._create_generation_metadata(
|
|
99
|
+
custom_field="value",
|
|
100
|
+
another_field=123,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
assert metadata["provider"] == "anthropic"
|
|
104
|
+
assert metadata["workspace_id"] == "workspace-123"
|
|
105
|
+
assert metadata["thread_id"] == "thread-456"
|
|
106
|
+
assert metadata["agent_type"] == "test-agent"
|
|
107
|
+
assert metadata["environment"] == "test"
|
|
108
|
+
assert metadata["user_id"] == "user-789"
|
|
109
|
+
assert metadata["user_email"] == "test@example.com"
|
|
110
|
+
assert metadata["custom_field"] == "value"
|
|
111
|
+
assert metadata["another_field"] == 123
|
|
112
|
+
assert "timestamp" in metadata
|
|
113
|
+
|
|
114
|
+
def test_extract_usage_details_anthropic(self, instrumented_client):
|
|
115
|
+
"""Test extraction of usage details from Anthropic response."""
|
|
116
|
+
# Mock Message with usage_metadata
|
|
117
|
+
mock_response = Mock()
|
|
118
|
+
mock_response.usage_metadata = {
|
|
119
|
+
"provider": "anthropic",
|
|
120
|
+
"input_tokens": 100,
|
|
121
|
+
"output_tokens": 50,
|
|
122
|
+
"cache_read_input_tokens": 25,
|
|
123
|
+
"cache_creation_input_tokens": 5,
|
|
124
|
+
"cache_write_input_tokens": 5, # Add this field that the code expects
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
usage = instrumented_client._extract_usage_details(mock_response)
|
|
128
|
+
|
|
129
|
+
assert usage["input_tokens"] == 100
|
|
130
|
+
assert usage["output_tokens"] == 50
|
|
131
|
+
assert usage["cache_read_input_tokens"] == 25
|
|
132
|
+
assert usage["cache_write_input_tokens"] == 5
|
|
133
|
+
|
|
134
|
+
def test_extract_usage_details_openai(self, instrumented_client):
|
|
135
|
+
"""Test extraction of usage details from OpenAI response."""
|
|
136
|
+
# Mock Message with usage_metadata
|
|
137
|
+
mock_response = Mock()
|
|
138
|
+
mock_response.usage_metadata = {
|
|
139
|
+
"provider": "openai",
|
|
140
|
+
"prompt_tokens": 100,
|
|
141
|
+
"completion_tokens": 50,
|
|
142
|
+
"total_tokens": 150,
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
usage = instrumented_client._extract_usage_details(mock_response)
|
|
146
|
+
|
|
147
|
+
assert usage["prompt_tokens"] == 100
|
|
148
|
+
assert usage["completion_tokens"] == 50
|
|
149
|
+
assert usage["total_tokens"] == 150
|
|
150
|
+
|
|
151
|
+
def test_extract_usage_details_google(self, instrumented_client):
|
|
152
|
+
"""Test extraction of usage details from Google response."""
|
|
153
|
+
# Mock Message with usage_metadata
|
|
154
|
+
mock_response = Mock()
|
|
155
|
+
mock_response.usage_metadata = {
|
|
156
|
+
"provider": "google",
|
|
157
|
+
"prompt_token_count": 100,
|
|
158
|
+
"candidates_token_count": 50,
|
|
159
|
+
"total_token_count": 150,
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
usage = instrumented_client._extract_usage_details(mock_response)
|
|
163
|
+
|
|
164
|
+
assert usage["prompt_token_count"] == 100
|
|
165
|
+
assert usage["candidates_token_count"] == 50
|
|
166
|
+
assert usage["total_token_count"] == 150
|
|
167
|
+
|
|
168
|
+
def test_normalize_usage_data_deepseek(self, instrumented_client):
|
|
169
|
+
usage = instrumented_client._normalize_usage_data(
|
|
170
|
+
{
|
|
171
|
+
"provider": "deepseek",
|
|
172
|
+
"input_tokens": 100,
|
|
173
|
+
"output_tokens": 50,
|
|
174
|
+
"cache_read_input_tokens": 25,
|
|
175
|
+
"cache_write_input_tokens": 5,
|
|
176
|
+
},
|
|
177
|
+
model="deepseek-v4-pro",
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
assert usage["provider"] == "deepseek"
|
|
181
|
+
assert usage["model"] == "deepseek-v4-pro"
|
|
182
|
+
assert usage["input_tokens"] == 100
|
|
183
|
+
assert usage["output_tokens"] == 50
|
|
184
|
+
assert usage["cache_read_input_tokens"] == 25
|
|
185
|
+
assert usage["cache_write_input_tokens"] == 5
|
|
186
|
+
|
|
187
|
+
def test_extract_usage_details_no_metadata(self, instrumented_client):
|
|
188
|
+
"""Test extraction returns empty dict when no metadata."""
|
|
189
|
+
# Mock Message without usage_metadata
|
|
190
|
+
mock_response = Mock(spec=[])
|
|
191
|
+
|
|
192
|
+
usage = instrumented_client._extract_usage_details(mock_response)
|
|
193
|
+
assert usage == {}
|
|
194
|
+
|
|
195
|
+
# Mock Message with empty usage_metadata
|
|
196
|
+
mock_response = Mock()
|
|
197
|
+
mock_response.usage_metadata = {}
|
|
198
|
+
|
|
199
|
+
usage = instrumented_client._extract_usage_details(mock_response)
|
|
200
|
+
assert usage == {}
|
|
201
|
+
|
|
202
|
+
@pytest.mark.asyncio
|
|
203
|
+
async def test_generate_with_langfuse(self, instrumented_client, mock_langfuse):
|
|
204
|
+
"""Test generate method with Langfuse tracing."""
|
|
205
|
+
langfuse, generation = mock_langfuse
|
|
206
|
+
trace = langfuse.start_span.return_value
|
|
207
|
+
|
|
208
|
+
# Mock the parent generate method
|
|
209
|
+
mock_response = Mock()
|
|
210
|
+
mock_response.to_dict = Mock(return_value={"content": "test response"})
|
|
211
|
+
mock_response.usage_metadata = {
|
|
212
|
+
"provider": "anthropic",
|
|
213
|
+
"input_tokens": 10,
|
|
214
|
+
"output_tokens": 5,
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
# Use patch on the parent class method
|
|
218
|
+
with patch("kolega_code.llm.client.LLMClient.generate", AsyncMock(return_value=mock_response)):
|
|
219
|
+
messages = MessageHistory([Message(role="user", content=[TextBlock(text="Hello")])])
|
|
220
|
+
|
|
221
|
+
await instrumented_client.generate(
|
|
222
|
+
messages=messages, model="claude-3-opus", temperature=0.5, max_completion_tokens=100
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Verify Langfuse trace was called correctly
|
|
226
|
+
langfuse.start_span.assert_called_once()
|
|
227
|
+
span_args = langfuse.start_span.call_args
|
|
228
|
+
|
|
229
|
+
assert span_args.kwargs["name"] == "test-agent-llm-call"
|
|
230
|
+
|
|
231
|
+
# Verify trace attributes were updated
|
|
232
|
+
trace.update_trace.assert_called_once()
|
|
233
|
+
trace_update_args = trace.update_trace.call_args
|
|
234
|
+
assert trace_update_args.kwargs["user_id"] == "user-789" # Now uses actual user_id
|
|
235
|
+
assert trace_update_args.kwargs["session_id"] == "workspace-123/thread-456"
|
|
236
|
+
assert "test" in trace_update_args.kwargs["tags"]
|
|
237
|
+
assert "user:user-789" in trace_update_args.kwargs["tags"]
|
|
238
|
+
|
|
239
|
+
# Verify generation was called on the trace
|
|
240
|
+
trace.start_generation.assert_called_once()
|
|
241
|
+
gen_args = trace.start_generation.call_args
|
|
242
|
+
assert gen_args.kwargs["name"] == "test-agent-llm-generation"
|
|
243
|
+
assert gen_args.kwargs["model"] == "claude-3-opus"
|
|
244
|
+
assert gen_args.kwargs["model_parameters"]["temperature"] == 0.5
|
|
245
|
+
|
|
246
|
+
# Verify generation.update and end were called with success
|
|
247
|
+
generation.update.assert_called_once()
|
|
248
|
+
update_args = generation.update.call_args
|
|
249
|
+
assert update_args.kwargs["level"] == "DEFAULT"
|
|
250
|
+
assert update_args.kwargs["status_message"] == "Success"
|
|
251
|
+
assert update_args.kwargs["usage_details"]["input"] == 10
|
|
252
|
+
assert update_args.kwargs["usage_details"]["output"] == 5
|
|
253
|
+
generation.end.assert_called_once()
|
|
254
|
+
trace.end.assert_called_once()
|
|
255
|
+
|
|
256
|
+
@pytest.mark.asyncio
|
|
257
|
+
async def test_generate_with_user_tracking(self):
|
|
258
|
+
"""Test generate method with user tracking information."""
|
|
259
|
+
mock_langfuse = MagicMock()
|
|
260
|
+
generation = MagicMock()
|
|
261
|
+
trace = MagicMock()
|
|
262
|
+
trace.start_generation = MagicMock(return_value=generation)
|
|
263
|
+
mock_langfuse.start_span = MagicMock(return_value=trace)
|
|
264
|
+
|
|
265
|
+
# Create client with user information
|
|
266
|
+
client = InstrumentedLLMClient(
|
|
267
|
+
provider="anthropic",
|
|
268
|
+
api_key="test-key",
|
|
269
|
+
langfuse_client=mock_langfuse,
|
|
270
|
+
workspace_id="workspace-123",
|
|
271
|
+
thread_id="thread-456",
|
|
272
|
+
agent_type="test-agent",
|
|
273
|
+
environment="test",
|
|
274
|
+
user_id="user-789",
|
|
275
|
+
user_email="test@example.com",
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Mock the parent generate method
|
|
279
|
+
mock_response = Mock()
|
|
280
|
+
mock_response.to_dict = Mock(return_value={"content": "test response"})
|
|
281
|
+
mock_response.usage_metadata = {
|
|
282
|
+
"provider": "anthropic",
|
|
283
|
+
"input_tokens": 10,
|
|
284
|
+
"output_tokens": 5,
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
# Use patch on the parent class method
|
|
288
|
+
with patch("kolega_code.llm.client.LLMClient.generate", AsyncMock(return_value=mock_response)):
|
|
289
|
+
messages = MessageHistory([Message(role="user", content=[TextBlock(text="Hello")])])
|
|
290
|
+
|
|
291
|
+
await client.generate(
|
|
292
|
+
messages=messages, model="claude-3-opus", temperature=0.5, max_completion_tokens=100
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# Verify trace attributes include user information
|
|
296
|
+
trace.update_trace.assert_called_once()
|
|
297
|
+
trace_update_args = trace.update_trace.call_args
|
|
298
|
+
assert trace_update_args.kwargs["user_id"] == "user-789" # Uses actual user_id, not workspace
|
|
299
|
+
assert trace_update_args.kwargs["session_id"] == "workspace-123/thread-456" # No email in session name
|
|
300
|
+
assert "user:user-789" in trace_update_args.kwargs["tags"]
|
|
301
|
+
|
|
302
|
+
# Verify metadata includes user information
|
|
303
|
+
span_args = mock_langfuse.start_span.call_args
|
|
304
|
+
metadata = span_args.kwargs["metadata"]
|
|
305
|
+
assert metadata["user_id"] == "user-789"
|
|
306
|
+
assert metadata["user_email"] == "test@example.com"
|
|
307
|
+
|
|
308
|
+
@pytest.mark.asyncio
|
|
309
|
+
async def test_generate_without_langfuse(self):
|
|
310
|
+
"""Test generate falls back to parent when no Langfuse."""
|
|
311
|
+
client = InstrumentedLLMClient(
|
|
312
|
+
provider="anthropic",
|
|
313
|
+
api_key="test-key",
|
|
314
|
+
langfuse_client=None,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Mock parent generate
|
|
318
|
+
mock_response = Mock()
|
|
319
|
+
with patch("kolega_code.llm.client.LLMClient.generate", AsyncMock(return_value=mock_response)):
|
|
320
|
+
messages = MessageHistory([Message(role="user", content=[TextBlock(text="Hello")])])
|
|
321
|
+
|
|
322
|
+
result = await client.generate(messages=messages)
|
|
323
|
+
assert result == mock_response
|
|
324
|
+
|
|
325
|
+
@pytest.mark.asyncio
|
|
326
|
+
async def test_error_handling(self, instrumented_client, mock_langfuse):
|
|
327
|
+
"""Test error handling in generate method."""
|
|
328
|
+
langfuse, generation = mock_langfuse
|
|
329
|
+
trace = langfuse.start_span.return_value
|
|
330
|
+
|
|
331
|
+
error_msg = "API Error"
|
|
332
|
+
with patch("kolega_code.llm.client.LLMClient.generate", AsyncMock(side_effect=Exception(error_msg))):
|
|
333
|
+
messages = MessageHistory([Message(role="user", content=[TextBlock(text="Hello")])])
|
|
334
|
+
|
|
335
|
+
with pytest.raises(Exception) as exc_info:
|
|
336
|
+
await instrumented_client.generate(messages=messages)
|
|
337
|
+
|
|
338
|
+
assert str(exc_info.value) == error_msg
|
|
339
|
+
|
|
340
|
+
# Verify generation.update was called with error
|
|
341
|
+
generation.update.assert_called_once()
|
|
342
|
+
update_args = generation.update.call_args
|
|
343
|
+
assert update_args.kwargs["level"] == "ERROR"
|
|
344
|
+
assert update_args.kwargs["status_message"] == error_msg
|
|
345
|
+
|
|
346
|
+
# Verify generation.end and trace.end were still called
|
|
347
|
+
generation.end.assert_called_once()
|
|
348
|
+
trace.update.assert_called_once()
|
|
349
|
+
trace.end.assert_called_once()
|
|
350
|
+
|
|
351
|
+
@pytest.mark.asyncio
|
|
352
|
+
async def test_stream_with_langfuse(self, instrumented_client, mock_langfuse):
|
|
353
|
+
"""Test stream method with Langfuse tracing."""
|
|
354
|
+
langfuse, generation = mock_langfuse
|
|
355
|
+
trace = langfuse.start_span.return_value
|
|
356
|
+
|
|
357
|
+
# Mock stream context manager
|
|
358
|
+
mock_stream = AsyncMock()
|
|
359
|
+
mock_stream.__aenter__ = AsyncMock(return_value=mock_stream)
|
|
360
|
+
mock_stream.__aexit__ = AsyncMock(return_value=None)
|
|
361
|
+
|
|
362
|
+
with patch("kolega_code.llm.client.LLMClient.stream", MagicMock(return_value=mock_stream)):
|
|
363
|
+
messages = MessageHistory([Message(role="user", content=[TextBlock(text="Hello")])])
|
|
364
|
+
|
|
365
|
+
# stream() now returns a coroutine, so we need to await it
|
|
366
|
+
result_coro = instrumented_client.stream(messages=messages, model="claude-3-opus")
|
|
367
|
+
|
|
368
|
+
# The coroutine should create langfuse tracing when awaited
|
|
369
|
+
result = await result_coro
|
|
370
|
+
|
|
371
|
+
# Verify Langfuse trace was called
|
|
372
|
+
langfuse.start_span.assert_called_once()
|
|
373
|
+
span_args = langfuse.start_span.call_args
|
|
374
|
+
assert span_args.kwargs["name"] == "test-agent-llm-stream"
|
|
375
|
+
|
|
376
|
+
# Verify trace attributes were updated
|
|
377
|
+
trace.update_trace.assert_called_once()
|
|
378
|
+
trace_update_args = trace.update_trace.call_args
|
|
379
|
+
assert "streaming" in trace_update_args.kwargs["tags"]
|
|
380
|
+
|
|
381
|
+
# Verify generation was called on the trace
|
|
382
|
+
trace.start_generation.assert_called_once()
|
|
383
|
+
gen_args = trace.start_generation.call_args
|
|
384
|
+
assert gen_args.kwargs["name"] == "test-agent-llm-stream-generation"
|
|
385
|
+
assert gen_args.kwargs["model"] == "claude-3-opus"
|
|
386
|
+
|
|
387
|
+
# Should return an instrumented wrapper
|
|
388
|
+
assert isinstance(result, MinimalLangfuseStreamWrapper)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class TestStreamWrappers:
|
|
392
|
+
"""Test the instrumented stream wrapper classes."""
|
|
393
|
+
|
|
394
|
+
@pytest.mark.asyncio
|
|
395
|
+
async def test_stream_wrapper_accumulates_content(self):
|
|
396
|
+
"""Test that stream wrapper accumulates content for Langfuse."""
|
|
397
|
+
mock_stream = AsyncMock()
|
|
398
|
+
mock_generation = MagicMock()
|
|
399
|
+
mock_trace = MagicMock()
|
|
400
|
+
mock_client = MagicMock()
|
|
401
|
+
mock_client._record_usage = AsyncMock()
|
|
402
|
+
model = "claude-3-opus"
|
|
403
|
+
|
|
404
|
+
wrapper = MinimalLangfuseStreamWrapper(mock_stream, mock_generation, mock_trace, mock_client, model)
|
|
405
|
+
|
|
406
|
+
# Create mock chunks with get_text_content method
|
|
407
|
+
chunks = []
|
|
408
|
+
for text in ["Hello", " ", "world"]:
|
|
409
|
+
chunk = MagicMock()
|
|
410
|
+
chunk.get_text_content.return_value = text
|
|
411
|
+
chunks.append(chunk)
|
|
412
|
+
|
|
413
|
+
mock_stream.__anext__ = AsyncMock(side_effect=chunks + [StopAsyncIteration])
|
|
414
|
+
|
|
415
|
+
# Create mock final message
|
|
416
|
+
final_message = MagicMock()
|
|
417
|
+
final_message.get_text_content.return_value = "Hello world"
|
|
418
|
+
final_message.usage_metadata = {"provider": "anthropic", "input_tokens": 10, "output_tokens": 2}
|
|
419
|
+
mock_stream.get_final_message = AsyncMock(return_value=final_message)
|
|
420
|
+
|
|
421
|
+
# Consume stream
|
|
422
|
+
collected = []
|
|
423
|
+
async with wrapper:
|
|
424
|
+
async for chunk in wrapper:
|
|
425
|
+
collected.append(chunk)
|
|
426
|
+
|
|
427
|
+
assert len(collected) == 3
|
|
428
|
+
|
|
429
|
+
# Verify generation was updated with final data
|
|
430
|
+
mock_generation.update.assert_called_once()
|
|
431
|
+
update_call = mock_generation.update.call_args
|
|
432
|
+
assert update_call[1]["output"] == "Hello world"
|
|
433
|
+
assert update_call[1]["usage_details"] == {
|
|
434
|
+
"input": 10,
|
|
435
|
+
"output": 2,
|
|
436
|
+
"total": 12,
|
|
437
|
+
"cache_read_input_tokens": 0,
|
|
438
|
+
"cache_creation_input_tokens": 0,
|
|
439
|
+
}
|
|
440
|
+
mock_generation.end.assert_called_once()
|
|
441
|
+
mock_trace.end.assert_called_once()
|
|
442
|
+
|
|
443
|
+
mock_client._record_usage.assert_awaited_once()
|
|
444
|
+
usage_call = mock_client._record_usage.call_args
|
|
445
|
+
assert usage_call[0][0] == {"provider": "anthropic", "input_tokens": 10, "output_tokens": 2}
|
|
446
|
+
assert usage_call[0][1] == model
|
|
447
|
+
|
|
448
|
+
@pytest.mark.asyncio
|
|
449
|
+
@pytest.mark.skipif(SKIP_IN_CI, reason="Skipping slow test in CI environment")
|
|
450
|
+
async def test_stream_wrapper_provider_selection(self):
|
|
451
|
+
"""Test correct stream wrapper is selected based on provider."""
|
|
452
|
+
test_cases = [
|
|
453
|
+
("anthropic", MinimalLangfuseStreamWrapper),
|
|
454
|
+
("openai", MinimalLangfuseStreamWrapper),
|
|
455
|
+
("groq", MinimalLangfuseStreamWrapper),
|
|
456
|
+
("google", MinimalLangfuseStreamWrapper),
|
|
457
|
+
]
|
|
458
|
+
|
|
459
|
+
for provider, expected_wrapper in test_cases:
|
|
460
|
+
mock_stream = AsyncMock()
|
|
461
|
+
mock_generation = MagicMock()
|
|
462
|
+
mock_langfuse = MagicMock()
|
|
463
|
+
mock_langfuse.generation.return_value = mock_generation
|
|
464
|
+
|
|
465
|
+
# Create instrumented client with specific provider
|
|
466
|
+
instrumented_client = InstrumentedLLMClient(
|
|
467
|
+
provider=provider,
|
|
468
|
+
api_key="test-key",
|
|
469
|
+
langfuse_client=mock_langfuse,
|
|
470
|
+
workspace_id="workspace-123",
|
|
471
|
+
thread_id="thread-456",
|
|
472
|
+
agent_type="test-agent",
|
|
473
|
+
environment="test",
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
# All providers now use MinimalLangfuseStreamWrapper
|
|
477
|
+
mock_trace = MagicMock()
|
|
478
|
+
model = "test-model"
|
|
479
|
+
wrapper = MinimalLangfuseStreamWrapper(mock_stream, mock_generation, mock_trace, instrumented_client, model)
|
|
480
|
+
|
|
481
|
+
assert isinstance(wrapper, expected_wrapper)
|
|
482
|
+
|
|
483
|
+
@pytest.mark.asyncio
|
|
484
|
+
async def test_stream_wrapper_handles_exceptions(self):
|
|
485
|
+
"""Test stream wrapper handles exceptions gracefully."""
|
|
486
|
+
mock_stream = AsyncMock()
|
|
487
|
+
mock_generation = MagicMock()
|
|
488
|
+
mock_trace = MagicMock()
|
|
489
|
+
mock_client = MagicMock()
|
|
490
|
+
mock_client._record_usage = AsyncMock()
|
|
491
|
+
model = "test-model"
|
|
492
|
+
|
|
493
|
+
wrapper = MinimalLangfuseStreamWrapper(mock_stream, mock_generation, mock_trace, mock_client, model)
|
|
494
|
+
|
|
495
|
+
# Mock stream that raises exception
|
|
496
|
+
mock_stream.__anext__ = AsyncMock(side_effect=Exception("Stream error"))
|
|
497
|
+
mock_stream.get_final_message = AsyncMock(side_effect=Exception("No final message"))
|
|
498
|
+
|
|
499
|
+
# Should propagate exception but still end generation
|
|
500
|
+
with pytest.raises(Exception, match="Stream error"):
|
|
501
|
+
async with wrapper:
|
|
502
|
+
async for chunk in wrapper:
|
|
503
|
+
pass
|
|
504
|
+
|
|
505
|
+
# Generation should still be ended
|
|
506
|
+
mock_generation.end.assert_called_once()
|
|
507
|
+
|
|
508
|
+
@pytest.mark.asyncio
|
|
509
|
+
async def test_stream_wrapper_without_usage_data(self):
|
|
510
|
+
"""Test stream wrapper handles missing usage data gracefully."""
|
|
511
|
+
mock_stream = AsyncMock()
|
|
512
|
+
mock_generation = MagicMock()
|
|
513
|
+
mock_trace = MagicMock()
|
|
514
|
+
mock_client = MagicMock()
|
|
515
|
+
mock_client._record_usage = AsyncMock()
|
|
516
|
+
model = "test-model"
|
|
517
|
+
|
|
518
|
+
wrapper = MinimalLangfuseStreamWrapper(mock_stream, mock_generation, mock_trace, mock_client, model)
|
|
519
|
+
|
|
520
|
+
# Mock final message without usage metadata
|
|
521
|
+
final_message = MagicMock()
|
|
522
|
+
final_message.get_text_content.return_value = "Response"
|
|
523
|
+
final_message.usage_metadata = {}
|
|
524
|
+
mock_stream.get_final_message = AsyncMock(return_value=final_message)
|
|
525
|
+
mock_stream.__anext__ = AsyncMock(side_effect=StopAsyncIteration)
|
|
526
|
+
|
|
527
|
+
async with wrapper:
|
|
528
|
+
pass
|
|
529
|
+
|
|
530
|
+
# Should update without usage data
|
|
531
|
+
mock_generation.update.assert_called_once()
|
|
532
|
+
update_call = mock_generation.update.call_args
|
|
533
|
+
assert update_call[1]["output"] == "Response"
|
|
534
|
+
assert "usage_details" not in update_call[1] # usage_details key should not be present when no usage data
|
|
535
|
+
mock_generation.end.assert_called_once()
|
|
536
|
+
mock_trace.end.assert_called_once()
|