foundry-mcp 0.8.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of foundry-mcp might be problematic. Click here for more details.
- foundry_mcp/__init__.py +13 -0
- foundry_mcp/cli/__init__.py +67 -0
- foundry_mcp/cli/__main__.py +9 -0
- foundry_mcp/cli/agent.py +96 -0
- foundry_mcp/cli/commands/__init__.py +37 -0
- foundry_mcp/cli/commands/cache.py +137 -0
- foundry_mcp/cli/commands/dashboard.py +148 -0
- foundry_mcp/cli/commands/dev.py +446 -0
- foundry_mcp/cli/commands/journal.py +377 -0
- foundry_mcp/cli/commands/lifecycle.py +274 -0
- foundry_mcp/cli/commands/modify.py +824 -0
- foundry_mcp/cli/commands/plan.py +640 -0
- foundry_mcp/cli/commands/pr.py +393 -0
- foundry_mcp/cli/commands/review.py +667 -0
- foundry_mcp/cli/commands/session.py +472 -0
- foundry_mcp/cli/commands/specs.py +686 -0
- foundry_mcp/cli/commands/tasks.py +807 -0
- foundry_mcp/cli/commands/testing.py +676 -0
- foundry_mcp/cli/commands/validate.py +982 -0
- foundry_mcp/cli/config.py +98 -0
- foundry_mcp/cli/context.py +298 -0
- foundry_mcp/cli/logging.py +212 -0
- foundry_mcp/cli/main.py +44 -0
- foundry_mcp/cli/output.py +122 -0
- foundry_mcp/cli/registry.py +110 -0
- foundry_mcp/cli/resilience.py +178 -0
- foundry_mcp/cli/transcript.py +217 -0
- foundry_mcp/config.py +1454 -0
- foundry_mcp/core/__init__.py +144 -0
- foundry_mcp/core/ai_consultation.py +1773 -0
- foundry_mcp/core/batch_operations.py +1202 -0
- foundry_mcp/core/cache.py +195 -0
- foundry_mcp/core/capabilities.py +446 -0
- foundry_mcp/core/concurrency.py +898 -0
- foundry_mcp/core/context.py +540 -0
- foundry_mcp/core/discovery.py +1603 -0
- foundry_mcp/core/error_collection.py +728 -0
- foundry_mcp/core/error_store.py +592 -0
- foundry_mcp/core/health.py +749 -0
- foundry_mcp/core/intake.py +933 -0
- foundry_mcp/core/journal.py +700 -0
- foundry_mcp/core/lifecycle.py +412 -0
- foundry_mcp/core/llm_config.py +1376 -0
- foundry_mcp/core/llm_patterns.py +510 -0
- foundry_mcp/core/llm_provider.py +1569 -0
- foundry_mcp/core/logging_config.py +374 -0
- foundry_mcp/core/metrics_persistence.py +584 -0
- foundry_mcp/core/metrics_registry.py +327 -0
- foundry_mcp/core/metrics_store.py +641 -0
- foundry_mcp/core/modifications.py +224 -0
- foundry_mcp/core/naming.py +146 -0
- foundry_mcp/core/observability.py +1216 -0
- foundry_mcp/core/otel.py +452 -0
- foundry_mcp/core/otel_stubs.py +264 -0
- foundry_mcp/core/pagination.py +255 -0
- foundry_mcp/core/progress.py +387 -0
- foundry_mcp/core/prometheus.py +564 -0
- foundry_mcp/core/prompts/__init__.py +464 -0
- foundry_mcp/core/prompts/fidelity_review.py +691 -0
- foundry_mcp/core/prompts/markdown_plan_review.py +515 -0
- foundry_mcp/core/prompts/plan_review.py +627 -0
- foundry_mcp/core/providers/__init__.py +237 -0
- foundry_mcp/core/providers/base.py +515 -0
- foundry_mcp/core/providers/claude.py +472 -0
- foundry_mcp/core/providers/codex.py +637 -0
- foundry_mcp/core/providers/cursor_agent.py +630 -0
- foundry_mcp/core/providers/detectors.py +515 -0
- foundry_mcp/core/providers/gemini.py +426 -0
- foundry_mcp/core/providers/opencode.py +718 -0
- foundry_mcp/core/providers/opencode_wrapper.js +308 -0
- foundry_mcp/core/providers/package-lock.json +24 -0
- foundry_mcp/core/providers/package.json +25 -0
- foundry_mcp/core/providers/registry.py +607 -0
- foundry_mcp/core/providers/test_provider.py +171 -0
- foundry_mcp/core/providers/validation.py +857 -0
- foundry_mcp/core/rate_limit.py +427 -0
- foundry_mcp/core/research/__init__.py +68 -0
- foundry_mcp/core/research/memory.py +528 -0
- foundry_mcp/core/research/models.py +1234 -0
- foundry_mcp/core/research/providers/__init__.py +40 -0
- foundry_mcp/core/research/providers/base.py +242 -0
- foundry_mcp/core/research/providers/google.py +507 -0
- foundry_mcp/core/research/providers/perplexity.py +442 -0
- foundry_mcp/core/research/providers/semantic_scholar.py +544 -0
- foundry_mcp/core/research/providers/tavily.py +383 -0
- foundry_mcp/core/research/workflows/__init__.py +25 -0
- foundry_mcp/core/research/workflows/base.py +298 -0
- foundry_mcp/core/research/workflows/chat.py +271 -0
- foundry_mcp/core/research/workflows/consensus.py +539 -0
- foundry_mcp/core/research/workflows/deep_research.py +4142 -0
- foundry_mcp/core/research/workflows/ideate.py +682 -0
- foundry_mcp/core/research/workflows/thinkdeep.py +405 -0
- foundry_mcp/core/resilience.py +600 -0
- foundry_mcp/core/responses.py +1624 -0
- foundry_mcp/core/review.py +366 -0
- foundry_mcp/core/security.py +438 -0
- foundry_mcp/core/spec.py +4119 -0
- foundry_mcp/core/task.py +2463 -0
- foundry_mcp/core/testing.py +839 -0
- foundry_mcp/core/validation.py +2357 -0
- foundry_mcp/dashboard/__init__.py +32 -0
- foundry_mcp/dashboard/app.py +119 -0
- foundry_mcp/dashboard/components/__init__.py +17 -0
- foundry_mcp/dashboard/components/cards.py +88 -0
- foundry_mcp/dashboard/components/charts.py +177 -0
- foundry_mcp/dashboard/components/filters.py +136 -0
- foundry_mcp/dashboard/components/tables.py +195 -0
- foundry_mcp/dashboard/data/__init__.py +11 -0
- foundry_mcp/dashboard/data/stores.py +433 -0
- foundry_mcp/dashboard/launcher.py +300 -0
- foundry_mcp/dashboard/views/__init__.py +12 -0
- foundry_mcp/dashboard/views/errors.py +217 -0
- foundry_mcp/dashboard/views/metrics.py +164 -0
- foundry_mcp/dashboard/views/overview.py +96 -0
- foundry_mcp/dashboard/views/providers.py +83 -0
- foundry_mcp/dashboard/views/sdd_workflow.py +255 -0
- foundry_mcp/dashboard/views/tool_usage.py +139 -0
- foundry_mcp/prompts/__init__.py +9 -0
- foundry_mcp/prompts/workflows.py +525 -0
- foundry_mcp/resources/__init__.py +9 -0
- foundry_mcp/resources/specs.py +591 -0
- foundry_mcp/schemas/__init__.py +38 -0
- foundry_mcp/schemas/intake-schema.json +89 -0
- foundry_mcp/schemas/sdd-spec-schema.json +414 -0
- foundry_mcp/server.py +150 -0
- foundry_mcp/tools/__init__.py +10 -0
- foundry_mcp/tools/unified/__init__.py +92 -0
- foundry_mcp/tools/unified/authoring.py +3620 -0
- foundry_mcp/tools/unified/context_helpers.py +98 -0
- foundry_mcp/tools/unified/documentation_helpers.py +268 -0
- foundry_mcp/tools/unified/environment.py +1341 -0
- foundry_mcp/tools/unified/error.py +479 -0
- foundry_mcp/tools/unified/health.py +225 -0
- foundry_mcp/tools/unified/journal.py +841 -0
- foundry_mcp/tools/unified/lifecycle.py +640 -0
- foundry_mcp/tools/unified/metrics.py +777 -0
- foundry_mcp/tools/unified/plan.py +876 -0
- foundry_mcp/tools/unified/pr.py +294 -0
- foundry_mcp/tools/unified/provider.py +589 -0
- foundry_mcp/tools/unified/research.py +1283 -0
- foundry_mcp/tools/unified/review.py +1042 -0
- foundry_mcp/tools/unified/review_helpers.py +314 -0
- foundry_mcp/tools/unified/router.py +102 -0
- foundry_mcp/tools/unified/server.py +565 -0
- foundry_mcp/tools/unified/spec.py +1283 -0
- foundry_mcp/tools/unified/task.py +3846 -0
- foundry_mcp/tools/unified/test.py +431 -0
- foundry_mcp/tools/unified/verification.py +520 -0
- foundry_mcp-0.8.22.dist-info/METADATA +344 -0
- foundry_mcp-0.8.22.dist-info/RECORD +153 -0
- foundry_mcp-0.8.22.dist-info/WHEEL +4 -0
- foundry_mcp-0.8.22.dist-info/entry_points.txt +3 -0
- foundry_mcp-0.8.22.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1569 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLM Provider abstraction for foundry-mcp.
|
|
3
|
+
|
|
4
|
+
Provides a unified interface for interacting with different LLM providers
|
|
5
|
+
(OpenAI, Anthropic, local models) with consistent error handling,
|
|
6
|
+
rate limiting, and observability.
|
|
7
|
+
|
|
8
|
+
Example:
|
|
9
|
+
from foundry_mcp.core.llm_provider import (
|
|
10
|
+
LLMProvider, ChatMessage, ChatRole, CompletionRequest
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
class MyProvider(LLMProvider):
|
|
14
|
+
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
|
15
|
+
# Implementation
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
async def chat(self, messages: List[ChatMessage], **kwargs) -> ChatResponse:
|
|
19
|
+
# Implementation
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
async def embed(self, texts: List[str], **kwargs) -> EmbeddingResponse:
|
|
23
|
+
# Implementation
|
|
24
|
+
pass
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
import logging
|
|
28
|
+
from abc import ABC, abstractmethod
|
|
29
|
+
from dataclasses import dataclass, field
|
|
30
|
+
from enum import Enum
|
|
31
|
+
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# =============================================================================
|
|
37
|
+
# Enums
|
|
38
|
+
# =============================================================================
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class ChatRole(str, Enum):
|
|
42
|
+
"""Role of a message in a chat conversation.
|
|
43
|
+
|
|
44
|
+
SYSTEM: System instructions/context
|
|
45
|
+
USER: User input
|
|
46
|
+
ASSISTANT: Model response
|
|
47
|
+
TOOL: Tool/function call result
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
SYSTEM = "system"
|
|
51
|
+
USER = "user"
|
|
52
|
+
ASSISTANT = "assistant"
|
|
53
|
+
TOOL = "tool"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class FinishReason(str, Enum):
|
|
57
|
+
"""Reason why the model stopped generating.
|
|
58
|
+
|
|
59
|
+
STOP: Natural completion (hit stop sequence or end)
|
|
60
|
+
LENGTH: Hit max_tokens limit
|
|
61
|
+
TOOL_CALL: Model wants to call a tool/function
|
|
62
|
+
CONTENT_FILTER: Filtered due to content policy
|
|
63
|
+
ERROR: Generation error occurred
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
STOP = "stop"
|
|
67
|
+
LENGTH = "length"
|
|
68
|
+
TOOL_CALL = "tool_call"
|
|
69
|
+
CONTENT_FILTER = "content_filter"
|
|
70
|
+
ERROR = "error"
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# =============================================================================
|
|
74
|
+
# Data Classes - Messages
|
|
75
|
+
# =============================================================================
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class ToolCall:
|
|
80
|
+
"""A tool/function call requested by the model.
|
|
81
|
+
|
|
82
|
+
Attributes:
|
|
83
|
+
id: Unique identifier for this tool call
|
|
84
|
+
name: Name of the tool/function to call
|
|
85
|
+
arguments: JSON-encoded arguments for the call
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
id: str
|
|
89
|
+
name: str
|
|
90
|
+
arguments: str # JSON string
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class ChatMessage:
|
|
95
|
+
"""A message in a chat conversation.
|
|
96
|
+
|
|
97
|
+
Attributes:
|
|
98
|
+
role: The role of the message sender
|
|
99
|
+
content: The text content of the message
|
|
100
|
+
name: Optional name for the sender (for multi-user chats)
|
|
101
|
+
tool_calls: List of tool calls if role is ASSISTANT
|
|
102
|
+
tool_call_id: ID of the tool call this responds to (if role is TOOL)
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
role: ChatRole
|
|
106
|
+
content: Optional[str] = None
|
|
107
|
+
name: Optional[str] = None
|
|
108
|
+
tool_calls: Optional[List[ToolCall]] = None
|
|
109
|
+
tool_call_id: Optional[str] = None
|
|
110
|
+
|
|
111
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
112
|
+
"""Convert to dictionary for API calls."""
|
|
113
|
+
result: Dict[str, Any] = {"role": self.role.value}
|
|
114
|
+
if self.content is not None:
|
|
115
|
+
result["content"] = self.content
|
|
116
|
+
if self.name:
|
|
117
|
+
result["name"] = self.name
|
|
118
|
+
if self.tool_calls:
|
|
119
|
+
result["tool_calls"] = [
|
|
120
|
+
{"id": tc.id, "type": "function", "function": {"name": tc.name, "arguments": tc.arguments}}
|
|
121
|
+
for tc in self.tool_calls
|
|
122
|
+
]
|
|
123
|
+
if self.tool_call_id:
|
|
124
|
+
result["tool_call_id"] = self.tool_call_id
|
|
125
|
+
return result
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# =============================================================================
|
|
129
|
+
# Data Classes - Requests
|
|
130
|
+
# =============================================================================
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@dataclass
|
|
134
|
+
class CompletionRequest:
|
|
135
|
+
"""Request for text completion (non-chat).
|
|
136
|
+
|
|
137
|
+
Attributes:
|
|
138
|
+
prompt: The prompt to complete
|
|
139
|
+
max_tokens: Maximum tokens to generate
|
|
140
|
+
temperature: Sampling temperature (0-2)
|
|
141
|
+
top_p: Nucleus sampling parameter
|
|
142
|
+
stop: Stop sequences
|
|
143
|
+
model: Model identifier (optional, uses provider default)
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
prompt: str
|
|
147
|
+
max_tokens: int = 256
|
|
148
|
+
temperature: float = 0.7
|
|
149
|
+
top_p: float = 1.0
|
|
150
|
+
stop: Optional[List[str]] = None
|
|
151
|
+
model: Optional[str] = None
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@dataclass
|
|
155
|
+
class ChatRequest:
|
|
156
|
+
"""Request for chat completion.
|
|
157
|
+
|
|
158
|
+
Attributes:
|
|
159
|
+
messages: The conversation messages
|
|
160
|
+
max_tokens: Maximum tokens to generate
|
|
161
|
+
temperature: Sampling temperature (0-2)
|
|
162
|
+
top_p: Nucleus sampling parameter
|
|
163
|
+
stop: Stop sequences
|
|
164
|
+
model: Model identifier (optional, uses provider default)
|
|
165
|
+
tools: Tool/function definitions for function calling
|
|
166
|
+
tool_choice: How to handle tool selection ('auto', 'none', or specific)
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
messages: List[ChatMessage]
|
|
170
|
+
max_tokens: int = 1024
|
|
171
|
+
temperature: float = 0.7
|
|
172
|
+
top_p: float = 1.0
|
|
173
|
+
stop: Optional[List[str]] = None
|
|
174
|
+
model: Optional[str] = None
|
|
175
|
+
tools: Optional[List[Dict[str, Any]]] = None
|
|
176
|
+
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
@dataclass
|
|
180
|
+
class EmbeddingRequest:
|
|
181
|
+
"""Request for text embeddings.
|
|
182
|
+
|
|
183
|
+
Attributes:
|
|
184
|
+
texts: List of texts to embed
|
|
185
|
+
model: Model identifier (optional, uses provider default)
|
|
186
|
+
dimensions: Output dimension size (if supported by model)
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
texts: List[str]
|
|
190
|
+
model: Optional[str] = None
|
|
191
|
+
dimensions: Optional[int] = None
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# =============================================================================
|
|
195
|
+
# Data Classes - Responses
|
|
196
|
+
# =============================================================================
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
@dataclass
|
|
200
|
+
class TokenUsage:
|
|
201
|
+
"""Token usage statistics.
|
|
202
|
+
|
|
203
|
+
Attributes:
|
|
204
|
+
prompt_tokens: Tokens in the input
|
|
205
|
+
completion_tokens: Tokens in the output
|
|
206
|
+
total_tokens: Total tokens used
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
prompt_tokens: int = 0
|
|
210
|
+
completion_tokens: int = 0
|
|
211
|
+
total_tokens: int = 0
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@dataclass
|
|
215
|
+
class CompletionResponse:
|
|
216
|
+
"""Response from text completion.
|
|
217
|
+
|
|
218
|
+
Attributes:
|
|
219
|
+
text: The generated text
|
|
220
|
+
finish_reason: Why generation stopped
|
|
221
|
+
usage: Token usage statistics
|
|
222
|
+
model: Model that generated the response
|
|
223
|
+
raw_response: Original API response (for debugging)
|
|
224
|
+
"""
|
|
225
|
+
|
|
226
|
+
text: str
|
|
227
|
+
finish_reason: FinishReason = FinishReason.STOP
|
|
228
|
+
usage: TokenUsage = field(default_factory=TokenUsage)
|
|
229
|
+
model: Optional[str] = None
|
|
230
|
+
raw_response: Optional[Dict[str, Any]] = None
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@dataclass
|
|
234
|
+
class ChatResponse:
|
|
235
|
+
"""Response from chat completion.
|
|
236
|
+
|
|
237
|
+
Attributes:
|
|
238
|
+
message: The assistant's response message
|
|
239
|
+
finish_reason: Why generation stopped
|
|
240
|
+
usage: Token usage statistics
|
|
241
|
+
model: Model that generated the response
|
|
242
|
+
raw_response: Original API response (for debugging)
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
message: ChatMessage
|
|
246
|
+
finish_reason: FinishReason = FinishReason.STOP
|
|
247
|
+
usage: TokenUsage = field(default_factory=TokenUsage)
|
|
248
|
+
model: Optional[str] = None
|
|
249
|
+
raw_response: Optional[Dict[str, Any]] = None
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@dataclass
|
|
253
|
+
class EmbeddingResponse:
|
|
254
|
+
"""Response from embedding request.
|
|
255
|
+
|
|
256
|
+
Attributes:
|
|
257
|
+
embeddings: List of embedding vectors
|
|
258
|
+
usage: Token usage statistics
|
|
259
|
+
model: Model that generated the embeddings
|
|
260
|
+
dimensions: Dimension size of embeddings
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
embeddings: List[List[float]]
|
|
264
|
+
usage: TokenUsage = field(default_factory=TokenUsage)
|
|
265
|
+
model: Optional[str] = None
|
|
266
|
+
dimensions: Optional[int] = None
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
# =============================================================================
|
|
270
|
+
# Exceptions
|
|
271
|
+
# =============================================================================
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class LLMError(Exception):
|
|
275
|
+
"""Base exception for LLM operations.
|
|
276
|
+
|
|
277
|
+
Attributes:
|
|
278
|
+
message: Human-readable error description
|
|
279
|
+
provider: Name of the provider that raised the error
|
|
280
|
+
retryable: Whether the operation can be retried
|
|
281
|
+
status_code: HTTP status code if applicable
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
def __init__(
|
|
285
|
+
self,
|
|
286
|
+
message: str,
|
|
287
|
+
*,
|
|
288
|
+
provider: Optional[str] = None,
|
|
289
|
+
retryable: bool = False,
|
|
290
|
+
status_code: Optional[int] = None,
|
|
291
|
+
):
|
|
292
|
+
super().__init__(message)
|
|
293
|
+
self.provider = provider
|
|
294
|
+
self.retryable = retryable
|
|
295
|
+
self.status_code = status_code
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
class RateLimitError(LLMError):
|
|
299
|
+
"""Rate limit exceeded error.
|
|
300
|
+
|
|
301
|
+
Attributes:
|
|
302
|
+
retry_after: Seconds to wait before retrying
|
|
303
|
+
"""
|
|
304
|
+
|
|
305
|
+
def __init__(
|
|
306
|
+
self,
|
|
307
|
+
message: str = "Rate limit exceeded",
|
|
308
|
+
*,
|
|
309
|
+
provider: Optional[str] = None,
|
|
310
|
+
retry_after: Optional[float] = None,
|
|
311
|
+
):
|
|
312
|
+
super().__init__(message, provider=provider, retryable=True, status_code=429)
|
|
313
|
+
self.retry_after = retry_after
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class AuthenticationError(LLMError):
|
|
317
|
+
"""Authentication failed error."""
|
|
318
|
+
|
|
319
|
+
def __init__(
|
|
320
|
+
self,
|
|
321
|
+
message: str = "Authentication failed",
|
|
322
|
+
*,
|
|
323
|
+
provider: Optional[str] = None,
|
|
324
|
+
):
|
|
325
|
+
super().__init__(message, provider=provider, retryable=False, status_code=401)
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class InvalidRequestError(LLMError):
|
|
329
|
+
"""Invalid request error (bad parameters, etc.)."""
|
|
330
|
+
|
|
331
|
+
def __init__(
|
|
332
|
+
self,
|
|
333
|
+
message: str,
|
|
334
|
+
*,
|
|
335
|
+
provider: Optional[str] = None,
|
|
336
|
+
param: Optional[str] = None,
|
|
337
|
+
):
|
|
338
|
+
super().__init__(message, provider=provider, retryable=False, status_code=400)
|
|
339
|
+
self.param = param
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
class ModelNotFoundError(LLMError):
|
|
343
|
+
"""Requested model not found or not accessible."""
|
|
344
|
+
|
|
345
|
+
def __init__(
|
|
346
|
+
self,
|
|
347
|
+
message: str,
|
|
348
|
+
*,
|
|
349
|
+
provider: Optional[str] = None,
|
|
350
|
+
model: Optional[str] = None,
|
|
351
|
+
):
|
|
352
|
+
super().__init__(message, provider=provider, retryable=False, status_code=404)
|
|
353
|
+
self.model = model
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class ContentFilterError(LLMError):
|
|
357
|
+
"""Content was filtered due to policy violation."""
|
|
358
|
+
|
|
359
|
+
def __init__(
|
|
360
|
+
self,
|
|
361
|
+
message: str = "Content filtered",
|
|
362
|
+
*,
|
|
363
|
+
provider: Optional[str] = None,
|
|
364
|
+
):
|
|
365
|
+
super().__init__(message, provider=provider, retryable=False, status_code=400)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
# =============================================================================
|
|
369
|
+
# Abstract Base Class
|
|
370
|
+
# =============================================================================
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
class LLMProvider(ABC):
|
|
374
|
+
"""Abstract base class for LLM providers.
|
|
375
|
+
|
|
376
|
+
Defines the interface that all LLM provider implementations must follow.
|
|
377
|
+
Provides consistent methods for completion, chat, and embedding operations.
|
|
378
|
+
|
|
379
|
+
Attributes:
|
|
380
|
+
name: Provider name (e.g., 'openai', 'anthropic', 'local')
|
|
381
|
+
default_model: Default model to use if not specified in requests
|
|
382
|
+
|
|
383
|
+
Example:
|
|
384
|
+
class OpenAIProvider(LLMProvider):
|
|
385
|
+
name = "openai"
|
|
386
|
+
default_model = "gpt-4"
|
|
387
|
+
|
|
388
|
+
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
|
389
|
+
# Call OpenAI API
|
|
390
|
+
pass
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
name: str = "base"
|
|
394
|
+
default_model: str = ""
|
|
395
|
+
|
|
396
|
+
@abstractmethod
|
|
397
|
+
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
|
398
|
+
"""Generate a text completion.
|
|
399
|
+
|
|
400
|
+
Args:
|
|
401
|
+
request: Completion request with prompt and parameters
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
CompletionResponse with generated text
|
|
405
|
+
|
|
406
|
+
Raises:
|
|
407
|
+
LLMError: On API or generation errors
|
|
408
|
+
RateLimitError: If rate limited
|
|
409
|
+
AuthenticationError: If authentication fails
|
|
410
|
+
"""
|
|
411
|
+
pass
|
|
412
|
+
|
|
413
|
+
@abstractmethod
|
|
414
|
+
async def chat(self, request: ChatRequest) -> ChatResponse:
|
|
415
|
+
"""Generate a chat completion.
|
|
416
|
+
|
|
417
|
+
Args:
|
|
418
|
+
request: Chat request with messages and parameters
|
|
419
|
+
|
|
420
|
+
Returns:
|
|
421
|
+
ChatResponse with assistant message
|
|
422
|
+
|
|
423
|
+
Raises:
|
|
424
|
+
LLMError: On API or generation errors
|
|
425
|
+
RateLimitError: If rate limited
|
|
426
|
+
AuthenticationError: If authentication fails
|
|
427
|
+
"""
|
|
428
|
+
pass
|
|
429
|
+
|
|
430
|
+
@abstractmethod
|
|
431
|
+
async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
|
432
|
+
"""Generate embeddings for texts.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
request: Embedding request with texts
|
|
436
|
+
|
|
437
|
+
Returns:
|
|
438
|
+
EmbeddingResponse with embedding vectors
|
|
439
|
+
|
|
440
|
+
Raises:
|
|
441
|
+
LLMError: On API or generation errors
|
|
442
|
+
RateLimitError: If rate limited
|
|
443
|
+
AuthenticationError: If authentication fails
|
|
444
|
+
"""
|
|
445
|
+
pass
|
|
446
|
+
|
|
447
|
+
async def stream_chat(
|
|
448
|
+
self, request: ChatRequest
|
|
449
|
+
) -> AsyncIterator[ChatResponse]:
|
|
450
|
+
"""Stream chat completion tokens.
|
|
451
|
+
|
|
452
|
+
Default implementation yields a single complete response.
|
|
453
|
+
Providers can override for true streaming support.
|
|
454
|
+
|
|
455
|
+
Args:
|
|
456
|
+
request: Chat request with messages and parameters
|
|
457
|
+
|
|
458
|
+
Yields:
|
|
459
|
+
ChatResponse chunks as they are generated
|
|
460
|
+
|
|
461
|
+
Raises:
|
|
462
|
+
LLMError: On API or generation errors
|
|
463
|
+
"""
|
|
464
|
+
response = await self.chat(request)
|
|
465
|
+
yield response
|
|
466
|
+
|
|
467
|
+
async def stream_complete(
|
|
468
|
+
self, request: CompletionRequest
|
|
469
|
+
) -> AsyncIterator[CompletionResponse]:
|
|
470
|
+
"""Stream completion tokens.
|
|
471
|
+
|
|
472
|
+
Default implementation yields a single complete response.
|
|
473
|
+
Providers can override for true streaming support.
|
|
474
|
+
|
|
475
|
+
Args:
|
|
476
|
+
request: Completion request with prompt and parameters
|
|
477
|
+
|
|
478
|
+
Yields:
|
|
479
|
+
CompletionResponse chunks as they are generated
|
|
480
|
+
|
|
481
|
+
Raises:
|
|
482
|
+
LLMError: On API or generation errors
|
|
483
|
+
"""
|
|
484
|
+
response = await self.complete(request)
|
|
485
|
+
yield response
|
|
486
|
+
|
|
487
|
+
def count_tokens(self, text: str, model: Optional[str] = None) -> int:
|
|
488
|
+
"""Count tokens in text.
|
|
489
|
+
|
|
490
|
+
Default implementation provides a rough estimate.
|
|
491
|
+
Providers should override with accurate tokenization.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
text: Text to count tokens for
|
|
495
|
+
model: Model to use for tokenization (optional)
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
Estimated token count
|
|
499
|
+
"""
|
|
500
|
+
# Rough estimate: ~4 characters per token for English
|
|
501
|
+
return len(text) // 4
|
|
502
|
+
|
|
503
|
+
def validate_request(self, request: Union[CompletionRequest, ChatRequest, EmbeddingRequest]) -> None:
|
|
504
|
+
"""Validate a request before sending.
|
|
505
|
+
|
|
506
|
+
Override to add provider-specific validation.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
request: Request to validate
|
|
510
|
+
|
|
511
|
+
Raises:
|
|
512
|
+
InvalidRequestError: If request is invalid
|
|
513
|
+
"""
|
|
514
|
+
if isinstance(request, CompletionRequest):
|
|
515
|
+
if not request.prompt:
|
|
516
|
+
raise InvalidRequestError("Prompt cannot be empty", provider=self.name)
|
|
517
|
+
if request.max_tokens < 1:
|
|
518
|
+
raise InvalidRequestError("max_tokens must be positive", provider=self.name, param="max_tokens")
|
|
519
|
+
|
|
520
|
+
elif isinstance(request, ChatRequest):
|
|
521
|
+
if not request.messages:
|
|
522
|
+
raise InvalidRequestError("Messages cannot be empty", provider=self.name)
|
|
523
|
+
if request.max_tokens < 1:
|
|
524
|
+
raise InvalidRequestError("max_tokens must be positive", provider=self.name, param="max_tokens")
|
|
525
|
+
|
|
526
|
+
elif isinstance(request, EmbeddingRequest):
|
|
527
|
+
if not request.texts:
|
|
528
|
+
raise InvalidRequestError("Texts cannot be empty", provider=self.name)
|
|
529
|
+
|
|
530
|
+
def get_model(self, requested: Optional[str] = None) -> str:
|
|
531
|
+
"""Get the model to use for a request.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
requested: Explicitly requested model (optional)
|
|
535
|
+
|
|
536
|
+
Returns:
|
|
537
|
+
Model identifier to use
|
|
538
|
+
"""
|
|
539
|
+
return requested or self.default_model
|
|
540
|
+
|
|
541
|
+
async def health_check(self) -> bool:
|
|
542
|
+
"""Check if the provider is healthy and accessible.
|
|
543
|
+
|
|
544
|
+
Default implementation tries a minimal chat request.
|
|
545
|
+
Providers can override with more efficient checks.
|
|
546
|
+
|
|
547
|
+
Returns:
|
|
548
|
+
True if provider is healthy, False otherwise
|
|
549
|
+
"""
|
|
550
|
+
try:
|
|
551
|
+
request = ChatRequest(
|
|
552
|
+
messages=[ChatMessage(role=ChatRole.USER, content="ping")],
|
|
553
|
+
max_tokens=1,
|
|
554
|
+
)
|
|
555
|
+
await self.chat(request)
|
|
556
|
+
return True
|
|
557
|
+
except Exception as e:
|
|
558
|
+
logger.warning(f"Health check failed for {self.name}: {e}")
|
|
559
|
+
return False
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
# =============================================================================
|
|
563
|
+
# OpenAI Provider Implementation
|
|
564
|
+
# =============================================================================
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
class OpenAIProvider(LLMProvider):
|
|
568
|
+
"""OpenAI API provider implementation.
|
|
569
|
+
|
|
570
|
+
Supports GPT-4, GPT-3.5-turbo, and embedding models via the OpenAI API.
|
|
571
|
+
|
|
572
|
+
Attributes:
|
|
573
|
+
name: Provider identifier ('openai')
|
|
574
|
+
default_model: Default chat model ('gpt-4')
|
|
575
|
+
default_embedding_model: Default embedding model
|
|
576
|
+
api_key: OpenAI API key
|
|
577
|
+
organization: Optional organization ID
|
|
578
|
+
base_url: API base URL (for proxies/Azure)
|
|
579
|
+
|
|
580
|
+
Example:
|
|
581
|
+
provider = OpenAIProvider(api_key="sk-...")
|
|
582
|
+
response = await provider.chat(ChatRequest(
|
|
583
|
+
messages=[ChatMessage(role=ChatRole.USER, content="Hello!")]
|
|
584
|
+
))
|
|
585
|
+
"""
|
|
586
|
+
|
|
587
|
+
name: str = "openai"
|
|
588
|
+
default_model: str = "gpt-4"
|
|
589
|
+
default_embedding_model: str = "text-embedding-3-small"
|
|
590
|
+
|
|
591
|
+
def __init__(
|
|
592
|
+
self,
|
|
593
|
+
api_key: Optional[str] = None,
|
|
594
|
+
organization: Optional[str] = None,
|
|
595
|
+
base_url: Optional[str] = None,
|
|
596
|
+
default_model: Optional[str] = None,
|
|
597
|
+
default_embedding_model: Optional[str] = None,
|
|
598
|
+
):
|
|
599
|
+
"""Initialize the OpenAI provider.
|
|
600
|
+
|
|
601
|
+
Args:
|
|
602
|
+
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
|
|
603
|
+
organization: Optional organization ID
|
|
604
|
+
base_url: API base URL (defaults to OpenAI's API)
|
|
605
|
+
default_model: Override default chat model
|
|
606
|
+
default_embedding_model: Override default embedding model
|
|
607
|
+
"""
|
|
608
|
+
import os
|
|
609
|
+
|
|
610
|
+
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
|
611
|
+
self.organization = organization or os.environ.get("OPENAI_ORGANIZATION")
|
|
612
|
+
self.base_url = base_url or "https://api.openai.com/v1"
|
|
613
|
+
|
|
614
|
+
if default_model:
|
|
615
|
+
self.default_model = default_model
|
|
616
|
+
if default_embedding_model:
|
|
617
|
+
self.default_embedding_model = default_embedding_model
|
|
618
|
+
|
|
619
|
+
self._client: Optional[Any] = None
|
|
620
|
+
|
|
621
|
+
def _get_client(self) -> Any:
|
|
622
|
+
"""Get or create the OpenAI client (lazy initialization)."""
|
|
623
|
+
if self._client is None:
|
|
624
|
+
try:
|
|
625
|
+
from openai import AsyncOpenAI
|
|
626
|
+
except ImportError:
|
|
627
|
+
raise LLMError(
|
|
628
|
+
"openai package not installed. Install with: pip install openai",
|
|
629
|
+
provider=self.name,
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
if not self.api_key:
|
|
633
|
+
raise AuthenticationError(
|
|
634
|
+
"OpenAI API key not provided. Set OPENAI_API_KEY or pass api_key.",
|
|
635
|
+
provider=self.name,
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
self._client = AsyncOpenAI(
|
|
639
|
+
api_key=self.api_key,
|
|
640
|
+
organization=self.organization,
|
|
641
|
+
base_url=self.base_url,
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
return self._client
|
|
645
|
+
|
|
646
|
+
def _handle_api_error(self, error: Exception) -> None:
|
|
647
|
+
"""Convert OpenAI errors to LLMError types."""
|
|
648
|
+
error_str = str(error)
|
|
649
|
+
error_type = type(error).__name__
|
|
650
|
+
|
|
651
|
+
if "rate_limit" in error_str.lower() or error_type == "RateLimitError":
|
|
652
|
+
# Try to extract retry-after
|
|
653
|
+
retry_after = None
|
|
654
|
+
if hasattr(error, "response"):
|
|
655
|
+
retry_after = getattr(error.response.headers, "get", lambda x: None)(
|
|
656
|
+
"retry-after"
|
|
657
|
+
)
|
|
658
|
+
if retry_after:
|
|
659
|
+
try:
|
|
660
|
+
retry_after = float(retry_after)
|
|
661
|
+
except ValueError:
|
|
662
|
+
retry_after = None
|
|
663
|
+
raise RateLimitError(error_str, provider=self.name, retry_after=retry_after)
|
|
664
|
+
|
|
665
|
+
if "authentication" in error_str.lower() or error_type == "AuthenticationError":
|
|
666
|
+
raise AuthenticationError(error_str, provider=self.name)
|
|
667
|
+
|
|
668
|
+
if "invalid" in error_str.lower() or error_type == "BadRequestError":
|
|
669
|
+
raise InvalidRequestError(error_str, provider=self.name)
|
|
670
|
+
|
|
671
|
+
if "not found" in error_str.lower() or error_type == "NotFoundError":
|
|
672
|
+
raise ModelNotFoundError(error_str, provider=self.name)
|
|
673
|
+
|
|
674
|
+
if "content_filter" in error_str.lower():
|
|
675
|
+
raise ContentFilterError(error_str, provider=self.name)
|
|
676
|
+
|
|
677
|
+
# Generic error
|
|
678
|
+
raise LLMError(error_str, provider=self.name, retryable=True)
|
|
679
|
+
|
|
680
|
+
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
|
681
|
+
"""Generate a text completion using OpenAI's API.
|
|
682
|
+
|
|
683
|
+
Note: Uses chat completions API internally as legacy completions
|
|
684
|
+
are deprecated for most models.
|
|
685
|
+
"""
|
|
686
|
+
self.validate_request(request)
|
|
687
|
+
client = self._get_client()
|
|
688
|
+
model = self.get_model(request.model)
|
|
689
|
+
|
|
690
|
+
try:
|
|
691
|
+
# Use chat completions API (legacy completions deprecated)
|
|
692
|
+
response = await client.chat.completions.create(
|
|
693
|
+
model=model,
|
|
694
|
+
messages=[{"role": "user", "content": request.prompt}],
|
|
695
|
+
max_tokens=request.max_tokens,
|
|
696
|
+
temperature=request.temperature,
|
|
697
|
+
top_p=request.top_p,
|
|
698
|
+
stop=request.stop,
|
|
699
|
+
)
|
|
700
|
+
|
|
701
|
+
choice = response.choices[0]
|
|
702
|
+
usage = TokenUsage(
|
|
703
|
+
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
|
704
|
+
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
|
705
|
+
total_tokens=response.usage.total_tokens if response.usage else 0,
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
return CompletionResponse(
|
|
709
|
+
text=choice.message.content or "",
|
|
710
|
+
finish_reason=self._map_finish_reason(choice.finish_reason),
|
|
711
|
+
usage=usage,
|
|
712
|
+
model=response.model,
|
|
713
|
+
raw_response=response.model_dump() if hasattr(response, "model_dump") else None,
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
except Exception as e:
|
|
717
|
+
self._handle_api_error(e)
|
|
718
|
+
raise # Unreachable but keeps type checker happy
|
|
719
|
+
|
|
720
|
+
async def chat(self, request: ChatRequest) -> ChatResponse:
|
|
721
|
+
"""Generate a chat completion using OpenAI's API."""
|
|
722
|
+
self.validate_request(request)
|
|
723
|
+
client = self._get_client()
|
|
724
|
+
model = self.get_model(request.model)
|
|
725
|
+
|
|
726
|
+
try:
|
|
727
|
+
# Convert messages to OpenAI format
|
|
728
|
+
messages = [msg.to_dict() for msg in request.messages]
|
|
729
|
+
|
|
730
|
+
kwargs: Dict[str, Any] = {
|
|
731
|
+
"model": model,
|
|
732
|
+
"messages": messages,
|
|
733
|
+
"max_tokens": request.max_tokens,
|
|
734
|
+
"temperature": request.temperature,
|
|
735
|
+
"top_p": request.top_p,
|
|
736
|
+
}
|
|
737
|
+
|
|
738
|
+
if request.stop:
|
|
739
|
+
kwargs["stop"] = request.stop
|
|
740
|
+
if request.tools:
|
|
741
|
+
kwargs["tools"] = request.tools
|
|
742
|
+
if request.tool_choice:
|
|
743
|
+
kwargs["tool_choice"] = request.tool_choice
|
|
744
|
+
|
|
745
|
+
response = await client.chat.completions.create(**kwargs)
|
|
746
|
+
|
|
747
|
+
choice = response.choices[0]
|
|
748
|
+
message = choice.message
|
|
749
|
+
|
|
750
|
+
# Parse tool calls if present
|
|
751
|
+
tool_calls = None
|
|
752
|
+
if message.tool_calls:
|
|
753
|
+
tool_calls = [
|
|
754
|
+
ToolCall(
|
|
755
|
+
id=tc.id,
|
|
756
|
+
name=tc.function.name,
|
|
757
|
+
arguments=tc.function.arguments,
|
|
758
|
+
)
|
|
759
|
+
for tc in message.tool_calls
|
|
760
|
+
]
|
|
761
|
+
|
|
762
|
+
usage = TokenUsage(
|
|
763
|
+
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
|
764
|
+
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
|
765
|
+
total_tokens=response.usage.total_tokens if response.usage else 0,
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
return ChatResponse(
|
|
769
|
+
message=ChatMessage(
|
|
770
|
+
role=ChatRole.ASSISTANT,
|
|
771
|
+
content=message.content,
|
|
772
|
+
tool_calls=tool_calls,
|
|
773
|
+
),
|
|
774
|
+
finish_reason=self._map_finish_reason(choice.finish_reason),
|
|
775
|
+
usage=usage,
|
|
776
|
+
model=response.model,
|
|
777
|
+
raw_response=response.model_dump() if hasattr(response, "model_dump") else None,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
except Exception as e:
|
|
781
|
+
self._handle_api_error(e)
|
|
782
|
+
raise
|
|
783
|
+
|
|
784
|
+
async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
|
785
|
+
"""Generate embeddings using OpenAI's API."""
|
|
786
|
+
self.validate_request(request)
|
|
787
|
+
client = self._get_client()
|
|
788
|
+
model = request.model or self.default_embedding_model
|
|
789
|
+
|
|
790
|
+
try:
|
|
791
|
+
kwargs: Dict[str, Any] = {
|
|
792
|
+
"model": model,
|
|
793
|
+
"input": request.texts,
|
|
794
|
+
}
|
|
795
|
+
|
|
796
|
+
if request.dimensions:
|
|
797
|
+
kwargs["dimensions"] = request.dimensions
|
|
798
|
+
|
|
799
|
+
response = await client.embeddings.create(**kwargs)
|
|
800
|
+
|
|
801
|
+
embeddings = [item.embedding for item in response.data]
|
|
802
|
+
|
|
803
|
+
usage = TokenUsage(
|
|
804
|
+
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
|
805
|
+
total_tokens=response.usage.total_tokens if response.usage else 0,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
return EmbeddingResponse(
|
|
809
|
+
embeddings=embeddings,
|
|
810
|
+
usage=usage,
|
|
811
|
+
model=response.model,
|
|
812
|
+
dimensions=len(embeddings[0]) if embeddings else None,
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
except Exception as e:
|
|
816
|
+
self._handle_api_error(e)
|
|
817
|
+
raise
|
|
818
|
+
|
|
819
|
+
async def stream_chat(self, request: ChatRequest) -> AsyncIterator[ChatResponse]:
|
|
820
|
+
"""Stream chat completion tokens from OpenAI."""
|
|
821
|
+
self.validate_request(request)
|
|
822
|
+
client = self._get_client()
|
|
823
|
+
model = self.get_model(request.model)
|
|
824
|
+
|
|
825
|
+
try:
|
|
826
|
+
messages = [msg.to_dict() for msg in request.messages]
|
|
827
|
+
|
|
828
|
+
kwargs: Dict[str, Any] = {
|
|
829
|
+
"model": model,
|
|
830
|
+
"messages": messages,
|
|
831
|
+
"max_tokens": request.max_tokens,
|
|
832
|
+
"temperature": request.temperature,
|
|
833
|
+
"top_p": request.top_p,
|
|
834
|
+
"stream": True,
|
|
835
|
+
}
|
|
836
|
+
|
|
837
|
+
if request.stop:
|
|
838
|
+
kwargs["stop"] = request.stop
|
|
839
|
+
|
|
840
|
+
stream = await client.chat.completions.create(**kwargs)
|
|
841
|
+
|
|
842
|
+
async for chunk in stream:
|
|
843
|
+
if not chunk.choices:
|
|
844
|
+
continue
|
|
845
|
+
|
|
846
|
+
choice = chunk.choices[0]
|
|
847
|
+
delta = choice.delta
|
|
848
|
+
|
|
849
|
+
if delta.content:
|
|
850
|
+
yield ChatResponse(
|
|
851
|
+
message=ChatMessage(
|
|
852
|
+
role=ChatRole.ASSISTANT,
|
|
853
|
+
content=delta.content,
|
|
854
|
+
),
|
|
855
|
+
finish_reason=self._map_finish_reason(choice.finish_reason) if choice.finish_reason else FinishReason.STOP,
|
|
856
|
+
model=chunk.model,
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
except Exception as e:
|
|
860
|
+
self._handle_api_error(e)
|
|
861
|
+
raise
|
|
862
|
+
|
|
863
|
+
def _map_finish_reason(self, reason: Optional[str]) -> FinishReason:
|
|
864
|
+
"""Map OpenAI finish reason to FinishReason enum."""
|
|
865
|
+
if reason is None:
|
|
866
|
+
return FinishReason.STOP
|
|
867
|
+
|
|
868
|
+
mapping = {
|
|
869
|
+
"stop": FinishReason.STOP,
|
|
870
|
+
"length": FinishReason.LENGTH,
|
|
871
|
+
"tool_calls": FinishReason.TOOL_CALL,
|
|
872
|
+
"function_call": FinishReason.TOOL_CALL,
|
|
873
|
+
"content_filter": FinishReason.CONTENT_FILTER,
|
|
874
|
+
}
|
|
875
|
+
return mapping.get(reason, FinishReason.STOP)
|
|
876
|
+
|
|
877
|
+
def count_tokens(self, text: str, model: Optional[str] = None) -> int:
|
|
878
|
+
"""Count tokens using tiktoken (if available)."""
|
|
879
|
+
try:
|
|
880
|
+
import tiktoken
|
|
881
|
+
|
|
882
|
+
model_name = model or self.default_model
|
|
883
|
+
try:
|
|
884
|
+
encoding = tiktoken.encoding_for_model(model_name)
|
|
885
|
+
except KeyError:
|
|
886
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
|
887
|
+
|
|
888
|
+
return len(encoding.encode(text))
|
|
889
|
+
except ImportError:
|
|
890
|
+
# Fall back to rough estimate
|
|
891
|
+
return super().count_tokens(text, model)
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
# =============================================================================
|
|
895
|
+
# Anthropic Provider Implementation
|
|
896
|
+
# =============================================================================
|
|
897
|
+
|
|
898
|
+
|
|
899
|
+
class AnthropicProvider(LLMProvider):
|
|
900
|
+
"""Anthropic API provider implementation.
|
|
901
|
+
|
|
902
|
+
Supports Claude 3 models (opus, sonnet, haiku) via the Anthropic API.
|
|
903
|
+
|
|
904
|
+
Attributes:
|
|
905
|
+
name: Provider identifier ('anthropic')
|
|
906
|
+
default_model: Default chat model ('claude-3-sonnet-20240229')
|
|
907
|
+
api_key: Anthropic API key
|
|
908
|
+
base_url: API base URL (for proxies)
|
|
909
|
+
max_tokens_default: Default max tokens for responses
|
|
910
|
+
|
|
911
|
+
Example:
|
|
912
|
+
provider = AnthropicProvider(api_key="sk-ant-...")
|
|
913
|
+
response = await provider.chat(ChatRequest(
|
|
914
|
+
messages=[ChatMessage(role=ChatRole.USER, content="Hello!")]
|
|
915
|
+
))
|
|
916
|
+
|
|
917
|
+
Note:
|
|
918
|
+
Anthropic does not support embeddings. The embed() method will raise
|
|
919
|
+
an error if called.
|
|
920
|
+
"""
|
|
921
|
+
|
|
922
|
+
name: str = "anthropic"
|
|
923
|
+
default_model: str = "claude-sonnet-4-20250514"
|
|
924
|
+
max_tokens_default: int = 4096
|
|
925
|
+
|
|
926
|
+
def __init__(
|
|
927
|
+
self,
|
|
928
|
+
api_key: Optional[str] = None,
|
|
929
|
+
base_url: Optional[str] = None,
|
|
930
|
+
default_model: Optional[str] = None,
|
|
931
|
+
max_tokens_default: Optional[int] = None,
|
|
932
|
+
):
|
|
933
|
+
"""Initialize the Anthropic provider.
|
|
934
|
+
|
|
935
|
+
Args:
|
|
936
|
+
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
|
|
937
|
+
base_url: API base URL (defaults to Anthropic's API)
|
|
938
|
+
default_model: Override default chat model
|
|
939
|
+
max_tokens_default: Override default max tokens
|
|
940
|
+
"""
|
|
941
|
+
import os
|
|
942
|
+
|
|
943
|
+
self.api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
|
|
944
|
+
self.base_url = base_url
|
|
945
|
+
|
|
946
|
+
if default_model:
|
|
947
|
+
self.default_model = default_model
|
|
948
|
+
if max_tokens_default:
|
|
949
|
+
self.max_tokens_default = max_tokens_default
|
|
950
|
+
|
|
951
|
+
self._client: Optional[Any] = None
|
|
952
|
+
|
|
953
|
+
def _get_client(self) -> Any:
|
|
954
|
+
"""Get or create the Anthropic client (lazy initialization)."""
|
|
955
|
+
if self._client is None:
|
|
956
|
+
try:
|
|
957
|
+
from anthropic import AsyncAnthropic
|
|
958
|
+
except ImportError:
|
|
959
|
+
raise LLMError(
|
|
960
|
+
"anthropic package not installed. Install with: pip install anthropic",
|
|
961
|
+
provider=self.name,
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
if not self.api_key:
|
|
965
|
+
raise AuthenticationError(
|
|
966
|
+
"Anthropic API key not provided. Set ANTHROPIC_API_KEY or pass api_key.",
|
|
967
|
+
provider=self.name,
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
kwargs: Dict[str, Any] = {"api_key": self.api_key}
|
|
971
|
+
if self.base_url:
|
|
972
|
+
kwargs["base_url"] = self.base_url
|
|
973
|
+
|
|
974
|
+
self._client = AsyncAnthropic(**kwargs)
|
|
975
|
+
|
|
976
|
+
return self._client
|
|
977
|
+
|
|
978
|
+
def _handle_api_error(self, error: Exception) -> None:
|
|
979
|
+
"""Convert Anthropic errors to LLMError types."""
|
|
980
|
+
error_str = str(error)
|
|
981
|
+
error_type = type(error).__name__
|
|
982
|
+
|
|
983
|
+
if "rate_limit" in error_str.lower() or error_type == "RateLimitError":
|
|
984
|
+
retry_after = None
|
|
985
|
+
if hasattr(error, "response") and error.response:
|
|
986
|
+
retry_after_str = error.response.headers.get("retry-after")
|
|
987
|
+
if retry_after_str:
|
|
988
|
+
try:
|
|
989
|
+
retry_after = float(retry_after_str)
|
|
990
|
+
except ValueError:
|
|
991
|
+
pass
|
|
992
|
+
raise RateLimitError(error_str, provider=self.name, retry_after=retry_after)
|
|
993
|
+
|
|
994
|
+
if "authentication" in error_str.lower() or error_type == "AuthenticationError":
|
|
995
|
+
raise AuthenticationError(error_str, provider=self.name)
|
|
996
|
+
|
|
997
|
+
if "invalid" in error_str.lower() or error_type == "BadRequestError":
|
|
998
|
+
raise InvalidRequestError(error_str, provider=self.name)
|
|
999
|
+
|
|
1000
|
+
if "not found" in error_str.lower() or error_type == "NotFoundError":
|
|
1001
|
+
raise ModelNotFoundError(error_str, provider=self.name)
|
|
1002
|
+
|
|
1003
|
+
# Generic error
|
|
1004
|
+
raise LLMError(error_str, provider=self.name, retryable=True)
|
|
1005
|
+
|
|
1006
|
+
def _convert_messages(self, messages: List[ChatMessage]) -> tuple[Optional[str], List[Dict[str, Any]]]:
|
|
1007
|
+
"""Convert ChatMessages to Anthropic format, extracting system message.
|
|
1008
|
+
|
|
1009
|
+
Anthropic requires system message as a separate parameter.
|
|
1010
|
+
|
|
1011
|
+
Returns:
|
|
1012
|
+
Tuple of (system_message, messages_list)
|
|
1013
|
+
"""
|
|
1014
|
+
system_message = None
|
|
1015
|
+
converted = []
|
|
1016
|
+
|
|
1017
|
+
for msg in messages:
|
|
1018
|
+
if msg.role == ChatRole.SYSTEM:
|
|
1019
|
+
# Anthropic takes system as separate param
|
|
1020
|
+
system_message = msg.content
|
|
1021
|
+
elif msg.role == ChatRole.TOOL:
|
|
1022
|
+
# Tool results in Anthropic format
|
|
1023
|
+
converted.append({
|
|
1024
|
+
"role": "user",
|
|
1025
|
+
"content": [{
|
|
1026
|
+
"type": "tool_result",
|
|
1027
|
+
"tool_use_id": msg.tool_call_id,
|
|
1028
|
+
"content": msg.content,
|
|
1029
|
+
}],
|
|
1030
|
+
})
|
|
1031
|
+
elif msg.role == ChatRole.ASSISTANT and msg.tool_calls:
|
|
1032
|
+
# Assistant message with tool calls
|
|
1033
|
+
content: List[Dict[str, Any]] = []
|
|
1034
|
+
if msg.content:
|
|
1035
|
+
content.append({"type": "text", "text": msg.content})
|
|
1036
|
+
for tc in msg.tool_calls:
|
|
1037
|
+
import json
|
|
1038
|
+
content.append({
|
|
1039
|
+
"type": "tool_use",
|
|
1040
|
+
"id": tc.id,
|
|
1041
|
+
"name": tc.name,
|
|
1042
|
+
"input": json.loads(tc.arguments) if tc.arguments else {},
|
|
1043
|
+
})
|
|
1044
|
+
converted.append({"role": "assistant", "content": content})
|
|
1045
|
+
else:
|
|
1046
|
+
# Regular user/assistant message
|
|
1047
|
+
converted.append({
|
|
1048
|
+
"role": msg.role.value,
|
|
1049
|
+
"content": msg.content or "",
|
|
1050
|
+
})
|
|
1051
|
+
|
|
1052
|
+
return system_message, converted
|
|
1053
|
+
|
|
1054
|
+
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
|
1055
|
+
"""Generate a text completion using Anthropic's API.
|
|
1056
|
+
|
|
1057
|
+
Uses the messages API internally.
|
|
1058
|
+
"""
|
|
1059
|
+
self.validate_request(request)
|
|
1060
|
+
client = self._get_client()
|
|
1061
|
+
model = self.get_model(request.model)
|
|
1062
|
+
|
|
1063
|
+
try:
|
|
1064
|
+
response = await client.messages.create(
|
|
1065
|
+
model=model,
|
|
1066
|
+
messages=[{"role": "user", "content": request.prompt}],
|
|
1067
|
+
max_tokens=request.max_tokens or self.max_tokens_default,
|
|
1068
|
+
temperature=request.temperature,
|
|
1069
|
+
top_p=request.top_p,
|
|
1070
|
+
stop_sequences=request.stop or [],
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
text = ""
|
|
1074
|
+
for block in response.content:
|
|
1075
|
+
if block.type == "text":
|
|
1076
|
+
text += block.text
|
|
1077
|
+
|
|
1078
|
+
usage = TokenUsage(
|
|
1079
|
+
prompt_tokens=response.usage.input_tokens,
|
|
1080
|
+
completion_tokens=response.usage.output_tokens,
|
|
1081
|
+
total_tokens=response.usage.input_tokens + response.usage.output_tokens,
|
|
1082
|
+
)
|
|
1083
|
+
|
|
1084
|
+
return CompletionResponse(
|
|
1085
|
+
text=text,
|
|
1086
|
+
finish_reason=self._map_stop_reason(response.stop_reason),
|
|
1087
|
+
usage=usage,
|
|
1088
|
+
model=response.model,
|
|
1089
|
+
raw_response=response.model_dump() if hasattr(response, "model_dump") else None,
|
|
1090
|
+
)
|
|
1091
|
+
|
|
1092
|
+
except Exception as e:
|
|
1093
|
+
self._handle_api_error(e)
|
|
1094
|
+
raise
|
|
1095
|
+
|
|
1096
|
+
async def chat(self, request: ChatRequest) -> ChatResponse:
|
|
1097
|
+
"""Generate a chat completion using Anthropic's API."""
|
|
1098
|
+
self.validate_request(request)
|
|
1099
|
+
client = self._get_client()
|
|
1100
|
+
model = self.get_model(request.model)
|
|
1101
|
+
|
|
1102
|
+
try:
|
|
1103
|
+
system_message, messages = self._convert_messages(request.messages)
|
|
1104
|
+
|
|
1105
|
+
kwargs: Dict[str, Any] = {
|
|
1106
|
+
"model": model,
|
|
1107
|
+
"messages": messages,
|
|
1108
|
+
"max_tokens": request.max_tokens or self.max_tokens_default,
|
|
1109
|
+
"temperature": request.temperature,
|
|
1110
|
+
"top_p": request.top_p,
|
|
1111
|
+
}
|
|
1112
|
+
|
|
1113
|
+
if system_message:
|
|
1114
|
+
kwargs["system"] = system_message
|
|
1115
|
+
if request.stop:
|
|
1116
|
+
kwargs["stop_sequences"] = request.stop
|
|
1117
|
+
if request.tools:
|
|
1118
|
+
# Convert OpenAI-style tools to Anthropic format
|
|
1119
|
+
kwargs["tools"] = self._convert_tools(request.tools)
|
|
1120
|
+
|
|
1121
|
+
response = await client.messages.create(**kwargs)
|
|
1122
|
+
|
|
1123
|
+
# Parse response content
|
|
1124
|
+
content_text = None
|
|
1125
|
+
tool_calls = []
|
|
1126
|
+
|
|
1127
|
+
for block in response.content:
|
|
1128
|
+
if block.type == "text":
|
|
1129
|
+
content_text = (content_text or "") + block.text
|
|
1130
|
+
elif block.type == "tool_use":
|
|
1131
|
+
import json
|
|
1132
|
+
tool_calls.append(ToolCall(
|
|
1133
|
+
id=block.id,
|
|
1134
|
+
name=block.name,
|
|
1135
|
+
arguments=json.dumps(block.input),
|
|
1136
|
+
))
|
|
1137
|
+
|
|
1138
|
+
usage = TokenUsage(
|
|
1139
|
+
prompt_tokens=response.usage.input_tokens,
|
|
1140
|
+
completion_tokens=response.usage.output_tokens,
|
|
1141
|
+
total_tokens=response.usage.input_tokens + response.usage.output_tokens,
|
|
1142
|
+
)
|
|
1143
|
+
|
|
1144
|
+
return ChatResponse(
|
|
1145
|
+
message=ChatMessage(
|
|
1146
|
+
role=ChatRole.ASSISTANT,
|
|
1147
|
+
content=content_text,
|
|
1148
|
+
tool_calls=tool_calls if tool_calls else None,
|
|
1149
|
+
),
|
|
1150
|
+
finish_reason=self._map_stop_reason(response.stop_reason),
|
|
1151
|
+
usage=usage,
|
|
1152
|
+
model=response.model,
|
|
1153
|
+
raw_response=response.model_dump() if hasattr(response, "model_dump") else None,
|
|
1154
|
+
)
|
|
1155
|
+
|
|
1156
|
+
except Exception as e:
|
|
1157
|
+
self._handle_api_error(e)
|
|
1158
|
+
raise
|
|
1159
|
+
|
|
1160
|
+
async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
|
1161
|
+
"""Anthropic does not support embeddings.
|
|
1162
|
+
|
|
1163
|
+
Raises:
|
|
1164
|
+
LLMError: Always, as embeddings are not supported
|
|
1165
|
+
"""
|
|
1166
|
+
raise LLMError(
|
|
1167
|
+
"Anthropic does not support embeddings. Use OpenAI or a local embedding model.",
|
|
1168
|
+
provider=self.name,
|
|
1169
|
+
retryable=False,
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
async def stream_chat(self, request: ChatRequest) -> AsyncIterator[ChatResponse]:
|
|
1173
|
+
"""Stream chat completion tokens from Anthropic."""
|
|
1174
|
+
self.validate_request(request)
|
|
1175
|
+
client = self._get_client()
|
|
1176
|
+
model = self.get_model(request.model)
|
|
1177
|
+
|
|
1178
|
+
try:
|
|
1179
|
+
system_message, messages = self._convert_messages(request.messages)
|
|
1180
|
+
|
|
1181
|
+
kwargs: Dict[str, Any] = {
|
|
1182
|
+
"model": model,
|
|
1183
|
+
"messages": messages,
|
|
1184
|
+
"max_tokens": request.max_tokens or self.max_tokens_default,
|
|
1185
|
+
"temperature": request.temperature,
|
|
1186
|
+
"top_p": request.top_p,
|
|
1187
|
+
}
|
|
1188
|
+
|
|
1189
|
+
if system_message:
|
|
1190
|
+
kwargs["system"] = system_message
|
|
1191
|
+
if request.stop:
|
|
1192
|
+
kwargs["stop_sequences"] = request.stop
|
|
1193
|
+
|
|
1194
|
+
async with client.messages.stream(**kwargs) as stream:
|
|
1195
|
+
async for text in stream.text_stream:
|
|
1196
|
+
yield ChatResponse(
|
|
1197
|
+
message=ChatMessage(
|
|
1198
|
+
role=ChatRole.ASSISTANT,
|
|
1199
|
+
content=text,
|
|
1200
|
+
),
|
|
1201
|
+
finish_reason=FinishReason.STOP,
|
|
1202
|
+
model=model,
|
|
1203
|
+
)
|
|
1204
|
+
|
|
1205
|
+
except Exception as e:
|
|
1206
|
+
self._handle_api_error(e)
|
|
1207
|
+
raise
|
|
1208
|
+
|
|
1209
|
+
def _convert_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
1210
|
+
"""Convert OpenAI-style tool definitions to Anthropic format."""
|
|
1211
|
+
converted = []
|
|
1212
|
+
for tool in tools:
|
|
1213
|
+
if tool.get("type") == "function":
|
|
1214
|
+
func = tool.get("function", {})
|
|
1215
|
+
converted.append({
|
|
1216
|
+
"name": func.get("name"),
|
|
1217
|
+
"description": func.get("description", ""),
|
|
1218
|
+
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
|
|
1219
|
+
})
|
|
1220
|
+
return converted
|
|
1221
|
+
|
|
1222
|
+
def _map_stop_reason(self, reason: Optional[str]) -> FinishReason:
|
|
1223
|
+
"""Map Anthropic stop reason to FinishReason enum."""
|
|
1224
|
+
if reason is None:
|
|
1225
|
+
return FinishReason.STOP
|
|
1226
|
+
|
|
1227
|
+
mapping = {
|
|
1228
|
+
"end_turn": FinishReason.STOP,
|
|
1229
|
+
"stop_sequence": FinishReason.STOP,
|
|
1230
|
+
"max_tokens": FinishReason.LENGTH,
|
|
1231
|
+
"tool_use": FinishReason.TOOL_CALL,
|
|
1232
|
+
}
|
|
1233
|
+
return mapping.get(reason, FinishReason.STOP)
|
|
1234
|
+
|
|
1235
|
+
def count_tokens(self, text: str, model: Optional[str] = None) -> int:
|
|
1236
|
+
"""Estimate tokens for Anthropic models.
|
|
1237
|
+
|
|
1238
|
+
Uses anthropic's token counting if available, otherwise rough estimate.
|
|
1239
|
+
"""
|
|
1240
|
+
try:
|
|
1241
|
+
from anthropic import Anthropic
|
|
1242
|
+
client = Anthropic(api_key=self.api_key or "dummy")
|
|
1243
|
+
return client.count_tokens(text)
|
|
1244
|
+
except (ImportError, Exception):
|
|
1245
|
+
# Rough estimate: ~4 characters per token
|
|
1246
|
+
return len(text) // 4
|
|
1247
|
+
|
|
1248
|
+
|
|
1249
|
+
# =============================================================================
|
|
1250
|
+
# Local Provider Implementation (Ollama/llama.cpp)
|
|
1251
|
+
# =============================================================================
|
|
1252
|
+
|
|
1253
|
+
|
|
1254
|
+
class LocalProvider(LLMProvider):
|
|
1255
|
+
"""Local LLM provider using Ollama or llama.cpp compatible API.
|
|
1256
|
+
|
|
1257
|
+
Supports local models via Ollama's OpenAI-compatible API endpoint.
|
|
1258
|
+
|
|
1259
|
+
Attributes:
|
|
1260
|
+
name: Provider identifier ('local')
|
|
1261
|
+
default_model: Default model ('llama3.2')
|
|
1262
|
+
base_url: Local API endpoint (default: http://localhost:11434/v1)
|
|
1263
|
+
|
|
1264
|
+
Example:
|
|
1265
|
+
# Using Ollama (default)
|
|
1266
|
+
provider = LocalProvider()
|
|
1267
|
+
response = await provider.chat(ChatRequest(
|
|
1268
|
+
messages=[ChatMessage(role=ChatRole.USER, content="Hello!")]
|
|
1269
|
+
))
|
|
1270
|
+
|
|
1271
|
+
# Custom endpoint
|
|
1272
|
+
provider = LocalProvider(base_url="http://localhost:8080/v1")
|
|
1273
|
+
"""
|
|
1274
|
+
|
|
1275
|
+
name: str = "local"
|
|
1276
|
+
default_model: str = "llama3.2"
|
|
1277
|
+
default_embedding_model: str = "nomic-embed-text"
|
|
1278
|
+
|
|
1279
|
+
def __init__(
|
|
1280
|
+
self,
|
|
1281
|
+
base_url: Optional[str] = None,
|
|
1282
|
+
default_model: Optional[str] = None,
|
|
1283
|
+
default_embedding_model: Optional[str] = None,
|
|
1284
|
+
api_key: Optional[str] = None,
|
|
1285
|
+
):
|
|
1286
|
+
"""Initialize the local provider.
|
|
1287
|
+
|
|
1288
|
+
Args:
|
|
1289
|
+
base_url: API base URL (defaults to Ollama's OpenAI-compatible endpoint)
|
|
1290
|
+
default_model: Override default chat model
|
|
1291
|
+
default_embedding_model: Override default embedding model
|
|
1292
|
+
api_key: Optional API key (some local servers may require it)
|
|
1293
|
+
"""
|
|
1294
|
+
self.base_url = base_url or "http://localhost:11434/v1"
|
|
1295
|
+
self.api_key = api_key or "ollama" # Ollama accepts any key
|
|
1296
|
+
|
|
1297
|
+
if default_model:
|
|
1298
|
+
self.default_model = default_model
|
|
1299
|
+
if default_embedding_model:
|
|
1300
|
+
self.default_embedding_model = default_embedding_model
|
|
1301
|
+
|
|
1302
|
+
self._client: Optional[Any] = None
|
|
1303
|
+
|
|
1304
|
+
def _get_client(self) -> Any:
|
|
1305
|
+
"""Get or create the OpenAI-compatible client for local server."""
|
|
1306
|
+
if self._client is None:
|
|
1307
|
+
try:
|
|
1308
|
+
from openai import AsyncOpenAI
|
|
1309
|
+
except ImportError:
|
|
1310
|
+
raise LLMError(
|
|
1311
|
+
"openai package not installed. Install with: pip install openai",
|
|
1312
|
+
provider=self.name,
|
|
1313
|
+
)
|
|
1314
|
+
|
|
1315
|
+
self._client = AsyncOpenAI(
|
|
1316
|
+
api_key=self.api_key,
|
|
1317
|
+
base_url=self.base_url,
|
|
1318
|
+
)
|
|
1319
|
+
|
|
1320
|
+
return self._client
|
|
1321
|
+
|
|
1322
|
+
def _handle_api_error(self, error: Exception) -> None:
|
|
1323
|
+
"""Convert API errors to LLMError types."""
|
|
1324
|
+
error_str = str(error)
|
|
1325
|
+
|
|
1326
|
+
if "connection" in error_str.lower() or "refused" in error_str.lower():
|
|
1327
|
+
raise LLMError(
|
|
1328
|
+
f"Cannot connect to local server at {self.base_url}. "
|
|
1329
|
+
"Ensure Ollama is running: ollama serve",
|
|
1330
|
+
provider=self.name,
|
|
1331
|
+
retryable=True,
|
|
1332
|
+
)
|
|
1333
|
+
|
|
1334
|
+
if "model" in error_str.lower() and "not found" in error_str.lower():
|
|
1335
|
+
raise ModelNotFoundError(
|
|
1336
|
+
f"Model not found. Pull it first: ollama pull {self.default_model}",
|
|
1337
|
+
provider=self.name,
|
|
1338
|
+
)
|
|
1339
|
+
|
|
1340
|
+
raise LLMError(error_str, provider=self.name, retryable=True)
|
|
1341
|
+
|
|
1342
|
+
async def complete(self, request: CompletionRequest) -> CompletionResponse:
|
|
1343
|
+
"""Generate a text completion using local model."""
|
|
1344
|
+
self.validate_request(request)
|
|
1345
|
+
client = self._get_client()
|
|
1346
|
+
model = self.get_model(request.model)
|
|
1347
|
+
|
|
1348
|
+
try:
|
|
1349
|
+
response = await client.chat.completions.create(
|
|
1350
|
+
model=model,
|
|
1351
|
+
messages=[{"role": "user", "content": request.prompt}],
|
|
1352
|
+
max_tokens=request.max_tokens,
|
|
1353
|
+
temperature=request.temperature,
|
|
1354
|
+
top_p=request.top_p,
|
|
1355
|
+
stop=request.stop,
|
|
1356
|
+
)
|
|
1357
|
+
|
|
1358
|
+
choice = response.choices[0]
|
|
1359
|
+
usage = TokenUsage(
|
|
1360
|
+
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
|
1361
|
+
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
|
1362
|
+
total_tokens=response.usage.total_tokens if response.usage else 0,
|
|
1363
|
+
)
|
|
1364
|
+
|
|
1365
|
+
return CompletionResponse(
|
|
1366
|
+
text=choice.message.content or "",
|
|
1367
|
+
finish_reason=self._map_finish_reason(choice.finish_reason),
|
|
1368
|
+
usage=usage,
|
|
1369
|
+
model=response.model,
|
|
1370
|
+
)
|
|
1371
|
+
|
|
1372
|
+
except Exception as e:
|
|
1373
|
+
self._handle_api_error(e)
|
|
1374
|
+
raise
|
|
1375
|
+
|
|
1376
|
+
async def chat(self, request: ChatRequest) -> ChatResponse:
|
|
1377
|
+
"""Generate a chat completion using local model."""
|
|
1378
|
+
self.validate_request(request)
|
|
1379
|
+
client = self._get_client()
|
|
1380
|
+
model = self.get_model(request.model)
|
|
1381
|
+
|
|
1382
|
+
try:
|
|
1383
|
+
messages = [msg.to_dict() for msg in request.messages]
|
|
1384
|
+
|
|
1385
|
+
kwargs: Dict[str, Any] = {
|
|
1386
|
+
"model": model,
|
|
1387
|
+
"messages": messages,
|
|
1388
|
+
"max_tokens": request.max_tokens,
|
|
1389
|
+
"temperature": request.temperature,
|
|
1390
|
+
"top_p": request.top_p,
|
|
1391
|
+
}
|
|
1392
|
+
|
|
1393
|
+
if request.stop:
|
|
1394
|
+
kwargs["stop"] = request.stop
|
|
1395
|
+
if request.tools:
|
|
1396
|
+
kwargs["tools"] = request.tools
|
|
1397
|
+
if request.tool_choice:
|
|
1398
|
+
kwargs["tool_choice"] = request.tool_choice
|
|
1399
|
+
|
|
1400
|
+
response = await client.chat.completions.create(**kwargs)
|
|
1401
|
+
|
|
1402
|
+
choice = response.choices[0]
|
|
1403
|
+
message = choice.message
|
|
1404
|
+
|
|
1405
|
+
tool_calls = None
|
|
1406
|
+
if message.tool_calls:
|
|
1407
|
+
tool_calls = [
|
|
1408
|
+
ToolCall(
|
|
1409
|
+
id=tc.id,
|
|
1410
|
+
name=tc.function.name,
|
|
1411
|
+
arguments=tc.function.arguments,
|
|
1412
|
+
)
|
|
1413
|
+
for tc in message.tool_calls
|
|
1414
|
+
]
|
|
1415
|
+
|
|
1416
|
+
usage = TokenUsage(
|
|
1417
|
+
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
|
1418
|
+
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
|
1419
|
+
total_tokens=response.usage.total_tokens if response.usage else 0,
|
|
1420
|
+
)
|
|
1421
|
+
|
|
1422
|
+
return ChatResponse(
|
|
1423
|
+
message=ChatMessage(
|
|
1424
|
+
role=ChatRole.ASSISTANT,
|
|
1425
|
+
content=message.content,
|
|
1426
|
+
tool_calls=tool_calls,
|
|
1427
|
+
),
|
|
1428
|
+
finish_reason=self._map_finish_reason(choice.finish_reason),
|
|
1429
|
+
usage=usage,
|
|
1430
|
+
model=response.model,
|
|
1431
|
+
)
|
|
1432
|
+
|
|
1433
|
+
except Exception as e:
|
|
1434
|
+
self._handle_api_error(e)
|
|
1435
|
+
raise
|
|
1436
|
+
|
|
1437
|
+
async def embed(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
|
1438
|
+
"""Generate embeddings using local model."""
|
|
1439
|
+
self.validate_request(request)
|
|
1440
|
+
client = self._get_client()
|
|
1441
|
+
model = request.model or self.default_embedding_model
|
|
1442
|
+
|
|
1443
|
+
try:
|
|
1444
|
+
response = await client.embeddings.create(
|
|
1445
|
+
model=model,
|
|
1446
|
+
input=request.texts,
|
|
1447
|
+
)
|
|
1448
|
+
|
|
1449
|
+
embeddings = [item.embedding for item in response.data]
|
|
1450
|
+
|
|
1451
|
+
usage = TokenUsage(
|
|
1452
|
+
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
|
1453
|
+
total_tokens=response.usage.total_tokens if response.usage else 0,
|
|
1454
|
+
)
|
|
1455
|
+
|
|
1456
|
+
return EmbeddingResponse(
|
|
1457
|
+
embeddings=embeddings,
|
|
1458
|
+
usage=usage,
|
|
1459
|
+
model=response.model,
|
|
1460
|
+
dimensions=len(embeddings[0]) if embeddings else None,
|
|
1461
|
+
)
|
|
1462
|
+
|
|
1463
|
+
except Exception as e:
|
|
1464
|
+
self._handle_api_error(e)
|
|
1465
|
+
raise
|
|
1466
|
+
|
|
1467
|
+
async def stream_chat(self, request: ChatRequest) -> AsyncIterator[ChatResponse]:
|
|
1468
|
+
"""Stream chat completion tokens from local model."""
|
|
1469
|
+
self.validate_request(request)
|
|
1470
|
+
client = self._get_client()
|
|
1471
|
+
model = self.get_model(request.model)
|
|
1472
|
+
|
|
1473
|
+
try:
|
|
1474
|
+
messages = [msg.to_dict() for msg in request.messages]
|
|
1475
|
+
|
|
1476
|
+
kwargs: Dict[str, Any] = {
|
|
1477
|
+
"model": model,
|
|
1478
|
+
"messages": messages,
|
|
1479
|
+
"max_tokens": request.max_tokens,
|
|
1480
|
+
"temperature": request.temperature,
|
|
1481
|
+
"top_p": request.top_p,
|
|
1482
|
+
"stream": True,
|
|
1483
|
+
}
|
|
1484
|
+
|
|
1485
|
+
if request.stop:
|
|
1486
|
+
kwargs["stop"] = request.stop
|
|
1487
|
+
|
|
1488
|
+
stream = await client.chat.completions.create(**kwargs)
|
|
1489
|
+
|
|
1490
|
+
async for chunk in stream:
|
|
1491
|
+
if not chunk.choices:
|
|
1492
|
+
continue
|
|
1493
|
+
|
|
1494
|
+
choice = chunk.choices[0]
|
|
1495
|
+
delta = choice.delta
|
|
1496
|
+
|
|
1497
|
+
if delta.content:
|
|
1498
|
+
yield ChatResponse(
|
|
1499
|
+
message=ChatMessage(
|
|
1500
|
+
role=ChatRole.ASSISTANT,
|
|
1501
|
+
content=delta.content,
|
|
1502
|
+
),
|
|
1503
|
+
finish_reason=self._map_finish_reason(choice.finish_reason) if choice.finish_reason else FinishReason.STOP,
|
|
1504
|
+
model=chunk.model,
|
|
1505
|
+
)
|
|
1506
|
+
|
|
1507
|
+
except Exception as e:
|
|
1508
|
+
self._handle_api_error(e)
|
|
1509
|
+
raise
|
|
1510
|
+
|
|
1511
|
+
def _map_finish_reason(self, reason: Optional[str]) -> FinishReason:
|
|
1512
|
+
"""Map finish reason to FinishReason enum."""
|
|
1513
|
+
if reason is None:
|
|
1514
|
+
return FinishReason.STOP
|
|
1515
|
+
|
|
1516
|
+
mapping = {
|
|
1517
|
+
"stop": FinishReason.STOP,
|
|
1518
|
+
"length": FinishReason.LENGTH,
|
|
1519
|
+
"tool_calls": FinishReason.TOOL_CALL,
|
|
1520
|
+
}
|
|
1521
|
+
return mapping.get(reason, FinishReason.STOP)
|
|
1522
|
+
|
|
1523
|
+
async def health_check(self) -> bool:
|
|
1524
|
+
"""Check if local server is accessible."""
|
|
1525
|
+
try:
|
|
1526
|
+
import httpx
|
|
1527
|
+
async with httpx.AsyncClient() as client:
|
|
1528
|
+
# Check Ollama-style health endpoint
|
|
1529
|
+
response = await client.get(
|
|
1530
|
+
self.base_url.replace("/v1", "") + "/api/tags",
|
|
1531
|
+
timeout=5.0,
|
|
1532
|
+
)
|
|
1533
|
+
return response.status_code == 200
|
|
1534
|
+
except Exception:
|
|
1535
|
+
return False
|
|
1536
|
+
|
|
1537
|
+
|
|
1538
|
+
# =============================================================================
|
|
1539
|
+
# Exports
|
|
1540
|
+
# =============================================================================
|
|
1541
|
+
|
|
1542
|
+
__all__ = [
|
|
1543
|
+
# Enums
|
|
1544
|
+
"ChatRole",
|
|
1545
|
+
"FinishReason",
|
|
1546
|
+
# Data Classes
|
|
1547
|
+
"ToolCall",
|
|
1548
|
+
"ChatMessage",
|
|
1549
|
+
"CompletionRequest",
|
|
1550
|
+
"ChatRequest",
|
|
1551
|
+
"EmbeddingRequest",
|
|
1552
|
+
"TokenUsage",
|
|
1553
|
+
"CompletionResponse",
|
|
1554
|
+
"ChatResponse",
|
|
1555
|
+
"EmbeddingResponse",
|
|
1556
|
+
# Exceptions
|
|
1557
|
+
"LLMError",
|
|
1558
|
+
"RateLimitError",
|
|
1559
|
+
"AuthenticationError",
|
|
1560
|
+
"InvalidRequestError",
|
|
1561
|
+
"ModelNotFoundError",
|
|
1562
|
+
"ContentFilterError",
|
|
1563
|
+
# Provider ABC
|
|
1564
|
+
"LLMProvider",
|
|
1565
|
+
# Providers
|
|
1566
|
+
"OpenAIProvider",
|
|
1567
|
+
"AnthropicProvider",
|
|
1568
|
+
"LocalProvider",
|
|
1569
|
+
]
|