voxagent 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- voxagent/__init__.py +143 -0
- voxagent/_version.py +5 -0
- voxagent/agent/__init__.py +32 -0
- voxagent/agent/abort.py +178 -0
- voxagent/agent/core.py +902 -0
- voxagent/code/__init__.py +9 -0
- voxagent/mcp/__init__.py +16 -0
- voxagent/mcp/manager.py +188 -0
- voxagent/mcp/tool.py +152 -0
- voxagent/providers/__init__.py +110 -0
- voxagent/providers/anthropic.py +498 -0
- voxagent/providers/augment.py +293 -0
- voxagent/providers/auth.py +116 -0
- voxagent/providers/base.py +268 -0
- voxagent/providers/chatgpt.py +415 -0
- voxagent/providers/claudecode.py +162 -0
- voxagent/providers/cli_base.py +265 -0
- voxagent/providers/codex.py +183 -0
- voxagent/providers/failover.py +90 -0
- voxagent/providers/google.py +532 -0
- voxagent/providers/groq.py +96 -0
- voxagent/providers/ollama.py +425 -0
- voxagent/providers/openai.py +435 -0
- voxagent/providers/registry.py +175 -0
- voxagent/py.typed +1 -0
- voxagent/security/__init__.py +14 -0
- voxagent/security/events.py +75 -0
- voxagent/security/filter.py +169 -0
- voxagent/security/registry.py +87 -0
- voxagent/session/__init__.py +39 -0
- voxagent/session/compaction.py +237 -0
- voxagent/session/lock.py +103 -0
- voxagent/session/model.py +109 -0
- voxagent/session/storage.py +184 -0
- voxagent/streaming/__init__.py +52 -0
- voxagent/streaming/emitter.py +286 -0
- voxagent/streaming/events.py +255 -0
- voxagent/subagent/__init__.py +20 -0
- voxagent/subagent/context.py +124 -0
- voxagent/subagent/definition.py +172 -0
- voxagent/tools/__init__.py +32 -0
- voxagent/tools/context.py +50 -0
- voxagent/tools/decorator.py +175 -0
- voxagent/tools/definition.py +131 -0
- voxagent/tools/executor.py +109 -0
- voxagent/tools/policy.py +89 -0
- voxagent/tools/registry.py +89 -0
- voxagent/types/__init__.py +46 -0
- voxagent/types/messages.py +134 -0
- voxagent/types/run.py +176 -0
- voxagent-0.1.0.dist-info/METADATA +186 -0
- voxagent-0.1.0.dist-info/RECORD +53 -0
- voxagent-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
"""OpenAI provider implementation.
|
|
2
|
+
|
|
3
|
+
This module implements the OpenAI provider for chat completions,
|
|
4
|
+
supporting both streaming and non-streaming responses with tool calling.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from collections.abc import AsyncIterator
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
import httpx
|
|
12
|
+
|
|
13
|
+
from voxagent.providers.base import (
|
|
14
|
+
AbortSignal,
|
|
15
|
+
BaseProvider,
|
|
16
|
+
ErrorChunk,
|
|
17
|
+
MessageEndChunk,
|
|
18
|
+
StreamChunk,
|
|
19
|
+
TextDeltaChunk,
|
|
20
|
+
ToolUseChunk,
|
|
21
|
+
)
|
|
22
|
+
from voxagent.types import Message, ToolCall
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OpenAIProvider(BaseProvider):
|
|
26
|
+
"""OpenAI chat completions provider.
|
|
27
|
+
|
|
28
|
+
Supports GPT-4o, GPT-4-turbo, GPT-3.5-turbo, and O1 models
|
|
29
|
+
with streaming, tool calling, and token counting.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
ENV_KEY = "OPENAI_API_KEY"
|
|
33
|
+
DEFAULT_BASE_URL = "https://api.openai.com/v1"
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
api_key: str | None = None,
|
|
38
|
+
base_url: str | None = None,
|
|
39
|
+
model: str = "gpt-4o",
|
|
40
|
+
**kwargs: Any,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""Initialize the OpenAI provider.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
api_key: OpenAI API key. Falls back to OPENAI_API_KEY env var.
|
|
46
|
+
base_url: Custom base URL for API requests (e.g., for Azure or proxies).
|
|
47
|
+
model: Model name to use. Defaults to "gpt-4o".
|
|
48
|
+
**kwargs: Additional provider-specific arguments.
|
|
49
|
+
"""
|
|
50
|
+
super().__init__(api_key=api_key, base_url=base_url, **kwargs)
|
|
51
|
+
self._model = model
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def name(self) -> str:
|
|
55
|
+
"""Get the provider name."""
|
|
56
|
+
return "openai"
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def model(self) -> str:
|
|
60
|
+
"""Get the current model name."""
|
|
61
|
+
return self._model
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def models(self) -> list[str]:
|
|
65
|
+
"""Get the list of supported model names."""
|
|
66
|
+
return ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo", "o1", "o1-mini", "o1-preview"]
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def supports_tools(self) -> bool:
|
|
70
|
+
"""Check if the provider supports tool/function calling."""
|
|
71
|
+
return True
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def supports_streaming(self) -> bool:
|
|
75
|
+
"""Check if the provider supports streaming responses."""
|
|
76
|
+
return True
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def context_limit(self) -> int:
|
|
80
|
+
"""Get the maximum context length in tokens."""
|
|
81
|
+
return 128000
|
|
82
|
+
|
|
83
|
+
def _get_api_url(self) -> str:
|
|
84
|
+
"""Get the API URL for chat completions."""
|
|
85
|
+
base = self._base_url or self.DEFAULT_BASE_URL
|
|
86
|
+
return f"{base}/chat/completions"
|
|
87
|
+
|
|
88
|
+
def _get_headers(self) -> dict[str, str]:
|
|
89
|
+
"""Get HTTP headers for API requests."""
|
|
90
|
+
return {
|
|
91
|
+
"Authorization": f"Bearer {self.api_key}",
|
|
92
|
+
"Content-Type": "application/json",
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
def _convert_messages_to_openai(
|
|
96
|
+
self, messages: list[Message], system: str | None = None
|
|
97
|
+
) -> list[dict[str, Any]]:
|
|
98
|
+
"""Convert voxagent Messages to OpenAI message format.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
messages: List of voxagent Message objects.
|
|
102
|
+
system: Optional system prompt to prepend.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
List of OpenAI-format message dictionaries.
|
|
106
|
+
"""
|
|
107
|
+
result: list[dict[str, Any]] = []
|
|
108
|
+
|
|
109
|
+
if system:
|
|
110
|
+
result.append({"role": "system", "content": system})
|
|
111
|
+
|
|
112
|
+
for msg in messages:
|
|
113
|
+
openai_msg: dict[str, Any] = {"role": msg.role}
|
|
114
|
+
|
|
115
|
+
# Handle content
|
|
116
|
+
if isinstance(msg.content, str):
|
|
117
|
+
openai_msg["content"] = msg.content
|
|
118
|
+
else:
|
|
119
|
+
# Handle content blocks - convert to string for simplicity
|
|
120
|
+
text_parts = [b.text for b in msg.content if hasattr(b, "text")]
|
|
121
|
+
openai_msg["content"] = " ".join(text_parts) if text_parts else ""
|
|
122
|
+
|
|
123
|
+
# Handle tool calls
|
|
124
|
+
if msg.tool_calls:
|
|
125
|
+
openai_msg["tool_calls"] = [
|
|
126
|
+
{
|
|
127
|
+
"id": tc.id,
|
|
128
|
+
"type": "function",
|
|
129
|
+
"function": {
|
|
130
|
+
"name": tc.name,
|
|
131
|
+
"arguments": json.dumps(tc.params),
|
|
132
|
+
},
|
|
133
|
+
}
|
|
134
|
+
for tc in msg.tool_calls
|
|
135
|
+
]
|
|
136
|
+
|
|
137
|
+
result.append(openai_msg)
|
|
138
|
+
|
|
139
|
+
return result
|
|
140
|
+
|
|
141
|
+
def _convert_openai_response_to_message(self, response: dict[str, Any]) -> Message:
|
|
142
|
+
"""Convert OpenAI response message to voxagent Message.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
response: OpenAI message dictionary from API response.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
A voxagent Message object.
|
|
149
|
+
"""
|
|
150
|
+
role = response.get("role", "assistant")
|
|
151
|
+
content = response.get("content") or ""
|
|
152
|
+
|
|
153
|
+
tool_calls: list[ToolCall] | None = None
|
|
154
|
+
if "tool_calls" in response and response["tool_calls"]:
|
|
155
|
+
tool_calls = []
|
|
156
|
+
for tc in response["tool_calls"]:
|
|
157
|
+
params = {}
|
|
158
|
+
if tc.get("function", {}).get("arguments"):
|
|
159
|
+
try:
|
|
160
|
+
params = json.loads(tc["function"]["arguments"])
|
|
161
|
+
except json.JSONDecodeError:
|
|
162
|
+
params = {}
|
|
163
|
+
tool_calls.append(
|
|
164
|
+
ToolCall(
|
|
165
|
+
id=tc["id"],
|
|
166
|
+
name=tc["function"]["name"],
|
|
167
|
+
params=params,
|
|
168
|
+
)
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
return Message(role=role, content=content, tool_calls=tool_calls)
|
|
172
|
+
|
|
173
|
+
def _build_request_body(
|
|
174
|
+
self,
|
|
175
|
+
messages: list[Message],
|
|
176
|
+
system: str | None = None,
|
|
177
|
+
tools: list[Any] | None = None,
|
|
178
|
+
stream: bool = False,
|
|
179
|
+
) -> dict[str, Any]:
|
|
180
|
+
"""Build the request body for OpenAI API.
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
messages: List of voxagent Message objects.
|
|
184
|
+
system: Optional system prompt.
|
|
185
|
+
tools: Optional list of tool definitions.
|
|
186
|
+
stream: Whether to enable streaming.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Request body dictionary.
|
|
190
|
+
"""
|
|
191
|
+
body: dict[str, Any] = {
|
|
192
|
+
"model": self._model,
|
|
193
|
+
"messages": self._convert_messages_to_openai(messages, system=system),
|
|
194
|
+
"stream": stream,
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
if tools:
|
|
198
|
+
body["tools"] = tools
|
|
199
|
+
|
|
200
|
+
return body
|
|
201
|
+
|
|
202
|
+
async def _make_request(
|
|
203
|
+
self,
|
|
204
|
+
messages: list[Message],
|
|
205
|
+
system: str | None = None,
|
|
206
|
+
tools: list[Any] | None = None,
|
|
207
|
+
) -> dict[str, Any]:
|
|
208
|
+
"""Make a non-streaming request to the OpenAI API.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
messages: List of voxagent Message objects.
|
|
212
|
+
system: Optional system prompt.
|
|
213
|
+
tools: Optional list of tool definitions.
|
|
214
|
+
|
|
215
|
+
Returns:
|
|
216
|
+
The JSON response from the API.
|
|
217
|
+
|
|
218
|
+
Raises:
|
|
219
|
+
Exception: If the API request fails.
|
|
220
|
+
"""
|
|
221
|
+
body = self._build_request_body(messages, system=system, tools=tools, stream=False)
|
|
222
|
+
|
|
223
|
+
async with httpx.AsyncClient() as client:
|
|
224
|
+
response = await client.post(
|
|
225
|
+
self._get_api_url(),
|
|
226
|
+
headers=self._get_headers(),
|
|
227
|
+
json=body,
|
|
228
|
+
timeout=60.0,
|
|
229
|
+
)
|
|
230
|
+
response.raise_for_status()
|
|
231
|
+
return response.json()
|
|
232
|
+
|
|
233
|
+
async def _make_streaming_request(
|
|
234
|
+
self,
|
|
235
|
+
messages: list[Message],
|
|
236
|
+
system: str | None = None,
|
|
237
|
+
tools: list[Any] | None = None,
|
|
238
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
239
|
+
"""Make a streaming request to the OpenAI API.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
messages: List of voxagent Message objects.
|
|
243
|
+
system: Optional system prompt.
|
|
244
|
+
tools: Optional list of tool definitions.
|
|
245
|
+
|
|
246
|
+
Yields:
|
|
247
|
+
Parsed JSON chunks from the SSE stream.
|
|
248
|
+
"""
|
|
249
|
+
body = self._build_request_body(messages, system=system, tools=tools, stream=True)
|
|
250
|
+
|
|
251
|
+
async with httpx.AsyncClient() as client:
|
|
252
|
+
async with client.stream(
|
|
253
|
+
"POST",
|
|
254
|
+
self._get_api_url(),
|
|
255
|
+
headers=self._get_headers(),
|
|
256
|
+
json=body,
|
|
257
|
+
timeout=60.0,
|
|
258
|
+
) as response:
|
|
259
|
+
response.raise_for_status()
|
|
260
|
+
async for line in response.aiter_lines():
|
|
261
|
+
if not line:
|
|
262
|
+
continue
|
|
263
|
+
if line.startswith("data: "):
|
|
264
|
+
data = line[6:]
|
|
265
|
+
if data == "[DONE]":
|
|
266
|
+
return
|
|
267
|
+
try:
|
|
268
|
+
yield json.loads(data)
|
|
269
|
+
except json.JSONDecodeError:
|
|
270
|
+
continue
|
|
271
|
+
|
|
272
|
+
async def stream(
|
|
273
|
+
self,
|
|
274
|
+
messages: list[Message],
|
|
275
|
+
system: str | None = None,
|
|
276
|
+
tools: list[Any] | None = None,
|
|
277
|
+
abort_signal: AbortSignal | None = None,
|
|
278
|
+
) -> AsyncIterator[StreamChunk]:
|
|
279
|
+
"""Stream a response from the OpenAI API.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
messages: The conversation messages.
|
|
283
|
+
system: Optional system prompt.
|
|
284
|
+
tools: Optional list of tool definitions.
|
|
285
|
+
abort_signal: Optional signal to abort the stream.
|
|
286
|
+
|
|
287
|
+
Yields:
|
|
288
|
+
StreamChunk objects containing response data.
|
|
289
|
+
"""
|
|
290
|
+
# Track tool calls being built across chunks
|
|
291
|
+
pending_tool_calls: dict[int, dict[str, Any]] = {}
|
|
292
|
+
|
|
293
|
+
try:
|
|
294
|
+
async for chunk in self._make_streaming_request(messages, system=system, tools=tools):
|
|
295
|
+
# Check abort signal
|
|
296
|
+
if abort_signal and abort_signal.aborted:
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
choices = chunk.get("choices", [])
|
|
300
|
+
if not choices:
|
|
301
|
+
continue
|
|
302
|
+
|
|
303
|
+
choice = choices[0]
|
|
304
|
+
delta = choice.get("delta", {})
|
|
305
|
+
finish_reason = choice.get("finish_reason")
|
|
306
|
+
|
|
307
|
+
# Handle content delta
|
|
308
|
+
if "content" in delta and delta["content"]:
|
|
309
|
+
yield TextDeltaChunk(delta=delta["content"])
|
|
310
|
+
|
|
311
|
+
# Handle tool calls
|
|
312
|
+
if "tool_calls" in delta:
|
|
313
|
+
for tc_delta in delta["tool_calls"]:
|
|
314
|
+
idx = tc_delta.get("index", 0)
|
|
315
|
+
if idx not in pending_tool_calls:
|
|
316
|
+
pending_tool_calls[idx] = {
|
|
317
|
+
"id": tc_delta.get("id", ""),
|
|
318
|
+
"name": tc_delta.get("function", {}).get("name", ""),
|
|
319
|
+
"arguments": "",
|
|
320
|
+
}
|
|
321
|
+
else:
|
|
322
|
+
if tc_delta.get("id"):
|
|
323
|
+
pending_tool_calls[idx]["id"] = tc_delta["id"]
|
|
324
|
+
if tc_delta.get("function", {}).get("name"):
|
|
325
|
+
pending_tool_calls[idx]["name"] = tc_delta["function"]["name"]
|
|
326
|
+
|
|
327
|
+
# Accumulate arguments
|
|
328
|
+
if tc_delta.get("function", {}).get("arguments"):
|
|
329
|
+
pending_tool_calls[idx]["arguments"] += tc_delta["function"]["arguments"]
|
|
330
|
+
|
|
331
|
+
# Handle finish
|
|
332
|
+
if finish_reason:
|
|
333
|
+
# Emit any pending tool calls
|
|
334
|
+
for tc_data in pending_tool_calls.values():
|
|
335
|
+
params = {}
|
|
336
|
+
if tc_data["arguments"]:
|
|
337
|
+
try:
|
|
338
|
+
params = json.loads(tc_data["arguments"])
|
|
339
|
+
except json.JSONDecodeError:
|
|
340
|
+
params = {}
|
|
341
|
+
yield ToolUseChunk(
|
|
342
|
+
tool_call=ToolCall(
|
|
343
|
+
id=tc_data["id"],
|
|
344
|
+
name=tc_data["name"],
|
|
345
|
+
params=params,
|
|
346
|
+
)
|
|
347
|
+
)
|
|
348
|
+
yield MessageEndChunk()
|
|
349
|
+
return
|
|
350
|
+
|
|
351
|
+
except Exception as e:
|
|
352
|
+
yield ErrorChunk(error=str(e))
|
|
353
|
+
|
|
354
|
+
async def complete(
|
|
355
|
+
self,
|
|
356
|
+
messages: list[Message],
|
|
357
|
+
system: str | None = None,
|
|
358
|
+
tools: list[Any] | None = None,
|
|
359
|
+
) -> Message:
|
|
360
|
+
"""Get a complete response from the OpenAI API.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
messages: The conversation messages.
|
|
364
|
+
system: Optional system prompt.
|
|
365
|
+
tools: Optional list of tool definitions.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
The assistant's response message.
|
|
369
|
+
|
|
370
|
+
Raises:
|
|
371
|
+
Exception: If the API request fails.
|
|
372
|
+
"""
|
|
373
|
+
response = await self._make_request(messages, system=system, tools=tools)
|
|
374
|
+
choices = response.get("choices", [])
|
|
375
|
+
if not choices:
|
|
376
|
+
return Message(role="assistant", content="")
|
|
377
|
+
|
|
378
|
+
return self._convert_openai_response_to_message(choices[0]["message"])
|
|
379
|
+
|
|
380
|
+
def count_tokens(
|
|
381
|
+
self,
|
|
382
|
+
messages: list[Message],
|
|
383
|
+
system: str | None = None,
|
|
384
|
+
) -> int:
|
|
385
|
+
"""Count tokens in the messages.
|
|
386
|
+
|
|
387
|
+
Uses tiktoken if available, otherwise falls back to character-based estimation.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
messages: The conversation messages.
|
|
391
|
+
system: Optional system prompt.
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
The estimated token count.
|
|
395
|
+
"""
|
|
396
|
+
try:
|
|
397
|
+
import tiktoken
|
|
398
|
+
|
|
399
|
+
encoding = tiktoken.encoding_for_model(self._model)
|
|
400
|
+
except (ImportError, KeyError):
|
|
401
|
+
# Fall back to character-based estimation (roughly 4 chars per token)
|
|
402
|
+
total_chars = 0
|
|
403
|
+
if system:
|
|
404
|
+
total_chars += len(system)
|
|
405
|
+
for msg in messages:
|
|
406
|
+
if isinstance(msg.content, str):
|
|
407
|
+
total_chars += len(msg.content)
|
|
408
|
+
else:
|
|
409
|
+
for block in msg.content:
|
|
410
|
+
if hasattr(block, "text"):
|
|
411
|
+
total_chars += len(block.text)
|
|
412
|
+
return max(1, total_chars // 4)
|
|
413
|
+
|
|
414
|
+
# Use tiktoken for accurate counting
|
|
415
|
+
total_tokens = 0
|
|
416
|
+
|
|
417
|
+
if system:
|
|
418
|
+
total_tokens += len(encoding.encode(system))
|
|
419
|
+
total_tokens += 4 # Overhead for system message
|
|
420
|
+
|
|
421
|
+
for msg in messages:
|
|
422
|
+
total_tokens += 4 # Message overhead
|
|
423
|
+
if isinstance(msg.content, str):
|
|
424
|
+
total_tokens += len(encoding.encode(msg.content))
|
|
425
|
+
else:
|
|
426
|
+
for block in msg.content:
|
|
427
|
+
if hasattr(block, "text"):
|
|
428
|
+
total_tokens += len(encoding.encode(block.text))
|
|
429
|
+
|
|
430
|
+
total_tokens += 2 # Response priming
|
|
431
|
+
return total_tokens
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
__all__ = ["OpenAIProvider"]
|
|
435
|
+
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Provider registry for managing and instantiating LLM providers.
|
|
2
|
+
|
|
3
|
+
This module provides:
|
|
4
|
+
- ProviderRegistry for registering and looking up provider classes
|
|
5
|
+
- get_default_registry() for accessing a global singleton registry
|
|
6
|
+
- Exceptions for registry-related errors
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from voxagent.providers.base import BaseProvider
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# =============================================================================
|
|
15
|
+
# Exceptions
|
|
16
|
+
# =============================================================================
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ProviderNotFoundError(Exception):
|
|
20
|
+
"""Raised when a provider is not found in the registry."""
|
|
21
|
+
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class InvalidModelStringError(Exception):
|
|
26
|
+
"""Raised when a model string is malformed."""
|
|
27
|
+
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# =============================================================================
|
|
32
|
+
# Provider Registry
|
|
33
|
+
# =============================================================================
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ProviderRegistry:
|
|
37
|
+
"""Registry for managing LLM provider classes.
|
|
38
|
+
|
|
39
|
+
Provides functionality to register, unregister, and look up provider classes
|
|
40
|
+
by name. Also supports instantiating providers from model strings in the
|
|
41
|
+
format "provider:model".
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self) -> None:
|
|
45
|
+
"""Initialize an empty provider registry."""
|
|
46
|
+
self._providers: dict[str, type[BaseProvider]] = {}
|
|
47
|
+
|
|
48
|
+
def register(self, name: str, provider_class: type[BaseProvider]) -> None:
|
|
49
|
+
"""Register a provider class by name.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
name: The provider name (must be non-empty, non-whitespace).
|
|
53
|
+
provider_class: The provider class (must be a BaseProvider subclass).
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If name is empty or whitespace-only.
|
|
57
|
+
TypeError: If provider_class is not a BaseProvider subclass.
|
|
58
|
+
"""
|
|
59
|
+
if not name or not name.strip():
|
|
60
|
+
raise ValueError("Provider name cannot be empty or whitespace-only")
|
|
61
|
+
|
|
62
|
+
if not isinstance(provider_class, type) or not issubclass(
|
|
63
|
+
provider_class, BaseProvider
|
|
64
|
+
):
|
|
65
|
+
raise TypeError("provider_class must be a subclass of BaseProvider")
|
|
66
|
+
|
|
67
|
+
self._providers[name] = provider_class
|
|
68
|
+
|
|
69
|
+
def unregister(self, name: str) -> None:
|
|
70
|
+
"""Remove a provider from the registry.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
name: The provider name to unregister.
|
|
74
|
+
|
|
75
|
+
Raises:
|
|
76
|
+
ProviderNotFoundError: If the provider is not registered.
|
|
77
|
+
"""
|
|
78
|
+
if name not in self._providers:
|
|
79
|
+
raise ProviderNotFoundError(name)
|
|
80
|
+
del self._providers[name]
|
|
81
|
+
|
|
82
|
+
def is_registered(self, name: str) -> bool:
|
|
83
|
+
"""Check if a provider is registered.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
name: The provider name to check.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
True if the provider is registered, False otherwise.
|
|
90
|
+
"""
|
|
91
|
+
return name in self._providers
|
|
92
|
+
|
|
93
|
+
def get_provider_class(self, name: str) -> type[BaseProvider]:
|
|
94
|
+
"""Get a provider class by name.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
name: The provider name.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
The registered provider class.
|
|
101
|
+
|
|
102
|
+
Raises:
|
|
103
|
+
ProviderNotFoundError: If the provider is not found.
|
|
104
|
+
"""
|
|
105
|
+
if name not in self._providers:
|
|
106
|
+
raise ProviderNotFoundError(name)
|
|
107
|
+
return self._providers[name]
|
|
108
|
+
|
|
109
|
+
def list_providers(self) -> list[str]:
|
|
110
|
+
"""List all registered provider names.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
A copy of the list of registered provider names.
|
|
114
|
+
"""
|
|
115
|
+
return list(self._providers.keys())
|
|
116
|
+
|
|
117
|
+
def get_provider(self, model_string: str, **kwargs: Any) -> BaseProvider:
|
|
118
|
+
"""Parse model string and instantiate provider.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
model_string: Format "provider:model" (e.g., "openai:gpt-4o").
|
|
122
|
+
Handles multiple colons: first part is provider, rest is model.
|
|
123
|
+
e.g., "ollama:model:latest" → provider="ollama", model="model:latest"
|
|
124
|
+
**kwargs: Additional arguments to pass to provider constructor.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Instantiated provider.
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
InvalidModelStringError: If model_string is malformed.
|
|
131
|
+
ProviderNotFoundError: If provider is not registered.
|
|
132
|
+
"""
|
|
133
|
+
if not model_string or ":" not in model_string:
|
|
134
|
+
raise InvalidModelStringError(model_string)
|
|
135
|
+
|
|
136
|
+
# Split on first colon only
|
|
137
|
+
parts = model_string.split(":", 1)
|
|
138
|
+
provider_name = parts[0]
|
|
139
|
+
model_name = parts[1] if len(parts) > 1 else ""
|
|
140
|
+
|
|
141
|
+
if not provider_name:
|
|
142
|
+
raise InvalidModelStringError(model_string)
|
|
143
|
+
if not model_name:
|
|
144
|
+
raise InvalidModelStringError(model_string)
|
|
145
|
+
|
|
146
|
+
provider_class = self.get_provider_class(provider_name)
|
|
147
|
+
return provider_class(model=model_name, **kwargs)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# =============================================================================
|
|
151
|
+
# Default Registry Singleton
|
|
152
|
+
# =============================================================================
|
|
153
|
+
|
|
154
|
+
_default_registry: ProviderRegistry | None = None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def get_default_registry() -> ProviderRegistry:
|
|
158
|
+
"""Get the default global provider registry.
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
The singleton ProviderRegistry instance.
|
|
162
|
+
"""
|
|
163
|
+
global _default_registry
|
|
164
|
+
if _default_registry is None:
|
|
165
|
+
_default_registry = ProviderRegistry()
|
|
166
|
+
return _default_registry
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
__all__ = [
|
|
170
|
+
"InvalidModelStringError",
|
|
171
|
+
"ProviderNotFoundError",
|
|
172
|
+
"ProviderRegistry",
|
|
173
|
+
"get_default_registry",
|
|
174
|
+
]
|
|
175
|
+
|
voxagent/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Security module for voxagent."""
|
|
2
|
+
|
|
3
|
+
from voxagent.security.events import SecurityEvent, SecurityEventEmitter
|
|
4
|
+
from voxagent.security.filter import RedactionFilter, StreamFilter
|
|
5
|
+
from voxagent.security.registry import SecretRegistry
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"RedactionFilter",
|
|
9
|
+
"SecretRegistry",
|
|
10
|
+
"SecurityEvent",
|
|
11
|
+
"SecurityEventEmitter",
|
|
12
|
+
"StreamFilter",
|
|
13
|
+
]
|
|
14
|
+
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""Security events for voxagent."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any, Callable
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SecurityEvent(Enum):
|
|
10
|
+
"""Security-related events."""
|
|
11
|
+
|
|
12
|
+
SECRET_REDACTED = "secret_redacted"
|
|
13
|
+
CREDENTIAL_ACCESSED = "credential_accessed"
|
|
14
|
+
PATTERN_MATCHED = "pattern_matched"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SecurityEventEmitter:
|
|
18
|
+
"""Emits security events to listeners."""
|
|
19
|
+
|
|
20
|
+
def __init__(self) -> None:
|
|
21
|
+
"""Initialize an empty event emitter."""
|
|
22
|
+
self._listeners: dict[
|
|
23
|
+
SecurityEvent, list[Callable[[dict[str, Any]], None]]
|
|
24
|
+
] = {}
|
|
25
|
+
|
|
26
|
+
def on(
|
|
27
|
+
self,
|
|
28
|
+
event: SecurityEvent,
|
|
29
|
+
callback: Callable[[dict[str, Any]], None],
|
|
30
|
+
) -> Callable[[], None]:
|
|
31
|
+
"""Register a listener for an event.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
event: The security event type to listen for.
|
|
35
|
+
callback: Function to call when the event is emitted.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
An unsubscribe function that removes the listener.
|
|
39
|
+
"""
|
|
40
|
+
if event not in self._listeners:
|
|
41
|
+
self._listeners[event] = []
|
|
42
|
+
|
|
43
|
+
self._listeners[event].append(callback)
|
|
44
|
+
|
|
45
|
+
def unsubscribe() -> None:
|
|
46
|
+
if event in self._listeners and callback in self._listeners[event]:
|
|
47
|
+
self._listeners[event].remove(callback)
|
|
48
|
+
|
|
49
|
+
return unsubscribe
|
|
50
|
+
|
|
51
|
+
def emit(self, event: SecurityEvent, data: dict[str, Any]) -> None:
|
|
52
|
+
"""Emit an event to all registered listeners.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
event: The security event type to emit.
|
|
56
|
+
data: Event data to pass to listeners.
|
|
57
|
+
"""
|
|
58
|
+
if event in self._listeners:
|
|
59
|
+
for callback in self._listeners[event]:
|
|
60
|
+
callback(data)
|
|
61
|
+
|
|
62
|
+
def off(
|
|
63
|
+
self,
|
|
64
|
+
event: SecurityEvent,
|
|
65
|
+
callback: Callable[[dict[str, Any]], None],
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Remove a specific listener.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
event: The security event type.
|
|
71
|
+
callback: The callback to remove.
|
|
72
|
+
"""
|
|
73
|
+
if event in self._listeners and callback in self._listeners[event]:
|
|
74
|
+
self._listeners[event].remove(callback)
|
|
75
|
+
|