cade-cli 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cade_cli-0.3.3.dist-info/METADATA +151 -0
- cade_cli-0.3.3.dist-info/RECORD +44 -0
- cade_cli-0.3.3.dist-info/WHEEL +4 -0
- cade_cli-0.3.3.dist-info/entry_points.txt +2 -0
- cadecoder/__init__.py +1 -0
- cadecoder/ai/__init__.py +6 -0
- cadecoder/ai/prompts.py +572 -0
- cadecoder/cli/__init__.py +0 -0
- cadecoder/cli/app.py +147 -0
- cadecoder/cli/auth.py +483 -0
- cadecoder/cli/commands/__init__.py +5 -0
- cadecoder/cli/commands/auth.py +143 -0
- cadecoder/cli/commands/chat.py +264 -0
- cadecoder/cli/commands/mcp.py +477 -0
- cadecoder/cli/commands/tools.py +226 -0
- cadecoder/core/__init__.py +12 -0
- cadecoder/core/config.py +380 -0
- cadecoder/core/constants.py +281 -0
- cadecoder/core/errors.py +145 -0
- cadecoder/core/logging.py +148 -0
- cadecoder/core/types.py +235 -0
- cadecoder/core/utils.py +279 -0
- cadecoder/execution/__init__.py +46 -0
- cadecoder/execution/context_window.py +521 -0
- cadecoder/execution/orchestrator.py +562 -0
- cadecoder/execution/parallel.py +287 -0
- cadecoder/providers/__init__.py +60 -0
- cadecoder/providers/base.py +294 -0
- cadecoder/providers/openai.py +251 -0
- cadecoder/storage/__init__.py +0 -0
- cadecoder/storage/threads.py +489 -0
- cadecoder/templates/login_failed.html +21 -0
- cadecoder/templates/login_success.html +21 -0
- cadecoder/templates/styles.css +87 -0
- cadecoder/tools/__init__.py +19 -0
- cadecoder/tools/builtin.py +644 -0
- cadecoder/tools/filesystem.py +315 -0
- cadecoder/tools/git.py +221 -0
- cadecoder/tools/manager.py +1635 -0
- cadecoder/ui/__init__.py +7 -0
- cadecoder/ui/display.py +338 -0
- cadecoder/ui/input.py +145 -0
- cadecoder/ui/session.py +455 -0
- cadecoder/ui/state.py +20 -0
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
"""Parallel tool executor.
|
|
2
|
+
|
|
3
|
+
This module provides intelligent parallel execution of tools.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import asyncio
|
|
7
|
+
import json
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
from cadecoder.core.logging import log
|
|
11
|
+
from cadecoder.core.types import (
|
|
12
|
+
ResourceSet,
|
|
13
|
+
ToolCallList,
|
|
14
|
+
ToolExecutionResult,
|
|
15
|
+
ToolGroup,
|
|
16
|
+
ToolGroups,
|
|
17
|
+
extract_tool_output_content,
|
|
18
|
+
)
|
|
19
|
+
from cadecoder.tools.manager import ToolAuthorizationRequired
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from cadecoder.tools.manager import ToolManager
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ParallelToolExecutor:
|
|
26
|
+
"""Executes tools in parallel."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, tool_manager: "ToolManager", max_concurrent: int = 10):
|
|
29
|
+
"""Initialize parallel executor."""
|
|
30
|
+
self.tool_manager = tool_manager
|
|
31
|
+
self.max_concurrent = max_concurrent
|
|
32
|
+
self._semaphore = asyncio.Semaphore(max_concurrent)
|
|
33
|
+
self._interactive_lock = asyncio.Lock()
|
|
34
|
+
|
|
35
|
+
async def execute_tools(
|
|
36
|
+
self, tool_calls: ToolCallList, preserve_order: bool = True
|
|
37
|
+
) -> list[ToolExecutionResult]:
|
|
38
|
+
"""Execute tools with intelligent parallelization."""
|
|
39
|
+
if not tool_calls:
|
|
40
|
+
return []
|
|
41
|
+
|
|
42
|
+
independent_groups = self._analyze_tool_dependencies(tool_calls)
|
|
43
|
+
|
|
44
|
+
# Log detailed parallelization info
|
|
45
|
+
if len(independent_groups) == 1 and len(tool_calls) > 1:
|
|
46
|
+
log.info(f"Parallel execution: {len(tool_calls)} tools in 1 group (all parallel)")
|
|
47
|
+
elif len(independent_groups) == len(tool_calls):
|
|
48
|
+
log.info(f"Sequential execution: {len(tool_calls)} tools (dependencies detected)")
|
|
49
|
+
else:
|
|
50
|
+
group_sizes = [len(g) for g in independent_groups]
|
|
51
|
+
log.info(
|
|
52
|
+
f"Mixed execution: {len(tool_calls)} tools in "
|
|
53
|
+
f"{len(independent_groups)} groups {group_sizes}"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
all_results: list[ToolExecutionResult] = []
|
|
57
|
+
|
|
58
|
+
for group_idx, group in enumerate(independent_groups):
|
|
59
|
+
log.debug(f"Executing group {group_idx + 1} with {len(group)} tools")
|
|
60
|
+
|
|
61
|
+
group_tasks: list[asyncio.Task[ToolExecutionResult]] = []
|
|
62
|
+
for tool_call in group:
|
|
63
|
+
task = asyncio.create_task(self._execute_single_tool_with_semaphore(tool_call))
|
|
64
|
+
group_tasks.append(task)
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
group_results: list[ToolExecutionResult | BaseException] = await asyncio.gather(
|
|
68
|
+
*group_tasks, return_exceptions=True
|
|
69
|
+
)
|
|
70
|
+
except asyncio.CancelledError:
|
|
71
|
+
for task in group_tasks:
|
|
72
|
+
if not task.done():
|
|
73
|
+
task.cancel()
|
|
74
|
+
await asyncio.gather(*group_tasks, return_exceptions=True)
|
|
75
|
+
raise
|
|
76
|
+
|
|
77
|
+
for tool_call, result in zip(group, group_results):
|
|
78
|
+
if isinstance(result, Exception):
|
|
79
|
+
auth_url = None
|
|
80
|
+
# Check for authorization exception
|
|
81
|
+
if isinstance(result, ToolAuthorizationRequired):
|
|
82
|
+
auth_url = result.authorization_url
|
|
83
|
+
log.error(f"Authorization required: {result}")
|
|
84
|
+
else:
|
|
85
|
+
log.error(f"Tool execution failed: {result}")
|
|
86
|
+
|
|
87
|
+
all_results.append(
|
|
88
|
+
ToolExecutionResult(
|
|
89
|
+
tool_call_id=tool_call.get("id", ""),
|
|
90
|
+
name=tool_call.get("function", {}).get("name", ""),
|
|
91
|
+
content=str(result),
|
|
92
|
+
status="error",
|
|
93
|
+
error=str(result),
|
|
94
|
+
authorization_url=auth_url,
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
if isinstance(result, ToolExecutionResult):
|
|
99
|
+
all_results.append(result)
|
|
100
|
+
|
|
101
|
+
if preserve_order:
|
|
102
|
+
ordered_results = self._restore_order(tool_calls, all_results)
|
|
103
|
+
return ordered_results
|
|
104
|
+
|
|
105
|
+
return all_results
|
|
106
|
+
|
|
107
|
+
async def _execute_single_tool_with_semaphore(
|
|
108
|
+
self, tool_call: dict[str, Any]
|
|
109
|
+
) -> ToolExecutionResult:
|
|
110
|
+
"""Execute a single tool with semaphore for rate limiting."""
|
|
111
|
+
async with self._semaphore:
|
|
112
|
+
if await self._is_interactive_call(tool_call):
|
|
113
|
+
async with self._interactive_lock:
|
|
114
|
+
return await self._execute_single_tool(tool_call)
|
|
115
|
+
return await self._execute_single_tool(tool_call)
|
|
116
|
+
|
|
117
|
+
async def _is_interactive_call(self, tool_call: dict[str, Any]) -> bool:
|
|
118
|
+
"""Determine if the tool call requires exclusive terminal access."""
|
|
119
|
+
function = tool_call.get("function", {})
|
|
120
|
+
name = function.get("name", "")
|
|
121
|
+
|
|
122
|
+
if hasattr(self.tool_manager, "is_interactive_tool"):
|
|
123
|
+
try:
|
|
124
|
+
return bool(self.tool_manager.is_interactive_tool(name))
|
|
125
|
+
except Exception:
|
|
126
|
+
return False
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
async def _execute_single_tool(
|
|
130
|
+
self,
|
|
131
|
+
tool_call: dict[str, Any],
|
|
132
|
+
timeout: float = 120.0,
|
|
133
|
+
) -> ToolExecutionResult:
|
|
134
|
+
"""Execute a single tool with timeout and cancellation support."""
|
|
135
|
+
function = tool_call.get("function", {})
|
|
136
|
+
name = function.get("name", "unknown")
|
|
137
|
+
tool_call_id = tool_call.get("id", "unknown")
|
|
138
|
+
|
|
139
|
+
try:
|
|
140
|
+
args = json.loads(function.get("arguments", "{}"))
|
|
141
|
+
log.debug(f"Executing tool: {name} with args: {args}")
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
result_content = await asyncio.wait_for(
|
|
145
|
+
self.tool_manager.execute(name, args),
|
|
146
|
+
timeout=timeout,
|
|
147
|
+
)
|
|
148
|
+
except TimeoutError:
|
|
149
|
+
log.warning(f"Tool {name} timed out after {timeout}s")
|
|
150
|
+
return ToolExecutionResult(
|
|
151
|
+
tool_call_id=tool_call_id,
|
|
152
|
+
name=name,
|
|
153
|
+
content=f"Tool execution timed out after {timeout} seconds",
|
|
154
|
+
status="error",
|
|
155
|
+
error="timeout",
|
|
156
|
+
authorization_url=None,
|
|
157
|
+
)
|
|
158
|
+
except asyncio.CancelledError:
|
|
159
|
+
log.info(f"Tool {name} was cancelled")
|
|
160
|
+
return ToolExecutionResult(
|
|
161
|
+
tool_call_id=tool_call_id,
|
|
162
|
+
name=name,
|
|
163
|
+
content="Tool execution was cancelled",
|
|
164
|
+
status="cancelled",
|
|
165
|
+
error="cancelled",
|
|
166
|
+
authorization_url=None,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
actual_content = extract_tool_output_content(result_content)
|
|
170
|
+
|
|
171
|
+
return ToolExecutionResult(
|
|
172
|
+
tool_call_id=tool_call_id,
|
|
173
|
+
name=name,
|
|
174
|
+
content=actual_content,
|
|
175
|
+
status="success",
|
|
176
|
+
error=None,
|
|
177
|
+
authorization_url=None,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
except Exception as e:
|
|
181
|
+
auth_url = None
|
|
182
|
+
if isinstance(e, ToolAuthorizationRequired):
|
|
183
|
+
auth_url = e.authorization_url
|
|
184
|
+
log.error(f"Authorization required for tool {name}")
|
|
185
|
+
else:
|
|
186
|
+
log.error(f"Tool execution failed for {name}: {e}")
|
|
187
|
+
|
|
188
|
+
return ToolExecutionResult(
|
|
189
|
+
tool_call_id=tool_call_id,
|
|
190
|
+
name=name,
|
|
191
|
+
content=str(e),
|
|
192
|
+
status="error",
|
|
193
|
+
error=str(e),
|
|
194
|
+
authorization_url=auth_url,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def _analyze_tool_dependencies(self, tool_calls: ToolCallList) -> ToolGroups:
|
|
198
|
+
"""Analyze tool dependencies to determine execution groups.
|
|
199
|
+
|
|
200
|
+
Tools are grouped for parallel execution unless they operate on
|
|
201
|
+
the same resource (detected by examining argument values).
|
|
202
|
+
"""
|
|
203
|
+
if len(tool_calls) <= 1:
|
|
204
|
+
return [tool_calls] if tool_calls else []
|
|
205
|
+
|
|
206
|
+
# Extract resources for each tool call
|
|
207
|
+
tool_resources: list[ResourceSet] = []
|
|
208
|
+
for tool_call in tool_calls:
|
|
209
|
+
function = tool_call.get("function", {})
|
|
210
|
+
args_str = function.get("arguments", "{}")
|
|
211
|
+
try:
|
|
212
|
+
args = json.loads(args_str)
|
|
213
|
+
except json.JSONDecodeError:
|
|
214
|
+
args = {}
|
|
215
|
+
resources = self._extract_resources(args)
|
|
216
|
+
tool_resources.append(resources)
|
|
217
|
+
|
|
218
|
+
# Group tools - only serialize if same resource is accessed
|
|
219
|
+
groups: ToolGroups = []
|
|
220
|
+
current_group: ToolGroup = []
|
|
221
|
+
current_group_resources: ResourceSet = set()
|
|
222
|
+
|
|
223
|
+
for i, tool_call in enumerate(tool_calls):
|
|
224
|
+
resources = tool_resources[i]
|
|
225
|
+
|
|
226
|
+
# Check if this tool conflicts with current group
|
|
227
|
+
if resources & current_group_resources:
|
|
228
|
+
# Conflict - start new group
|
|
229
|
+
if current_group:
|
|
230
|
+
groups.append(current_group)
|
|
231
|
+
current_group = [tool_call]
|
|
232
|
+
current_group_resources = resources.copy()
|
|
233
|
+
else:
|
|
234
|
+
# No conflict - add to current group
|
|
235
|
+
current_group.append(tool_call)
|
|
236
|
+
current_group_resources.update(resources)
|
|
237
|
+
|
|
238
|
+
if current_group:
|
|
239
|
+
groups.append(current_group)
|
|
240
|
+
|
|
241
|
+
return groups if groups else [[tc] for tc in tool_calls]
|
|
242
|
+
|
|
243
|
+
def _extract_resources(self, obj: Any) -> ResourceSet:
|
|
244
|
+
"""Extract resource identifiers from tool arguments.
|
|
245
|
+
|
|
246
|
+
Looks for any string values that appear to be paths or identifiers.
|
|
247
|
+
"""
|
|
248
|
+
resources: set[str] = set()
|
|
249
|
+
|
|
250
|
+
if isinstance(obj, str):
|
|
251
|
+
# Include strings that look like paths or identifiers
|
|
252
|
+
if "/" in obj or "\\" in obj or "." in obj:
|
|
253
|
+
resources.add(obj)
|
|
254
|
+
elif isinstance(obj, dict):
|
|
255
|
+
for value in obj.values():
|
|
256
|
+
resources.update(self._extract_resources(value))
|
|
257
|
+
elif isinstance(obj, list):
|
|
258
|
+
for item in obj:
|
|
259
|
+
resources.update(self._extract_resources(item))
|
|
260
|
+
|
|
261
|
+
return resources
|
|
262
|
+
|
|
263
|
+
def _restore_order(
|
|
264
|
+
self, original_calls: list[dict[str, Any]], results: list[ToolExecutionResult]
|
|
265
|
+
) -> list[ToolExecutionResult]:
|
|
266
|
+
"""Restore original order of results."""
|
|
267
|
+
result_map = {r.tool_call_id: r for r in results}
|
|
268
|
+
ordered = []
|
|
269
|
+
|
|
270
|
+
for call in original_calls:
|
|
271
|
+
call_id = call.get("id", "unknown")
|
|
272
|
+
if call_id in result_map:
|
|
273
|
+
ordered.append(result_map[call_id])
|
|
274
|
+
else:
|
|
275
|
+
log.warning(f"No result found for tool call {call_id}")
|
|
276
|
+
ordered.append(
|
|
277
|
+
ToolExecutionResult(
|
|
278
|
+
tool_call_id=call_id,
|
|
279
|
+
name=call.get("function", {}).get("name", ""),
|
|
280
|
+
content="No result",
|
|
281
|
+
status="error",
|
|
282
|
+
error="Result not found",
|
|
283
|
+
authorization_url=None,
|
|
284
|
+
)
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
return ordered
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
"""Provider package for LLM integrations.
|
|
2
|
+
|
|
3
|
+
This package provides a clean, unified interface for different LLM providers
|
|
4
|
+
with automatic registration and easy access.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from cadecoder.core.logging import log
|
|
8
|
+
from cadecoder.providers.base import (
|
|
9
|
+
Provider,
|
|
10
|
+
ProviderError,
|
|
11
|
+
ProviderRegistry,
|
|
12
|
+
ProviderRequest,
|
|
13
|
+
ProviderResponse,
|
|
14
|
+
ProviderType,
|
|
15
|
+
StreamEvent,
|
|
16
|
+
provider_registry,
|
|
17
|
+
)
|
|
18
|
+
from cadecoder.providers.openai import OpenAIProvider
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def initialize_providers() -> None:
|
|
22
|
+
"""Initialize and register available providers."""
|
|
23
|
+
import os
|
|
24
|
+
|
|
25
|
+
# Register OpenAI if configured
|
|
26
|
+
if os.environ.get("OPENAI_API_KEY"):
|
|
27
|
+
try:
|
|
28
|
+
openai_provider = OpenAIProvider()
|
|
29
|
+
provider_registry.register(openai_provider)
|
|
30
|
+
provider_registry.set_default(ProviderType.OPENAI)
|
|
31
|
+
log.info("Registered OpenAI provider")
|
|
32
|
+
except Exception as e:
|
|
33
|
+
log.warning(f"Failed to register OpenAI provider: {e}")
|
|
34
|
+
|
|
35
|
+
# Add other providers as they're implemented
|
|
36
|
+
# Example:
|
|
37
|
+
# if config.settings.anthropic_api_key:
|
|
38
|
+
# anthropic_provider = AnthropicProvider()
|
|
39
|
+
# provider_registry.register(anthropic_provider)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# Auto-initialize on import
|
|
43
|
+
initialize_providers()
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
__all__ = [
|
|
47
|
+
# Base classes
|
|
48
|
+
"Provider",
|
|
49
|
+
"ProviderError",
|
|
50
|
+
"ProviderRegistry",
|
|
51
|
+
"ProviderRequest",
|
|
52
|
+
"ProviderResponse",
|
|
53
|
+
"ProviderType",
|
|
54
|
+
"StreamEvent",
|
|
55
|
+
"provider_registry",
|
|
56
|
+
# Providers
|
|
57
|
+
"OpenAIProvider",
|
|
58
|
+
# Functions
|
|
59
|
+
"initialize_providers",
|
|
60
|
+
]
|
|
@@ -0,0 +1,294 @@
|
|
|
1
|
+
"""Base provider interface for LLM integrations.
|
|
2
|
+
|
|
3
|
+
This module defines a clean, simple interface that all LLM providers
|
|
4
|
+
must implement, ensuring consistent behavior across different AI services.
|
|
5
|
+
Also includes provider-related utility functions.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
from collections.abc import AsyncIterator
|
|
10
|
+
from enum import Enum
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from pydantic import BaseModel, Field
|
|
14
|
+
|
|
15
|
+
from cadecoder.core.types import ConversationMessageDict
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ProviderType(str, Enum):
|
|
19
|
+
"""Supported LLM providers."""
|
|
20
|
+
|
|
21
|
+
OPENAI = "openai"
|
|
22
|
+
ANTHROPIC = "anthropic"
|
|
23
|
+
ARCADE = "arcade"
|
|
24
|
+
LOCAL = "local"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class ProviderRequest(BaseModel):
|
|
28
|
+
"""Unified request format for all providers."""
|
|
29
|
+
|
|
30
|
+
messages: list[ConversationMessageDict] = Field(..., description="Conversation messages")
|
|
31
|
+
model: str = Field(..., description="Model to use")
|
|
32
|
+
tools: list[dict[str, Any]] | None = Field(None, description="Available tools in OpenAI format")
|
|
33
|
+
temperature: float = Field(0.7, description="Response randomness")
|
|
34
|
+
max_tokens: int | None = Field(None, description="Maximum response tokens")
|
|
35
|
+
stream: bool = Field(False, description="Whether to stream responses")
|
|
36
|
+
|
|
37
|
+
# Optional parameters
|
|
38
|
+
tool_choice: str | dict[str, Any] | None = Field(None, description="Tool selection preference")
|
|
39
|
+
system_prompt: str | None = Field(None, description="System prompt override")
|
|
40
|
+
metadata: dict[str, Any] = Field(default_factory=dict, description="Provider-specific metadata")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ProviderResponse(BaseModel):
|
|
44
|
+
"""Unified response format from providers."""
|
|
45
|
+
|
|
46
|
+
content: str | None = Field(None, description="Response content")
|
|
47
|
+
tool_calls: list[dict[str, Any]] | None = Field(None, description="Tool calls requested")
|
|
48
|
+
finish_reason: str = Field("stop", description="Why generation stopped")
|
|
49
|
+
usage: dict[str, int] | None = Field(None, description="Token usage statistics")
|
|
50
|
+
model: str | None = Field(None, description="Model that was used")
|
|
51
|
+
provider: ProviderType | None = Field(None, description="Provider type")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class StreamEvent(BaseModel):
|
|
55
|
+
"""Streaming event from providers."""
|
|
56
|
+
|
|
57
|
+
type: str = Field(..., description="Event type (see ExecutionEventType for allowed values)")
|
|
58
|
+
content: str | None = Field(None, description="Content delta")
|
|
59
|
+
metadata: dict[str, Any] = Field(default_factory=dict, description="Event metadata")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class Provider(ABC):
|
|
63
|
+
"""Abstract base class for LLM providers.
|
|
64
|
+
|
|
65
|
+
This defines the minimal interface that all providers must implement.
|
|
66
|
+
Each provider adapter translates between this interface and the
|
|
67
|
+
provider's native API.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def provider_type(self) -> ProviderType:
|
|
73
|
+
"""Return the provider type."""
|
|
74
|
+
...
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def supported_models(self) -> list[str]:
|
|
79
|
+
"""Return list of supported model names."""
|
|
80
|
+
...
|
|
81
|
+
|
|
82
|
+
@abstractmethod
|
|
83
|
+
async def complete(self, request: ProviderRequest) -> ProviderResponse:
|
|
84
|
+
"""Create a completion.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
request: The request to process
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
The provider's response
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
ProviderError: If the request fails
|
|
94
|
+
"""
|
|
95
|
+
...
|
|
96
|
+
|
|
97
|
+
@abstractmethod
|
|
98
|
+
def stream(self, request: ProviderRequest) -> AsyncIterator[StreamEvent]:
|
|
99
|
+
"""Stream a completion.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
request: The request to process
|
|
103
|
+
|
|
104
|
+
Yields:
|
|
105
|
+
Stream events as they arrive
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
ProviderError: If the request fails
|
|
109
|
+
"""
|
|
110
|
+
...
|
|
111
|
+
|
|
112
|
+
def supports_feature(self, feature: str) -> bool:
|
|
113
|
+
"""Check if provider supports a feature.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
feature: Feature name (e.g., "tools", "streaming", "vision")
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
True if supported
|
|
120
|
+
"""
|
|
121
|
+
# Default implementation - override in subclasses
|
|
122
|
+
return feature in ["streaming", "tools"]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class ProviderError(Exception):
|
|
126
|
+
"""Base exception for provider errors."""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
message: str,
|
|
131
|
+
provider: ProviderType | None = None,
|
|
132
|
+
status_code: int | None = None,
|
|
133
|
+
details: dict[str, Any] | None = None,
|
|
134
|
+
):
|
|
135
|
+
super().__init__(message)
|
|
136
|
+
self.provider = provider
|
|
137
|
+
self.status_code = status_code
|
|
138
|
+
self.details = details or {}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class ProviderRegistry:
|
|
142
|
+
"""Registry for managing available providers."""
|
|
143
|
+
|
|
144
|
+
def __init__(self):
|
|
145
|
+
self._providers: dict[ProviderType, Provider] = {}
|
|
146
|
+
self._default_provider: ProviderType | None = None
|
|
147
|
+
|
|
148
|
+
def register(self, provider: Provider) -> None:
|
|
149
|
+
"""Register a provider.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
provider: Provider instance to register
|
|
153
|
+
"""
|
|
154
|
+
self._providers[provider.provider_type] = provider
|
|
155
|
+
|
|
156
|
+
def get(self, provider_type: ProviderType) -> Provider | None:
|
|
157
|
+
"""Get a provider by type.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
provider_type: Type of provider to get
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
Provider instance or None if not found
|
|
164
|
+
"""
|
|
165
|
+
return self._providers.get(provider_type)
|
|
166
|
+
|
|
167
|
+
def get_default(self) -> Provider | None:
|
|
168
|
+
"""Get the default provider.
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
Default provider or None if not set
|
|
172
|
+
"""
|
|
173
|
+
if self._default_provider:
|
|
174
|
+
return self.get(self._default_provider)
|
|
175
|
+
# Return first available provider
|
|
176
|
+
return next(iter(self._providers.values())) if self._providers else None
|
|
177
|
+
|
|
178
|
+
def set_default(self, provider_type: ProviderType) -> None:
|
|
179
|
+
"""Set the default provider.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
provider_type: Provider type to use as default
|
|
183
|
+
"""
|
|
184
|
+
if provider_type not in self._providers:
|
|
185
|
+
raise ValueError(f"Provider {provider_type} not registered")
|
|
186
|
+
self._default_provider = provider_type
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def available_providers(self) -> list[ProviderType]:
|
|
190
|
+
"""Get list of available provider types."""
|
|
191
|
+
return list(self._providers.keys())
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# Global registry instance
|
|
195
|
+
provider_registry = ProviderRegistry()
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
# ============================================================================
|
|
199
|
+
# Provider Utility Functions
|
|
200
|
+
# ============================================================================
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def get_total_tokens(usage_source: Any) -> int:
|
|
204
|
+
"""Return total tokens from a provider response or usage mapping.
|
|
205
|
+
|
|
206
|
+
Accepts either a provider response object (with optional ``usage`` field)
|
|
207
|
+
or a raw usage dictionary/object. Prefers an explicit ``total_tokens`` when
|
|
208
|
+
present; otherwise computes a total from known token fields.
|
|
209
|
+
|
|
210
|
+
Known field sets:
|
|
211
|
+
- OpenAI: ``prompt_tokens``, ``completion_tokens``, optional ``reasoning_tokens``
|
|
212
|
+
- Anthropic: ``input_tokens``, ``output_tokens``
|
|
213
|
+
- Generic: ``total_tokens``
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
usage_source: Provider response or usage mapping/object
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
Total token count as an integer (0 if unavailable)
|
|
220
|
+
"""
|
|
221
|
+
if usage_source is None:
|
|
222
|
+
return 0
|
|
223
|
+
|
|
224
|
+
# If a ProviderResponse-like object, extract its usage field first
|
|
225
|
+
usage = getattr(usage_source, "usage", None)
|
|
226
|
+
candidate = usage if usage is not None else usage_source
|
|
227
|
+
|
|
228
|
+
# Handle dict-like usage
|
|
229
|
+
if isinstance(candidate, dict):
|
|
230
|
+
# Prefer explicit total_tokens
|
|
231
|
+
if "total_tokens" in candidate and candidate["total_tokens"] is not None:
|
|
232
|
+
try:
|
|
233
|
+
return int(candidate.get("total_tokens", 0))
|
|
234
|
+
except Exception:
|
|
235
|
+
return 0
|
|
236
|
+
|
|
237
|
+
# OpenAI-style fields
|
|
238
|
+
prompt = candidate.get("prompt_tokens")
|
|
239
|
+
completion = candidate.get("completion_tokens")
|
|
240
|
+
reasoning = candidate.get("reasoning_tokens")
|
|
241
|
+
if prompt is not None or completion is not None or reasoning is not None:
|
|
242
|
+
try:
|
|
243
|
+
total = int(prompt or 0) + int(completion or 0)
|
|
244
|
+
# Only include reasoning if present and not already included
|
|
245
|
+
# Some APIs may already include reasoning in completion; when in doubt,
|
|
246
|
+
# prefer inclusion to avoid undercounting.
|
|
247
|
+
total += int(reasoning or 0)
|
|
248
|
+
return total
|
|
249
|
+
except Exception:
|
|
250
|
+
return 0
|
|
251
|
+
|
|
252
|
+
# Anthropic-style fields
|
|
253
|
+
input_tokens = candidate.get("input_tokens")
|
|
254
|
+
output_tokens = candidate.get("output_tokens")
|
|
255
|
+
if input_tokens is not None or output_tokens is not None:
|
|
256
|
+
try:
|
|
257
|
+
return int(input_tokens or 0) + int(output_tokens or 0)
|
|
258
|
+
except Exception:
|
|
259
|
+
return 0
|
|
260
|
+
|
|
261
|
+
return 0
|
|
262
|
+
|
|
263
|
+
# Handle object-style usage with attributes
|
|
264
|
+
# Prefer explicit total_tokens
|
|
265
|
+
if hasattr(candidate, "total_tokens"):
|
|
266
|
+
try:
|
|
267
|
+
return int(getattr(candidate, "total_tokens", 0) or 0)
|
|
268
|
+
except Exception:
|
|
269
|
+
return 0
|
|
270
|
+
|
|
271
|
+
# OpenAI-style attributes
|
|
272
|
+
has_prompt = hasattr(candidate, "prompt_tokens")
|
|
273
|
+
has_completion = hasattr(candidate, "completion_tokens")
|
|
274
|
+
has_reasoning = hasattr(candidate, "reasoning_tokens")
|
|
275
|
+
if has_prompt or has_completion or has_reasoning:
|
|
276
|
+
try:
|
|
277
|
+
total = int(getattr(candidate, "prompt_tokens", 0) or 0) + int(
|
|
278
|
+
getattr(candidate, "completion_tokens", 0) or 0
|
|
279
|
+
)
|
|
280
|
+
total += int(getattr(candidate, "reasoning_tokens", 0) or 0)
|
|
281
|
+
return total
|
|
282
|
+
except Exception:
|
|
283
|
+
return 0
|
|
284
|
+
|
|
285
|
+
# Anthropic-style attributes
|
|
286
|
+
if hasattr(candidate, "input_tokens") or hasattr(candidate, "output_tokens"):
|
|
287
|
+
try:
|
|
288
|
+
return int(getattr(candidate, "input_tokens", 0) or 0) + int(
|
|
289
|
+
getattr(candidate, "output_tokens", 0) or 0
|
|
290
|
+
)
|
|
291
|
+
except Exception:
|
|
292
|
+
return 0
|
|
293
|
+
|
|
294
|
+
return 0
|