kolega-code 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kolega_code/__init__.py +151 -0
- kolega_code/agent/__init__.py +42 -0
- kolega_code/agent/baseagent.py +998 -0
- kolega_code/agent/browseragent.py +123 -0
- kolega_code/agent/coder.py +157 -0
- kolega_code/agent/common.py +41 -0
- kolega_code/agent/compression.py +81 -0
- kolega_code/agent/context.py +112 -0
- kolega_code/agent/conversation.py +408 -0
- kolega_code/agent/generalagent.py +146 -0
- kolega_code/agent/investigationagent.py +123 -0
- kolega_code/agent/planningagent.py +187 -0
- kolega_code/agent/prompt_provider.py +196 -0
- kolega_code/agent/prompt_templates/agents/browser.j2 +102 -0
- kolega_code/agent/prompt_templates/agents/coder_cli_mode.j2 +127 -0
- kolega_code/agent/prompt_templates/agents/general.j2 +68 -0
- kolega_code/agent/prompt_templates/agents/investigation.j2 +72 -0
- kolega_code/agent/prompt_templates/common/frontend_guidance.md +36 -0
- kolega_code/agent/prompt_templates/common/kolega_md_instructions.md +14 -0
- kolega_code/agent/prompt_templates/environment_variables/workspace_env_vars.md +11 -0
- kolega_code/agent/prompt_templates/template_guidance/expo-template.md +379 -0
- kolega_code/agent/prompt_templates/template_guidance/html-website-template.md +3 -0
- kolega_code/agent/prompt_templates/template_guidance/mern-stack-template.md +3 -0
- kolega_code/agent/prompt_templates/template_guidance/react-vite-shadcdn-template.md +182 -0
- kolega_code/agent/prompts.py +192 -0
- kolega_code/agent/tests/__init__.py +0 -0
- kolega_code/agent/tests/llm/__init__.py +0 -0
- kolega_code/agent/tests/llm/test_anthropic_token_counting.py +633 -0
- kolega_code/agent/tests/llm/test_billing_openai_cache.py +74 -0
- kolega_code/agent/tests/llm/test_client.py +773 -0
- kolega_code/agent/tests/llm/test_dashscope_mapping.py +32 -0
- kolega_code/agent/tests/llm/test_error_boundary.py +322 -0
- kolega_code/agent/tests/llm/test_exceptions.py +249 -0
- kolega_code/agent/tests/llm/test_instrumented_client.py +536 -0
- kolega_code/agent/tests/llm/test_instrumented_client_integration.py +547 -0
- kolega_code/agent/tests/llm/test_langfuse_normalization.py +39 -0
- kolega_code/agent/tests/llm/test_model_specs.py +17 -0
- kolega_code/agent/tests/llm/test_openai_cached_tokens.py +58 -0
- kolega_code/agent/tests/llm/test_openai_cached_tokens_stream.py +74 -0
- kolega_code/agent/tests/llm/test_openai_message_conversion.py +30 -0
- kolega_code/agent/tests/llm/test_openai_token_counting.py +687 -0
- kolega_code/agent/tests/llm/test_tool_execution_ids.py +193 -0
- kolega_code/agent/tests/services/__init__.py +1 -0
- kolega_code/agent/tests/services/test_browser.py +447 -0
- kolega_code/agent/tests/services/test_browser_parity.py +353 -0
- kolega_code/agent/tests/services/test_file_system.py +699 -0
- kolega_code/agent/tests/services/test_sandbox_terminal_input.py +98 -0
- kolega_code/agent/tests/services/test_terminal.py +154 -0
- kolega_code/agent/tests/services/test_terminal_command_tracking.py +385 -0
- kolega_code/agent/tests/services/test_terminal_state_serializer.py +262 -0
- kolega_code/agent/tests/test_agent_tools_inventory.py +267 -0
- kolega_code/agent/tests/test_base_agent.py +1942 -0
- kolega_code/agent/tests/test_coder_attachments.py +330 -0
- kolega_code/agent/tests/test_coder_prompt_extensions.py +61 -0
- kolega_code/agent/tests/test_commands.py +179 -0
- kolega_code/agent/tests/test_duplicate_tool_results.py +556 -0
- kolega_code/agent/tests/test_empty_message_handling.py +48 -0
- kolega_code/agent/tests/test_general_agent.py +242 -0
- kolega_code/agent/tests/test_html.py +320 -0
- kolega_code/agent/tests/test_parallel_tool_calls.py +291 -0
- kolega_code/agent/tests/test_planning_agent.py +227 -0
- kolega_code/agent/tests/test_prompt_provider.py +271 -0
- kolega_code/agent/tests/test_tool_registry.py +102 -0
- kolega_code/agent/tests/test_tools.py +549 -0
- kolega_code/agent/tests/tool_backend/__init__.py +0 -0
- kolega_code/agent/tests/tool_backend/test_agent_tool.py +356 -0
- kolega_code/agent/tests/tool_backend/test_base_tool.py +147 -0
- kolega_code/agent/tests/tool_backend/test_browser_tool.py +335 -0
- kolega_code/agent/tests/tool_backend/test_build_tool.py +93 -0
- kolega_code/agent/tests/tool_backend/test_create_file_tool.py +115 -0
- kolega_code/agent/tests/tool_backend/test_glob_tool.py +196 -0
- kolega_code/agent/tests/tool_backend/test_glob_tool_sandbox_parity.py +230 -0
- kolega_code/agent/tests/tool_backend/test_list_directory_tool.py +292 -0
- kolega_code/agent/tests/tool_backend/test_read_file_tool.py +173 -0
- kolega_code/agent/tests/tool_backend/test_replace_entire_file_tool.py +115 -0
- kolega_code/agent/tests/tool_backend/test_replace_lines_tool.py +141 -0
- kolega_code/agent/tests/tool_backend/test_search_and_replace_tool.py +174 -0
- kolega_code/agent/tests/tool_backend/test_search_codebase_tool.py +228 -0
- kolega_code/agent/tests/tool_backend/test_terminal_tool.py +482 -0
- kolega_code/agent/tests/tool_backend/test_think_hard_integration.py +189 -0
- kolega_code/agent/tests/tool_backend/test_think_hard_streaming.py +445 -0
- kolega_code/agent/tests/tool_backend/test_web_fetch_tool.py +194 -0
- kolega_code/agent/tool_backend/agent_tool.py +414 -0
- kolega_code/agent/tool_backend/apply_edit_tool.py +98 -0
- kolega_code/agent/tool_backend/apply_patch_tool.py +514 -0
- kolega_code/agent/tool_backend/base_tool.py +217 -0
- kolega_code/agent/tool_backend/browser_tool.py +271 -0
- kolega_code/agent/tool_backend/build_tool.py +93 -0
- kolega_code/agent/tool_backend/create_file_tool.py +52 -0
- kolega_code/agent/tool_backend/glob_tool.py +323 -0
- kolega_code/agent/tool_backend/list_directory_tool.py +300 -0
- kolega_code/agent/tool_backend/memory_tool.py +79 -0
- kolega_code/agent/tool_backend/read_file_tool.py +119 -0
- kolega_code/agent/tool_backend/replace_entire_file_tool.py +40 -0
- kolega_code/agent/tool_backend/replace_lines_tool.py +97 -0
- kolega_code/agent/tool_backend/search_and_replace_tool.py +146 -0
- kolega_code/agent/tool_backend/search_codebase_tool.py +377 -0
- kolega_code/agent/tool_backend/streaming_tool.py +47 -0
- kolega_code/agent/tool_backend/terminal_tool.py +643 -0
- kolega_code/agent/tool_backend/think_hard_tool.py +211 -0
- kolega_code/agent/tool_backend/web_fetch_tool.py +205 -0
- kolega_code/agent/tools.py +1704 -0
- kolega_code/agent/utils/commands.py +94 -0
- kolega_code/cli/__init__.py +1 -0
- kolega_code/cli/app.py +2756 -0
- kolega_code/cli/config.py +280 -0
- kolega_code/cli/connection.py +49 -0
- kolega_code/cli/file_index.py +147 -0
- kolega_code/cli/main.py +564 -0
- kolega_code/cli/mentions.py +155 -0
- kolega_code/cli/messages.py +89 -0
- kolega_code/cli/provider_registry.py +96 -0
- kolega_code/cli/session_store.py +207 -0
- kolega_code/cli/settings.py +87 -0
- kolega_code/cli/skills.py +409 -0
- kolega_code/cli/slash_commands.py +108 -0
- kolega_code/cli/tests/__init__.py +1 -0
- kolega_code/cli/tests/test_app.py +4251 -0
- kolega_code/cli/tests/test_cli_config.py +171 -0
- kolega_code/cli/tests/test_connection.py +26 -0
- kolega_code/cli/tests/test_file_index.py +103 -0
- kolega_code/cli/tests/test_main.py +455 -0
- kolega_code/cli/tests/test_mentions.py +108 -0
- kolega_code/cli/tests/test_session_store.py +67 -0
- kolega_code/cli/tests/test_settings.py +62 -0
- kolega_code/cli/tests/test_skills.py +157 -0
- kolega_code/cli/tests/test_slash_commands.py +88 -0
- kolega_code/cli/theme.py +180 -0
- kolega_code/config.py +154 -0
- kolega_code/events.py +202 -0
- kolega_code/llm/client.py +300 -0
- kolega_code/llm/exceptions.py +285 -0
- kolega_code/llm/instrumented_client.py +520 -0
- kolega_code/llm/models.py +1368 -0
- kolega_code/llm/providers/__init__.py +0 -0
- kolega_code/llm/providers/anthropic.py +387 -0
- kolega_code/llm/providers/base.py +71 -0
- kolega_code/llm/providers/google.py +157 -0
- kolega_code/llm/providers/models.py +37 -0
- kolega_code/llm/providers/openai.py +363 -0
- kolega_code/llm/ratelimit.py +40 -0
- kolega_code/llm/specs.py +67 -0
- kolega_code/llm/tool_execution_ids.py +18 -0
- kolega_code/models/__init__.py +9 -0
- kolega_code/models/sandbox_terminal_state.py +47 -0
- kolega_code/runtime.py +50 -0
- kolega_code/sandbox/README.md +200 -0
- kolega_code/sandbox/__init__.py +21 -0
- kolega_code/sandbox/async_filesystem.py +475 -0
- kolega_code/sandbox/base.py +297 -0
- kolega_code/sandbox/browser.py +25 -0
- kolega_code/sandbox/event_loop.py +43 -0
- kolega_code/sandbox/filesystem.py +341 -0
- kolega_code/sandbox/local.py +118 -0
- kolega_code/sandbox/serializer.py +175 -0
- kolega_code/sandbox/terminal.py +868 -0
- kolega_code/sandbox/utils.py +216 -0
- kolega_code/services/base.py +255 -0
- kolega_code/services/browser.py +444 -0
- kolega_code/services/file_system.py +749 -0
- kolega_code/services/html.py +221 -0
- kolega_code/services/terminal.py +903 -0
- kolega_code/tools/__init__.py +22 -0
- kolega_code/tools/core.py +33 -0
- kolega_code/tools/definitions.py +81 -0
- kolega_code/tools/registry.py +73 -0
- kolega_code-0.1.0.dist-info/METADATA +157 -0
- kolega_code-0.1.0.dist-info/RECORD +171 -0
- kolega_code-0.1.0.dist-info/WHEEL +4 -0
- kolega_code-0.1.0.dist-info/entry_points.txt +2 -0
- kolega_code-0.1.0.dist-info/licenses/LICENSE +21 -0
kolega_code/events.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
"""Agent events: the event model, connection-manager contract, and emitter.
|
|
2
|
+
|
|
3
|
+
AgentEvent is the wire format broadcast to hosts; AgentConnectionManager is
|
|
4
|
+
the abstract transport hosts implement; AgentEventEmitter is the agent-side
|
|
5
|
+
helper that constructs and broadcasts events.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import abc
|
|
9
|
+
import uuid
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import Any, Callable, Dict, Literal, Optional
|
|
13
|
+
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AgentEvent(BaseModel):
|
|
18
|
+
uuid: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
19
|
+
timestamp: str = Field(default_factory=lambda: datetime.now().isoformat())
|
|
20
|
+
event_type: Literal[
|
|
21
|
+
"system_message",
|
|
22
|
+
"chat_message",
|
|
23
|
+
"log_message",
|
|
24
|
+
"terminal_command",
|
|
25
|
+
"terminal_output",
|
|
26
|
+
"terminal_launched",
|
|
27
|
+
"terminal_closed",
|
|
28
|
+
"browser_launched",
|
|
29
|
+
"browser_closed",
|
|
30
|
+
"status_update",
|
|
31
|
+
"llm_status_update",
|
|
32
|
+
"credit_alert",
|
|
33
|
+
"llm_context_update",
|
|
34
|
+
"tool_streaming_update",
|
|
35
|
+
"memory_suggestions",
|
|
36
|
+
]
|
|
37
|
+
sender: str
|
|
38
|
+
recipient: Optional[str] = None
|
|
39
|
+
content: dict = Field(default_factory=dict)
|
|
40
|
+
is_streaming: bool = False
|
|
41
|
+
sub_agent_info: Optional[dict] = None
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class AgentStatus(Enum):
|
|
45
|
+
"""
|
|
46
|
+
Enum representing the current status of an agent.
|
|
47
|
+
|
|
48
|
+
Values:
|
|
49
|
+
STOPPED: The agent is not currently generating content.
|
|
50
|
+
GENERATING: The agent is actively generating content.
|
|
51
|
+
INTERRUPT_REQUESTED: The agent has received a request to stop generation but hasn't fully stopped yet.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
STOPPED = "stopped"
|
|
55
|
+
GENERATING = "generating"
|
|
56
|
+
INTERRUPT_REQUESTED = "interrupt_requested"
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class AgentConnectionManager(abc.ABC):
|
|
60
|
+
"""Abstract base class for agent connection managers."""
|
|
61
|
+
|
|
62
|
+
@abc.abstractmethod
|
|
63
|
+
async def connect(self, websocket: Any, workspace_id: str, thread_id: str, connection_type: str, user_info=None) -> None:
|
|
64
|
+
"""
|
|
65
|
+
Connect a client to a specific workspace, thread and connection type.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
websocket: The WebSocket connection
|
|
69
|
+
workspace_id: ID of the workspace to connect to
|
|
70
|
+
thread_id: ID of the thread to connect to
|
|
71
|
+
connection_type: Type of connection ('chat', 'terminal', or 'logs')
|
|
72
|
+
user_info: Optional user information dictionary
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
@abc.abstractmethod
|
|
76
|
+
def disconnect(self, websocket: Any, workspace_id: str, thread_id: str, connection_type: str) -> None:
|
|
77
|
+
"""
|
|
78
|
+
Disconnect a client from a specific workspace, thread and connection type.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
websocket: The WebSocket connection
|
|
82
|
+
workspace_id: ID of the workspace to disconnect from
|
|
83
|
+
thread_id: ID of the thread to disconnect from
|
|
84
|
+
connection_type: Type of connection ('chat', 'terminal', or 'logs')
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
@abc.abstractmethod
|
|
88
|
+
async def broadcast_event(self, event: AgentEvent, workspace_id: str, thread_id: str) -> None:
|
|
89
|
+
"""
|
|
90
|
+
Broadcast a chat message to all connected clients for a thread.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
event: The event to broadcast
|
|
94
|
+
workspace_id: ID of the workspace
|
|
95
|
+
thread_id: ID of the thread
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
@abc.abstractmethod
|
|
99
|
+
def get_connection_count(self, workspace_id: str, thread_id: str) -> dict:
|
|
100
|
+
"""
|
|
101
|
+
Get the number of connections for each type for a thread.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
workspace_id: ID of the workspace
|
|
105
|
+
thread_id: ID of the thread
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Dictionary with connection counts
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class AgentEventEmitter:
|
|
115
|
+
"""Constructs and broadcasts AgentEvents for one agent instance."""
|
|
116
|
+
|
|
117
|
+
def __init__(
|
|
118
|
+
self,
|
|
119
|
+
connection_manager: AgentConnectionManager,
|
|
120
|
+
workspace_id: str,
|
|
121
|
+
thread_id: str,
|
|
122
|
+
sender: str,
|
|
123
|
+
sub_agent_info_provider: Optional[Callable[[], Optional[Dict[str, Any]]]] = None,
|
|
124
|
+
) -> None:
|
|
125
|
+
self.connection_manager = connection_manager
|
|
126
|
+
self.workspace_id = workspace_id
|
|
127
|
+
self.thread_id = thread_id
|
|
128
|
+
self.sender = sender
|
|
129
|
+
# Callable rather than a value: sub-agent dispatch metadata is set on the
|
|
130
|
+
# agent after construction and changes per tool call.
|
|
131
|
+
self._sub_agent_info_provider = sub_agent_info_provider
|
|
132
|
+
|
|
133
|
+
async def emit(self, event: AgentEvent) -> None:
|
|
134
|
+
await self.connection_manager.broadcast_event(event, self.workspace_id, self.thread_id)
|
|
135
|
+
|
|
136
|
+
async def chat(
|
|
137
|
+
self,
|
|
138
|
+
message_type: str,
|
|
139
|
+
content: str,
|
|
140
|
+
*,
|
|
141
|
+
is_streaming: bool = False,
|
|
142
|
+
tool_description: Optional[str] = None,
|
|
143
|
+
tool_call_id: Optional[str] = None,
|
|
144
|
+
) -> None:
|
|
145
|
+
"""Send a chat_message event (responses, tool calls/results/errors)."""
|
|
146
|
+
sub_agent_info = self._sub_agent_info_provider() if self._sub_agent_info_provider else None
|
|
147
|
+
|
|
148
|
+
await self.emit(
|
|
149
|
+
AgentEvent(
|
|
150
|
+
sender=self.sender,
|
|
151
|
+
event_type="chat_message",
|
|
152
|
+
content={
|
|
153
|
+
"message_type": message_type,
|
|
154
|
+
"text": content,
|
|
155
|
+
"tool_description": tool_description,
|
|
156
|
+
"tool_call_id": tool_call_id,
|
|
157
|
+
},
|
|
158
|
+
timestamp=datetime.now().isoformat(),
|
|
159
|
+
is_streaming=is_streaming,
|
|
160
|
+
sub_agent_info=sub_agent_info,
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
async def context_update(
|
|
165
|
+
self,
|
|
166
|
+
*,
|
|
167
|
+
input_tokens: int,
|
|
168
|
+
model_context_length: int,
|
|
169
|
+
compression_threshold: float,
|
|
170
|
+
alert_level: str,
|
|
171
|
+
message: Optional[str],
|
|
172
|
+
) -> None:
|
|
173
|
+
"""Send an llm_context_update event describing context-window usage."""
|
|
174
|
+
usage_percentage = (input_tokens / model_context_length) * 100
|
|
175
|
+
|
|
176
|
+
await self.emit(
|
|
177
|
+
AgentEvent(
|
|
178
|
+
event_type="llm_context_update",
|
|
179
|
+
sender=self.sender,
|
|
180
|
+
content={
|
|
181
|
+
"input_tokens": input_tokens,
|
|
182
|
+
"max_tokens": model_context_length,
|
|
183
|
+
"usage_percentage": round(usage_percentage, 1),
|
|
184
|
+
"alert_level": alert_level,
|
|
185
|
+
"message": message,
|
|
186
|
+
"compression_threshold": compression_threshold * 100, # Convert to percentage
|
|
187
|
+
"will_compress_at": int(model_context_length * compression_threshold),
|
|
188
|
+
},
|
|
189
|
+
)
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
async def llm_status(self, status: str, message: str) -> None:
|
|
193
|
+
"""Send an llm_status_update event (e.g. provider overload notices)."""
|
|
194
|
+
await self.emit(
|
|
195
|
+
AgentEvent(
|
|
196
|
+
sender=self.sender,
|
|
197
|
+
event_type="llm_status_update",
|
|
198
|
+
content={"status": status, "message": message},
|
|
199
|
+
timestamp=datetime.now().isoformat(),
|
|
200
|
+
is_streaming=False,
|
|
201
|
+
)
|
|
202
|
+
)
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
"""Client library for interacting with Large Language Model (LLM) providers.
|
|
2
|
+
|
|
3
|
+
This module provides a unified interface for making requests to various LLM services
|
|
4
|
+
including Anthropic, OpenAI, and Google. The main class LLMClient handles:
|
|
5
|
+
|
|
6
|
+
- Provider-specific API initialization and authentication
|
|
7
|
+
- Rate limiting and retry logic
|
|
8
|
+
- Message formatting and parsing
|
|
9
|
+
- Streaming and non-streaming completions
|
|
10
|
+
- Token counting and budget management
|
|
11
|
+
- Tool/function calling capabilities
|
|
12
|
+
|
|
13
|
+
The client abstracts away provider differences to give applications a clean, consistent
|
|
14
|
+
API for using any supported LLM service interchangeably.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
client = LLMClient(
|
|
18
|
+
provider='openai',
|
|
19
|
+
api_key='sk-...',
|
|
20
|
+
max_retries=3,
|
|
21
|
+
requests_per_minute=60
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
response = await client.generate(
|
|
25
|
+
messages=message_history,
|
|
26
|
+
system=system_message,
|
|
27
|
+
temperature=0.7
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
The module also provides supporting classes and types for working with messages,
|
|
31
|
+
tools, and provider-specific parameters in a standardized way.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from typing import Any, AsyncContextManager, Coroutine, Dict, List, Optional, Type, Union
|
|
35
|
+
|
|
36
|
+
from .exceptions import map_to_llm_error
|
|
37
|
+
from .models import Message, MessageHistory, ToolDefinition
|
|
38
|
+
from .providers.anthropic import AnthropicProvider
|
|
39
|
+
from .providers.google import GoogleProvider
|
|
40
|
+
from .providers.models import GeminiThinkingConfig, GenerationParams, ReasoningEffort, ThinkingConfig, TokenCount
|
|
41
|
+
from .providers.openai import OpenAIProvider
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LLMClient:
|
|
45
|
+
"""A unified client for interacting with different LLM providers.
|
|
46
|
+
|
|
47
|
+
This class provides a consistent interface for making requests to various LLM providers
|
|
48
|
+
including Anthropic, OpenAI, Google, and others. It handles:
|
|
49
|
+
|
|
50
|
+
- Provider-specific API initialization and authentication
|
|
51
|
+
- Rate limiting and retry logic
|
|
52
|
+
- Message formatting and parsing
|
|
53
|
+
- Streaming and non-streaming completions
|
|
54
|
+
- Token counting and budget management
|
|
55
|
+
- Tool/function calling capabilities
|
|
56
|
+
|
|
57
|
+
The client abstracts away provider differences to give a clean, unified API for
|
|
58
|
+
applications to use any supported LLM service interchangeably.
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
provider: str,
|
|
64
|
+
api_key: str,
|
|
65
|
+
max_retries: int = 3,
|
|
66
|
+
requests_per_minute: Optional[int] = None,
|
|
67
|
+
tokens_per_minute: Optional[int] = None,
|
|
68
|
+
):
|
|
69
|
+
self.provider_name = provider.lower()
|
|
70
|
+
self._api_key = api_key # Store API key privately
|
|
71
|
+
self.provider = self._initialize_provider(
|
|
72
|
+
provider,
|
|
73
|
+
max_retries=max_retries,
|
|
74
|
+
requests_per_minute=requests_per_minute,
|
|
75
|
+
tokens_per_minute=tokens_per_minute,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def _initialize_provider(
|
|
79
|
+
self,
|
|
80
|
+
provider: str,
|
|
81
|
+
max_retries: int = 3,
|
|
82
|
+
requests_per_minute: Optional[int] = None,
|
|
83
|
+
tokens_per_minute: Optional[int] = None,
|
|
84
|
+
) -> Union[AnthropicProvider, OpenAIProvider, GoogleProvider]:
|
|
85
|
+
"""Initialize the appropriate LLM provider based on the provider name.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
provider (str): Name of the LLM provider to initialize (e.g. 'anthropic', 'openai', 'google')
|
|
89
|
+
max_retries (int, optional): Maximum number of retries for failed API calls. Defaults to 3.
|
|
90
|
+
requests_per_minute (int, optional): Maximum number of requests allowed per minute. Defaults to None.
|
|
91
|
+
tokens_per_minute (int, optional): Maximum number of tokens allowed per minute. Defaults to None.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Union[AnthropicProvider, OpenAIProvider, GoogleProvider]: Initialized provider instance
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
LLMError: If an unsupported provider name is specified or initialization fails
|
|
98
|
+
"""
|
|
99
|
+
try:
|
|
100
|
+
providers: Dict[str, Type[Union[AnthropicProvider, OpenAIProvider, GoogleProvider]]] = {
|
|
101
|
+
"anthropic": AnthropicProvider,
|
|
102
|
+
"openai": OpenAIProvider,
|
|
103
|
+
"together": OpenAIProvider,
|
|
104
|
+
"groq": OpenAIProvider,
|
|
105
|
+
"fireworks": OpenAIProvider,
|
|
106
|
+
"llama": OpenAIProvider,
|
|
107
|
+
"google": GoogleProvider,
|
|
108
|
+
"xai": OpenAIProvider,
|
|
109
|
+
"dashscope": OpenAIProvider,
|
|
110
|
+
"moonshot": AnthropicProvider,
|
|
111
|
+
"deepseek": AnthropicProvider,
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
base_urls: Dict[str, str] = {
|
|
115
|
+
"openai": "https://api.openai.com/v1/",
|
|
116
|
+
"together": "https://api.together.xyz/v1",
|
|
117
|
+
"groq": "https://api.groq.com/openai/v1",
|
|
118
|
+
"fireworks": "https://api.fireworks.ai/inference/v1",
|
|
119
|
+
"llama": "http://localhost:8000/v1",
|
|
120
|
+
"google": "https://generativelanguage.googleapis.com",
|
|
121
|
+
"xai": "https://api.x.ai/v1",
|
|
122
|
+
"dashscope": "https://dashscope-intl.aliyuncs.com/compatible-mode/v1",
|
|
123
|
+
"moonshot": "https://api.moonshot.ai/anthropic",
|
|
124
|
+
"deepseek": "https://api.deepseek.com/anthropic",
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
provider_class = providers.get(provider.lower())
|
|
128
|
+
if not provider_class:
|
|
129
|
+
raise ValueError(f"Unsupported provider: {provider}")
|
|
130
|
+
|
|
131
|
+
base_url = base_urls.get(provider.lower())
|
|
132
|
+
|
|
133
|
+
provider_kwargs = {}
|
|
134
|
+
if provider_class is AnthropicProvider:
|
|
135
|
+
provider_kwargs["provider_name"] = provider.lower()
|
|
136
|
+
|
|
137
|
+
return provider_class(
|
|
138
|
+
api_key=self._api_key,
|
|
139
|
+
max_retries=max_retries,
|
|
140
|
+
requests_per_minute=requests_per_minute,
|
|
141
|
+
tokens_per_minute=tokens_per_minute,
|
|
142
|
+
base_url=base_url,
|
|
143
|
+
**provider_kwargs,
|
|
144
|
+
)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
raise map_to_llm_error(e, provider) from e
|
|
147
|
+
|
|
148
|
+
async def count_tokens(
|
|
149
|
+
self,
|
|
150
|
+
messages: MessageHistory,
|
|
151
|
+
system: Optional[Message] = None,
|
|
152
|
+
tools: List[ToolDefinition] = [],
|
|
153
|
+
**kwargs: Dict[str, Any],
|
|
154
|
+
) -> TokenCount:
|
|
155
|
+
"""Count tokens for a list of messages and optional system message.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
messages (MessageHistory): The message history to count tokens for
|
|
159
|
+
system (Optional[Message]): Optional system message to include in token count
|
|
160
|
+
tools (List[ToolDefinition]): List of tool definitions to include in token count
|
|
161
|
+
**kwargs (Dict[str, Any]): Additional provider-specific arguments
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
TokenCount: Object containing input token count and optionally output token count
|
|
165
|
+
depending on provider capabilities
|
|
166
|
+
|
|
167
|
+
Raises:
|
|
168
|
+
LLMError: Any LLM-related error that occurs during token counting
|
|
169
|
+
"""
|
|
170
|
+
try:
|
|
171
|
+
model: Optional[str] = str(kwargs.pop("model", None))
|
|
172
|
+
return await self.provider.count_tokens(
|
|
173
|
+
messages=messages, system=system, model=model, tools=tools, **kwargs
|
|
174
|
+
)
|
|
175
|
+
except Exception as e:
|
|
176
|
+
raise map_to_llm_error(e, self.provider_name) from e
|
|
177
|
+
|
|
178
|
+
def _prepare_thinking_param(
|
|
179
|
+
self, thinking: Optional[Union[int, str]] = None
|
|
180
|
+
) -> Optional[Union[ThinkingConfig, ReasoningEffort, GeminiThinkingConfig]]:
|
|
181
|
+
"""Convert thinking parameter to appropriate provider-specific thinking configuration.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
thinking (Optional[Union[int, str]]): The thinking parameter to convert. Can be:
|
|
185
|
+
- For Anthropic: An integer specifying budget_tokens or None for default
|
|
186
|
+
- For Google: An integer to enable thoughts, or None/other to disable
|
|
187
|
+
- For OpenAI: A string matching ReasoningEffort enum values, or None for MEDIUM
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Optional[Union[ThinkingConfig, ReasoningEffort, GeminiThinkingConfig]]: The provider-specific
|
|
191
|
+
thinking configuration, or None if thinking parameter is None. Returns:
|
|
192
|
+
- ThinkingConfig for Anthropic with specified budget_tokens
|
|
193
|
+
- GeminiThinkingConfig for Google with include_thoughts flag
|
|
194
|
+
- ReasoningEffort enum for OpenAI providers
|
|
195
|
+
"""
|
|
196
|
+
if thinking is None:
|
|
197
|
+
return None
|
|
198
|
+
|
|
199
|
+
if self.provider_name in ["anthropic", "moonshot", "deepseek"]:
|
|
200
|
+
if isinstance(thinking, int):
|
|
201
|
+
return ThinkingConfig(budget_tokens=thinking)
|
|
202
|
+
return ThinkingConfig() # Use default budget_tokens
|
|
203
|
+
elif self.provider_name == "google":
|
|
204
|
+
if isinstance(thinking, int):
|
|
205
|
+
return GeminiThinkingConfig(include_thoughts=True)
|
|
206
|
+
return GeminiThinkingConfig(include_thoughts=False)
|
|
207
|
+
else: # OpenAI
|
|
208
|
+
if isinstance(thinking, str) and thinking.lower() in [e.value for e in ReasoningEffort]:
|
|
209
|
+
return ReasoningEffort(thinking.lower())
|
|
210
|
+
return ReasoningEffort.MEDIUM # Default to medium
|
|
211
|
+
|
|
212
|
+
async def generate(
|
|
213
|
+
self,
|
|
214
|
+
messages: MessageHistory,
|
|
215
|
+
system: Optional[Message] = None,
|
|
216
|
+
temperature: float = 1.0,
|
|
217
|
+
max_completion_tokens: Optional[int] = None,
|
|
218
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
219
|
+
thinking: Optional[Union[int, str]] = None,
|
|
220
|
+
params: Optional[GenerationParams] = None,
|
|
221
|
+
**kwargs: Dict[str, Any],
|
|
222
|
+
) -> Message:
|
|
223
|
+
"""Generate a complete response from the LLM provider.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
messages (MessageHistory): The conversation history to generate from
|
|
227
|
+
system (Optional[Message]): Optional system message to prepend
|
|
228
|
+
temperature (float): Sampling temperature, higher is more random (default: 1.0)
|
|
229
|
+
max_completion_tokens (Optional[int]): Maximum tokens to generate in response
|
|
230
|
+
tools (Optional[List[Dict[str, Any]]]): List of tool definitions for function calling
|
|
231
|
+
thinking (Optional[Union[int, str]]): Provider-specific thinking configuration:
|
|
232
|
+
- Anthropic: Integer budget_tokens or None for default
|
|
233
|
+
- Google: Integer to enable thoughts, None/other to disable
|
|
234
|
+
- OpenAI: String matching ReasoningEffort enum or None for MEDIUM
|
|
235
|
+
params (Optional[GenerationParams]): Override all parameters with a GenerationParams object
|
|
236
|
+
**kwargs: Additional provider-specific parameters
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
Message: The complete generated response message
|
|
240
|
+
|
|
241
|
+
Raises:
|
|
242
|
+
LLMError: Any LLM-related error that occurs during generation
|
|
243
|
+
"""
|
|
244
|
+
try:
|
|
245
|
+
if params is None:
|
|
246
|
+
params = GenerationParams(
|
|
247
|
+
temperature=temperature,
|
|
248
|
+
max_completion_tokens=max_completion_tokens,
|
|
249
|
+
tools=tools,
|
|
250
|
+
thinking=self._prepare_thinking_param(thinking),
|
|
251
|
+
)
|
|
252
|
+
return await self.provider.generate(messages, system, params, **kwargs)
|
|
253
|
+
except Exception as e:
|
|
254
|
+
raise map_to_llm_error(e, self.provider_name) from e
|
|
255
|
+
|
|
256
|
+
def stream(
|
|
257
|
+
self,
|
|
258
|
+
messages: MessageHistory,
|
|
259
|
+
system: Optional[Message] = None,
|
|
260
|
+
temperature: float = 1.0,
|
|
261
|
+
max_completion_tokens: Optional[int] = None,
|
|
262
|
+
tools: Optional[List[Dict[str, Any]]] = None,
|
|
263
|
+
thinking: Optional[Union[int, str]] = None,
|
|
264
|
+
params: Optional[GenerationParams] = None,
|
|
265
|
+
**kwargs: Dict[str, Any],
|
|
266
|
+
) -> Union[AsyncContextManager[Any], Coroutine[Any, Any, AsyncContextManager[Any]]]:
|
|
267
|
+
"""Generate a streaming response from the LLM provider.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
messages (MessageHistory): The conversation history to generate from
|
|
271
|
+
system (Optional[Message]): Optional system message to prepend
|
|
272
|
+
temperature (float): Sampling temperature, higher is more random (default: 1.0)
|
|
273
|
+
max_completion_tokens (Optional[int]): Maximum tokens to generate in response
|
|
274
|
+
tools (Optional[List[Dict[str, Any]]]): List of tool definitions for function calling
|
|
275
|
+
thinking (Optional[Union[int, str]]): Provider-specific thinking configuration:
|
|
276
|
+
- Anthropic: Integer budget_tokens or None for default
|
|
277
|
+
- Google: Integer to enable thoughts, None/other to disable
|
|
278
|
+
- OpenAI: String matching ReasoningEffort enum or None for MEDIUM
|
|
279
|
+
params (Optional[GenerationParams]): Override all parameters with a GenerationParams object
|
|
280
|
+
**kwargs: Additional provider-specific parameters
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
AsyncContextManager: A context manager that yields message chunks when streamed
|
|
284
|
+
|
|
285
|
+
Raises:
|
|
286
|
+
LLMError: Any LLM-related error that occurs during stream initialization
|
|
287
|
+
"""
|
|
288
|
+
try:
|
|
289
|
+
if params is None:
|
|
290
|
+
params = GenerationParams(
|
|
291
|
+
temperature=temperature,
|
|
292
|
+
max_completion_tokens=max_completion_tokens,
|
|
293
|
+
tools=tools,
|
|
294
|
+
thinking=self._prepare_thinking_param(thinking),
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Return the appropriate stream type for the provider
|
|
298
|
+
return self.provider.stream(messages, system, params, **kwargs)
|
|
299
|
+
except Exception as e:
|
|
300
|
+
raise map_to_llm_error(e, self.provider_name) from e
|