kolega-code 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. kolega_code/__init__.py +151 -0
  2. kolega_code/agent/__init__.py +42 -0
  3. kolega_code/agent/baseagent.py +998 -0
  4. kolega_code/agent/browseragent.py +123 -0
  5. kolega_code/agent/coder.py +157 -0
  6. kolega_code/agent/common.py +41 -0
  7. kolega_code/agent/compression.py +81 -0
  8. kolega_code/agent/context.py +112 -0
  9. kolega_code/agent/conversation.py +408 -0
  10. kolega_code/agent/generalagent.py +146 -0
  11. kolega_code/agent/investigationagent.py +123 -0
  12. kolega_code/agent/planningagent.py +187 -0
  13. kolega_code/agent/prompt_provider.py +196 -0
  14. kolega_code/agent/prompt_templates/agents/browser.j2 +102 -0
  15. kolega_code/agent/prompt_templates/agents/coder_cli_mode.j2 +127 -0
  16. kolega_code/agent/prompt_templates/agents/general.j2 +68 -0
  17. kolega_code/agent/prompt_templates/agents/investigation.j2 +72 -0
  18. kolega_code/agent/prompt_templates/common/frontend_guidance.md +36 -0
  19. kolega_code/agent/prompt_templates/common/kolega_md_instructions.md +14 -0
  20. kolega_code/agent/prompt_templates/environment_variables/workspace_env_vars.md +11 -0
  21. kolega_code/agent/prompt_templates/template_guidance/expo-template.md +379 -0
  22. kolega_code/agent/prompt_templates/template_guidance/html-website-template.md +3 -0
  23. kolega_code/agent/prompt_templates/template_guidance/mern-stack-template.md +3 -0
  24. kolega_code/agent/prompt_templates/template_guidance/react-vite-shadcdn-template.md +182 -0
  25. kolega_code/agent/prompts.py +192 -0
  26. kolega_code/agent/tests/__init__.py +0 -0
  27. kolega_code/agent/tests/llm/__init__.py +0 -0
  28. kolega_code/agent/tests/llm/test_anthropic_token_counting.py +633 -0
  29. kolega_code/agent/tests/llm/test_billing_openai_cache.py +74 -0
  30. kolega_code/agent/tests/llm/test_client.py +773 -0
  31. kolega_code/agent/tests/llm/test_dashscope_mapping.py +32 -0
  32. kolega_code/agent/tests/llm/test_error_boundary.py +322 -0
  33. kolega_code/agent/tests/llm/test_exceptions.py +249 -0
  34. kolega_code/agent/tests/llm/test_instrumented_client.py +536 -0
  35. kolega_code/agent/tests/llm/test_instrumented_client_integration.py +547 -0
  36. kolega_code/agent/tests/llm/test_langfuse_normalization.py +39 -0
  37. kolega_code/agent/tests/llm/test_model_specs.py +17 -0
  38. kolega_code/agent/tests/llm/test_openai_cached_tokens.py +58 -0
  39. kolega_code/agent/tests/llm/test_openai_cached_tokens_stream.py +74 -0
  40. kolega_code/agent/tests/llm/test_openai_message_conversion.py +30 -0
  41. kolega_code/agent/tests/llm/test_openai_token_counting.py +687 -0
  42. kolega_code/agent/tests/llm/test_tool_execution_ids.py +193 -0
  43. kolega_code/agent/tests/services/__init__.py +1 -0
  44. kolega_code/agent/tests/services/test_browser.py +447 -0
  45. kolega_code/agent/tests/services/test_browser_parity.py +353 -0
  46. kolega_code/agent/tests/services/test_file_system.py +699 -0
  47. kolega_code/agent/tests/services/test_sandbox_terminal_input.py +98 -0
  48. kolega_code/agent/tests/services/test_terminal.py +154 -0
  49. kolega_code/agent/tests/services/test_terminal_command_tracking.py +385 -0
  50. kolega_code/agent/tests/services/test_terminal_state_serializer.py +262 -0
  51. kolega_code/agent/tests/test_agent_tools_inventory.py +267 -0
  52. kolega_code/agent/tests/test_base_agent.py +1942 -0
  53. kolega_code/agent/tests/test_coder_attachments.py +330 -0
  54. kolega_code/agent/tests/test_coder_prompt_extensions.py +61 -0
  55. kolega_code/agent/tests/test_commands.py +179 -0
  56. kolega_code/agent/tests/test_duplicate_tool_results.py +556 -0
  57. kolega_code/agent/tests/test_empty_message_handling.py +48 -0
  58. kolega_code/agent/tests/test_general_agent.py +242 -0
  59. kolega_code/agent/tests/test_html.py +320 -0
  60. kolega_code/agent/tests/test_parallel_tool_calls.py +291 -0
  61. kolega_code/agent/tests/test_planning_agent.py +227 -0
  62. kolega_code/agent/tests/test_prompt_provider.py +271 -0
  63. kolega_code/agent/tests/test_tool_registry.py +102 -0
  64. kolega_code/agent/tests/test_tools.py +549 -0
  65. kolega_code/agent/tests/tool_backend/__init__.py +0 -0
  66. kolega_code/agent/tests/tool_backend/test_agent_tool.py +356 -0
  67. kolega_code/agent/tests/tool_backend/test_base_tool.py +147 -0
  68. kolega_code/agent/tests/tool_backend/test_browser_tool.py +335 -0
  69. kolega_code/agent/tests/tool_backend/test_build_tool.py +93 -0
  70. kolega_code/agent/tests/tool_backend/test_create_file_tool.py +115 -0
  71. kolega_code/agent/tests/tool_backend/test_glob_tool.py +196 -0
  72. kolega_code/agent/tests/tool_backend/test_glob_tool_sandbox_parity.py +230 -0
  73. kolega_code/agent/tests/tool_backend/test_list_directory_tool.py +292 -0
  74. kolega_code/agent/tests/tool_backend/test_read_file_tool.py +173 -0
  75. kolega_code/agent/tests/tool_backend/test_replace_entire_file_tool.py +115 -0
  76. kolega_code/agent/tests/tool_backend/test_replace_lines_tool.py +141 -0
  77. kolega_code/agent/tests/tool_backend/test_search_and_replace_tool.py +174 -0
  78. kolega_code/agent/tests/tool_backend/test_search_codebase_tool.py +228 -0
  79. kolega_code/agent/tests/tool_backend/test_terminal_tool.py +482 -0
  80. kolega_code/agent/tests/tool_backend/test_think_hard_integration.py +189 -0
  81. kolega_code/agent/tests/tool_backend/test_think_hard_streaming.py +445 -0
  82. kolega_code/agent/tests/tool_backend/test_web_fetch_tool.py +194 -0
  83. kolega_code/agent/tool_backend/agent_tool.py +414 -0
  84. kolega_code/agent/tool_backend/apply_edit_tool.py +98 -0
  85. kolega_code/agent/tool_backend/apply_patch_tool.py +514 -0
  86. kolega_code/agent/tool_backend/base_tool.py +217 -0
  87. kolega_code/agent/tool_backend/browser_tool.py +271 -0
  88. kolega_code/agent/tool_backend/build_tool.py +93 -0
  89. kolega_code/agent/tool_backend/create_file_tool.py +52 -0
  90. kolega_code/agent/tool_backend/glob_tool.py +323 -0
  91. kolega_code/agent/tool_backend/list_directory_tool.py +300 -0
  92. kolega_code/agent/tool_backend/memory_tool.py +79 -0
  93. kolega_code/agent/tool_backend/read_file_tool.py +119 -0
  94. kolega_code/agent/tool_backend/replace_entire_file_tool.py +40 -0
  95. kolega_code/agent/tool_backend/replace_lines_tool.py +97 -0
  96. kolega_code/agent/tool_backend/search_and_replace_tool.py +146 -0
  97. kolega_code/agent/tool_backend/search_codebase_tool.py +377 -0
  98. kolega_code/agent/tool_backend/streaming_tool.py +47 -0
  99. kolega_code/agent/tool_backend/terminal_tool.py +643 -0
  100. kolega_code/agent/tool_backend/think_hard_tool.py +211 -0
  101. kolega_code/agent/tool_backend/web_fetch_tool.py +205 -0
  102. kolega_code/agent/tools.py +1704 -0
  103. kolega_code/agent/utils/commands.py +94 -0
  104. kolega_code/cli/__init__.py +1 -0
  105. kolega_code/cli/app.py +2756 -0
  106. kolega_code/cli/config.py +280 -0
  107. kolega_code/cli/connection.py +49 -0
  108. kolega_code/cli/file_index.py +147 -0
  109. kolega_code/cli/main.py +564 -0
  110. kolega_code/cli/mentions.py +155 -0
  111. kolega_code/cli/messages.py +89 -0
  112. kolega_code/cli/provider_registry.py +96 -0
  113. kolega_code/cli/session_store.py +207 -0
  114. kolega_code/cli/settings.py +87 -0
  115. kolega_code/cli/skills.py +409 -0
  116. kolega_code/cli/slash_commands.py +108 -0
  117. kolega_code/cli/tests/__init__.py +1 -0
  118. kolega_code/cli/tests/test_app.py +4251 -0
  119. kolega_code/cli/tests/test_cli_config.py +171 -0
  120. kolega_code/cli/tests/test_connection.py +26 -0
  121. kolega_code/cli/tests/test_file_index.py +103 -0
  122. kolega_code/cli/tests/test_main.py +455 -0
  123. kolega_code/cli/tests/test_mentions.py +108 -0
  124. kolega_code/cli/tests/test_session_store.py +67 -0
  125. kolega_code/cli/tests/test_settings.py +62 -0
  126. kolega_code/cli/tests/test_skills.py +157 -0
  127. kolega_code/cli/tests/test_slash_commands.py +88 -0
  128. kolega_code/cli/theme.py +180 -0
  129. kolega_code/config.py +154 -0
  130. kolega_code/events.py +202 -0
  131. kolega_code/llm/client.py +300 -0
  132. kolega_code/llm/exceptions.py +285 -0
  133. kolega_code/llm/instrumented_client.py +520 -0
  134. kolega_code/llm/models.py +1368 -0
  135. kolega_code/llm/providers/__init__.py +0 -0
  136. kolega_code/llm/providers/anthropic.py +387 -0
  137. kolega_code/llm/providers/base.py +71 -0
  138. kolega_code/llm/providers/google.py +157 -0
  139. kolega_code/llm/providers/models.py +37 -0
  140. kolega_code/llm/providers/openai.py +363 -0
  141. kolega_code/llm/ratelimit.py +40 -0
  142. kolega_code/llm/specs.py +67 -0
  143. kolega_code/llm/tool_execution_ids.py +18 -0
  144. kolega_code/models/__init__.py +9 -0
  145. kolega_code/models/sandbox_terminal_state.py +47 -0
  146. kolega_code/runtime.py +50 -0
  147. kolega_code/sandbox/README.md +200 -0
  148. kolega_code/sandbox/__init__.py +21 -0
  149. kolega_code/sandbox/async_filesystem.py +475 -0
  150. kolega_code/sandbox/base.py +297 -0
  151. kolega_code/sandbox/browser.py +25 -0
  152. kolega_code/sandbox/event_loop.py +43 -0
  153. kolega_code/sandbox/filesystem.py +341 -0
  154. kolega_code/sandbox/local.py +118 -0
  155. kolega_code/sandbox/serializer.py +175 -0
  156. kolega_code/sandbox/terminal.py +868 -0
  157. kolega_code/sandbox/utils.py +216 -0
  158. kolega_code/services/base.py +255 -0
  159. kolega_code/services/browser.py +444 -0
  160. kolega_code/services/file_system.py +749 -0
  161. kolega_code/services/html.py +221 -0
  162. kolega_code/services/terminal.py +903 -0
  163. kolega_code/tools/__init__.py +22 -0
  164. kolega_code/tools/core.py +33 -0
  165. kolega_code/tools/definitions.py +81 -0
  166. kolega_code/tools/registry.py +73 -0
  167. kolega_code-0.1.0.dist-info/METADATA +157 -0
  168. kolega_code-0.1.0.dist-info/RECORD +171 -0
  169. kolega_code-0.1.0.dist-info/WHEEL +4 -0
  170. kolega_code-0.1.0.dist-info/entry_points.txt +2 -0
  171. kolega_code-0.1.0.dist-info/licenses/LICENSE +21 -0
File without changes
@@ -0,0 +1,387 @@
1
+ import json
2
+ import os
3
+ from typing import Any, AsyncContextManager, Dict, List, Optional, Union
4
+
5
+ import tiktoken
6
+ from anthropic import Anthropic, AsyncAnthropic
7
+
8
+ from ..models import Message, MessageChunk, MessageHistory, ToolDefinition
9
+ from ..specs import get_model_specs
10
+ from ..tool_execution_ids import ToolExecutionIdRegistry
11
+ from .base import BaseLLMProvider
12
+ from .models import GenerationParams, ReasoningEffort, ThinkingConfig, TokenCount
13
+
14
+
15
+ class AnthropicStreamWrapper:
16
+ def __init__(self, anthropic_stream, provider_name: str = "anthropic"):
17
+ self.anthropic_stream = anthropic_stream
18
+ self.provider_name = provider_name
19
+ self.generator = None
20
+ self._closed = False
21
+
22
+ # Track tool calls being streamed
23
+ self.tool_execution_ids = ToolExecutionIdRegistry()
24
+ self.current_tool_calls = {} # Maps tool_call_id to accumulated data
25
+ self.tool_call_order = [] # Track order of tool calls
26
+ self.current_block_index = None # Track which content block we're processing
27
+
28
+ async def __aenter__(self):
29
+ self.generator = await self.anthropic_stream.__aenter__()
30
+ return self
31
+
32
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
33
+ return await self.anthropic_stream.__aexit__(exc_type, exc_val, exc_tb)
34
+
35
+ def __aiter__(self):
36
+ if self.generator is None:
37
+ raise RuntimeError("Must use 'async with' before iterating")
38
+ return self
39
+
40
+ async def __anext__(self):
41
+ if self.generator is None:
42
+ raise RuntimeError("Must use 'async with' before iterating")
43
+
44
+ try:
45
+ chunk = await self.generator.__anext__()
46
+
47
+ # Handle content_block_start events for tool use
48
+ if chunk.type == "content_block_start" and hasattr(chunk, "content_block"):
49
+ if chunk.content_block.type == "tool_use":
50
+ # Track this new tool call
51
+ tool_id = chunk.content_block.id
52
+ self.current_tool_calls[tool_id] = {
53
+ "id": tool_id,
54
+ "name": chunk.content_block.name,
55
+ "input_json": "",
56
+ "block_index": chunk.index if hasattr(chunk, "index") else len(self.tool_call_order),
57
+ "execution_id": self.tool_execution_ids.get_or_create(tool_id),
58
+ }
59
+ self.tool_call_order.append(tool_id)
60
+ self.current_block_index = chunk.index if hasattr(chunk, "index") else None
61
+
62
+ # Handle content_block_delta events for tool use input
63
+ elif chunk.type == "content_block_delta" and hasattr(chunk, "delta"):
64
+ if chunk.delta.type == "input_json_delta" and hasattr(chunk, "index"):
65
+ # Find the tool call by block index
66
+ for tool_id, tool_data in self.current_tool_calls.items():
67
+ if tool_data["block_index"] == chunk.index:
68
+ # Accumulate the JSON input
69
+ tool_data["input_json"] += chunk.delta.partial_json
70
+ break
71
+
72
+ message_chunk = MessageChunk.from_anthropic(chunk)
73
+ if message_chunk.type == "tool_use_start" and message_chunk.tool_call_delta:
74
+ tool_id = message_chunk.tool_call_delta.get("id")
75
+ tool_data = self.current_tool_calls.get(tool_id)
76
+ if tool_data and tool_data.get("execution_id"):
77
+ message_chunk.tool_call_delta["execution_id"] = tool_data["execution_id"]
78
+
79
+ return message_chunk
80
+
81
+ except StopAsyncIteration:
82
+ raise
83
+
84
+ async def get_final_message(self):
85
+ message = Message.from_anthropic(
86
+ await self.generator.get_final_message(),
87
+ tool_execution_ids=self.tool_execution_ids,
88
+ )
89
+ if message.usage_metadata:
90
+ message.usage_metadata["provider"] = self.provider_name
91
+ return message
92
+
93
+
94
+ class AnthropicProvider(BaseLLMProvider):
95
+ SYSTEM_OVERHEAD = 4
96
+ MESSAGE_OVERHEAD = 3
97
+ TOOL_DEFINITION_OVERHEAD = 65
98
+
99
+ def __init__(
100
+ self,
101
+ api_key: str,
102
+ max_retries: int = 3,
103
+ requests_per_minute: Optional[int] = None,
104
+ tokens_per_minute: Optional[int] = None,
105
+ base_url: Optional[str] = None,
106
+ provider_name: str = "anthropic",
107
+ ):
108
+ super().__init__(api_key, max_retries, requests_per_minute, tokens_per_minute, base_url)
109
+ self.provider_name = provider_name
110
+ self.async_client = AsyncAnthropic(api_key=api_key, base_url=base_url)
111
+ self.sync_client = Anthropic(api_key=api_key, base_url=base_url)
112
+
113
+ # OpenAI-compatible Anthropic-shaped APIs do not expose messages/count_tokens,
114
+ # so local counting is only a preflight context-size estimate for those models.
115
+ # Billing/accounting must use provider response usage metadata instead.
116
+ self.use_local_token_counting = (
117
+ provider_name in {"moonshot", "deepseek"}
118
+ or os.getenv('ANTHROPIC_USE_LOCAL_TOKEN_COUNTING', 'false').lower() == 'true'
119
+ )
120
+
121
+ @property
122
+ def retry_decorator(self):
123
+ """Get retry decorator with configured max retries"""
124
+ return self.get_retry_decorator()
125
+
126
+ def _prepare_thinking_params(self, thinking: Optional[Union[ThinkingConfig, ReasoningEffort]]) -> Dict[str, Any]:
127
+ """Convert thinking parameters to provider-specific format"""
128
+ return {"type": "enabled", "budget_tokens": thinking.budget_tokens}
129
+
130
+ def _prepare_generation_params(self, params: Optional[GenerationParams] = None) -> Dict[str, Any]:
131
+ """Convert common parameters to provider-specific format"""
132
+ generation_params = {
133
+ "model": "claude-opus-4-7", # Default model
134
+ "max_tokens": 1024, # Default max tokens
135
+ }
136
+
137
+ if params:
138
+ if params.temperature is not None:
139
+ generation_params["temperature"] = params.temperature
140
+ if params.max_completion_tokens is not None:
141
+ generation_params["max_tokens"] = params.max_completion_tokens
142
+ if params.tools:
143
+ generation_params["tools"] = [t.to_anthropic() for t in params.tools]
144
+ generation_params["tool_choice"] = {"type": "auto"}
145
+ if params.thinking:
146
+ generation_params["thinking"] = self._prepare_thinking_params(params.thinking)
147
+
148
+ return generation_params
149
+
150
+ def _sanitize_generation_params(self, generation_params: Dict[str, Any]) -> Dict[str, Any]:
151
+ """Remove parameters unsupported by the selected Anthropic model."""
152
+ model = generation_params.get("model")
153
+ if self.provider_name != "anthropic" or not isinstance(model, str):
154
+ return generation_params
155
+
156
+ try:
157
+ model_specs = get_model_specs(self.provider_name, model)
158
+ except ValueError:
159
+ return generation_params
160
+
161
+ if model_specs.get("supports_temperature", True) is False:
162
+ generation_params.pop("temperature", None)
163
+
164
+ return generation_params
165
+
166
+ async def count_tokens(
167
+ self,
168
+ messages: MessageHistory,
169
+ system: Optional[Message] = None,
170
+ model: Optional[str] = None,
171
+ tools: List[ToolDefinition] = None,
172
+ **kwargs,
173
+ ) -> TokenCount:
174
+ tools = tools or []
175
+ if self.use_local_token_counting:
176
+ # Use local tiktoken-based counting (no API call). This is an
177
+ # estimate for context management, not authoritative billing usage.
178
+ return self._count_tokens_local(messages, system, model, tools)
179
+ else:
180
+ # Use Anthropic API for token counting
181
+ await self.rate_limiter.acquire()
182
+ count = await self.async_client.messages.count_tokens(
183
+ messages=messages.to_anthropic(),
184
+ system=[c.to_anthropic() for c in system.content],
185
+ model=model,
186
+ tools=[t.to_anthropic() for t in tools],
187
+ **kwargs,
188
+ )
189
+
190
+ # The API now only returns input_tokens
191
+ return TokenCount(
192
+ input_tokens=count.input_tokens,
193
+ output_tokens=None, # MessageTokensCount no longer includes output_tokens
194
+ )
195
+
196
+ def _count_tokens_local(
197
+ self,
198
+ messages: MessageHistory,
199
+ system: Optional[Message] = None,
200
+ model: Optional[str] = None,
201
+ tools: List[ToolDefinition] = None,
202
+ ) -> TokenCount:
203
+ """Count tokens locally using tiktoken with p50k_base encoding.
204
+
205
+ This provides a fast approximation without making an API call.
206
+ Uses minimal overhead and direct text encoding for better accuracy.
207
+ Handles images by estimating token cost based on data size.
208
+
209
+ Args:
210
+ messages: Message history to count tokens for
211
+ system: Optional system message
212
+ model: Optional model name (not used for local counting)
213
+ tools: Optional tool definitions
214
+
215
+ Returns:
216
+ TokenCount object with estimated input token count
217
+ """
218
+ encoding = tiktoken.get_encoding("p50k_base")
219
+ num_tokens = 0
220
+ tools = tools or []
221
+
222
+ if system:
223
+ num_tokens += self.SYSTEM_OVERHEAD
224
+ num_tokens += self._count_message_content_tokens(encoding, system.content)
225
+
226
+ for message in messages:
227
+ num_tokens += self.MESSAGE_OVERHEAD
228
+ num_tokens += self._count_message_content_tokens(encoding, message.content)
229
+
230
+ for tool in tools:
231
+ num_tokens += self._count_value_tokens(encoding, tool.to_anthropic())
232
+ num_tokens += self.TOOL_DEFINITION_OVERHEAD
233
+
234
+ return TokenCount(input_tokens=num_tokens, output_tokens=None)
235
+
236
+ def _count_message_content_tokens(self, encoding, content: Any) -> int:
237
+ if isinstance(content, str):
238
+ return len(encoding.encode(content))
239
+
240
+ if isinstance(content, list):
241
+ return sum(self._count_content_block_tokens(encoding, block) for block in content)
242
+
243
+ return self._count_value_tokens(encoding, content)
244
+
245
+ def _count_content_block_tokens(self, encoding, block: Any) -> int:
246
+ if hasattr(block, "text"):
247
+ return len(encoding.encode(block.text))
248
+
249
+ if getattr(block, "type", None) == "image_url":
250
+ data = getattr(block, "data", None)
251
+ if isinstance(data, str):
252
+ return self._estimate_image_tokens(len(data))
253
+
254
+ if getattr(block, "type", None) == "tool_result":
255
+ content = getattr(block, "content", "")
256
+ return self._count_message_content_tokens(encoding, content)
257
+
258
+ if hasattr(block, "thinking"):
259
+ return len(encoding.encode(block.thinking))
260
+
261
+ if hasattr(block, "data"):
262
+ return len(encoding.encode(str(block.data)))
263
+
264
+ if hasattr(block, "to_anthropic"):
265
+ return self._count_value_tokens(encoding, block.to_anthropic())
266
+
267
+ return self._count_value_tokens(encoding, block)
268
+
269
+ def _count_value_tokens(self, encoding, value: Any) -> int:
270
+ if value is None:
271
+ return 0
272
+
273
+ if isinstance(value, str):
274
+ return len(encoding.encode(value))
275
+
276
+ if isinstance(value, (int, float, bool)):
277
+ return len(encoding.encode(str(value)))
278
+
279
+ if isinstance(value, list):
280
+ return 2 + sum(self._count_value_tokens(encoding, item) for item in value)
281
+
282
+ if isinstance(value, dict):
283
+ if value.get("type") == "image":
284
+ source = value.get("source") or {}
285
+ data = source.get("data")
286
+ if isinstance(data, str):
287
+ return self._estimate_image_tokens(len(data))
288
+
289
+ total = 2
290
+ for key, item in value.items():
291
+ total += len(encoding.encode(str(key)))
292
+ total += self._count_value_tokens(encoding, item)
293
+ return total
294
+
295
+ return len(encoding.encode(json.dumps(value, ensure_ascii=False, default=str)))
296
+
297
+ def _estimate_image_tokens(self, base64_data_length: int) -> int:
298
+ """Estimate image token cost based on base64 data length.
299
+
300
+ Anthropic charges for images based on their dimensions after resizing.
301
+ Since we don't decode images (performance), we estimate based on data size.
302
+
303
+ Empirically observed from tests:
304
+ - Tiny images (96 chars base64, 1x1 px): ~25 tokens
305
+ - Small images (~50-200KB base64): ~200-800 tokens
306
+ - Medium images (~200-800KB base64): ~800-2000 tokens
307
+ - Large images (~800KB+ base64): ~2000-4000 tokens
308
+
309
+ Formula uses square root scaling for better approximation across sizes:
310
+ tokens ≈ 20 + sqrt(base64_length * 6)
311
+
312
+ This gives:
313
+ - 96 chars → 20 + sqrt(576) = 44 tokens (~25 actual)
314
+ - 50KB (68K chars) → 20 + sqrt(408K) = 659 tokens
315
+ - 200KB (273K chars) → 20 + sqrt(1.6M) = 1285 tokens
316
+ - 800KB (1.1M chars) → 20 + sqrt(6.4M) = 2549 tokens
317
+
318
+ Args:
319
+ base64_data_length: Length of base64 encoded image data
320
+
321
+ Returns:
322
+ Estimated token count for the image
323
+ """
324
+ import math
325
+
326
+ # Use square root scaling for better fit across image sizes
327
+ # Base cost of 20 tokens + sqrt scaling
328
+ estimated_tokens = 20 + int(math.sqrt(base64_data_length * 6))
329
+
330
+ return estimated_tokens
331
+
332
+ async def stream(
333
+ self,
334
+ messages: MessageHistory,
335
+ system: Optional[Message] = None,
336
+ params: Optional[GenerationParams] = None,
337
+ **kwargs,
338
+ ) -> AsyncContextManager:
339
+ """Generate a streaming response from Anthropic
340
+
341
+ Returns a context manager that provides an async iterator when entered.
342
+ The context manager also provides get_final_message() to retrieve the
343
+ complete message after streaming.
344
+ """
345
+ generation_params = self._prepare_generation_params(params)
346
+ generation_params.update(kwargs)
347
+ generation_params = self._sanitize_generation_params(generation_params)
348
+
349
+ if generation_params["model"].startswith("claude-3-7"):
350
+ generation_params["extra_headers"] = {"anthropic-beta": "output-128k-2025-02-19"}
351
+
352
+ await self.rate_limiter.acquire()
353
+
354
+ # Return the stream context manager
355
+ return AnthropicStreamWrapper(
356
+ self.async_client.messages.stream(
357
+ messages=messages.to_anthropic(),
358
+ system=[c.to_anthropic() for c in system.content],
359
+ **generation_params,
360
+ ),
361
+ provider_name=self.provider_name,
362
+ )
363
+
364
+ async def generate(
365
+ self,
366
+ messages: MessageHistory,
367
+ system: Optional[Message] = None,
368
+ params: Optional[GenerationParams] = None,
369
+ **kwargs,
370
+ ) -> Message:
371
+ generation_params = self._prepare_generation_params(params)
372
+ generation_params.update(kwargs)
373
+ generation_params = self._sanitize_generation_params(generation_params)
374
+
375
+ if generation_params["model"].startswith("claude-3-7"):
376
+ generation_params["extra_headers"] = {"anthropic-beta": "output-128k-2025-02-19"}
377
+
378
+ await self.rate_limiter.acquire()
379
+ response = await self.async_client.messages.create(
380
+ messages=messages.to_anthropic(),
381
+ system=[c.to_anthropic() for c in system.content],
382
+ **generation_params,
383
+ )
384
+ message = Message.from_anthropic(response)
385
+ if message.usage_metadata:
386
+ message.usage_metadata["provider"] = self.provider_name
387
+ return message
@@ -0,0 +1,71 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, AsyncContextManager, Dict, List, Optional, Union
3
+
4
+ from anthropic import APIError as AnthropicAPIError
5
+ from google.genai.errors import APIError as GeminiAPIError
6
+ from openai import APIError as OpenAIAPIError
7
+ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
8
+
9
+ from ..models import Message, MessageHistory, ToolDefinition
10
+ from ..ratelimit import RateLimiter
11
+ from .models import GenerationParams, ReasoningEffort, ThinkingConfig, TokenCount
12
+
13
+
14
+ class BaseLLMProvider(ABC):
15
+ """Abstract base class defining the interface for LLM providers"""
16
+
17
+ def __init__(
18
+ self,
19
+ api_key: str,
20
+ max_retries: int = 3,
21
+ requests_per_minute: Optional[int] = None,
22
+ tokens_per_minute: Optional[int] = None,
23
+ base_url: Optional[str] = None,
24
+ ):
25
+ self.api_key = api_key
26
+ self.max_retries = max_retries
27
+ self.rate_limiter = RateLimiter(requests_per_minute, tokens_per_minute)
28
+ self.base_url = base_url
29
+
30
+ @abstractmethod
31
+ async def count_tokens(
32
+ self,
33
+ messages: MessageHistory,
34
+ system: Message = None,
35
+ model: Optional[str] = None,
36
+ tools: List[ToolDefinition] = None,
37
+ **kwargs,
38
+ ) -> TokenCount:
39
+ pass
40
+
41
+ @abstractmethod
42
+ def stream(
43
+ self,
44
+ messages: MessageHistory,
45
+ system: Optional[Message] = None,
46
+ params: Optional[GenerationParams] = None,
47
+ **kwargs,
48
+ ) -> AsyncContextManager:
49
+ pass
50
+
51
+ @abstractmethod
52
+ async def generate(
53
+ self, messages: MessageHistory, system: Message = None, params: Optional[GenerationParams] = None, **kwargs
54
+ ) -> Message:
55
+ pass
56
+
57
+ def _prepare_generation_params(self, params: Optional[GenerationParams] = None) -> Dict[str, Any]:
58
+ """Convert common parameters to provider-specific format"""
59
+ return {}
60
+
61
+ def _prepare_thinking_params(self, thinking: Optional[Union[ThinkingConfig, ReasoningEffort]]) -> Dict[str, Any]:
62
+ """Convert thinking parameters to provider-specific format"""
63
+ return {}
64
+
65
+ def get_retry_decorator(self):
66
+ """Get retry decorator with exponential backoff"""
67
+ return retry(
68
+ stop=stop_after_attempt(self.max_retries),
69
+ wait=wait_exponential(multiplier=1, min=4, max=10),
70
+ retry=retry_if_exception_type((AnthropicAPIError, OpenAIAPIError, GeminiAPIError)),
71
+ )
@@ -0,0 +1,157 @@
1
+ from typing import AsyncContextManager, List, Optional
2
+
3
+ from google.genai import Client as genai_client
4
+ from google.genai import types as genai_types
5
+
6
+ from ..models import Message, MessageChunk, MessageHistory, ToolDefinition
7
+ from ..tool_execution_ids import ToolExecutionIdRegistry
8
+ from .base import BaseLLMProvider
9
+ from .models import GenerationParams, TokenCount
10
+
11
+
12
+ class GoogleStreamWrapper:
13
+ def __init__(self, gemini_stream):
14
+ self.gemini_stream = gemini_stream
15
+ self.final_content = ""
16
+ self.final_tool_calls = {}
17
+ self.stop_reason = None
18
+ self.tool_execution_ids = ToolExecutionIdRegistry()
19
+
20
+ self._closed = False
21
+
22
+ async def __aenter__(self):
23
+ return self
24
+
25
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
26
+ if hasattr(self.gemini_stream, "aclose"):
27
+ await self.gemini_stream.aclose()
28
+
29
+ self._closed = True
30
+ return False
31
+
32
+ def __aiter__(self):
33
+ return self
34
+
35
+ async def __anext__(self):
36
+ if self._closed:
37
+ raise StopAsyncIteration
38
+
39
+ try:
40
+ chunk = await self.gemini_stream.__anext__()
41
+
42
+ content = chunk.text or ""
43
+ self.final_content += content
44
+
45
+ for idx, function_call in enumerate(chunk.function_calls or []):
46
+ self.final_tool_calls[idx] = function_call
47
+
48
+ # self.final_tool_calls[function_call_id].function.arguments += tool_call.function.arguments
49
+
50
+ self.stop_reason = chunk.candidates[0].finish_reason.value if chunk.candidates[0].finish_reason else None
51
+
52
+ return MessageChunk.from_google(chunk)
53
+
54
+ except StopAsyncIteration:
55
+ raise
56
+
57
+ async def get_final_message(self):
58
+ return Message.from_google_stream(
59
+ role="assistant",
60
+ content=self.final_content,
61
+ tool_calls=self.final_tool_calls,
62
+ stop_reason=self.stop_reason,
63
+ tool_execution_ids=self.tool_execution_ids,
64
+ )
65
+
66
+
67
+ class GoogleProvider(BaseLLMProvider):
68
+ def __init__(
69
+ self,
70
+ api_key: str,
71
+ max_retries: int = 3,
72
+ requests_per_minute: Optional[int] = None,
73
+ tokens_per_minute: Optional[int] = None,
74
+ base_url: Optional[str] = None,
75
+ ):
76
+ super().__init__(api_key, max_retries, requests_per_minute, tokens_per_minute, base_url)
77
+ self.async_client = genai_client(api_key=api_key)
78
+
79
+ @property
80
+ def retry_decorator(self):
81
+ """Get retry decorator with configured max retries"""
82
+ return self.get_retry_decorator()
83
+
84
+ async def count_tokens(
85
+ self,
86
+ messages: MessageHistory,
87
+ system: Optional[Message] = None,
88
+ model: Optional[str] = None,
89
+ tools: List[ToolDefinition] = None,
90
+ **kwargs,
91
+ ) -> TokenCount:
92
+ """Count tokens for a list of messages using tiktoken
93
+
94
+ Args:
95
+ messages: List of messages to count tokens for
96
+ system: Optional system message
97
+ model: Optional model name to use for counting (defaults to gpt-4)
98
+
99
+ Returns:
100
+ TokenCount object with input token count
101
+ """
102
+ count = await self.async_client.aio.models.count_tokens(
103
+ model=model,
104
+ contents=messages.to_google(),
105
+ )
106
+
107
+ return TokenCount(input_tokens=count.total_tokens, output_tokens=None)
108
+
109
+ async def stream(
110
+ self,
111
+ messages: MessageHistory,
112
+ system: Optional[Message] = None,
113
+ params: Optional[GenerationParams] = None,
114
+ **kwargs,
115
+ ) -> AsyncContextManager:
116
+ """Generate a streaming response from Google
117
+
118
+ Returns a coroutine that resolves to an async iterator.
119
+ """
120
+ config = genai_types.GenerateContentConfig(
121
+ system_instruction=system.content[0].text,
122
+ temperature=params.temperature,
123
+ max_output_tokens=params.max_completion_tokens,
124
+ tools=[t.to_google() for t in params.tools] if params.tools else None,
125
+ thinking_config=params.thinking,
126
+ )
127
+
128
+ await self.rate_limiter.acquire()
129
+
130
+ return GoogleStreamWrapper(
131
+ await self.async_client.aio.models.generate_content_stream(
132
+ model=kwargs["model"], contents=messages.to_google(), config=config
133
+ )
134
+ )
135
+
136
+ async def generate(
137
+ self,
138
+ messages: MessageHistory,
139
+ system: Optional[Message] = None,
140
+ params: Optional[GenerationParams] = None,
141
+ **kwargs,
142
+ ) -> Message:
143
+ config = genai_types.GenerateContentConfig(
144
+ system_instruction=system.content[0].text,
145
+ temperature=params.temperature,
146
+ max_output_tokens=params.max_completion_tokens,
147
+ tools=[t.to_google() for t in params.tools] if params.tools else None,
148
+ thinking_config=params.thinking,
149
+ )
150
+
151
+ await self.rate_limiter.acquire()
152
+
153
+ response = await self.async_client.aio.models.generate_content(
154
+ model=kwargs["model"], contents=messages.to_google(), config=config
155
+ )
156
+
157
+ return Message.from_google(response)
@@ -0,0 +1,37 @@
1
+ from dataclasses import dataclass
2
+ from enum import Enum
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+
6
+ class ReasoningEffort(Enum):
7
+ LOW = "low"
8
+ MEDIUM = "medium"
9
+ HIGH = "high"
10
+
11
+
12
+ @dataclass
13
+ class ThinkingConfig:
14
+ """Configuration for model's thinking depth"""
15
+
16
+ budget_tokens: int = 4096
17
+
18
+
19
+ @dataclass
20
+ class GeminiThinkingConfig:
21
+ include_thoughts: bool = True
22
+
23
+
24
+ @dataclass
25
+ class TokenCount:
26
+ input_tokens: int
27
+ output_tokens: Optional[int] = None
28
+
29
+
30
+ @dataclass
31
+ class GenerationParams:
32
+ """Common parameters for text generation across providers"""
33
+
34
+ temperature: float = 1.0
35
+ max_completion_tokens: Optional[int] = None
36
+ tools: Optional[List[Dict[str, Any]]] = None
37
+ thinking: Optional[Union[ThinkingConfig, ReasoningEffort, GeminiThinkingConfig]] = None