lionagi 0.12.3__py3-none-any.whl → 0.12.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/config.py +123 -0
- lionagi/libs/schema/load_pydantic_model_from_schema.py +259 -0
- lionagi/libs/token_transform/perplexity.py +2 -4
- lionagi/libs/token_transform/synthlang_/translate_to_synthlang.py +1 -1
- lionagi/operations/chat/chat.py +2 -2
- lionagi/operations/communicate/communicate.py +20 -5
- lionagi/operations/parse/parse.py +131 -43
- lionagi/protocols/generic/pile.py +94 -33
- lionagi/protocols/graph/node.py +25 -19
- lionagi/protocols/messages/assistant_response.py +20 -1
- lionagi/service/connections/__init__.py +15 -0
- lionagi/service/connections/api_calling.py +230 -0
- lionagi/service/connections/endpoint.py +410 -0
- lionagi/service/connections/endpoint_config.py +137 -0
- lionagi/service/connections/header_factory.py +56 -0
- lionagi/service/connections/match_endpoint.py +49 -0
- lionagi/service/connections/providers/__init__.py +3 -0
- lionagi/service/connections/providers/anthropic_.py +87 -0
- lionagi/service/connections/providers/exa_.py +33 -0
- lionagi/service/connections/providers/oai_.py +166 -0
- lionagi/service/connections/providers/ollama_.py +122 -0
- lionagi/service/connections/providers/perplexity_.py +29 -0
- lionagi/service/imodel.py +36 -144
- lionagi/service/manager.py +1 -7
- lionagi/service/{endpoints/rate_limited_processor.py → rate_limited_processor.py} +4 -2
- lionagi/service/resilience.py +545 -0
- lionagi/service/third_party/README.md +71 -0
- lionagi/service/third_party/anthropic_models.py +159 -0
- lionagi/service/{providers/exa_/models.py → third_party/exa_models.py} +18 -13
- lionagi/service/third_party/openai_models.py +18241 -0
- lionagi/service/third_party/pplx_models.py +156 -0
- lionagi/service/types.py +5 -4
- lionagi/session/branch.py +12 -7
- lionagi/tools/file/reader.py +1 -1
- lionagi/tools/memory/tools.py +497 -0
- lionagi/version.py +1 -1
- {lionagi-0.12.3.dist-info → lionagi-0.12.5.dist-info}/METADATA +17 -19
- {lionagi-0.12.3.dist-info → lionagi-0.12.5.dist-info}/RECORD +43 -54
- lionagi/adapters/__init__.py +0 -1
- lionagi/adapters/adapter.py +0 -120
- lionagi/adapters/json_adapter.py +0 -181
- lionagi/adapters/pandas_/csv_adapter.py +0 -94
- lionagi/adapters/pandas_/excel_adapter.py +0 -94
- lionagi/adapters/pandas_/pd_dataframe_adapter.py +0 -81
- lionagi/adapters/pandas_/pd_series_adapter.py +0 -57
- lionagi/adapters/toml_adapter.py +0 -204
- lionagi/adapters/types.py +0 -21
- lionagi/service/endpoints/__init__.py +0 -3
- lionagi/service/endpoints/base.py +0 -706
- lionagi/service/endpoints/chat_completion.py +0 -116
- lionagi/service/endpoints/match_endpoint.py +0 -72
- lionagi/service/providers/__init__.py +0 -3
- lionagi/service/providers/anthropic_/__init__.py +0 -3
- lionagi/service/providers/anthropic_/messages.py +0 -99
- lionagi/service/providers/exa_/search.py +0 -80
- lionagi/service/providers/exa_/types.py +0 -7
- lionagi/service/providers/groq_/__init__.py +0 -3
- lionagi/service/providers/groq_/chat_completions.py +0 -56
- lionagi/service/providers/ollama_/__init__.py +0 -3
- lionagi/service/providers/ollama_/chat_completions.py +0 -134
- lionagi/service/providers/openai_/__init__.py +0 -3
- lionagi/service/providers/openai_/chat_completions.py +0 -101
- lionagi/service/providers/openai_/spec.py +0 -14
- lionagi/service/providers/openrouter_/__init__.py +0 -3
- lionagi/service/providers/openrouter_/chat_completions.py +0 -62
- lionagi/service/providers/perplexity_/__init__.py +0 -3
- lionagi/service/providers/perplexity_/chat_completions.py +0 -44
- lionagi/service/providers/perplexity_/models.py +0 -144
- lionagi/service/providers/types.py +0 -17
- /lionagi/{adapters/pandas_/__init__.py → py.typed} +0 -0
- /lionagi/service/{providers/exa_ → third_party}/__init__.py +0 -0
- /lionagi/service/{endpoints/token_calculator.py → token_calculator.py} +0 -0
- {lionagi-0.12.3.dist-info → lionagi-0.12.5.dist-info}/WHEEL +0 -0
- {lionagi-0.12.3.dist-info → lionagi-0.12.5.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,410 @@
|
|
1
|
+
# Copyright (c) 2025, HaiyangLi <quantocean.li at gmail dot com>
|
2
|
+
#
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import logging
|
7
|
+
|
8
|
+
import aiohttp
|
9
|
+
import backoff
|
10
|
+
from aiocache import cached
|
11
|
+
from pydantic import BaseModel
|
12
|
+
|
13
|
+
from lionagi.config import settings
|
14
|
+
from lionagi.service.resilience import (
|
15
|
+
CircuitBreaker,
|
16
|
+
RetryConfig,
|
17
|
+
retry_with_backoff,
|
18
|
+
)
|
19
|
+
from lionagi.utils import to_dict
|
20
|
+
|
21
|
+
from .endpoint_config import EndpointConfig
|
22
|
+
from .header_factory import HeaderFactory
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class Endpoint:
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
config: dict | EndpointConfig,
|
31
|
+
circuit_breaker: CircuitBreaker | None = None,
|
32
|
+
retry_config: RetryConfig | None = None,
|
33
|
+
**kwargs,
|
34
|
+
):
|
35
|
+
"""
|
36
|
+
Initialize the endpoint.
|
37
|
+
|
38
|
+
This endpoint is designed to be stateless and thread-safe for parallel operations.
|
39
|
+
Each API call will create its own client session to avoid conflicts.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
config: The endpoint configuration.
|
43
|
+
circuit_breaker: Optional circuit breaker for resilience.
|
44
|
+
retry_config: Optional retry configuration for resilience.
|
45
|
+
**kwargs: Additional keyword arguments to update the configuration.
|
46
|
+
"""
|
47
|
+
if isinstance(config, dict):
|
48
|
+
_config = EndpointConfig(**config, **kwargs)
|
49
|
+
elif isinstance(config, EndpointConfig):
|
50
|
+
_config = config.model_copy(
|
51
|
+
deep=True
|
52
|
+
) # Use deep copy to avoid sharing kwargs dict
|
53
|
+
_config.update(**kwargs)
|
54
|
+
else:
|
55
|
+
raise ValueError(
|
56
|
+
"Config must be a dict or EndpointConfig instance"
|
57
|
+
)
|
58
|
+
self.config = _config
|
59
|
+
self.circuit_breaker = circuit_breaker
|
60
|
+
self.retry_config = retry_config
|
61
|
+
|
62
|
+
logger.debug(
|
63
|
+
f"Initialized Endpoint with provider={self.config.provider}, "
|
64
|
+
f"endpoint={self.config.endpoint}, circuit_breaker={circuit_breaker is not None}, "
|
65
|
+
f"retry_config={retry_config is not None}"
|
66
|
+
)
|
67
|
+
|
68
|
+
def _create_http_session(self):
|
69
|
+
"""Create a new HTTP session (not thread-safe, create new for each request)."""
|
70
|
+
return aiohttp.ClientSession(
|
71
|
+
timeout=aiohttp.ClientTimeout(self.config.timeout),
|
72
|
+
**self.config.client_kwargs,
|
73
|
+
)
|
74
|
+
|
75
|
+
# Removed old context manager methods - endpoint is now stateless
|
76
|
+
|
77
|
+
@property
|
78
|
+
def request_options(self):
|
79
|
+
return self.config.request_options
|
80
|
+
|
81
|
+
@request_options.setter
|
82
|
+
def request_options(self, value):
|
83
|
+
self.config.request_options = EndpointConfig._validate_request_options(
|
84
|
+
value
|
85
|
+
)
|
86
|
+
|
87
|
+
def create_payload(
|
88
|
+
self,
|
89
|
+
request: dict | BaseModel,
|
90
|
+
extra_headers: dict | None = None,
|
91
|
+
**kwargs,
|
92
|
+
):
|
93
|
+
# First, create headers
|
94
|
+
headers = HeaderFactory.get_header(
|
95
|
+
auth_type=self.config.auth_type,
|
96
|
+
content_type=self.config.content_type,
|
97
|
+
api_key=self.config._api_key,
|
98
|
+
default_headers=self.config.default_headers,
|
99
|
+
)
|
100
|
+
if extra_headers:
|
101
|
+
headers.update(extra_headers)
|
102
|
+
|
103
|
+
# Convert request to dict if it's a BaseModel
|
104
|
+
request = (
|
105
|
+
request
|
106
|
+
if isinstance(request, dict)
|
107
|
+
else request.model_dump(exclude_none=True)
|
108
|
+
)
|
109
|
+
|
110
|
+
# Start with config defaults
|
111
|
+
payload = self.config.kwargs.copy()
|
112
|
+
|
113
|
+
# Update with request data
|
114
|
+
payload.update(request)
|
115
|
+
|
116
|
+
# Update with additional kwargs
|
117
|
+
if kwargs:
|
118
|
+
payload.update(kwargs)
|
119
|
+
|
120
|
+
# If we have request_options, use the model's fields to filter valid params
|
121
|
+
if self.config.request_options is not None:
|
122
|
+
# Get valid field names from the model
|
123
|
+
valid_fields = set(self.config.request_options.model_fields.keys())
|
124
|
+
|
125
|
+
# Filter payload to only include valid fields
|
126
|
+
filtered_payload = {
|
127
|
+
k: v for k, v in payload.items() if k in valid_fields
|
128
|
+
}
|
129
|
+
|
130
|
+
# Validate the filtered payload
|
131
|
+
payload = self.config.validate_payload(filtered_payload)
|
132
|
+
else:
|
133
|
+
# If no request_options, we still need to remove obvious non-API params
|
134
|
+
# These are parameters that are never part of any API payload
|
135
|
+
non_api_params = {
|
136
|
+
"task",
|
137
|
+
"provider",
|
138
|
+
"base_url",
|
139
|
+
"endpoint",
|
140
|
+
"endpoint_params",
|
141
|
+
"api_key",
|
142
|
+
"queue_capacity",
|
143
|
+
"capacity_refresh_time",
|
144
|
+
"interval",
|
145
|
+
"limit_requests",
|
146
|
+
"limit_tokens",
|
147
|
+
"invoke_with_endpoint",
|
148
|
+
"extra_headers",
|
149
|
+
"headers",
|
150
|
+
"cache_control",
|
151
|
+
"include_token_usage_to_model",
|
152
|
+
"chat_model",
|
153
|
+
"imodel",
|
154
|
+
"branch",
|
155
|
+
}
|
156
|
+
payload = {
|
157
|
+
k: v for k, v in payload.items() if k not in non_api_params
|
158
|
+
}
|
159
|
+
|
160
|
+
return (payload, headers)
|
161
|
+
|
162
|
+
async def call(
|
163
|
+
self, request: dict | BaseModel, cache_control: bool = False, **kwargs
|
164
|
+
):
|
165
|
+
"""
|
166
|
+
Make a call to the endpoint.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
request: The request parameters or model.
|
170
|
+
cache_control: Whether to use cache control.
|
171
|
+
**kwargs: Additional keyword arguments for the request.
|
172
|
+
|
173
|
+
Returns:
|
174
|
+
The response from the endpoint.
|
175
|
+
"""
|
176
|
+
# Extract extra_headers before passing to create_payload
|
177
|
+
extra_headers = kwargs.pop("extra_headers", None)
|
178
|
+
payload, headers = self.create_payload(
|
179
|
+
request, extra_headers=extra_headers, **kwargs
|
180
|
+
)
|
181
|
+
|
182
|
+
async def _call(payload: dict, headers: dict, **kwargs):
|
183
|
+
# Direct call without context manager - each method handles its own resources
|
184
|
+
return await self._call_aiohttp(
|
185
|
+
payload=payload, headers=headers, **kwargs
|
186
|
+
)
|
187
|
+
|
188
|
+
# Apply resilience patterns if configured
|
189
|
+
call_func = _call
|
190
|
+
|
191
|
+
# Apply retry if configured
|
192
|
+
if self.retry_config:
|
193
|
+
|
194
|
+
async def call_func(p, h, **kw):
|
195
|
+
return await retry_with_backoff(
|
196
|
+
_call, p, h, **kw, **self.retry_config.as_kwargs()
|
197
|
+
)
|
198
|
+
|
199
|
+
# Apply circuit breaker if configured
|
200
|
+
if self.circuit_breaker:
|
201
|
+
if self.retry_config:
|
202
|
+
# If both are configured, apply circuit breaker to the retry-wrapped function
|
203
|
+
if not cache_control:
|
204
|
+
return await self.circuit_breaker.execute(
|
205
|
+
call_func, payload, headers, **kwargs
|
206
|
+
)
|
207
|
+
else:
|
208
|
+
# If only circuit breaker is configured, apply it directly
|
209
|
+
if not cache_control:
|
210
|
+
return await self.circuit_breaker.execute(
|
211
|
+
_call, payload, headers, **kwargs
|
212
|
+
)
|
213
|
+
|
214
|
+
# Handle caching if requested
|
215
|
+
if cache_control:
|
216
|
+
|
217
|
+
@cached(**settings.aiocache_config.as_kwargs())
|
218
|
+
async def _cached_call(payload: dict, headers: dict, **kwargs):
|
219
|
+
# Apply resilience patterns to cached call if configured
|
220
|
+
if self.circuit_breaker and self.retry_config:
|
221
|
+
return await self.circuit_breaker.execute(
|
222
|
+
call_func, payload, headers, **kwargs
|
223
|
+
)
|
224
|
+
if self.circuit_breaker:
|
225
|
+
return await self.circuit_breaker.execute(
|
226
|
+
_call, payload, headers, **kwargs
|
227
|
+
)
|
228
|
+
if self.retry_config:
|
229
|
+
return await call_func(payload, headers, **kwargs)
|
230
|
+
|
231
|
+
return await _call(payload, headers, **kwargs)
|
232
|
+
|
233
|
+
return await _cached_call(payload, headers, **kwargs)
|
234
|
+
|
235
|
+
# No caching, apply resilience patterns directly
|
236
|
+
if self.retry_config:
|
237
|
+
return await call_func(payload, headers, **kwargs)
|
238
|
+
|
239
|
+
return await _call(payload, headers, **kwargs)
|
240
|
+
|
241
|
+
async def _call_aiohttp(self, payload: dict, headers: dict, **kwargs):
|
242
|
+
"""
|
243
|
+
Make a call using aiohttp with a fresh session for each request.
|
244
|
+
|
245
|
+
Args:
|
246
|
+
payload: The request payload.
|
247
|
+
headers: The request headers.
|
248
|
+
**kwargs: Additional keyword arguments for the request.
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
The response from the endpoint.
|
252
|
+
"""
|
253
|
+
|
254
|
+
async def _make_request_with_backoff():
|
255
|
+
# Create a new session for this request
|
256
|
+
async with self._create_http_session() as session:
|
257
|
+
response = None
|
258
|
+
try:
|
259
|
+
response = await session.request(
|
260
|
+
method=self.config.method,
|
261
|
+
url=self.config.full_url,
|
262
|
+
headers=headers,
|
263
|
+
json=payload,
|
264
|
+
**kwargs,
|
265
|
+
)
|
266
|
+
|
267
|
+
# Check for rate limit or server errors that should be retried
|
268
|
+
if response.status == 429 or response.status >= 500:
|
269
|
+
response.raise_for_status() # This will be caught by backoff
|
270
|
+
elif response.status != 200:
|
271
|
+
# Try to get error details from response body
|
272
|
+
try:
|
273
|
+
error_body = await response.json()
|
274
|
+
error_message = f"Request failed with status {response.status}: {error_body}"
|
275
|
+
except:
|
276
|
+
error_message = (
|
277
|
+
f"Request failed with status {response.status}"
|
278
|
+
)
|
279
|
+
|
280
|
+
raise aiohttp.ClientResponseError(
|
281
|
+
request_info=response.request_info,
|
282
|
+
history=response.history,
|
283
|
+
status=response.status,
|
284
|
+
message=error_message,
|
285
|
+
headers=response.headers,
|
286
|
+
)
|
287
|
+
|
288
|
+
# Extract and return the JSON response
|
289
|
+
return await response.json()
|
290
|
+
finally:
|
291
|
+
# Ensure response is properly released if coroutine is cancelled between retries
|
292
|
+
if response is not None and not response.closed:
|
293
|
+
await response.release()
|
294
|
+
|
295
|
+
# Define a giveup function for backoff
|
296
|
+
def giveup_on_client_error(e):
|
297
|
+
# Don't retry on 4xx errors except 429 (rate limit)
|
298
|
+
if isinstance(e, aiohttp.ClientResponseError):
|
299
|
+
return 400 <= e.status < 500 and e.status != 429
|
300
|
+
return False
|
301
|
+
|
302
|
+
# Use backoff for retries with exponential backoff and jitter
|
303
|
+
# Moved inside the method to reference runtime config
|
304
|
+
backoff_handler = backoff.on_exception(
|
305
|
+
backoff.expo,
|
306
|
+
(aiohttp.ClientError, asyncio.TimeoutError),
|
307
|
+
max_tries=self.config.max_retries,
|
308
|
+
giveup=giveup_on_client_error,
|
309
|
+
jitter=backoff.full_jitter,
|
310
|
+
)
|
311
|
+
|
312
|
+
# Apply the decorator at runtime
|
313
|
+
return await backoff_handler(_make_request_with_backoff)()
|
314
|
+
|
315
|
+
async def stream(
|
316
|
+
self,
|
317
|
+
request: dict | BaseModel,
|
318
|
+
extra_headers: dict | None = None,
|
319
|
+
**kwargs,
|
320
|
+
):
|
321
|
+
"""
|
322
|
+
Stream responses from the endpoint.
|
323
|
+
|
324
|
+
Args:
|
325
|
+
request: The request parameters or model.
|
326
|
+
extra_headers: Additional headers for the request.
|
327
|
+
**kwargs: Additional keyword arguments for the request.
|
328
|
+
|
329
|
+
Yields:
|
330
|
+
Streaming chunks from the API.
|
331
|
+
"""
|
332
|
+
payload, headers = self.create_payload(
|
333
|
+
request, extra_headers, **kwargs
|
334
|
+
)
|
335
|
+
|
336
|
+
# Direct streaming without context manager
|
337
|
+
async for chunk in self._stream_aiohttp(
|
338
|
+
payload=payload, headers=headers, **kwargs
|
339
|
+
):
|
340
|
+
yield chunk
|
341
|
+
|
342
|
+
async def _stream_aiohttp(self, payload: dict, headers: dict, **kwargs):
|
343
|
+
"""
|
344
|
+
Stream responses using aiohttp with a fresh session.
|
345
|
+
|
346
|
+
Args:
|
347
|
+
payload: The request payload.
|
348
|
+
headers: The request headers.
|
349
|
+
**kwargs: Additional keyword arguments for the request.
|
350
|
+
|
351
|
+
Yields:
|
352
|
+
Streaming chunks from the API.
|
353
|
+
"""
|
354
|
+
# Ensure stream is enabled
|
355
|
+
payload["stream"] = True
|
356
|
+
|
357
|
+
# Create a new session for streaming
|
358
|
+
async with self._create_http_session() as session:
|
359
|
+
async with session.request(
|
360
|
+
method=self.config.method,
|
361
|
+
url=self.config.full_url,
|
362
|
+
headers=headers,
|
363
|
+
json=payload,
|
364
|
+
**kwargs,
|
365
|
+
) as response:
|
366
|
+
if response.status != 200:
|
367
|
+
raise aiohttp.ClientResponseError(
|
368
|
+
request_info=response.request_info,
|
369
|
+
history=response.history,
|
370
|
+
status=response.status,
|
371
|
+
message=f"Request failed with status {response.status}",
|
372
|
+
headers=response.headers,
|
373
|
+
)
|
374
|
+
|
375
|
+
async for line in response.content:
|
376
|
+
if line:
|
377
|
+
yield line.decode("utf-8")
|
378
|
+
|
379
|
+
def to_dict(self):
|
380
|
+
return {
|
381
|
+
"retry_config": (
|
382
|
+
self.retry_config.to_dict() if self.retry_config else None
|
383
|
+
),
|
384
|
+
"circuit_breaker": (
|
385
|
+
self.circuit_breaker.to_dict()
|
386
|
+
if self.circuit_breaker
|
387
|
+
else None
|
388
|
+
),
|
389
|
+
"config": self.config.model_dump(exclude_none=True),
|
390
|
+
}
|
391
|
+
|
392
|
+
@classmethod
|
393
|
+
def from_dict(cls, data: dict):
|
394
|
+
data = to_dict(data, recursive=True)
|
395
|
+
retry_config = data.get("retry_config")
|
396
|
+
circuit_breaker = data.get("circuit_breaker")
|
397
|
+
config = data.get("config")
|
398
|
+
|
399
|
+
if retry_config:
|
400
|
+
retry_config = RetryConfig(**retry_config)
|
401
|
+
if circuit_breaker:
|
402
|
+
circuit_breaker = CircuitBreaker(**circuit_breaker)
|
403
|
+
if config:
|
404
|
+
config = EndpointConfig(**config)
|
405
|
+
|
406
|
+
return cls(
|
407
|
+
config=config,
|
408
|
+
circuit_breaker=circuit_breaker,
|
409
|
+
retry_config=retry_config,
|
410
|
+
)
|
@@ -0,0 +1,137 @@
|
|
1
|
+
# Copyright (c) 2025, HaiyangLi <quantocean.li at gmail dot com>
|
2
|
+
#
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
4
|
+
|
5
|
+
import os
|
6
|
+
from typing import Any, TypeVar
|
7
|
+
|
8
|
+
from pydantic import (
|
9
|
+
BaseModel,
|
10
|
+
Field,
|
11
|
+
PrivateAttr,
|
12
|
+
SecretStr,
|
13
|
+
field_serializer,
|
14
|
+
field_validator,
|
15
|
+
model_validator,
|
16
|
+
)
|
17
|
+
|
18
|
+
from .header_factory import AUTH_TYPES
|
19
|
+
|
20
|
+
B = TypeVar("B", bound=type[BaseModel])
|
21
|
+
|
22
|
+
|
23
|
+
class EndpointConfig(BaseModel):
|
24
|
+
name: str
|
25
|
+
provider: str
|
26
|
+
base_url: str | None = None
|
27
|
+
endpoint: str
|
28
|
+
endpoint_params: list[str] | None = None
|
29
|
+
method: str = "POST"
|
30
|
+
params: dict[str, str] = Field(default_factory=dict)
|
31
|
+
content_type: str = "application/json"
|
32
|
+
auth_type: AUTH_TYPES = "bearer"
|
33
|
+
default_headers: dict = {}
|
34
|
+
request_options: B | None = None
|
35
|
+
api_key: str | SecretStr | None = None
|
36
|
+
timeout: int = 300
|
37
|
+
max_retries: int = 3
|
38
|
+
openai_compatible: bool = False
|
39
|
+
requires_tokens: bool = False
|
40
|
+
kwargs: dict = Field(default_factory=dict)
|
41
|
+
client_kwargs: dict = Field(default_factory=dict)
|
42
|
+
_api_key: str | None = PrivateAttr(None)
|
43
|
+
|
44
|
+
@model_validator(mode="before")
|
45
|
+
def _validate_kwargs(cls, data: dict):
|
46
|
+
kwargs = data.pop("kwargs", {})
|
47
|
+
field_keys = list(cls.model_json_schema().get("properties", {}).keys())
|
48
|
+
for k in list(data.keys()):
|
49
|
+
if k not in field_keys:
|
50
|
+
kwargs[k] = data.pop(k)
|
51
|
+
data["kwargs"] = kwargs
|
52
|
+
return data
|
53
|
+
|
54
|
+
@model_validator(mode="after")
|
55
|
+
def _validate_api_key(self):
|
56
|
+
|
57
|
+
if self.api_key is not None:
|
58
|
+
if isinstance(self.api_key, SecretStr):
|
59
|
+
self._api_key = self.api_key.get_secret_value()
|
60
|
+
elif isinstance(self.api_key, str):
|
61
|
+
# Skip settings lookup for ollama special case
|
62
|
+
if self.provider == "ollama" and self.api_key == "ollama_key":
|
63
|
+
self._api_key = "ollama_key"
|
64
|
+
else:
|
65
|
+
from lionagi.config import settings
|
66
|
+
|
67
|
+
try:
|
68
|
+
self._api_key = settings.get_secret(self.api_key)
|
69
|
+
except (AttributeError, ValueError):
|
70
|
+
self._api_key = os.getenv(self.api_key, self.api_key)
|
71
|
+
|
72
|
+
return self
|
73
|
+
|
74
|
+
@property
|
75
|
+
def full_url(self):
|
76
|
+
if not self.endpoint_params:
|
77
|
+
return f"{self.base_url}/{self.endpoint}"
|
78
|
+
return f"{self.base_url}/{self.endpoint.format(**self.params)}"
|
79
|
+
|
80
|
+
@field_validator("request_options", mode="before")
|
81
|
+
def _validate_request_options(cls, v):
|
82
|
+
# Create a simple empty model if None is provided
|
83
|
+
if v is None:
|
84
|
+
return None
|
85
|
+
|
86
|
+
try:
|
87
|
+
if isinstance(v, type) and issubclass(v, BaseModel):
|
88
|
+
return v
|
89
|
+
if isinstance(v, BaseModel):
|
90
|
+
return v.__class__
|
91
|
+
if isinstance(v, dict | str):
|
92
|
+
from lionagi.libs.schema import SchemaUtil
|
93
|
+
|
94
|
+
return SchemaUtil.load_pydantic_model_from_schema(v)
|
95
|
+
except Exception as e:
|
96
|
+
raise ValueError("Invalid request options") from e
|
97
|
+
raise ValueError(
|
98
|
+
"Invalid request options: must be a Pydantic model or a schema dict"
|
99
|
+
)
|
100
|
+
|
101
|
+
@field_serializer("request_options")
|
102
|
+
def _serialize_request_options(self, v: B | None):
|
103
|
+
if v is None:
|
104
|
+
return None
|
105
|
+
return v.model_json_schema()
|
106
|
+
|
107
|
+
def update(self, **kwargs):
|
108
|
+
"""Update the config with new values."""
|
109
|
+
# Handle the special case of kwargs dict
|
110
|
+
if "kwargs" in kwargs:
|
111
|
+
# Merge the kwargs dicts
|
112
|
+
self.kwargs.update(kwargs.pop("kwargs"))
|
113
|
+
|
114
|
+
for key, value in kwargs.items():
|
115
|
+
if hasattr(self, key):
|
116
|
+
setattr(self, key, value)
|
117
|
+
else:
|
118
|
+
# Add to kwargs dict if not a direct attribute
|
119
|
+
self.kwargs[key] = value
|
120
|
+
|
121
|
+
def validate_payload(self, data: dict[str, Any]) -> dict[str, Any]:
|
122
|
+
"""Validate payload data against the request_options model.
|
123
|
+
|
124
|
+
Args:
|
125
|
+
data: The payload data to validate
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
The validated data
|
129
|
+
"""
|
130
|
+
if not self.request_options:
|
131
|
+
return data
|
132
|
+
|
133
|
+
try:
|
134
|
+
self.request_options.model_validate(data)
|
135
|
+
return data
|
136
|
+
except Exception as e:
|
137
|
+
raise ValueError("Invalid payload") from e
|
@@ -0,0 +1,56 @@
|
|
1
|
+
# Copyright (c) 2025, HaiyangLi <quantocean.li at gmail dot com>
|
2
|
+
#
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
4
|
+
|
5
|
+
from typing import Literal
|
6
|
+
|
7
|
+
from pydantic import SecretStr
|
8
|
+
|
9
|
+
AUTH_TYPES = Literal["bearer", "x-api-key", "none"]
|
10
|
+
|
11
|
+
|
12
|
+
class HeaderFactory:
|
13
|
+
@staticmethod
|
14
|
+
def get_content_type_header(
|
15
|
+
content_type: str = "application/json",
|
16
|
+
) -> dict[str, str]:
|
17
|
+
return {"Content-Type": content_type}
|
18
|
+
|
19
|
+
@staticmethod
|
20
|
+
def get_bearer_auth_header(api_key: str) -> dict[str, str]:
|
21
|
+
return {"Authorization": f"Bearer {api_key}"}
|
22
|
+
|
23
|
+
@staticmethod
|
24
|
+
def get_x_api_key_header(api_key: str) -> dict[str, str]:
|
25
|
+
return {"x-api-key": api_key}
|
26
|
+
|
27
|
+
@staticmethod
|
28
|
+
def get_header(
|
29
|
+
auth_type: AUTH_TYPES,
|
30
|
+
content_type: str = "application/json",
|
31
|
+
api_key: str | SecretStr | None = None,
|
32
|
+
default_headers: dict[str, str] | None = None,
|
33
|
+
) -> dict[str, str]:
|
34
|
+
dict_ = HeaderFactory.get_content_type_header(content_type)
|
35
|
+
|
36
|
+
if auth_type == "none":
|
37
|
+
# No authentication needed
|
38
|
+
pass
|
39
|
+
elif not api_key:
|
40
|
+
raise ValueError("API key is required for authentication")
|
41
|
+
else:
|
42
|
+
api_key = (
|
43
|
+
api_key.get_secret_value()
|
44
|
+
if isinstance(api_key, SecretStr)
|
45
|
+
else api_key
|
46
|
+
)
|
47
|
+
if auth_type == "bearer":
|
48
|
+
dict_.update(HeaderFactory.get_bearer_auth_header(api_key))
|
49
|
+
elif auth_type == "x-api-key":
|
50
|
+
dict_.update(HeaderFactory.get_x_api_key_header(api_key))
|
51
|
+
else:
|
52
|
+
raise ValueError(f"Unsupported auth type: {auth_type}")
|
53
|
+
|
54
|
+
if default_headers:
|
55
|
+
dict_.update(default_headers)
|
56
|
+
return dict_
|
@@ -0,0 +1,49 @@
|
|
1
|
+
# Copyright (c) 2023 - 2025, HaiyangLi <quantocean.li at gmail dot com>
|
2
|
+
#
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
4
|
+
|
5
|
+
from .endpoint import Endpoint
|
6
|
+
|
7
|
+
|
8
|
+
def match_endpoint(
|
9
|
+
provider: str,
|
10
|
+
endpoint: str,
|
11
|
+
**kwargs,
|
12
|
+
) -> Endpoint:
|
13
|
+
if provider == "openai":
|
14
|
+
if "chat" in endpoint:
|
15
|
+
from .providers.oai_ import OpenaiChatEndpoint
|
16
|
+
|
17
|
+
return OpenaiChatEndpoint(**kwargs)
|
18
|
+
if "response" in endpoint:
|
19
|
+
from .providers.oai_ import OpenaiResponseEndpoint
|
20
|
+
|
21
|
+
return OpenaiResponseEndpoint(**kwargs)
|
22
|
+
if provider == "openrouter" and "chat" in endpoint:
|
23
|
+
from .providers.oai_ import OpenrouterChatEndpoint
|
24
|
+
|
25
|
+
return OpenrouterChatEndpoint(**kwargs)
|
26
|
+
if provider == "ollama" and "chat" in endpoint:
|
27
|
+
from .providers.ollama_ import OllamaChatEndpoint
|
28
|
+
|
29
|
+
return OllamaChatEndpoint(**kwargs)
|
30
|
+
if provider == "exa" and "search" in endpoint:
|
31
|
+
from .providers.exa_ import ExaSearchEndpoint
|
32
|
+
|
33
|
+
return ExaSearchEndpoint(**kwargs)
|
34
|
+
if provider == "anthropic" and (
|
35
|
+
"messages" in endpoint or "chat" in endpoint
|
36
|
+
):
|
37
|
+
from .providers.anthropic_ import AnthropicMessagesEndpoint
|
38
|
+
|
39
|
+
return AnthropicMessagesEndpoint(**kwargs)
|
40
|
+
if provider == "groq" and "chat" in endpoint:
|
41
|
+
from .providers.oai_ import GroqChatEndpoint
|
42
|
+
|
43
|
+
return GroqChatEndpoint(**kwargs)
|
44
|
+
if provider == "perplexity" and "chat" in endpoint:
|
45
|
+
from .providers.perplexity_ import PerplexityChatEndpoint
|
46
|
+
|
47
|
+
return PerplexityChatEndpoint(**kwargs)
|
48
|
+
|
49
|
+
return None
|