llmbridgekit 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llmbridgekit/__init__.py +77 -0
- llmbridgekit/client.py +609 -0
- llmbridgekit/config.py +209 -0
- llmbridgekit/exceptions.py +151 -0
- llmbridgekit/fallback.py +161 -0
- llmbridgekit/json_utils.py +210 -0
- llmbridgekit/models.py +265 -0
- llmbridgekit/pricing.py +167 -0
- llmbridgekit/providers/__init__.py +38 -0
- llmbridgekit/providers/anthropic_provider.py +219 -0
- llmbridgekit/providers/azure_openai_provider.py +211 -0
- llmbridgekit/providers/base.py +260 -0
- llmbridgekit/providers/cohere_provider.py +212 -0
- llmbridgekit/providers/custom_http_provider.py +235 -0
- llmbridgekit/providers/gemini_provider.py +252 -0
- llmbridgekit/providers/groq_provider.py +204 -0
- llmbridgekit/providers/huggingface_provider.py +183 -0
- llmbridgekit/providers/lmstudio_provider.py +181 -0
- llmbridgekit/providers/mistral_provider.py +193 -0
- llmbridgekit/providers/ollama_provider.py +193 -0
- llmbridgekit/providers/openai_provider.py +257 -0
- llmbridgekit/providers/openrouter_provider.py +230 -0
- llmbridgekit/providers/together_provider.py +214 -0
- llmbridgekit/redaction.py +170 -0
- llmbridgekit/registry.py +271 -0
- llmbridgekit/retry.py +206 -0
- llmbridgekit/usage.py +174 -0
- llmbridgekit/validation.py +234 -0
- llmbridgekit-1.0.0.dist-info/METADATA +529 -0
- llmbridgekit-1.0.0.dist-info/RECORD +32 -0
- llmbridgekit-1.0.0.dist-info/WHEEL +5 -0
- llmbridgekit-1.0.0.dist-info/top_level.txt +1 -0
llmbridgekit/__init__.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLMBridgeKit - Universal Python LLM Provider Abstraction Library
|
|
3
|
+
|
|
4
|
+
A provider-agnostic library for calling multiple LLM providers through one consistent interface.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from importlib.metadata import version
|
|
9
|
+
__version__ = version("llmbridgekit")
|
|
10
|
+
except Exception:
|
|
11
|
+
__version__ = "0.0.0"
|
|
12
|
+
|
|
13
|
+
# Core imports
|
|
14
|
+
from llmbridgekit.client import LLMClient
|
|
15
|
+
from llmbridgekit.config import LLMConfig, RedactionConfig
|
|
16
|
+
from llmbridgekit.models import (
|
|
17
|
+
LLMRequest,
|
|
18
|
+
LLMResponse,
|
|
19
|
+
MessageRole,
|
|
20
|
+
ResponseMode,
|
|
21
|
+
ToolCall,
|
|
22
|
+
Usage,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# Exceptions
|
|
26
|
+
from llmbridgekit.exceptions import (
|
|
27
|
+
LLMGatewayError,
|
|
28
|
+
ProviderError,
|
|
29
|
+
ProviderNotFoundError,
|
|
30
|
+
InvalidConfigError,
|
|
31
|
+
InvalidRequestError,
|
|
32
|
+
RateLimitError,
|
|
33
|
+
TimeoutError,
|
|
34
|
+
AuthenticationError,
|
|
35
|
+
InvalidResponseError,
|
|
36
|
+
SchemaValidationError,
|
|
37
|
+
ToolCallError,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Registry
|
|
41
|
+
from llmbridgekit.registry import provider_registry
|
|
42
|
+
|
|
43
|
+
# Provider base
|
|
44
|
+
from llmbridgekit.providers.base import BaseProvider, ProviderFeatures
|
|
45
|
+
|
|
46
|
+
__all__ = [
|
|
47
|
+
# Version
|
|
48
|
+
"__version__",
|
|
49
|
+
# Core
|
|
50
|
+
"LLMClient",
|
|
51
|
+
"LLMConfig",
|
|
52
|
+
"RedactionConfig",
|
|
53
|
+
# Models
|
|
54
|
+
"LLMRequest",
|
|
55
|
+
"LLMResponse",
|
|
56
|
+
"MessageRole",
|
|
57
|
+
"ResponseMode",
|
|
58
|
+
"ToolCall",
|
|
59
|
+
"Usage",
|
|
60
|
+
# Exceptions
|
|
61
|
+
"LLMGatewayError",
|
|
62
|
+
"ProviderError",
|
|
63
|
+
"ProviderNotFoundError",
|
|
64
|
+
"InvalidConfigError",
|
|
65
|
+
"InvalidRequestError",
|
|
66
|
+
"RateLimitError",
|
|
67
|
+
"TimeoutError",
|
|
68
|
+
"AuthenticationError",
|
|
69
|
+
"InvalidResponseError",
|
|
70
|
+
"SchemaValidationError",
|
|
71
|
+
"ToolCallError",
|
|
72
|
+
# Registry
|
|
73
|
+
"provider_registry",
|
|
74
|
+
# Provider base
|
|
75
|
+
"BaseProvider",
|
|
76
|
+
"ProviderFeatures",
|
|
77
|
+
]
|
llmbridgekit/client.py
ADDED
|
@@ -0,0 +1,609 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LLMBridgeKit Client
|
|
3
|
+
|
|
4
|
+
Main client for interacting with LLM providers.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import time
|
|
9
|
+
from typing import Any, Generator
|
|
10
|
+
|
|
11
|
+
from llmbridgekit.config import (
|
|
12
|
+
GatewayConfig,
|
|
13
|
+
LLMConfig,
|
|
14
|
+
RedactionConfig,
|
|
15
|
+
RetryConfig,
|
|
16
|
+
ValidationConfig,
|
|
17
|
+
)
|
|
18
|
+
from llmbridgekit.exceptions import (
|
|
19
|
+
InvalidRequestError,
|
|
20
|
+
ProviderNotFoundError,
|
|
21
|
+
)
|
|
22
|
+
from llmbridgekit.fallback import create_fallback_chain
|
|
23
|
+
from llmbridgekit.models import LLMResponse, Message, StreamChunk, Usage
|
|
24
|
+
from llmbridgekit.pricing import estimate_cost
|
|
25
|
+
from llmbridgekit.redaction import redact_messages, redact_text
|
|
26
|
+
from llmbridgekit.registry import provider_registry
|
|
27
|
+
from llmbridgekit.retry import RetryConfig as RC, retry_with_backoff
|
|
28
|
+
from llmbridgekit.usage import record_usage
|
|
29
|
+
from llmbridgekit.validation import validate_config, validate_request
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class LLMClient:
|
|
35
|
+
"""
|
|
36
|
+
Main LLM Gateway client.
|
|
37
|
+
|
|
38
|
+
Provides unified interface for calling multiple LLM providers.
|
|
39
|
+
|
|
40
|
+
Example:
|
|
41
|
+
>>> client = LLMClient(provider="gemini")
|
|
42
|
+
>>> response = client.generate("Explain quantum computing")
|
|
43
|
+
>>> print(response.text)
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
provider: str | None = None,
|
|
49
|
+
model: str | None = None,
|
|
50
|
+
config: LLMConfig | None = None,
|
|
51
|
+
gateway_config: GatewayConfig | None = None,
|
|
52
|
+
**kwargs: Any,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Initialize LLM client.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
provider: Default provider name
|
|
59
|
+
model: Default model name
|
|
60
|
+
config: Default LLM configuration
|
|
61
|
+
gateway_config: Gateway-level configuration
|
|
62
|
+
**kwargs: Additional config parameters
|
|
63
|
+
"""
|
|
64
|
+
# Set up configurations
|
|
65
|
+
self.gateway_config = gateway_config or GatewayConfig()
|
|
66
|
+
|
|
67
|
+
if config:
|
|
68
|
+
self.default_config = config
|
|
69
|
+
else:
|
|
70
|
+
self.default_config = LLMConfig(
|
|
71
|
+
provider=provider or self.gateway_config.default_provider,
|
|
72
|
+
model=model or self.gateway_config.default_model,
|
|
73
|
+
**kwargs,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
# Initialize components
|
|
77
|
+
self.redaction_config = self.gateway_config.redaction
|
|
78
|
+
self.validation_config = self.gateway_config.validation
|
|
79
|
+
self.retry_config = self.gateway_config.retry
|
|
80
|
+
|
|
81
|
+
def generate(
|
|
82
|
+
self,
|
|
83
|
+
prompt: str,
|
|
84
|
+
provider: str | None = None,
|
|
85
|
+
model: str | None = None,
|
|
86
|
+
config: LLMConfig | None = None,
|
|
87
|
+
fallback_chain: list[dict[str, Any]] | str | None = None,
|
|
88
|
+
**kwargs: Any,
|
|
89
|
+
) -> LLMResponse:
|
|
90
|
+
"""
|
|
91
|
+
Generate text from prompt.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
prompt: Input prompt
|
|
95
|
+
provider: Provider name (overrides default)
|
|
96
|
+
model: Model name (overrides default)
|
|
97
|
+
config: Generation configuration (overrides default)
|
|
98
|
+
fallback_chain: Fallback providers
|
|
99
|
+
**kwargs: Additional config parameters
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
LLM response
|
|
103
|
+
|
|
104
|
+
Example:
|
|
105
|
+
>>> response = client.generate(
|
|
106
|
+
... "Explain machine learning",
|
|
107
|
+
... provider="gemini",
|
|
108
|
+
... temperature=0.7
|
|
109
|
+
... )
|
|
110
|
+
"""
|
|
111
|
+
# Merge configurations
|
|
112
|
+
merged_config = self._merge_config(provider, model, config, **kwargs)
|
|
113
|
+
|
|
114
|
+
# Validate
|
|
115
|
+
validate_request(prompt, None, merged_config, self.validation_config)
|
|
116
|
+
validate_config(merged_config)
|
|
117
|
+
|
|
118
|
+
# Redact PII
|
|
119
|
+
if self.redaction_config.enabled:
|
|
120
|
+
prompt = redact_text(prompt, self.redaction_config)
|
|
121
|
+
|
|
122
|
+
# Try fallback chain if configured
|
|
123
|
+
if fallback_chain:
|
|
124
|
+
chain = create_fallback_chain(
|
|
125
|
+
fallback_chain, self.gateway_config.fallback_chains
|
|
126
|
+
)
|
|
127
|
+
if chain:
|
|
128
|
+
return chain.execute(
|
|
129
|
+
self._execute_generate,
|
|
130
|
+
prompt=prompt,
|
|
131
|
+
base_config=merged_config,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Execute single provider
|
|
135
|
+
return self._execute_generate(
|
|
136
|
+
merged_config.provider,
|
|
137
|
+
merged_config.model,
|
|
138
|
+
merged_config,
|
|
139
|
+
prompt=prompt,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def chat(
|
|
143
|
+
self,
|
|
144
|
+
messages: list[dict[str, Any] | Message],
|
|
145
|
+
provider: str | None = None,
|
|
146
|
+
model: str | None = None,
|
|
147
|
+
config: LLMConfig | None = None,
|
|
148
|
+
fallback_chain: list[dict[str, Any]] | str | None = None,
|
|
149
|
+
**kwargs: Any,
|
|
150
|
+
) -> LLMResponse:
|
|
151
|
+
"""
|
|
152
|
+
Chat with messages.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
messages: Chat messages
|
|
156
|
+
provider: Provider name
|
|
157
|
+
model: Model name
|
|
158
|
+
config: Generation configuration
|
|
159
|
+
fallback_chain: Fallback providers
|
|
160
|
+
**kwargs: Additional config parameters
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
LLM response
|
|
164
|
+
|
|
165
|
+
Example:
|
|
166
|
+
>>> response = client.chat([
|
|
167
|
+
... {"role": "system", "content": "You are a helpful assistant."},
|
|
168
|
+
... {"role": "user", "content": "What is Python?"}
|
|
169
|
+
... ])
|
|
170
|
+
"""
|
|
171
|
+
# Merge configurations
|
|
172
|
+
merged_config = self._merge_config(provider, model, config, **kwargs)
|
|
173
|
+
|
|
174
|
+
# Convert Message objects to dicts
|
|
175
|
+
messages_dicts = [
|
|
176
|
+
msg if isinstance(msg, dict) else {"role": msg.role, "content": msg.content}
|
|
177
|
+
for msg in messages
|
|
178
|
+
]
|
|
179
|
+
|
|
180
|
+
# Validate
|
|
181
|
+
validate_request(None, messages_dicts, merged_config, self.validation_config)
|
|
182
|
+
validate_config(merged_config)
|
|
183
|
+
|
|
184
|
+
# Redact PII
|
|
185
|
+
if self.redaction_config.enabled:
|
|
186
|
+
messages_dicts = redact_messages(messages_dicts, self.redaction_config)
|
|
187
|
+
|
|
188
|
+
# Try fallback chain if configured
|
|
189
|
+
if fallback_chain:
|
|
190
|
+
chain = create_fallback_chain(
|
|
191
|
+
fallback_chain, self.gateway_config.fallback_chains
|
|
192
|
+
)
|
|
193
|
+
if chain:
|
|
194
|
+
return chain.execute(
|
|
195
|
+
self._execute_chat,
|
|
196
|
+
messages=messages_dicts,
|
|
197
|
+
base_config=merged_config,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
# Execute single provider
|
|
201
|
+
return self._execute_chat(
|
|
202
|
+
merged_config.provider,
|
|
203
|
+
merged_config.model,
|
|
204
|
+
merged_config,
|
|
205
|
+
messages=messages_dicts,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def generate_json(
|
|
209
|
+
self,
|
|
210
|
+
prompt: str,
|
|
211
|
+
schema: dict[str, Any],
|
|
212
|
+
provider: str | None = None,
|
|
213
|
+
model: str | None = None,
|
|
214
|
+
config: LLMConfig | None = None,
|
|
215
|
+
fallback_chain: list[dict[str, Any]] | str | None = None,
|
|
216
|
+
**kwargs: Any,
|
|
217
|
+
) -> LLMResponse:
|
|
218
|
+
"""
|
|
219
|
+
Generate structured JSON output.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
prompt: Input prompt
|
|
223
|
+
schema: JSON Schema
|
|
224
|
+
provider: Provider name
|
|
225
|
+
model: Model name
|
|
226
|
+
config: Generation configuration
|
|
227
|
+
fallback_chain: Fallback providers
|
|
228
|
+
**kwargs: Additional config parameters
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
LLM response with json_data populated
|
|
232
|
+
|
|
233
|
+
Example:
|
|
234
|
+
>>> schema = {
|
|
235
|
+
... "type": "object",
|
|
236
|
+
... "properties": {
|
|
237
|
+
... "category": {"type": "string"},
|
|
238
|
+
... "confidence": {"type": "number"}
|
|
239
|
+
... }
|
|
240
|
+
... }
|
|
241
|
+
>>> response = client.generate_json(
|
|
242
|
+
... "Categorize: Python tutorial",
|
|
243
|
+
... schema=schema
|
|
244
|
+
... )
|
|
245
|
+
"""
|
|
246
|
+
# Merge configurations
|
|
247
|
+
merged_config = self._merge_config(provider, model, config, **kwargs)
|
|
248
|
+
merged_config.json_schema = schema
|
|
249
|
+
merged_config.response_mode = "json_schema"
|
|
250
|
+
|
|
251
|
+
# Validate
|
|
252
|
+
validate_request(prompt, None, merged_config, self.validation_config)
|
|
253
|
+
validate_config(merged_config)
|
|
254
|
+
|
|
255
|
+
# Redact PII
|
|
256
|
+
if self.redaction_config.enabled:
|
|
257
|
+
prompt = redact_text(prompt, self.redaction_config)
|
|
258
|
+
|
|
259
|
+
# Try fallback chain if configured
|
|
260
|
+
if fallback_chain:
|
|
261
|
+
chain = create_fallback_chain(
|
|
262
|
+
fallback_chain, self.gateway_config.fallback_chains
|
|
263
|
+
)
|
|
264
|
+
if chain:
|
|
265
|
+
return chain.execute(
|
|
266
|
+
self._execute_generate_json,
|
|
267
|
+
prompt=prompt,
|
|
268
|
+
base_config=merged_config,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Execute single provider
|
|
272
|
+
return self._execute_generate_json(
|
|
273
|
+
merged_config.provider,
|
|
274
|
+
merged_config.model,
|
|
275
|
+
merged_config,
|
|
276
|
+
prompt=prompt,
|
|
277
|
+
schema=schema,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
def stream(
|
|
281
|
+
self,
|
|
282
|
+
prompt: str,
|
|
283
|
+
provider: str | None = None,
|
|
284
|
+
model: str | None = None,
|
|
285
|
+
config: LLMConfig | None = None,
|
|
286
|
+
**kwargs: Any,
|
|
287
|
+
) -> Generator[StreamChunk, None, None]:
|
|
288
|
+
"""
|
|
289
|
+
Stream generation.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
prompt: Input prompt
|
|
293
|
+
provider: Provider name
|
|
294
|
+
model: Model name
|
|
295
|
+
config: Generation configuration
|
|
296
|
+
**kwargs: Additional config parameters
|
|
297
|
+
|
|
298
|
+
Yields:
|
|
299
|
+
Stream chunks
|
|
300
|
+
|
|
301
|
+
Example:
|
|
302
|
+
>>> for chunk in client.stream("Write a story"):
|
|
303
|
+
... print(chunk.text, end="", flush=True)
|
|
304
|
+
"""
|
|
305
|
+
# Merge configurations
|
|
306
|
+
merged_config = self._merge_config(provider, model, config, **kwargs)
|
|
307
|
+
merged_config.stream = True
|
|
308
|
+
|
|
309
|
+
# Validate
|
|
310
|
+
validate_request(prompt, None, merged_config, self.validation_config)
|
|
311
|
+
validate_config(merged_config)
|
|
312
|
+
|
|
313
|
+
# Redact PII
|
|
314
|
+
if self.redaction_config.enabled:
|
|
315
|
+
prompt = redact_text(prompt, self.redaction_config)
|
|
316
|
+
|
|
317
|
+
# Get provider
|
|
318
|
+
provider_instance = provider_registry.get(merged_config.provider, merged_config)
|
|
319
|
+
|
|
320
|
+
# Stream
|
|
321
|
+
yield from provider_instance.stream(prompt, merged_config)
|
|
322
|
+
|
|
323
|
+
def generate_with_tools(
|
|
324
|
+
self,
|
|
325
|
+
prompt: str,
|
|
326
|
+
tools: list[dict[str, Any]],
|
|
327
|
+
provider: str | None = None,
|
|
328
|
+
model: str | None = None,
|
|
329
|
+
config: LLMConfig | None = None,
|
|
330
|
+
**kwargs: Any,
|
|
331
|
+
) -> LLMResponse:
|
|
332
|
+
"""
|
|
333
|
+
Generate with tool/function calling.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
prompt: Input prompt
|
|
337
|
+
tools: Tool definitions
|
|
338
|
+
provider: Provider name
|
|
339
|
+
model: Model name
|
|
340
|
+
config: Generation configuration
|
|
341
|
+
**kwargs: Additional config parameters
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
LLM response with tool_calls populated
|
|
345
|
+
|
|
346
|
+
Example:
|
|
347
|
+
>>> tools = [{
|
|
348
|
+
... "name": "get_weather",
|
|
349
|
+
... "description": "Get weather",
|
|
350
|
+
... "parameters": {
|
|
351
|
+
... "type": "object",
|
|
352
|
+
... "properties": {"location": {"type": "string"}}
|
|
353
|
+
... }
|
|
354
|
+
... }]
|
|
355
|
+
>>> response = client.generate_with_tools(
|
|
356
|
+
... "What's the weather in NYC?",
|
|
357
|
+
... tools=tools
|
|
358
|
+
... )
|
|
359
|
+
"""
|
|
360
|
+
# Merge configurations
|
|
361
|
+
merged_config = self._merge_config(provider, model, config, **kwargs)
|
|
362
|
+
merged_config.tools = tools
|
|
363
|
+
|
|
364
|
+
# Validate
|
|
365
|
+
validate_request(prompt, None, merged_config, self.validation_config)
|
|
366
|
+
validate_config(merged_config)
|
|
367
|
+
|
|
368
|
+
# Redact PII
|
|
369
|
+
if self.redaction_config.enabled:
|
|
370
|
+
prompt = redact_text(prompt, self.redaction_config)
|
|
371
|
+
|
|
372
|
+
# Get provider
|
|
373
|
+
provider_instance = provider_registry.get(merged_config.provider, merged_config)
|
|
374
|
+
|
|
375
|
+
# Execute with retry
|
|
376
|
+
start_time = time.time()
|
|
377
|
+
|
|
378
|
+
def _call():
|
|
379
|
+
return provider_instance.generate_with_tools(prompt, tools, merged_config)
|
|
380
|
+
|
|
381
|
+
response = retry_with_backoff(_call, self.retry_config)
|
|
382
|
+
|
|
383
|
+
# Post-process
|
|
384
|
+
response.latency_ms = int((time.time() - start_time) * 1000)
|
|
385
|
+
self._post_process_response(response, merged_config)
|
|
386
|
+
|
|
387
|
+
return response
|
|
388
|
+
|
|
389
|
+
def run_task(
|
|
390
|
+
self,
|
|
391
|
+
task: str,
|
|
392
|
+
prompt: str | None = None,
|
|
393
|
+
messages: list[dict[str, Any]] | None = None,
|
|
394
|
+
**kwargs: Any,
|
|
395
|
+
) -> LLMResponse:
|
|
396
|
+
"""
|
|
397
|
+
Run predefined task.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
task: Task name
|
|
401
|
+
prompt: Optional prompt
|
|
402
|
+
messages: Optional messages
|
|
403
|
+
**kwargs: Additional config parameters
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
LLM response
|
|
407
|
+
|
|
408
|
+
Example:
|
|
409
|
+
>>> response = client.run_task(
|
|
410
|
+
... "text_to_sql",
|
|
411
|
+
... prompt="Show revenue by category"
|
|
412
|
+
... )
|
|
413
|
+
"""
|
|
414
|
+
if task not in self.gateway_config.tasks:
|
|
415
|
+
raise InvalidRequestError(f"Task '{task}' not found in configuration")
|
|
416
|
+
|
|
417
|
+
task_config = self.gateway_config.tasks[task]
|
|
418
|
+
|
|
419
|
+
# Build config from task
|
|
420
|
+
config = LLMConfig(
|
|
421
|
+
provider=task_config.get("provider", self.default_config.provider),
|
|
422
|
+
model=task_config.get("model"),
|
|
423
|
+
**task_config.get("config", {}),
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# Execute
|
|
427
|
+
if prompt:
|
|
428
|
+
return self.generate(prompt, config=config, **kwargs)
|
|
429
|
+
elif messages:
|
|
430
|
+
return self.chat(messages, config=config, **kwargs)
|
|
431
|
+
else:
|
|
432
|
+
raise InvalidRequestError("Either prompt or messages required")
|
|
433
|
+
|
|
434
|
+
# Async methods
|
|
435
|
+
async def agenerate(self, *args, **kwargs) -> LLMResponse:
|
|
436
|
+
"""Async generate (future implementation)."""
|
|
437
|
+
return self.generate(*args, **kwargs)
|
|
438
|
+
|
|
439
|
+
async def achat(self, *args, **kwargs) -> LLMResponse:
|
|
440
|
+
"""Async chat (future implementation)."""
|
|
441
|
+
return self.chat(*args, **kwargs)
|
|
442
|
+
|
|
443
|
+
# Private methods
|
|
444
|
+
def _merge_config(
|
|
445
|
+
self,
|
|
446
|
+
provider: str | None,
|
|
447
|
+
model: str | None,
|
|
448
|
+
config: LLMConfig | None,
|
|
449
|
+
**kwargs: Any,
|
|
450
|
+
) -> LLMConfig:
|
|
451
|
+
"""Merge configurations with precedence."""
|
|
452
|
+
merged = LLMConfig()
|
|
453
|
+
|
|
454
|
+
# Start with default
|
|
455
|
+
merged = self.default_config
|
|
456
|
+
|
|
457
|
+
# Apply config parameter
|
|
458
|
+
if config:
|
|
459
|
+
merged = merged.merge(config)
|
|
460
|
+
|
|
461
|
+
# Apply provider/model overrides
|
|
462
|
+
if provider:
|
|
463
|
+
merged.provider = provider
|
|
464
|
+
if model:
|
|
465
|
+
merged.model = model
|
|
466
|
+
|
|
467
|
+
# Apply kwargs
|
|
468
|
+
for key, value in kwargs.items():
|
|
469
|
+
if hasattr(merged, key):
|
|
470
|
+
setattr(merged, key, value)
|
|
471
|
+
|
|
472
|
+
return merged
|
|
473
|
+
|
|
474
|
+
def _execute_generate(
|
|
475
|
+
self,
|
|
476
|
+
provider: str,
|
|
477
|
+
model: str | None,
|
|
478
|
+
config: LLMConfig,
|
|
479
|
+
prompt: str,
|
|
480
|
+
) -> LLMResponse:
|
|
481
|
+
"""Execute generate with provider."""
|
|
482
|
+
# Get provider
|
|
483
|
+
provider_instance = provider_registry.get(provider, config)
|
|
484
|
+
|
|
485
|
+
# Execute with retry
|
|
486
|
+
start_time = time.time()
|
|
487
|
+
|
|
488
|
+
def _call():
|
|
489
|
+
return provider_instance.generate(prompt, config)
|
|
490
|
+
|
|
491
|
+
response = retry_with_backoff(_call, self.retry_config)
|
|
492
|
+
|
|
493
|
+
# Post-process
|
|
494
|
+
response.latency_ms = int((time.time() - start_time) * 1000)
|
|
495
|
+
self._post_process_response(response, config)
|
|
496
|
+
|
|
497
|
+
return response
|
|
498
|
+
|
|
499
|
+
def _execute_chat(
|
|
500
|
+
self,
|
|
501
|
+
provider: str,
|
|
502
|
+
model: str | None,
|
|
503
|
+
config: LLMConfig,
|
|
504
|
+
messages: list[dict[str, Any]],
|
|
505
|
+
) -> LLMResponse:
|
|
506
|
+
"""Execute chat with provider."""
|
|
507
|
+
# Get provider
|
|
508
|
+
provider_instance = provider_registry.get(provider, config)
|
|
509
|
+
|
|
510
|
+
# Execute with retry
|
|
511
|
+
start_time = time.time()
|
|
512
|
+
|
|
513
|
+
def _call():
|
|
514
|
+
return provider_instance.chat(messages, config)
|
|
515
|
+
|
|
516
|
+
response = retry_with_backoff(_call, self.retry_config)
|
|
517
|
+
|
|
518
|
+
# Post-process
|
|
519
|
+
response.latency_ms = int((time.time() - start_time) * 1000)
|
|
520
|
+
self._post_process_response(response, config)
|
|
521
|
+
|
|
522
|
+
return response
|
|
523
|
+
|
|
524
|
+
def _execute_generate_json(
|
|
525
|
+
self,
|
|
526
|
+
provider: str,
|
|
527
|
+
model: str | None,
|
|
528
|
+
config: LLMConfig,
|
|
529
|
+
prompt: str,
|
|
530
|
+
schema: dict[str, Any],
|
|
531
|
+
) -> LLMResponse:
|
|
532
|
+
"""Execute generate_json with provider."""
|
|
533
|
+
# Get provider
|
|
534
|
+
provider_instance = provider_registry.get(provider, config)
|
|
535
|
+
|
|
536
|
+
# Execute with retry
|
|
537
|
+
start_time = time.time()
|
|
538
|
+
|
|
539
|
+
def _call():
|
|
540
|
+
return provider_instance.generate_json(prompt, schema, config)
|
|
541
|
+
|
|
542
|
+
response = retry_with_backoff(_call, self.retry_config)
|
|
543
|
+
|
|
544
|
+
# Post-process
|
|
545
|
+
response.latency_ms = int((time.time() - start_time) * 1000)
|
|
546
|
+
self._post_process_response(response, config)
|
|
547
|
+
|
|
548
|
+
return response
|
|
549
|
+
|
|
550
|
+
def _post_process_response(
|
|
551
|
+
self,
|
|
552
|
+
response: LLMResponse,
|
|
553
|
+
config: LLMConfig,
|
|
554
|
+
) -> None:
|
|
555
|
+
"""Post-process response (cost, usage tracking, etc.)."""
|
|
556
|
+
# Estimate cost
|
|
557
|
+
if self.gateway_config.cost.enabled and response.usage.total_tokens > 0:
|
|
558
|
+
response.cost = estimate_cost(
|
|
559
|
+
response.usage,
|
|
560
|
+
response.model or config.model or "",
|
|
561
|
+
self.gateway_config.cost.custom_pricing,
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
# Record usage
|
|
565
|
+
if response.usage.total_tokens > 0:
|
|
566
|
+
record_usage(
|
|
567
|
+
response.usage,
|
|
568
|
+
response.provider,
|
|
569
|
+
response.model or "",
|
|
570
|
+
task=config.metadata.get("task"),
|
|
571
|
+
)
|
|
572
|
+
|
|
573
|
+
@classmethod
|
|
574
|
+
def from_yaml(cls, config_path: str) -> "LLMClient":
|
|
575
|
+
"""
|
|
576
|
+
Create client from YAML configuration file.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
config_path: Path to YAML config file
|
|
580
|
+
|
|
581
|
+
Returns:
|
|
582
|
+
Configured LLM client
|
|
583
|
+
|
|
584
|
+
Example:
|
|
585
|
+
>>> client = LLMClient.from_yaml("config.yaml")
|
|
586
|
+
"""
|
|
587
|
+
try:
|
|
588
|
+
import yaml
|
|
589
|
+
except ImportError:
|
|
590
|
+
raise ImportError("PyYAML required for YAML config. Install with: pip install pyyaml")
|
|
591
|
+
|
|
592
|
+
with open(config_path, "r") as f:
|
|
593
|
+
config_dict = yaml.safe_load(f)
|
|
594
|
+
|
|
595
|
+
# Parse gateway config
|
|
596
|
+
gateway_config = GatewayConfig(
|
|
597
|
+
default_provider=config_dict.get("default_provider", "openai"),
|
|
598
|
+
default_model=config_dict.get("default_model"),
|
|
599
|
+
default_timeout=config_dict.get("default_timeout", 60.0),
|
|
600
|
+
default_retries=config_dict.get("default_retries", 2),
|
|
601
|
+
providers=config_dict.get("providers", {}),
|
|
602
|
+
fallback_chains=config_dict.get("fallback_chains", {}),
|
|
603
|
+
tasks=config_dict.get("tasks", {}),
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
return cls(gateway_config=gateway_config)
|
|
607
|
+
|
|
608
|
+
def __repr__(self) -> str:
|
|
609
|
+
return f"<LLMClient provider='{self.default_config.provider}' model='{self.default_config.model}'>"
|