kolega-code 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kolega_code/__init__.py +151 -0
- kolega_code/agent/__init__.py +42 -0
- kolega_code/agent/baseagent.py +998 -0
- kolega_code/agent/browseragent.py +123 -0
- kolega_code/agent/coder.py +157 -0
- kolega_code/agent/common.py +41 -0
- kolega_code/agent/compression.py +81 -0
- kolega_code/agent/context.py +112 -0
- kolega_code/agent/conversation.py +408 -0
- kolega_code/agent/generalagent.py +146 -0
- kolega_code/agent/investigationagent.py +123 -0
- kolega_code/agent/planningagent.py +187 -0
- kolega_code/agent/prompt_provider.py +196 -0
- kolega_code/agent/prompt_templates/agents/browser.j2 +102 -0
- kolega_code/agent/prompt_templates/agents/coder_cli_mode.j2 +127 -0
- kolega_code/agent/prompt_templates/agents/general.j2 +68 -0
- kolega_code/agent/prompt_templates/agents/investigation.j2 +72 -0
- kolega_code/agent/prompt_templates/common/frontend_guidance.md +36 -0
- kolega_code/agent/prompt_templates/common/kolega_md_instructions.md +14 -0
- kolega_code/agent/prompt_templates/environment_variables/workspace_env_vars.md +11 -0
- kolega_code/agent/prompt_templates/template_guidance/expo-template.md +379 -0
- kolega_code/agent/prompt_templates/template_guidance/html-website-template.md +3 -0
- kolega_code/agent/prompt_templates/template_guidance/mern-stack-template.md +3 -0
- kolega_code/agent/prompt_templates/template_guidance/react-vite-shadcdn-template.md +182 -0
- kolega_code/agent/prompts.py +192 -0
- kolega_code/agent/tests/__init__.py +0 -0
- kolega_code/agent/tests/llm/__init__.py +0 -0
- kolega_code/agent/tests/llm/test_anthropic_token_counting.py +633 -0
- kolega_code/agent/tests/llm/test_billing_openai_cache.py +74 -0
- kolega_code/agent/tests/llm/test_client.py +773 -0
- kolega_code/agent/tests/llm/test_dashscope_mapping.py +32 -0
- kolega_code/agent/tests/llm/test_error_boundary.py +322 -0
- kolega_code/agent/tests/llm/test_exceptions.py +249 -0
- kolega_code/agent/tests/llm/test_instrumented_client.py +536 -0
- kolega_code/agent/tests/llm/test_instrumented_client_integration.py +547 -0
- kolega_code/agent/tests/llm/test_langfuse_normalization.py +39 -0
- kolega_code/agent/tests/llm/test_model_specs.py +17 -0
- kolega_code/agent/tests/llm/test_openai_cached_tokens.py +58 -0
- kolega_code/agent/tests/llm/test_openai_cached_tokens_stream.py +74 -0
- kolega_code/agent/tests/llm/test_openai_message_conversion.py +30 -0
- kolega_code/agent/tests/llm/test_openai_token_counting.py +687 -0
- kolega_code/agent/tests/llm/test_tool_execution_ids.py +193 -0
- kolega_code/agent/tests/services/__init__.py +1 -0
- kolega_code/agent/tests/services/test_browser.py +447 -0
- kolega_code/agent/tests/services/test_browser_parity.py +353 -0
- kolega_code/agent/tests/services/test_file_system.py +699 -0
- kolega_code/agent/tests/services/test_sandbox_terminal_input.py +98 -0
- kolega_code/agent/tests/services/test_terminal.py +154 -0
- kolega_code/agent/tests/services/test_terminal_command_tracking.py +385 -0
- kolega_code/agent/tests/services/test_terminal_state_serializer.py +262 -0
- kolega_code/agent/tests/test_agent_tools_inventory.py +267 -0
- kolega_code/agent/tests/test_base_agent.py +1942 -0
- kolega_code/agent/tests/test_coder_attachments.py +330 -0
- kolega_code/agent/tests/test_coder_prompt_extensions.py +61 -0
- kolega_code/agent/tests/test_commands.py +179 -0
- kolega_code/agent/tests/test_duplicate_tool_results.py +556 -0
- kolega_code/agent/tests/test_empty_message_handling.py +48 -0
- kolega_code/agent/tests/test_general_agent.py +242 -0
- kolega_code/agent/tests/test_html.py +320 -0
- kolega_code/agent/tests/test_parallel_tool_calls.py +291 -0
- kolega_code/agent/tests/test_planning_agent.py +227 -0
- kolega_code/agent/tests/test_prompt_provider.py +271 -0
- kolega_code/agent/tests/test_tool_registry.py +102 -0
- kolega_code/agent/tests/test_tools.py +549 -0
- kolega_code/agent/tests/tool_backend/__init__.py +0 -0
- kolega_code/agent/tests/tool_backend/test_agent_tool.py +356 -0
- kolega_code/agent/tests/tool_backend/test_base_tool.py +147 -0
- kolega_code/agent/tests/tool_backend/test_browser_tool.py +335 -0
- kolega_code/agent/tests/tool_backend/test_build_tool.py +93 -0
- kolega_code/agent/tests/tool_backend/test_create_file_tool.py +115 -0
- kolega_code/agent/tests/tool_backend/test_glob_tool.py +196 -0
- kolega_code/agent/tests/tool_backend/test_glob_tool_sandbox_parity.py +230 -0
- kolega_code/agent/tests/tool_backend/test_list_directory_tool.py +292 -0
- kolega_code/agent/tests/tool_backend/test_read_file_tool.py +173 -0
- kolega_code/agent/tests/tool_backend/test_replace_entire_file_tool.py +115 -0
- kolega_code/agent/tests/tool_backend/test_replace_lines_tool.py +141 -0
- kolega_code/agent/tests/tool_backend/test_search_and_replace_tool.py +174 -0
- kolega_code/agent/tests/tool_backend/test_search_codebase_tool.py +228 -0
- kolega_code/agent/tests/tool_backend/test_terminal_tool.py +482 -0
- kolega_code/agent/tests/tool_backend/test_think_hard_integration.py +189 -0
- kolega_code/agent/tests/tool_backend/test_think_hard_streaming.py +445 -0
- kolega_code/agent/tests/tool_backend/test_web_fetch_tool.py +194 -0
- kolega_code/agent/tool_backend/agent_tool.py +414 -0
- kolega_code/agent/tool_backend/apply_edit_tool.py +98 -0
- kolega_code/agent/tool_backend/apply_patch_tool.py +514 -0
- kolega_code/agent/tool_backend/base_tool.py +217 -0
- kolega_code/agent/tool_backend/browser_tool.py +271 -0
- kolega_code/agent/tool_backend/build_tool.py +93 -0
- kolega_code/agent/tool_backend/create_file_tool.py +52 -0
- kolega_code/agent/tool_backend/glob_tool.py +323 -0
- kolega_code/agent/tool_backend/list_directory_tool.py +300 -0
- kolega_code/agent/tool_backend/memory_tool.py +79 -0
- kolega_code/agent/tool_backend/read_file_tool.py +119 -0
- kolega_code/agent/tool_backend/replace_entire_file_tool.py +40 -0
- kolega_code/agent/tool_backend/replace_lines_tool.py +97 -0
- kolega_code/agent/tool_backend/search_and_replace_tool.py +146 -0
- kolega_code/agent/tool_backend/search_codebase_tool.py +377 -0
- kolega_code/agent/tool_backend/streaming_tool.py +47 -0
- kolega_code/agent/tool_backend/terminal_tool.py +643 -0
- kolega_code/agent/tool_backend/think_hard_tool.py +211 -0
- kolega_code/agent/tool_backend/web_fetch_tool.py +205 -0
- kolega_code/agent/tools.py +1704 -0
- kolega_code/agent/utils/commands.py +94 -0
- kolega_code/cli/__init__.py +1 -0
- kolega_code/cli/app.py +2756 -0
- kolega_code/cli/config.py +280 -0
- kolega_code/cli/connection.py +49 -0
- kolega_code/cli/file_index.py +147 -0
- kolega_code/cli/main.py +564 -0
- kolega_code/cli/mentions.py +155 -0
- kolega_code/cli/messages.py +89 -0
- kolega_code/cli/provider_registry.py +96 -0
- kolega_code/cli/session_store.py +207 -0
- kolega_code/cli/settings.py +87 -0
- kolega_code/cli/skills.py +409 -0
- kolega_code/cli/slash_commands.py +108 -0
- kolega_code/cli/tests/__init__.py +1 -0
- kolega_code/cli/tests/test_app.py +4251 -0
- kolega_code/cli/tests/test_cli_config.py +171 -0
- kolega_code/cli/tests/test_connection.py +26 -0
- kolega_code/cli/tests/test_file_index.py +103 -0
- kolega_code/cli/tests/test_main.py +455 -0
- kolega_code/cli/tests/test_mentions.py +108 -0
- kolega_code/cli/tests/test_session_store.py +67 -0
- kolega_code/cli/tests/test_settings.py +62 -0
- kolega_code/cli/tests/test_skills.py +157 -0
- kolega_code/cli/tests/test_slash_commands.py +88 -0
- kolega_code/cli/theme.py +180 -0
- kolega_code/config.py +154 -0
- kolega_code/events.py +202 -0
- kolega_code/llm/client.py +300 -0
- kolega_code/llm/exceptions.py +285 -0
- kolega_code/llm/instrumented_client.py +520 -0
- kolega_code/llm/models.py +1368 -0
- kolega_code/llm/providers/__init__.py +0 -0
- kolega_code/llm/providers/anthropic.py +387 -0
- kolega_code/llm/providers/base.py +71 -0
- kolega_code/llm/providers/google.py +157 -0
- kolega_code/llm/providers/models.py +37 -0
- kolega_code/llm/providers/openai.py +363 -0
- kolega_code/llm/ratelimit.py +40 -0
- kolega_code/llm/specs.py +67 -0
- kolega_code/llm/tool_execution_ids.py +18 -0
- kolega_code/models/__init__.py +9 -0
- kolega_code/models/sandbox_terminal_state.py +47 -0
- kolega_code/runtime.py +50 -0
- kolega_code/sandbox/README.md +200 -0
- kolega_code/sandbox/__init__.py +21 -0
- kolega_code/sandbox/async_filesystem.py +475 -0
- kolega_code/sandbox/base.py +297 -0
- kolega_code/sandbox/browser.py +25 -0
- kolega_code/sandbox/event_loop.py +43 -0
- kolega_code/sandbox/filesystem.py +341 -0
- kolega_code/sandbox/local.py +118 -0
- kolega_code/sandbox/serializer.py +175 -0
- kolega_code/sandbox/terminal.py +868 -0
- kolega_code/sandbox/utils.py +216 -0
- kolega_code/services/base.py +255 -0
- kolega_code/services/browser.py +444 -0
- kolega_code/services/file_system.py +749 -0
- kolega_code/services/html.py +221 -0
- kolega_code/services/terminal.py +903 -0
- kolega_code/tools/__init__.py +22 -0
- kolega_code/tools/core.py +33 -0
- kolega_code/tools/definitions.py +81 -0
- kolega_code/tools/registry.py +73 -0
- kolega_code-0.1.0.dist-info/METADATA +157 -0
- kolega_code-0.1.0.dist-info/RECORD +171 -0
- kolega_code-0.1.0.dist-info/WHEEL +4 -0
- kolega_code-0.1.0.dist-info/entry_points.txt +2 -0
- kolega_code-0.1.0.dist-info/licenses/LICENSE +21 -0
|
File without changes
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any, AsyncContextManager, Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
import tiktoken
|
|
6
|
+
from anthropic import Anthropic, AsyncAnthropic
|
|
7
|
+
|
|
8
|
+
from ..models import Message, MessageChunk, MessageHistory, ToolDefinition
|
|
9
|
+
from ..specs import get_model_specs
|
|
10
|
+
from ..tool_execution_ids import ToolExecutionIdRegistry
|
|
11
|
+
from .base import BaseLLMProvider
|
|
12
|
+
from .models import GenerationParams, ReasoningEffort, ThinkingConfig, TokenCount
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AnthropicStreamWrapper:
|
|
16
|
+
def __init__(self, anthropic_stream, provider_name: str = "anthropic"):
|
|
17
|
+
self.anthropic_stream = anthropic_stream
|
|
18
|
+
self.provider_name = provider_name
|
|
19
|
+
self.generator = None
|
|
20
|
+
self._closed = False
|
|
21
|
+
|
|
22
|
+
# Track tool calls being streamed
|
|
23
|
+
self.tool_execution_ids = ToolExecutionIdRegistry()
|
|
24
|
+
self.current_tool_calls = {} # Maps tool_call_id to accumulated data
|
|
25
|
+
self.tool_call_order = [] # Track order of tool calls
|
|
26
|
+
self.current_block_index = None # Track which content block we're processing
|
|
27
|
+
|
|
28
|
+
async def __aenter__(self):
|
|
29
|
+
self.generator = await self.anthropic_stream.__aenter__()
|
|
30
|
+
return self
|
|
31
|
+
|
|
32
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
33
|
+
return await self.anthropic_stream.__aexit__(exc_type, exc_val, exc_tb)
|
|
34
|
+
|
|
35
|
+
def __aiter__(self):
|
|
36
|
+
if self.generator is None:
|
|
37
|
+
raise RuntimeError("Must use 'async with' before iterating")
|
|
38
|
+
return self
|
|
39
|
+
|
|
40
|
+
async def __anext__(self):
|
|
41
|
+
if self.generator is None:
|
|
42
|
+
raise RuntimeError("Must use 'async with' before iterating")
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
chunk = await self.generator.__anext__()
|
|
46
|
+
|
|
47
|
+
# Handle content_block_start events for tool use
|
|
48
|
+
if chunk.type == "content_block_start" and hasattr(chunk, "content_block"):
|
|
49
|
+
if chunk.content_block.type == "tool_use":
|
|
50
|
+
# Track this new tool call
|
|
51
|
+
tool_id = chunk.content_block.id
|
|
52
|
+
self.current_tool_calls[tool_id] = {
|
|
53
|
+
"id": tool_id,
|
|
54
|
+
"name": chunk.content_block.name,
|
|
55
|
+
"input_json": "",
|
|
56
|
+
"block_index": chunk.index if hasattr(chunk, "index") else len(self.tool_call_order),
|
|
57
|
+
"execution_id": self.tool_execution_ids.get_or_create(tool_id),
|
|
58
|
+
}
|
|
59
|
+
self.tool_call_order.append(tool_id)
|
|
60
|
+
self.current_block_index = chunk.index if hasattr(chunk, "index") else None
|
|
61
|
+
|
|
62
|
+
# Handle content_block_delta events for tool use input
|
|
63
|
+
elif chunk.type == "content_block_delta" and hasattr(chunk, "delta"):
|
|
64
|
+
if chunk.delta.type == "input_json_delta" and hasattr(chunk, "index"):
|
|
65
|
+
# Find the tool call by block index
|
|
66
|
+
for tool_id, tool_data in self.current_tool_calls.items():
|
|
67
|
+
if tool_data["block_index"] == chunk.index:
|
|
68
|
+
# Accumulate the JSON input
|
|
69
|
+
tool_data["input_json"] += chunk.delta.partial_json
|
|
70
|
+
break
|
|
71
|
+
|
|
72
|
+
message_chunk = MessageChunk.from_anthropic(chunk)
|
|
73
|
+
if message_chunk.type == "tool_use_start" and message_chunk.tool_call_delta:
|
|
74
|
+
tool_id = message_chunk.tool_call_delta.get("id")
|
|
75
|
+
tool_data = self.current_tool_calls.get(tool_id)
|
|
76
|
+
if tool_data and tool_data.get("execution_id"):
|
|
77
|
+
message_chunk.tool_call_delta["execution_id"] = tool_data["execution_id"]
|
|
78
|
+
|
|
79
|
+
return message_chunk
|
|
80
|
+
|
|
81
|
+
except StopAsyncIteration:
|
|
82
|
+
raise
|
|
83
|
+
|
|
84
|
+
async def get_final_message(self):
|
|
85
|
+
message = Message.from_anthropic(
|
|
86
|
+
await self.generator.get_final_message(),
|
|
87
|
+
tool_execution_ids=self.tool_execution_ids,
|
|
88
|
+
)
|
|
89
|
+
if message.usage_metadata:
|
|
90
|
+
message.usage_metadata["provider"] = self.provider_name
|
|
91
|
+
return message
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class AnthropicProvider(BaseLLMProvider):
|
|
95
|
+
SYSTEM_OVERHEAD = 4
|
|
96
|
+
MESSAGE_OVERHEAD = 3
|
|
97
|
+
TOOL_DEFINITION_OVERHEAD = 65
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
api_key: str,
|
|
102
|
+
max_retries: int = 3,
|
|
103
|
+
requests_per_minute: Optional[int] = None,
|
|
104
|
+
tokens_per_minute: Optional[int] = None,
|
|
105
|
+
base_url: Optional[str] = None,
|
|
106
|
+
provider_name: str = "anthropic",
|
|
107
|
+
):
|
|
108
|
+
super().__init__(api_key, max_retries, requests_per_minute, tokens_per_minute, base_url)
|
|
109
|
+
self.provider_name = provider_name
|
|
110
|
+
self.async_client = AsyncAnthropic(api_key=api_key, base_url=base_url)
|
|
111
|
+
self.sync_client = Anthropic(api_key=api_key, base_url=base_url)
|
|
112
|
+
|
|
113
|
+
# OpenAI-compatible Anthropic-shaped APIs do not expose messages/count_tokens,
|
|
114
|
+
# so local counting is only a preflight context-size estimate for those models.
|
|
115
|
+
# Billing/accounting must use provider response usage metadata instead.
|
|
116
|
+
self.use_local_token_counting = (
|
|
117
|
+
provider_name in {"moonshot", "deepseek"}
|
|
118
|
+
or os.getenv('ANTHROPIC_USE_LOCAL_TOKEN_COUNTING', 'false').lower() == 'true'
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def retry_decorator(self):
|
|
123
|
+
"""Get retry decorator with configured max retries"""
|
|
124
|
+
return self.get_retry_decorator()
|
|
125
|
+
|
|
126
|
+
def _prepare_thinking_params(self, thinking: Optional[Union[ThinkingConfig, ReasoningEffort]]) -> Dict[str, Any]:
|
|
127
|
+
"""Convert thinking parameters to provider-specific format"""
|
|
128
|
+
return {"type": "enabled", "budget_tokens": thinking.budget_tokens}
|
|
129
|
+
|
|
130
|
+
def _prepare_generation_params(self, params: Optional[GenerationParams] = None) -> Dict[str, Any]:
|
|
131
|
+
"""Convert common parameters to provider-specific format"""
|
|
132
|
+
generation_params = {
|
|
133
|
+
"model": "claude-opus-4-7", # Default model
|
|
134
|
+
"max_tokens": 1024, # Default max tokens
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
if params:
|
|
138
|
+
if params.temperature is not None:
|
|
139
|
+
generation_params["temperature"] = params.temperature
|
|
140
|
+
if params.max_completion_tokens is not None:
|
|
141
|
+
generation_params["max_tokens"] = params.max_completion_tokens
|
|
142
|
+
if params.tools:
|
|
143
|
+
generation_params["tools"] = [t.to_anthropic() for t in params.tools]
|
|
144
|
+
generation_params["tool_choice"] = {"type": "auto"}
|
|
145
|
+
if params.thinking:
|
|
146
|
+
generation_params["thinking"] = self._prepare_thinking_params(params.thinking)
|
|
147
|
+
|
|
148
|
+
return generation_params
|
|
149
|
+
|
|
150
|
+
def _sanitize_generation_params(self, generation_params: Dict[str, Any]) -> Dict[str, Any]:
|
|
151
|
+
"""Remove parameters unsupported by the selected Anthropic model."""
|
|
152
|
+
model = generation_params.get("model")
|
|
153
|
+
if self.provider_name != "anthropic" or not isinstance(model, str):
|
|
154
|
+
return generation_params
|
|
155
|
+
|
|
156
|
+
try:
|
|
157
|
+
model_specs = get_model_specs(self.provider_name, model)
|
|
158
|
+
except ValueError:
|
|
159
|
+
return generation_params
|
|
160
|
+
|
|
161
|
+
if model_specs.get("supports_temperature", True) is False:
|
|
162
|
+
generation_params.pop("temperature", None)
|
|
163
|
+
|
|
164
|
+
return generation_params
|
|
165
|
+
|
|
166
|
+
async def count_tokens(
|
|
167
|
+
self,
|
|
168
|
+
messages: MessageHistory,
|
|
169
|
+
system: Optional[Message] = None,
|
|
170
|
+
model: Optional[str] = None,
|
|
171
|
+
tools: List[ToolDefinition] = None,
|
|
172
|
+
**kwargs,
|
|
173
|
+
) -> TokenCount:
|
|
174
|
+
tools = tools or []
|
|
175
|
+
if self.use_local_token_counting:
|
|
176
|
+
# Use local tiktoken-based counting (no API call). This is an
|
|
177
|
+
# estimate for context management, not authoritative billing usage.
|
|
178
|
+
return self._count_tokens_local(messages, system, model, tools)
|
|
179
|
+
else:
|
|
180
|
+
# Use Anthropic API for token counting
|
|
181
|
+
await self.rate_limiter.acquire()
|
|
182
|
+
count = await self.async_client.messages.count_tokens(
|
|
183
|
+
messages=messages.to_anthropic(),
|
|
184
|
+
system=[c.to_anthropic() for c in system.content],
|
|
185
|
+
model=model,
|
|
186
|
+
tools=[t.to_anthropic() for t in tools],
|
|
187
|
+
**kwargs,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# The API now only returns input_tokens
|
|
191
|
+
return TokenCount(
|
|
192
|
+
input_tokens=count.input_tokens,
|
|
193
|
+
output_tokens=None, # MessageTokensCount no longer includes output_tokens
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
def _count_tokens_local(
|
|
197
|
+
self,
|
|
198
|
+
messages: MessageHistory,
|
|
199
|
+
system: Optional[Message] = None,
|
|
200
|
+
model: Optional[str] = None,
|
|
201
|
+
tools: List[ToolDefinition] = None,
|
|
202
|
+
) -> TokenCount:
|
|
203
|
+
"""Count tokens locally using tiktoken with p50k_base encoding.
|
|
204
|
+
|
|
205
|
+
This provides a fast approximation without making an API call.
|
|
206
|
+
Uses minimal overhead and direct text encoding for better accuracy.
|
|
207
|
+
Handles images by estimating token cost based on data size.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
messages: Message history to count tokens for
|
|
211
|
+
system: Optional system message
|
|
212
|
+
model: Optional model name (not used for local counting)
|
|
213
|
+
tools: Optional tool definitions
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
TokenCount object with estimated input token count
|
|
217
|
+
"""
|
|
218
|
+
encoding = tiktoken.get_encoding("p50k_base")
|
|
219
|
+
num_tokens = 0
|
|
220
|
+
tools = tools or []
|
|
221
|
+
|
|
222
|
+
if system:
|
|
223
|
+
num_tokens += self.SYSTEM_OVERHEAD
|
|
224
|
+
num_tokens += self._count_message_content_tokens(encoding, system.content)
|
|
225
|
+
|
|
226
|
+
for message in messages:
|
|
227
|
+
num_tokens += self.MESSAGE_OVERHEAD
|
|
228
|
+
num_tokens += self._count_message_content_tokens(encoding, message.content)
|
|
229
|
+
|
|
230
|
+
for tool in tools:
|
|
231
|
+
num_tokens += self._count_value_tokens(encoding, tool.to_anthropic())
|
|
232
|
+
num_tokens += self.TOOL_DEFINITION_OVERHEAD
|
|
233
|
+
|
|
234
|
+
return TokenCount(input_tokens=num_tokens, output_tokens=None)
|
|
235
|
+
|
|
236
|
+
def _count_message_content_tokens(self, encoding, content: Any) -> int:
|
|
237
|
+
if isinstance(content, str):
|
|
238
|
+
return len(encoding.encode(content))
|
|
239
|
+
|
|
240
|
+
if isinstance(content, list):
|
|
241
|
+
return sum(self._count_content_block_tokens(encoding, block) for block in content)
|
|
242
|
+
|
|
243
|
+
return self._count_value_tokens(encoding, content)
|
|
244
|
+
|
|
245
|
+
def _count_content_block_tokens(self, encoding, block: Any) -> int:
|
|
246
|
+
if hasattr(block, "text"):
|
|
247
|
+
return len(encoding.encode(block.text))
|
|
248
|
+
|
|
249
|
+
if getattr(block, "type", None) == "image_url":
|
|
250
|
+
data = getattr(block, "data", None)
|
|
251
|
+
if isinstance(data, str):
|
|
252
|
+
return self._estimate_image_tokens(len(data))
|
|
253
|
+
|
|
254
|
+
if getattr(block, "type", None) == "tool_result":
|
|
255
|
+
content = getattr(block, "content", "")
|
|
256
|
+
return self._count_message_content_tokens(encoding, content)
|
|
257
|
+
|
|
258
|
+
if hasattr(block, "thinking"):
|
|
259
|
+
return len(encoding.encode(block.thinking))
|
|
260
|
+
|
|
261
|
+
if hasattr(block, "data"):
|
|
262
|
+
return len(encoding.encode(str(block.data)))
|
|
263
|
+
|
|
264
|
+
if hasattr(block, "to_anthropic"):
|
|
265
|
+
return self._count_value_tokens(encoding, block.to_anthropic())
|
|
266
|
+
|
|
267
|
+
return self._count_value_tokens(encoding, block)
|
|
268
|
+
|
|
269
|
+
def _count_value_tokens(self, encoding, value: Any) -> int:
|
|
270
|
+
if value is None:
|
|
271
|
+
return 0
|
|
272
|
+
|
|
273
|
+
if isinstance(value, str):
|
|
274
|
+
return len(encoding.encode(value))
|
|
275
|
+
|
|
276
|
+
if isinstance(value, (int, float, bool)):
|
|
277
|
+
return len(encoding.encode(str(value)))
|
|
278
|
+
|
|
279
|
+
if isinstance(value, list):
|
|
280
|
+
return 2 + sum(self._count_value_tokens(encoding, item) for item in value)
|
|
281
|
+
|
|
282
|
+
if isinstance(value, dict):
|
|
283
|
+
if value.get("type") == "image":
|
|
284
|
+
source = value.get("source") or {}
|
|
285
|
+
data = source.get("data")
|
|
286
|
+
if isinstance(data, str):
|
|
287
|
+
return self._estimate_image_tokens(len(data))
|
|
288
|
+
|
|
289
|
+
total = 2
|
|
290
|
+
for key, item in value.items():
|
|
291
|
+
total += len(encoding.encode(str(key)))
|
|
292
|
+
total += self._count_value_tokens(encoding, item)
|
|
293
|
+
return total
|
|
294
|
+
|
|
295
|
+
return len(encoding.encode(json.dumps(value, ensure_ascii=False, default=str)))
|
|
296
|
+
|
|
297
|
+
def _estimate_image_tokens(self, base64_data_length: int) -> int:
|
|
298
|
+
"""Estimate image token cost based on base64 data length.
|
|
299
|
+
|
|
300
|
+
Anthropic charges for images based on their dimensions after resizing.
|
|
301
|
+
Since we don't decode images (performance), we estimate based on data size.
|
|
302
|
+
|
|
303
|
+
Empirically observed from tests:
|
|
304
|
+
- Tiny images (96 chars base64, 1x1 px): ~25 tokens
|
|
305
|
+
- Small images (~50-200KB base64): ~200-800 tokens
|
|
306
|
+
- Medium images (~200-800KB base64): ~800-2000 tokens
|
|
307
|
+
- Large images (~800KB+ base64): ~2000-4000 tokens
|
|
308
|
+
|
|
309
|
+
Formula uses square root scaling for better approximation across sizes:
|
|
310
|
+
tokens ≈ 20 + sqrt(base64_length * 6)
|
|
311
|
+
|
|
312
|
+
This gives:
|
|
313
|
+
- 96 chars → 20 + sqrt(576) = 44 tokens (~25 actual)
|
|
314
|
+
- 50KB (68K chars) → 20 + sqrt(408K) = 659 tokens
|
|
315
|
+
- 200KB (273K chars) → 20 + sqrt(1.6M) = 1285 tokens
|
|
316
|
+
- 800KB (1.1M chars) → 20 + sqrt(6.4M) = 2549 tokens
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
base64_data_length: Length of base64 encoded image data
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
Estimated token count for the image
|
|
323
|
+
"""
|
|
324
|
+
import math
|
|
325
|
+
|
|
326
|
+
# Use square root scaling for better fit across image sizes
|
|
327
|
+
# Base cost of 20 tokens + sqrt scaling
|
|
328
|
+
estimated_tokens = 20 + int(math.sqrt(base64_data_length * 6))
|
|
329
|
+
|
|
330
|
+
return estimated_tokens
|
|
331
|
+
|
|
332
|
+
async def stream(
|
|
333
|
+
self,
|
|
334
|
+
messages: MessageHistory,
|
|
335
|
+
system: Optional[Message] = None,
|
|
336
|
+
params: Optional[GenerationParams] = None,
|
|
337
|
+
**kwargs,
|
|
338
|
+
) -> AsyncContextManager:
|
|
339
|
+
"""Generate a streaming response from Anthropic
|
|
340
|
+
|
|
341
|
+
Returns a context manager that provides an async iterator when entered.
|
|
342
|
+
The context manager also provides get_final_message() to retrieve the
|
|
343
|
+
complete message after streaming.
|
|
344
|
+
"""
|
|
345
|
+
generation_params = self._prepare_generation_params(params)
|
|
346
|
+
generation_params.update(kwargs)
|
|
347
|
+
generation_params = self._sanitize_generation_params(generation_params)
|
|
348
|
+
|
|
349
|
+
if generation_params["model"].startswith("claude-3-7"):
|
|
350
|
+
generation_params["extra_headers"] = {"anthropic-beta": "output-128k-2025-02-19"}
|
|
351
|
+
|
|
352
|
+
await self.rate_limiter.acquire()
|
|
353
|
+
|
|
354
|
+
# Return the stream context manager
|
|
355
|
+
return AnthropicStreamWrapper(
|
|
356
|
+
self.async_client.messages.stream(
|
|
357
|
+
messages=messages.to_anthropic(),
|
|
358
|
+
system=[c.to_anthropic() for c in system.content],
|
|
359
|
+
**generation_params,
|
|
360
|
+
),
|
|
361
|
+
provider_name=self.provider_name,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
async def generate(
|
|
365
|
+
self,
|
|
366
|
+
messages: MessageHistory,
|
|
367
|
+
system: Optional[Message] = None,
|
|
368
|
+
params: Optional[GenerationParams] = None,
|
|
369
|
+
**kwargs,
|
|
370
|
+
) -> Message:
|
|
371
|
+
generation_params = self._prepare_generation_params(params)
|
|
372
|
+
generation_params.update(kwargs)
|
|
373
|
+
generation_params = self._sanitize_generation_params(generation_params)
|
|
374
|
+
|
|
375
|
+
if generation_params["model"].startswith("claude-3-7"):
|
|
376
|
+
generation_params["extra_headers"] = {"anthropic-beta": "output-128k-2025-02-19"}
|
|
377
|
+
|
|
378
|
+
await self.rate_limiter.acquire()
|
|
379
|
+
response = await self.async_client.messages.create(
|
|
380
|
+
messages=messages.to_anthropic(),
|
|
381
|
+
system=[c.to_anthropic() for c in system.content],
|
|
382
|
+
**generation_params,
|
|
383
|
+
)
|
|
384
|
+
message = Message.from_anthropic(response)
|
|
385
|
+
if message.usage_metadata:
|
|
386
|
+
message.usage_metadata["provider"] = self.provider_name
|
|
387
|
+
return message
|
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any, AsyncContextManager, Dict, List, Optional, Union
|
|
3
|
+
|
|
4
|
+
from anthropic import APIError as AnthropicAPIError
|
|
5
|
+
from google.genai.errors import APIError as GeminiAPIError
|
|
6
|
+
from openai import APIError as OpenAIAPIError
|
|
7
|
+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
|
8
|
+
|
|
9
|
+
from ..models import Message, MessageHistory, ToolDefinition
|
|
10
|
+
from ..ratelimit import RateLimiter
|
|
11
|
+
from .models import GenerationParams, ReasoningEffort, ThinkingConfig, TokenCount
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BaseLLMProvider(ABC):
|
|
15
|
+
"""Abstract base class defining the interface for LLM providers"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
api_key: str,
|
|
20
|
+
max_retries: int = 3,
|
|
21
|
+
requests_per_minute: Optional[int] = None,
|
|
22
|
+
tokens_per_minute: Optional[int] = None,
|
|
23
|
+
base_url: Optional[str] = None,
|
|
24
|
+
):
|
|
25
|
+
self.api_key = api_key
|
|
26
|
+
self.max_retries = max_retries
|
|
27
|
+
self.rate_limiter = RateLimiter(requests_per_minute, tokens_per_minute)
|
|
28
|
+
self.base_url = base_url
|
|
29
|
+
|
|
30
|
+
@abstractmethod
|
|
31
|
+
async def count_tokens(
|
|
32
|
+
self,
|
|
33
|
+
messages: MessageHistory,
|
|
34
|
+
system: Message = None,
|
|
35
|
+
model: Optional[str] = None,
|
|
36
|
+
tools: List[ToolDefinition] = None,
|
|
37
|
+
**kwargs,
|
|
38
|
+
) -> TokenCount:
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
def stream(
|
|
43
|
+
self,
|
|
44
|
+
messages: MessageHistory,
|
|
45
|
+
system: Optional[Message] = None,
|
|
46
|
+
params: Optional[GenerationParams] = None,
|
|
47
|
+
**kwargs,
|
|
48
|
+
) -> AsyncContextManager:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
@abstractmethod
|
|
52
|
+
async def generate(
|
|
53
|
+
self, messages: MessageHistory, system: Message = None, params: Optional[GenerationParams] = None, **kwargs
|
|
54
|
+
) -> Message:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
def _prepare_generation_params(self, params: Optional[GenerationParams] = None) -> Dict[str, Any]:
|
|
58
|
+
"""Convert common parameters to provider-specific format"""
|
|
59
|
+
return {}
|
|
60
|
+
|
|
61
|
+
def _prepare_thinking_params(self, thinking: Optional[Union[ThinkingConfig, ReasoningEffort]]) -> Dict[str, Any]:
|
|
62
|
+
"""Convert thinking parameters to provider-specific format"""
|
|
63
|
+
return {}
|
|
64
|
+
|
|
65
|
+
def get_retry_decorator(self):
|
|
66
|
+
"""Get retry decorator with exponential backoff"""
|
|
67
|
+
return retry(
|
|
68
|
+
stop=stop_after_attempt(self.max_retries),
|
|
69
|
+
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
70
|
+
retry=retry_if_exception_type((AnthropicAPIError, OpenAIAPIError, GeminiAPIError)),
|
|
71
|
+
)
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
from typing import AsyncContextManager, List, Optional
|
|
2
|
+
|
|
3
|
+
from google.genai import Client as genai_client
|
|
4
|
+
from google.genai import types as genai_types
|
|
5
|
+
|
|
6
|
+
from ..models import Message, MessageChunk, MessageHistory, ToolDefinition
|
|
7
|
+
from ..tool_execution_ids import ToolExecutionIdRegistry
|
|
8
|
+
from .base import BaseLLMProvider
|
|
9
|
+
from .models import GenerationParams, TokenCount
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GoogleStreamWrapper:
|
|
13
|
+
def __init__(self, gemini_stream):
|
|
14
|
+
self.gemini_stream = gemini_stream
|
|
15
|
+
self.final_content = ""
|
|
16
|
+
self.final_tool_calls = {}
|
|
17
|
+
self.stop_reason = None
|
|
18
|
+
self.tool_execution_ids = ToolExecutionIdRegistry()
|
|
19
|
+
|
|
20
|
+
self._closed = False
|
|
21
|
+
|
|
22
|
+
async def __aenter__(self):
|
|
23
|
+
return self
|
|
24
|
+
|
|
25
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
26
|
+
if hasattr(self.gemini_stream, "aclose"):
|
|
27
|
+
await self.gemini_stream.aclose()
|
|
28
|
+
|
|
29
|
+
self._closed = True
|
|
30
|
+
return False
|
|
31
|
+
|
|
32
|
+
def __aiter__(self):
|
|
33
|
+
return self
|
|
34
|
+
|
|
35
|
+
async def __anext__(self):
|
|
36
|
+
if self._closed:
|
|
37
|
+
raise StopAsyncIteration
|
|
38
|
+
|
|
39
|
+
try:
|
|
40
|
+
chunk = await self.gemini_stream.__anext__()
|
|
41
|
+
|
|
42
|
+
content = chunk.text or ""
|
|
43
|
+
self.final_content += content
|
|
44
|
+
|
|
45
|
+
for idx, function_call in enumerate(chunk.function_calls or []):
|
|
46
|
+
self.final_tool_calls[idx] = function_call
|
|
47
|
+
|
|
48
|
+
# self.final_tool_calls[function_call_id].function.arguments += tool_call.function.arguments
|
|
49
|
+
|
|
50
|
+
self.stop_reason = chunk.candidates[0].finish_reason.value if chunk.candidates[0].finish_reason else None
|
|
51
|
+
|
|
52
|
+
return MessageChunk.from_google(chunk)
|
|
53
|
+
|
|
54
|
+
except StopAsyncIteration:
|
|
55
|
+
raise
|
|
56
|
+
|
|
57
|
+
async def get_final_message(self):
|
|
58
|
+
return Message.from_google_stream(
|
|
59
|
+
role="assistant",
|
|
60
|
+
content=self.final_content,
|
|
61
|
+
tool_calls=self.final_tool_calls,
|
|
62
|
+
stop_reason=self.stop_reason,
|
|
63
|
+
tool_execution_ids=self.tool_execution_ids,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class GoogleProvider(BaseLLMProvider):
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
api_key: str,
|
|
71
|
+
max_retries: int = 3,
|
|
72
|
+
requests_per_minute: Optional[int] = None,
|
|
73
|
+
tokens_per_minute: Optional[int] = None,
|
|
74
|
+
base_url: Optional[str] = None,
|
|
75
|
+
):
|
|
76
|
+
super().__init__(api_key, max_retries, requests_per_minute, tokens_per_minute, base_url)
|
|
77
|
+
self.async_client = genai_client(api_key=api_key)
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def retry_decorator(self):
|
|
81
|
+
"""Get retry decorator with configured max retries"""
|
|
82
|
+
return self.get_retry_decorator()
|
|
83
|
+
|
|
84
|
+
async def count_tokens(
|
|
85
|
+
self,
|
|
86
|
+
messages: MessageHistory,
|
|
87
|
+
system: Optional[Message] = None,
|
|
88
|
+
model: Optional[str] = None,
|
|
89
|
+
tools: List[ToolDefinition] = None,
|
|
90
|
+
**kwargs,
|
|
91
|
+
) -> TokenCount:
|
|
92
|
+
"""Count tokens for a list of messages using tiktoken
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
messages: List of messages to count tokens for
|
|
96
|
+
system: Optional system message
|
|
97
|
+
model: Optional model name to use for counting (defaults to gpt-4)
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
TokenCount object with input token count
|
|
101
|
+
"""
|
|
102
|
+
count = await self.async_client.aio.models.count_tokens(
|
|
103
|
+
model=model,
|
|
104
|
+
contents=messages.to_google(),
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return TokenCount(input_tokens=count.total_tokens, output_tokens=None)
|
|
108
|
+
|
|
109
|
+
async def stream(
|
|
110
|
+
self,
|
|
111
|
+
messages: MessageHistory,
|
|
112
|
+
system: Optional[Message] = None,
|
|
113
|
+
params: Optional[GenerationParams] = None,
|
|
114
|
+
**kwargs,
|
|
115
|
+
) -> AsyncContextManager:
|
|
116
|
+
"""Generate a streaming response from Google
|
|
117
|
+
|
|
118
|
+
Returns a coroutine that resolves to an async iterator.
|
|
119
|
+
"""
|
|
120
|
+
config = genai_types.GenerateContentConfig(
|
|
121
|
+
system_instruction=system.content[0].text,
|
|
122
|
+
temperature=params.temperature,
|
|
123
|
+
max_output_tokens=params.max_completion_tokens,
|
|
124
|
+
tools=[t.to_google() for t in params.tools] if params.tools else None,
|
|
125
|
+
thinking_config=params.thinking,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
await self.rate_limiter.acquire()
|
|
129
|
+
|
|
130
|
+
return GoogleStreamWrapper(
|
|
131
|
+
await self.async_client.aio.models.generate_content_stream(
|
|
132
|
+
model=kwargs["model"], contents=messages.to_google(), config=config
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
async def generate(
|
|
137
|
+
self,
|
|
138
|
+
messages: MessageHistory,
|
|
139
|
+
system: Optional[Message] = None,
|
|
140
|
+
params: Optional[GenerationParams] = None,
|
|
141
|
+
**kwargs,
|
|
142
|
+
) -> Message:
|
|
143
|
+
config = genai_types.GenerateContentConfig(
|
|
144
|
+
system_instruction=system.content[0].text,
|
|
145
|
+
temperature=params.temperature,
|
|
146
|
+
max_output_tokens=params.max_completion_tokens,
|
|
147
|
+
tools=[t.to_google() for t in params.tools] if params.tools else None,
|
|
148
|
+
thinking_config=params.thinking,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
await self.rate_limiter.acquire()
|
|
152
|
+
|
|
153
|
+
response = await self.async_client.aio.models.generate_content(
|
|
154
|
+
model=kwargs["model"], contents=messages.to_google(), config=config
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
return Message.from_google(response)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Any, Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ReasoningEffort(Enum):
|
|
7
|
+
LOW = "low"
|
|
8
|
+
MEDIUM = "medium"
|
|
9
|
+
HIGH = "high"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class ThinkingConfig:
|
|
14
|
+
"""Configuration for model's thinking depth"""
|
|
15
|
+
|
|
16
|
+
budget_tokens: int = 4096
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class GeminiThinkingConfig:
|
|
21
|
+
include_thoughts: bool = True
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class TokenCount:
|
|
26
|
+
input_tokens: int
|
|
27
|
+
output_tokens: Optional[int] = None
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class GenerationParams:
|
|
32
|
+
"""Common parameters for text generation across providers"""
|
|
33
|
+
|
|
34
|
+
temperature: float = 1.0
|
|
35
|
+
max_completion_tokens: Optional[int] = None
|
|
36
|
+
tools: Optional[List[Dict[str, Any]]] = None
|
|
37
|
+
thinking: Optional[Union[ThinkingConfig, ReasoningEffort, GeminiThinkingConfig]] = None
|