voxagent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- voxagent/__init__.py +143 -0
- voxagent/_version.py +5 -0
- voxagent/agent/__init__.py +32 -0
- voxagent/agent/abort.py +178 -0
- voxagent/agent/core.py +902 -0
- voxagent/code/__init__.py +9 -0
- voxagent/mcp/__init__.py +16 -0
- voxagent/mcp/manager.py +188 -0
- voxagent/mcp/tool.py +152 -0
- voxagent/providers/__init__.py +110 -0
- voxagent/providers/anthropic.py +498 -0
- voxagent/providers/augment.py +293 -0
- voxagent/providers/auth.py +116 -0
- voxagent/providers/base.py +268 -0
- voxagent/providers/chatgpt.py +415 -0
- voxagent/providers/claudecode.py +162 -0
- voxagent/providers/cli_base.py +265 -0
- voxagent/providers/codex.py +183 -0
- voxagent/providers/failover.py +90 -0
- voxagent/providers/google.py +532 -0
- voxagent/providers/groq.py +96 -0
- voxagent/providers/ollama.py +425 -0
- voxagent/providers/openai.py +435 -0
- voxagent/providers/registry.py +175 -0
- voxagent/py.typed +1 -0
- voxagent/security/__init__.py +14 -0
- voxagent/security/events.py +75 -0
- voxagent/security/filter.py +169 -0
- voxagent/security/registry.py +87 -0
- voxagent/session/__init__.py +39 -0
- voxagent/session/compaction.py +237 -0
- voxagent/session/lock.py +103 -0
- voxagent/session/model.py +109 -0
- voxagent/session/storage.py +184 -0
- voxagent/streaming/__init__.py +52 -0
- voxagent/streaming/emitter.py +286 -0
- voxagent/streaming/events.py +255 -0
- voxagent/subagent/__init__.py +20 -0
- voxagent/subagent/context.py +124 -0
- voxagent/subagent/definition.py +172 -0
- voxagent/tools/__init__.py +32 -0
- voxagent/tools/context.py +50 -0
- voxagent/tools/decorator.py +175 -0
- voxagent/tools/definition.py +131 -0
- voxagent/tools/executor.py +109 -0
- voxagent/tools/policy.py +89 -0
- voxagent/tools/registry.py +89 -0
- voxagent/types/__init__.py +46 -0
- voxagent/types/messages.py +134 -0
- voxagent/types/run.py +176 -0
- voxagent-0.1.0.dist-info/METADATA +186 -0
- voxagent-0.1.0.dist-info/RECORD +53 -0
- voxagent-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""Augment (Auggie) CLI provider.
|
|
2
|
+
|
|
3
|
+
This provider wraps the Auggie CLI directly using subprocess.
|
|
4
|
+
It requires:
|
|
5
|
+
1. The auggie CLI to be installed: brew install augment-cli
|
|
6
|
+
2. Authentication via: auggie login OR a vault service providing tokens
|
|
7
|
+
|
|
8
|
+
Models available:
|
|
9
|
+
- haiku4.5: Fast and efficient responses
|
|
10
|
+
- sonnet4.5: Great for everyday tasks
|
|
11
|
+
- sonnet4: Legacy model
|
|
12
|
+
- opus4.5: Best for complex tasks (Claude Opus 4.5)
|
|
13
|
+
- gpt5: OpenAI GPT-5 legacy
|
|
14
|
+
- gpt5.1: Strong reasoning and planning
|
|
15
|
+
|
|
16
|
+
Vault Integration:
|
|
17
|
+
When a vault_service is provided (any object implementing VaultProtocol with a
|
|
18
|
+
get(key) method), the provider checks for an 'augment_access_token'. If found,
|
|
19
|
+
it sets the AUGMENT_SESSION_AUTH environment variable when invoking the CLI.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
from __future__ import annotations
|
|
23
|
+
|
|
24
|
+
import asyncio
|
|
25
|
+
import logging
|
|
26
|
+
import os
|
|
27
|
+
from collections.abc import AsyncIterator
|
|
28
|
+
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
|
29
|
+
|
|
30
|
+
from voxagent.providers.base import (
|
|
31
|
+
AbortSignal,
|
|
32
|
+
ErrorChunk,
|
|
33
|
+
MessageEndChunk,
|
|
34
|
+
StreamChunk,
|
|
35
|
+
TextDeltaChunk,
|
|
36
|
+
)
|
|
37
|
+
from voxagent.providers.cli_base import CLINotFoundError, CLIProvider
|
|
38
|
+
from voxagent.types import Message
|
|
39
|
+
|
|
40
|
+
if TYPE_CHECKING:
|
|
41
|
+
pass # No external type imports needed
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@runtime_checkable
|
|
45
|
+
class VaultProtocol(Protocol):
|
|
46
|
+
"""Protocol for vault services that can provide credentials.
|
|
47
|
+
|
|
48
|
+
This allows voxagent to work with any vault implementation without
|
|
49
|
+
depending on voxdomus directly.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def get(self, key: str) -> str:
|
|
53
|
+
"""Get a credential value by key.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
key: The credential key to retrieve.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
The credential value.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
KeyError: If the credential doesn't exist.
|
|
63
|
+
"""
|
|
64
|
+
...
|
|
65
|
+
|
|
66
|
+
def exists(self, key: str) -> bool:
|
|
67
|
+
"""Check if a credential exists.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
key: The credential key to check.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
True if the credential exists, False otherwise.
|
|
74
|
+
"""
|
|
75
|
+
...
|
|
76
|
+
|
|
77
|
+
logger = logging.getLogger(__name__)
|
|
78
|
+
|
|
79
|
+
# Vault key for the augment access token
|
|
80
|
+
VAULT_KEY_ACCESS_TOKEN = "augment_access_token"
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class AugmentProvider(CLIProvider):
|
|
84
|
+
"""Provider for Augment using the auggie CLI directly.
|
|
85
|
+
|
|
86
|
+
This provider spawns the auggie CLI as a subprocess to avoid
|
|
87
|
+
issues with the auggie-sdk's async task management.
|
|
88
|
+
|
|
89
|
+
Optionally integrates with voxDomus vault for token storage.
|
|
90
|
+
When a vault_service is provided and contains an access token,
|
|
91
|
+
the provider sets AUGMENT_SESSION_AUTH for the CLI subprocess.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
CLI_NAME = "auggie"
|
|
95
|
+
ENV_KEY = "AUGMENT_API_TOKEN"
|
|
96
|
+
|
|
97
|
+
SUPPORTED_MODELS = [
|
|
98
|
+
"haiku4.5",
|
|
99
|
+
"sonnet4.5",
|
|
100
|
+
"sonnet4",
|
|
101
|
+
"opus4.5",
|
|
102
|
+
"gpt5",
|
|
103
|
+
"gpt5.1",
|
|
104
|
+
]
|
|
105
|
+
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
model: str = "sonnet4.5",
|
|
109
|
+
api_key: str | None = None,
|
|
110
|
+
base_url: str | None = None,
|
|
111
|
+
vault_service: VaultProtocol | None = None,
|
|
112
|
+
**kwargs: Any,
|
|
113
|
+
) -> None:
|
|
114
|
+
"""Initialize Augment provider.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
model: Model name (haiku4.5, sonnet4.5, opus4.5, etc.).
|
|
118
|
+
api_key: Optional API token (usually from auggie login).
|
|
119
|
+
base_url: Optional API URL override.
|
|
120
|
+
vault_service: Optional vault service (any object with a get(key) method).
|
|
121
|
+
**kwargs: Additional arguments.
|
|
122
|
+
"""
|
|
123
|
+
super().__init__(model=model, api_key=api_key, base_url=base_url, **kwargs)
|
|
124
|
+
self._vault_service: VaultProtocol | None = vault_service
|
|
125
|
+
self._vault_token: str | None = None
|
|
126
|
+
self._vault_checked = False
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def name(self) -> str:
|
|
130
|
+
"""Get the provider name."""
|
|
131
|
+
return "augment"
|
|
132
|
+
|
|
133
|
+
@property
|
|
134
|
+
def models(self) -> list[str]:
|
|
135
|
+
"""Get supported models."""
|
|
136
|
+
return self.SUPPORTED_MODELS
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def supports_tools(self) -> bool:
|
|
140
|
+
"""Auggie supports tools but we don't expose them."""
|
|
141
|
+
return False
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def context_limit(self) -> int:
|
|
145
|
+
"""Approximate context limit."""
|
|
146
|
+
return 200000
|
|
147
|
+
|
|
148
|
+
def _build_cli_args(
|
|
149
|
+
self,
|
|
150
|
+
prompt: str,
|
|
151
|
+
system: str | None = None,
|
|
152
|
+
) -> list[str]:
|
|
153
|
+
"""Build auggie CLI arguments.
|
|
154
|
+
|
|
155
|
+
Uses --print mode for non-interactive one-shot execution.
|
|
156
|
+
"""
|
|
157
|
+
args = ["--print", "--quiet", "--instruction", prompt]
|
|
158
|
+
|
|
159
|
+
if self._model:
|
|
160
|
+
args.extend(["--model", self._model])
|
|
161
|
+
|
|
162
|
+
return args
|
|
163
|
+
|
|
164
|
+
def _parse_output(self, stdout: str, stderr: str) -> str:
|
|
165
|
+
"""Parse auggie CLI output."""
|
|
166
|
+
# auggie chat outputs the response directly
|
|
167
|
+
return stdout.strip()
|
|
168
|
+
|
|
169
|
+
def _get_vault_token(self) -> str | None:
|
|
170
|
+
"""Get access token from vault if available.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Access token string if found in vault, None otherwise.
|
|
174
|
+
|
|
175
|
+
Note:
|
|
176
|
+
Results are cached after first lookup.
|
|
177
|
+
"""
|
|
178
|
+
if self._vault_checked:
|
|
179
|
+
return self._vault_token
|
|
180
|
+
|
|
181
|
+
self._vault_checked = True
|
|
182
|
+
|
|
183
|
+
if self._vault_service is None:
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
try:
|
|
187
|
+
if self._vault_service.exists(VAULT_KEY_ACCESS_TOKEN):
|
|
188
|
+
self._vault_token = self._vault_service.get(VAULT_KEY_ACCESS_TOKEN)
|
|
189
|
+
logger.debug("Retrieved Augment token from vault")
|
|
190
|
+
except Exception as e:
|
|
191
|
+
logger.debug("Could not retrieve Augment token from vault: %s", e)
|
|
192
|
+
|
|
193
|
+
return self._vault_token
|
|
194
|
+
|
|
195
|
+
async def _run_cli(
|
|
196
|
+
self,
|
|
197
|
+
prompt: str,
|
|
198
|
+
system: str | None = None,
|
|
199
|
+
) -> str:
|
|
200
|
+
"""Run CLI command and return output.
|
|
201
|
+
|
|
202
|
+
Overrides base class to inject AUGMENT_SESSION_AUTH env var
|
|
203
|
+
if a vault token is available.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
prompt: User prompt.
|
|
207
|
+
system: Optional system prompt.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Parsed response text.
|
|
211
|
+
|
|
212
|
+
Raises:
|
|
213
|
+
Exception: If CLI execution fails.
|
|
214
|
+
"""
|
|
215
|
+
cli_path = self._get_cli_path()
|
|
216
|
+
args = [cli_path] + self._build_cli_args(prompt, system)
|
|
217
|
+
|
|
218
|
+
logger.debug("Running CLI: %s", " ".join(args))
|
|
219
|
+
|
|
220
|
+
# Build environment with vault token if available
|
|
221
|
+
env = os.environ.copy()
|
|
222
|
+
vault_token = self._get_vault_token()
|
|
223
|
+
if vault_token:
|
|
224
|
+
env["AUGMENT_SESSION_AUTH"] = vault_token
|
|
225
|
+
logger.debug("Using vault token for AUGMENT_SESSION_AUTH")
|
|
226
|
+
|
|
227
|
+
proc = await asyncio.create_subprocess_exec(
|
|
228
|
+
*args,
|
|
229
|
+
stdout=asyncio.subprocess.PIPE,
|
|
230
|
+
stderr=asyncio.subprocess.PIPE,
|
|
231
|
+
env=env,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
stdout_bytes, stderr_bytes = await proc.communicate()
|
|
235
|
+
stdout = stdout_bytes.decode("utf-8", errors="replace")
|
|
236
|
+
stderr = stderr_bytes.decode("utf-8", errors="replace")
|
|
237
|
+
|
|
238
|
+
if proc.returncode != 0:
|
|
239
|
+
logger.warning("CLI exited with code %d: %s", proc.returncode, stderr)
|
|
240
|
+
|
|
241
|
+
return self._parse_output(stdout, stderr)
|
|
242
|
+
|
|
243
|
+
async def stream(
|
|
244
|
+
self,
|
|
245
|
+
messages: list[Message],
|
|
246
|
+
system: str | None = None,
|
|
247
|
+
tools: list[Any] | None = None,
|
|
248
|
+
abort_signal: AbortSignal | None = None,
|
|
249
|
+
) -> AsyncIterator[StreamChunk]:
|
|
250
|
+
"""Stream a response from Auggie CLI.
|
|
251
|
+
|
|
252
|
+
Note: The auggie CLI uses its own MCP tool configuration.
|
|
253
|
+
Tools passed from voxDomus are not used - auggie will use
|
|
254
|
+
tools from its own ~/.augment/settings.json config.
|
|
255
|
+
"""
|
|
256
|
+
if tools:
|
|
257
|
+
logger.debug(
|
|
258
|
+
"Auggie CLI uses its own MCP tools - ignoring %d passed tools",
|
|
259
|
+
len(tools),
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
try:
|
|
263
|
+
prompt = self._messages_to_prompt(messages)
|
|
264
|
+
response = await self._run_cli(prompt, system)
|
|
265
|
+
if response:
|
|
266
|
+
yield TextDeltaChunk(delta=response)
|
|
267
|
+
except CLINotFoundError as e:
|
|
268
|
+
yield ErrorChunk(error=str(e))
|
|
269
|
+
except Exception as e:
|
|
270
|
+
yield ErrorChunk(error=f"Auggie CLI error: {e}")
|
|
271
|
+
|
|
272
|
+
yield MessageEndChunk()
|
|
273
|
+
|
|
274
|
+
async def complete(
|
|
275
|
+
self,
|
|
276
|
+
messages: list[Message],
|
|
277
|
+
system: str | None = None,
|
|
278
|
+
tools: list[Any] | None = None,
|
|
279
|
+
) -> Message:
|
|
280
|
+
"""Get a complete response from Auggie CLI."""
|
|
281
|
+
text_parts: list[str] = []
|
|
282
|
+
|
|
283
|
+
async for chunk in self.stream(messages, system, tools):
|
|
284
|
+
if isinstance(chunk, TextDeltaChunk):
|
|
285
|
+
text_parts.append(chunk.delta)
|
|
286
|
+
elif isinstance(chunk, ErrorChunk):
|
|
287
|
+
raise Exception(chunk.error)
|
|
288
|
+
|
|
289
|
+
return Message(role="assistant", content="".join(text_parts))
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
__all__ = ["AugmentProvider"]
|
|
293
|
+
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
"""Authentication profile management with failover support.
|
|
2
|
+
|
|
3
|
+
This module provides API credential profile management with cooldown
|
|
4
|
+
and failure tracking for multi-provider LLM failover.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from datetime import datetime, timedelta
|
|
8
|
+
from enum import Enum
|
|
9
|
+
|
|
10
|
+
from pydantic import BaseModel, ConfigDict
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FailoverReason(str, Enum):
|
|
14
|
+
"""Reasons for failover to a different API profile."""
|
|
15
|
+
|
|
16
|
+
AUTH_ERROR = "auth_error"
|
|
17
|
+
RATE_LIMIT = "rate_limit"
|
|
18
|
+
TIMEOUT = "timeout"
|
|
19
|
+
CONTEXT_OVERFLOW = "context_overflow"
|
|
20
|
+
MODEL_ERROR = "model_error"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class FailoverError(Exception):
|
|
24
|
+
"""Error that triggers failover to another profile."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, reason: FailoverReason, message: str = "") -> None:
|
|
27
|
+
self.reason = reason
|
|
28
|
+
super().__init__(message or reason.value)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AuthProfile(BaseModel):
|
|
32
|
+
"""API credential profile with cooldown and failure tracking."""
|
|
33
|
+
|
|
34
|
+
model_config = ConfigDict(validate_assignment=True)
|
|
35
|
+
|
|
36
|
+
id: str
|
|
37
|
+
api_key: str
|
|
38
|
+
provider: str
|
|
39
|
+
cooldown_until: datetime | None = None
|
|
40
|
+
failure_count: int = 0
|
|
41
|
+
disabled: bool = False
|
|
42
|
+
|
|
43
|
+
def is_available(self) -> bool:
|
|
44
|
+
"""Check if profile is available (not in cooldown, not disabled)."""
|
|
45
|
+
if self.disabled:
|
|
46
|
+
return False
|
|
47
|
+
if self.cooldown_until is not None:
|
|
48
|
+
if datetime.now() < self.cooldown_until:
|
|
49
|
+
return False
|
|
50
|
+
return True
|
|
51
|
+
|
|
52
|
+
def set_cooldown(self, duration_seconds: int) -> None:
|
|
53
|
+
"""Set cooldown until now + duration."""
|
|
54
|
+
self.cooldown_until = datetime.now() + timedelta(seconds=duration_seconds)
|
|
55
|
+
|
|
56
|
+
def record_failure(self) -> None:
|
|
57
|
+
"""Increment failure count."""
|
|
58
|
+
self.failure_count += 1
|
|
59
|
+
|
|
60
|
+
def reset_failures(self) -> None:
|
|
61
|
+
"""Reset failure count on success."""
|
|
62
|
+
self.failure_count = 0
|
|
63
|
+
|
|
64
|
+
def disable(self) -> None:
|
|
65
|
+
"""Disable this profile."""
|
|
66
|
+
self.disabled = True
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class AuthProfileManager:
|
|
70
|
+
"""Manages multiple API profiles with failover support."""
|
|
71
|
+
|
|
72
|
+
DEFAULT_MAX_FAILURES = 3
|
|
73
|
+
DEFAULT_RATE_LIMIT_COOLDOWN = 60 # seconds
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
profiles: list[AuthProfile],
|
|
78
|
+
max_failures: int = DEFAULT_MAX_FAILURES,
|
|
79
|
+
rate_limit_cooldown: int = DEFAULT_RATE_LIMIT_COOLDOWN,
|
|
80
|
+
) -> None:
|
|
81
|
+
self._profiles = list(profiles)
|
|
82
|
+
self._max_failures = max_failures
|
|
83
|
+
self._rate_limit_cooldown = rate_limit_cooldown
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def profiles(self) -> list[AuthProfile]:
|
|
87
|
+
return self._profiles
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def max_failures(self) -> int:
|
|
91
|
+
return self._max_failures
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def rate_limit_cooldown(self) -> int:
|
|
95
|
+
return self._rate_limit_cooldown
|
|
96
|
+
|
|
97
|
+
def get_available_profiles(self, provider: str | None = None) -> list[AuthProfile]:
|
|
98
|
+
"""Get profiles that are not in cooldown and not disabled."""
|
|
99
|
+
available = [p for p in self._profiles if p.is_available()]
|
|
100
|
+
if provider is not None:
|
|
101
|
+
available = [p for p in available if p.provider == provider]
|
|
102
|
+
return available
|
|
103
|
+
|
|
104
|
+
def handle_failover(self, profile: AuthProfile, error: FailoverError) -> None:
|
|
105
|
+
"""Update profile state based on failover error."""
|
|
106
|
+
if error.reason == FailoverReason.RATE_LIMIT:
|
|
107
|
+
profile.set_cooldown(self._rate_limit_cooldown)
|
|
108
|
+
elif error.reason == FailoverReason.AUTH_ERROR:
|
|
109
|
+
profile.record_failure()
|
|
110
|
+
if profile.failure_count >= self._max_failures:
|
|
111
|
+
profile.disable()
|
|
112
|
+
|
|
113
|
+
def record_success(self, profile: AuthProfile) -> None:
|
|
114
|
+
"""Record successful use of profile."""
|
|
115
|
+
profile.reset_failures()
|
|
116
|
+
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
"""Base provider abstract class and stream chunk types.
|
|
2
|
+
|
|
3
|
+
This module defines the abstract interface for LLM providers:
|
|
4
|
+
- StreamChunk types for streaming responses
|
|
5
|
+
- AbortSignal for cancellation
|
|
6
|
+
- BaseProvider ABC that all providers must implement
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from collections.abc import AsyncIterator
|
|
12
|
+
from typing import Any, Literal
|
|
13
|
+
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
|
|
16
|
+
from voxagent.types import Message, ToolCall
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# =============================================================================
|
|
20
|
+
# Stream Chunk Types
|
|
21
|
+
# =============================================================================
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TextDeltaChunk(BaseModel):
|
|
25
|
+
"""A text delta chunk from streaming response.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
type: Discriminator field, always "text_delta".
|
|
29
|
+
delta: The text content delta.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
type: Literal["text_delta"] = "text_delta"
|
|
33
|
+
delta: str
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ToolUseChunk(BaseModel):
|
|
37
|
+
"""A tool use chunk from streaming response.
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
type: Discriminator field, always "tool_use".
|
|
41
|
+
tool_call: The tool call request.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
type: Literal["tool_use"] = "tool_use"
|
|
45
|
+
tool_call: ToolCall
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class MessageEndChunk(BaseModel):
|
|
49
|
+
"""A message end chunk signaling stream completion.
|
|
50
|
+
|
|
51
|
+
Attributes:
|
|
52
|
+
type: Discriminator field, always "message_end".
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
type: Literal["message_end"] = "message_end"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ErrorChunk(BaseModel):
|
|
59
|
+
"""An error chunk from streaming response.
|
|
60
|
+
|
|
61
|
+
Attributes:
|
|
62
|
+
type: Discriminator field, always "error".
|
|
63
|
+
error: The error message.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
type: Literal["error"] = "error"
|
|
67
|
+
error: str
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# Union type for all stream chunks
|
|
71
|
+
StreamChunk = TextDeltaChunk | ToolUseChunk | MessageEndChunk | ErrorChunk
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# =============================================================================
|
|
75
|
+
# Abort Signal
|
|
76
|
+
# =============================================================================
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class AbortSignal:
|
|
80
|
+
"""A signal for aborting async operations.
|
|
81
|
+
|
|
82
|
+
Attributes:
|
|
83
|
+
_aborted: Internal flag indicating if abort has been requested.
|
|
84
|
+
_reason: The reason for the abort.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(self) -> None:
|
|
88
|
+
"""Initialize the abort signal."""
|
|
89
|
+
self._aborted = False
|
|
90
|
+
self._reason: str = ""
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def aborted(self) -> bool:
|
|
94
|
+
"""Check if abort has been requested."""
|
|
95
|
+
return self._aborted
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
def reason(self) -> str:
|
|
99
|
+
"""Get the reason for the abort."""
|
|
100
|
+
return self._reason
|
|
101
|
+
|
|
102
|
+
def abort(self, reason: str = "Aborted") -> None:
|
|
103
|
+
"""Request abortion of the operation.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
reason: The reason for aborting.
|
|
107
|
+
"""
|
|
108
|
+
self._aborted = True
|
|
109
|
+
self._reason = reason
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# =============================================================================
|
|
113
|
+
# Base Provider Abstract Class
|
|
114
|
+
# =============================================================================
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class BaseProvider(ABC):
|
|
118
|
+
"""Abstract base class for LLM providers.
|
|
119
|
+
|
|
120
|
+
Subclasses must implement all abstract properties and methods.
|
|
121
|
+
The ENV_KEY class variable should be set to the environment variable
|
|
122
|
+
name for the API key.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
ENV_KEY: str = ""
|
|
126
|
+
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
api_key: str | None = None,
|
|
130
|
+
base_url: str | None = None,
|
|
131
|
+
**kwargs: Any,
|
|
132
|
+
) -> None:
|
|
133
|
+
"""Initialize the provider.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
api_key: API key for authentication. Falls back to ENV_KEY env var.
|
|
137
|
+
base_url: Optional base URL for API requests.
|
|
138
|
+
**kwargs: Additional provider-specific arguments.
|
|
139
|
+
"""
|
|
140
|
+
self._api_key = api_key
|
|
141
|
+
self._base_url = base_url
|
|
142
|
+
self._kwargs = kwargs
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def api_key(self) -> str | None:
|
|
146
|
+
"""Get API key from constructor or environment variable."""
|
|
147
|
+
if self._api_key is not None:
|
|
148
|
+
return self._api_key
|
|
149
|
+
return os.environ.get(self.ENV_KEY) if self.ENV_KEY else None
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def base_url(self) -> str | None:
|
|
153
|
+
"""Get the base URL for API requests."""
|
|
154
|
+
return self._base_url
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
@abstractmethod
|
|
158
|
+
def name(self) -> str:
|
|
159
|
+
"""Get the provider name."""
|
|
160
|
+
...
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
@abstractmethod
|
|
164
|
+
def models(self) -> list[str]:
|
|
165
|
+
"""Get the list of supported model names."""
|
|
166
|
+
...
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
@abstractmethod
|
|
170
|
+
def supports_tools(self) -> bool:
|
|
171
|
+
"""Check if the provider supports tool/function calling."""
|
|
172
|
+
...
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
@abstractmethod
|
|
176
|
+
def supports_streaming(self) -> bool:
|
|
177
|
+
"""Check if the provider supports streaming responses."""
|
|
178
|
+
...
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
@abstractmethod
|
|
182
|
+
def context_limit(self) -> int:
|
|
183
|
+
"""Get the maximum context length in tokens."""
|
|
184
|
+
...
|
|
185
|
+
|
|
186
|
+
@abstractmethod
|
|
187
|
+
async def stream(
|
|
188
|
+
self,
|
|
189
|
+
messages: list[Message],
|
|
190
|
+
system: str | None = None,
|
|
191
|
+
tools: list[Any] | None = None,
|
|
192
|
+
abort_signal: AbortSignal | None = None,
|
|
193
|
+
) -> AsyncIterator[StreamChunk]:
|
|
194
|
+
"""Stream a response from the provider.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
messages: The conversation messages.
|
|
198
|
+
system: Optional system prompt.
|
|
199
|
+
tools: Optional list of tool definitions.
|
|
200
|
+
abort_signal: Optional signal to abort the stream.
|
|
201
|
+
|
|
202
|
+
Yields:
|
|
203
|
+
StreamChunk objects containing response data.
|
|
204
|
+
"""
|
|
205
|
+
...
|
|
206
|
+
# This yield is needed to make this an async generator
|
|
207
|
+
yield # type: ignore[misc]
|
|
208
|
+
|
|
209
|
+
@abstractmethod
|
|
210
|
+
async def complete(
|
|
211
|
+
self,
|
|
212
|
+
messages: list[Message],
|
|
213
|
+
system: str | None = None,
|
|
214
|
+
tools: list[Any] | None = None,
|
|
215
|
+
) -> Message:
|
|
216
|
+
"""Get a complete response from the provider.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
messages: The conversation messages.
|
|
220
|
+
system: Optional system prompt.
|
|
221
|
+
tools: Optional list of tool definitions.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
The assistant's response message.
|
|
225
|
+
"""
|
|
226
|
+
...
|
|
227
|
+
|
|
228
|
+
@abstractmethod
|
|
229
|
+
def count_tokens(
|
|
230
|
+
self,
|
|
231
|
+
messages: list[Message],
|
|
232
|
+
system: str | None = None,
|
|
233
|
+
) -> int:
|
|
234
|
+
"""Count tokens in the messages.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
messages: The conversation messages.
|
|
238
|
+
system: Optional system prompt.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
The token count.
|
|
242
|
+
"""
|
|
243
|
+
...
|
|
244
|
+
|
|
245
|
+
def get_api_key(self, env_var_name: str) -> str | None:
|
|
246
|
+
"""Get API key from constructor or specified environment variable.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
env_var_name: The environment variable name to check.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
The API key or None if not found.
|
|
253
|
+
"""
|
|
254
|
+
if self._api_key is not None:
|
|
255
|
+
return self._api_key
|
|
256
|
+
return os.environ.get(env_var_name)
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
__all__ = [
|
|
260
|
+
"AbortSignal",
|
|
261
|
+
"BaseProvider",
|
|
262
|
+
"ErrorChunk",
|
|
263
|
+
"MessageEndChunk",
|
|
264
|
+
"StreamChunk",
|
|
265
|
+
"TextDeltaChunk",
|
|
266
|
+
"ToolUseChunk",
|
|
267
|
+
]
|
|
268
|
+
|