airia 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airia/client/__init__.py +4 -0
- airia/client/async_client.py +448 -0
- airia/client/base_client.py +187 -0
- airia/client/sync_client.py +436 -0
- airia/types/__init__.py +19 -0
- airia/types/api_version.py +13 -0
- airia/types/pipeline_execution.py +31 -0
- airia/types/request_data.py +10 -0
- {airia-0.1.3.dist-info → airia-0.1.4.dist-info}/METADATA +1 -1
- airia-0.1.4.dist-info/RECORD +16 -0
- airia-0.1.3.dist-info/RECORD +0 -8
- {airia-0.1.3.dist-info → airia-0.1.4.dist-info}/WHEEL +0 -0
- {airia-0.1.3.dist-info → airia-0.1.4.dist-info}/licenses/LICENSE +0 -0
- {airia-0.1.3.dist-info → airia-0.1.4.dist-info}/top_level.txt +0 -0
airia/client/__init__.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
1
|
+
from typing import Any, AsyncIterator, Dict, List, Literal, Optional, overload
|
|
2
|
+
from urllib.parse import urljoin
|
|
3
|
+
|
|
4
|
+
import aiohttp
|
|
5
|
+
import loguru
|
|
6
|
+
|
|
7
|
+
from ..exceptions import AiriaAPIError
|
|
8
|
+
from ..types import (
|
|
9
|
+
ApiVersion,
|
|
10
|
+
PipelineExecutionDebugResponse,
|
|
11
|
+
PipelineExecutionResponse,
|
|
12
|
+
PipelineExecutionV1StreamedResponse,
|
|
13
|
+
PipelineExecutionV2AsyncStreamedResponse,
|
|
14
|
+
RequestData,
|
|
15
|
+
)
|
|
16
|
+
from .base_client import AiriaBaseClient
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AiriaAsyncClient(AiriaBaseClient):
|
|
20
|
+
"""Asynchronous client for interacting with the Airia API."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
api_key: Optional[str] = None,
|
|
25
|
+
timeout: float = 30.0,
|
|
26
|
+
log_requests: bool = False,
|
|
27
|
+
custom_logger: Optional["loguru.Logger"] = None,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Initialize the asynchronous Airia API client.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
34
|
+
timeout: Request timeout in seconds.
|
|
35
|
+
log_requests: Whether to log API requests and responses. Default is False.
|
|
36
|
+
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(api_key, timeout, log_requests, custom_logger)
|
|
39
|
+
|
|
40
|
+
# Session will be initialized in __aenter__
|
|
41
|
+
self.session = None
|
|
42
|
+
self.headers = {"Content-Type": "application/json"}
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def with_openai_gateway(
|
|
46
|
+
cls,
|
|
47
|
+
api_key: Optional[str] = None,
|
|
48
|
+
timeout: float = 30.0,
|
|
49
|
+
log_requests: bool = False,
|
|
50
|
+
custom_logger: Optional["loguru.Logger"] = None,
|
|
51
|
+
**kwargs,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
Initialize the asynchronous Airia API client with AsyncOpenAI gateway capabilities.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
58
|
+
timeout: Request timeout in seconds.
|
|
59
|
+
log_requests: Whether to log API requests and responses. Default is False.
|
|
60
|
+
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
61
|
+
**kwargs: Additional keyword arguments to pass to the AsyncOpenAI client initialization.
|
|
62
|
+
"""
|
|
63
|
+
from openai import AsyncOpenAI
|
|
64
|
+
|
|
65
|
+
api_key = cls._get_api_key(api_key)
|
|
66
|
+
cls.openai = AsyncOpenAI(
|
|
67
|
+
api_key=api_key,
|
|
68
|
+
base_url="https://gateway.airia.ai/openai/v1",
|
|
69
|
+
**kwargs,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return cls(api_key, timeout, log_requests, custom_logger)
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def with_anthropic_gateway(
|
|
76
|
+
cls,
|
|
77
|
+
api_key: Optional[str] = None,
|
|
78
|
+
timeout: float = 30.0,
|
|
79
|
+
log_requests: bool = False,
|
|
80
|
+
custom_logger: Optional["loguru.Logger"] = None,
|
|
81
|
+
**kwargs,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Initialize the asynchronous Airia API client with AsyncAnthropic gateway capabilities.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
88
|
+
timeout: Request timeout in seconds.
|
|
89
|
+
log_requests: Whether to log API requests and responses. Default is False.
|
|
90
|
+
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
91
|
+
**kwargs: Additional keyword arguments to pass to the AsyncAnthropic client initialization.
|
|
92
|
+
"""
|
|
93
|
+
from anthropic import AsyncAnthropic
|
|
94
|
+
|
|
95
|
+
api_key = cls._get_api_key(api_key)
|
|
96
|
+
cls.anthropic = AsyncAnthropic(
|
|
97
|
+
api_key=api_key,
|
|
98
|
+
base_url="https://gateway.airia.ai/anthropic",
|
|
99
|
+
**kwargs,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return cls(api_key, timeout, log_requests, custom_logger)
|
|
103
|
+
|
|
104
|
+
async def __aenter__(self):
|
|
105
|
+
"""Async context manager entry point."""
|
|
106
|
+
self.session = aiohttp.ClientSession(headers=self.headers)
|
|
107
|
+
return self
|
|
108
|
+
|
|
109
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
110
|
+
"""Async context manager exit point."""
|
|
111
|
+
if self.session:
|
|
112
|
+
await self.session.close()
|
|
113
|
+
self.session = None
|
|
114
|
+
|
|
115
|
+
def _check_session(self):
|
|
116
|
+
"""Check if the client session is initialized."""
|
|
117
|
+
if not self.session:
|
|
118
|
+
raise RuntimeError(
|
|
119
|
+
"Client session not initialized. Use async with AiriaAsyncClient() as client: ..."
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def _handle_exception(
|
|
123
|
+
self, e: aiohttp.ClientResponseError, url: str, correlation_id: str
|
|
124
|
+
):
|
|
125
|
+
# Log the error response if enabled
|
|
126
|
+
if self.log_requests:
|
|
127
|
+
self.logger.error(
|
|
128
|
+
f"API Error: {e.status} {e.message}\n"
|
|
129
|
+
f"URL: {url}\n"
|
|
130
|
+
f"Correlation ID: {correlation_id}"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Extract error details from response
|
|
134
|
+
error_message = e.message
|
|
135
|
+
|
|
136
|
+
# Make sure API key is not included in error messages
|
|
137
|
+
sanitized_message = (
|
|
138
|
+
error_message.replace(self.api_key, "[REDACTED]")
|
|
139
|
+
if self.api_key in error_message
|
|
140
|
+
else error_message
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Raise custom exception with status code and sanitized message
|
|
144
|
+
raise AiriaAPIError(status_code=e.status, message=sanitized_message) from e
|
|
145
|
+
|
|
146
|
+
async def _make_request(
|
|
147
|
+
self, method: str, request_data: RequestData
|
|
148
|
+
) -> Dict[str, Any]:
|
|
149
|
+
"""
|
|
150
|
+
Makes an asynchronous HTTP request to the Airia API.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
method (str): The HTTP method (e.g., 'GET', 'POST')
|
|
154
|
+
request_data: A dictionary containing the following request information:
|
|
155
|
+
- url: The endpoint URL for the request
|
|
156
|
+
- headers: HTTP headers to include in the request
|
|
157
|
+
- payload: The JSON payload/body for the request
|
|
158
|
+
- correlation_id: Unique identifier for request tracing
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
resp ([Dict[str, Any]): The JSON response from the API as a dictionary.
|
|
162
|
+
|
|
163
|
+
Raises:
|
|
164
|
+
AiriaAPIError: If the API returns an error response, with details about the error
|
|
165
|
+
aiohttp.ClientResponseError: For HTTP-related errors
|
|
166
|
+
|
|
167
|
+
Note:
|
|
168
|
+
This is an internal method used by other client methods to make API requests.
|
|
169
|
+
It handles logging, error handling, and API key redaction in error messages.
|
|
170
|
+
"""
|
|
171
|
+
try:
|
|
172
|
+
# Make the request
|
|
173
|
+
async with self.session.request(
|
|
174
|
+
method=method,
|
|
175
|
+
url=request_data.url,
|
|
176
|
+
json=request_data.payload,
|
|
177
|
+
headers=request_data.headers,
|
|
178
|
+
timeout=self.timeout,
|
|
179
|
+
) as response:
|
|
180
|
+
# Log the response if enabled
|
|
181
|
+
if self.log_requests:
|
|
182
|
+
self.logger.info(
|
|
183
|
+
f"API Response: {response.status} {response.reason}\n"
|
|
184
|
+
f"URL: {request_data.url}\n"
|
|
185
|
+
f"Correlation ID: {request_data.correlation_id}"
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Check for HTTP errors
|
|
189
|
+
response.raise_for_status()
|
|
190
|
+
|
|
191
|
+
# Return the response as a dictionary
|
|
192
|
+
return await response.json()
|
|
193
|
+
|
|
194
|
+
except aiohttp.ClientResponseError as e:
|
|
195
|
+
self._handle_exception(e, request_data.url, request_data.correlation_id)
|
|
196
|
+
|
|
197
|
+
async def _make_request_stream(
|
|
198
|
+
self, method: str, request_data: RequestData
|
|
199
|
+
) -> AsyncIterator[str]:
|
|
200
|
+
"""
|
|
201
|
+
Makes an asynchronous HTTP request to the Airia API.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
method (str): The HTTP method (e.g., 'GET', 'POST')
|
|
205
|
+
request_data: A dictionary containing the following request information:
|
|
206
|
+
- url: The endpoint URL for the request
|
|
207
|
+
- headers: HTTP headers to include in the request
|
|
208
|
+
- payload: The JSON payload/body for the request
|
|
209
|
+
- correlation_id: Unique identifier for request tracing
|
|
210
|
+
|
|
211
|
+
Yields:
|
|
212
|
+
resp AsyncIterator[str]]: yields chunks of the response as they are received.
|
|
213
|
+
|
|
214
|
+
Raises:
|
|
215
|
+
AiriaAPIError: If the API returns an error response, with details about the error
|
|
216
|
+
aiohttp.ClientResponseError: For HTTP-related errors
|
|
217
|
+
|
|
218
|
+
Note:
|
|
219
|
+
This is an internal method used by other client methods to make API requests.
|
|
220
|
+
It handles logging, error handling, and API key redaction in error messages.
|
|
221
|
+
"""
|
|
222
|
+
try:
|
|
223
|
+
# Make the request
|
|
224
|
+
async with self.session.request(
|
|
225
|
+
method=method,
|
|
226
|
+
url=request_data.url,
|
|
227
|
+
json=request_data.payload,
|
|
228
|
+
headers=request_data.headers,
|
|
229
|
+
timeout=self.timeout,
|
|
230
|
+
chunked=True,
|
|
231
|
+
) as response:
|
|
232
|
+
# Log the response if enabled
|
|
233
|
+
if self.log_requests:
|
|
234
|
+
self.logger.info(
|
|
235
|
+
f"API Response: {response.status} {response.reason}\n"
|
|
236
|
+
f"URL: {request_data.url}\n"
|
|
237
|
+
f"Correlation ID: {request_data.correlation_id}"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Check for HTTP errors
|
|
241
|
+
response.raise_for_status()
|
|
242
|
+
|
|
243
|
+
# Yields the response content as a stream if streaming
|
|
244
|
+
async for chunk in response.content.iter_any():
|
|
245
|
+
yield chunk.decode("utf-8")
|
|
246
|
+
|
|
247
|
+
except aiohttp.ClientResponseError as e:
|
|
248
|
+
self._handle_exception(e, request_data.url, request_data.correlation_id)
|
|
249
|
+
|
|
250
|
+
@overload
|
|
251
|
+
async def execute_pipeline(
|
|
252
|
+
self,
|
|
253
|
+
pipeline_id: str,
|
|
254
|
+
user_input: str,
|
|
255
|
+
debug: Literal[False] = False,
|
|
256
|
+
user_id: Optional[str] = None,
|
|
257
|
+
conversation_id: Optional[str] = None,
|
|
258
|
+
async_output: Literal[False] = False,
|
|
259
|
+
include_tools_response: bool = False,
|
|
260
|
+
images: Optional[List[str]] = None,
|
|
261
|
+
files: Optional[List[str]] = None,
|
|
262
|
+
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
263
|
+
data_source_files: Optional[Dict[str, Any]] = None,
|
|
264
|
+
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
265
|
+
current_date_time: Optional[str] = None,
|
|
266
|
+
save_history: bool = True,
|
|
267
|
+
additional_info: Optional[List[Any]] = None,
|
|
268
|
+
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
269
|
+
correlation_id: Optional[str] = None,
|
|
270
|
+
api_version: str = ApiVersion.V2.value,
|
|
271
|
+
) -> PipelineExecutionResponse: ...
|
|
272
|
+
|
|
273
|
+
@overload
|
|
274
|
+
async def execute_pipeline(
|
|
275
|
+
self,
|
|
276
|
+
pipeline_id: str,
|
|
277
|
+
user_input: str,
|
|
278
|
+
debug: Literal[True] = True,
|
|
279
|
+
user_id: Optional[str] = None,
|
|
280
|
+
conversation_id: Optional[str] = None,
|
|
281
|
+
async_output: Literal[False] = False,
|
|
282
|
+
include_tools_response: bool = False,
|
|
283
|
+
images: Optional[List[str]] = None,
|
|
284
|
+
files: Optional[List[str]] = None,
|
|
285
|
+
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
286
|
+
data_source_files: Optional[Dict[str, Any]] = None,
|
|
287
|
+
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
288
|
+
current_date_time: Optional[str] = None,
|
|
289
|
+
save_history: bool = True,
|
|
290
|
+
additional_info: Optional[List[Any]] = None,
|
|
291
|
+
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
292
|
+
correlation_id: Optional[str] = None,
|
|
293
|
+
api_version: str = ApiVersion.V2.value,
|
|
294
|
+
) -> PipelineExecutionDebugResponse: ...
|
|
295
|
+
|
|
296
|
+
@overload
|
|
297
|
+
async def execute_pipeline(
|
|
298
|
+
self,
|
|
299
|
+
pipeline_id: str,
|
|
300
|
+
user_input: str,
|
|
301
|
+
debug: bool = False,
|
|
302
|
+
user_id: Optional[str] = None,
|
|
303
|
+
conversation_id: Optional[str] = None,
|
|
304
|
+
async_output: Literal[True] = True,
|
|
305
|
+
include_tools_response: bool = False,
|
|
306
|
+
images: Optional[List[str]] = None,
|
|
307
|
+
files: Optional[List[str]] = None,
|
|
308
|
+
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
309
|
+
data_source_files: Optional[Dict[str, Any]] = None,
|
|
310
|
+
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
311
|
+
current_date_time: Optional[str] = None,
|
|
312
|
+
save_history: bool = True,
|
|
313
|
+
additional_info: Optional[List[Any]] = None,
|
|
314
|
+
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
315
|
+
correlation_id: Optional[str] = None,
|
|
316
|
+
api_version: Literal["v2"] = ApiVersion.V2.value,
|
|
317
|
+
) -> PipelineExecutionV2AsyncStreamedResponse: ...
|
|
318
|
+
|
|
319
|
+
@overload
|
|
320
|
+
async def execute_pipeline(
|
|
321
|
+
self,
|
|
322
|
+
pipeline_id: str,
|
|
323
|
+
user_input: str,
|
|
324
|
+
debug: bool = False,
|
|
325
|
+
user_id: Optional[str] = None,
|
|
326
|
+
conversation_id: Optional[str] = None,
|
|
327
|
+
async_output: Literal[True] = True,
|
|
328
|
+
include_tools_response: bool = False,
|
|
329
|
+
images: Optional[List[str]] = None,
|
|
330
|
+
files: Optional[List[str]] = None,
|
|
331
|
+
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
332
|
+
data_source_files: Optional[Dict[str, Any]] = None,
|
|
333
|
+
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
334
|
+
current_date_time: Optional[str] = None,
|
|
335
|
+
save_history: bool = True,
|
|
336
|
+
additional_info: Optional[List[Any]] = None,
|
|
337
|
+
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
338
|
+
correlation_id: Optional[str] = None,
|
|
339
|
+
api_version: Literal["v1"] = ApiVersion.V1.value,
|
|
340
|
+
) -> PipelineExecutionV1StreamedResponse: ...
|
|
341
|
+
|
|
342
|
+
async def execute_pipeline(
|
|
343
|
+
self,
|
|
344
|
+
pipeline_id: str,
|
|
345
|
+
user_input: str,
|
|
346
|
+
debug: bool = False,
|
|
347
|
+
user_id: Optional[str] = None,
|
|
348
|
+
conversation_id: Optional[str] = None,
|
|
349
|
+
async_output: bool = False,
|
|
350
|
+
include_tools_response: bool = False,
|
|
351
|
+
images: Optional[List[str]] = None,
|
|
352
|
+
files: Optional[List[str]] = None,
|
|
353
|
+
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
354
|
+
data_source_files: Optional[Dict[str, Any]] = None,
|
|
355
|
+
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
356
|
+
current_date_time: Optional[str] = None,
|
|
357
|
+
save_history: bool = True,
|
|
358
|
+
additional_info: Optional[List[Any]] = None,
|
|
359
|
+
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
360
|
+
correlation_id: Optional[str] = None,
|
|
361
|
+
api_version: str = ApiVersion.V2.value,
|
|
362
|
+
) -> Dict[str, Any]:
|
|
363
|
+
"""
|
|
364
|
+
Execute a pipeline with the provided input asynchronously.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
pipeline_id: The ID of the pipeline to execute.
|
|
368
|
+
user_input: input text to process.
|
|
369
|
+
debug: Whether debug mode execution is enabled. Default is False.
|
|
370
|
+
user_id: Optional ID of the user making the request (guid).
|
|
371
|
+
conversation_id: Optional conversation ID (guid).
|
|
372
|
+
async_output: Whether to stream the response. Default is False.
|
|
373
|
+
include_tools_response: Whether to return the initial LLM tool result. Default is False.
|
|
374
|
+
images: Optional list of images formatted as base64 strings.
|
|
375
|
+
files: Optional list of files formatted as base64 strings.
|
|
376
|
+
data_source_folders: Optional data source folders information.
|
|
377
|
+
data_source_files: Optional data source files information.
|
|
378
|
+
in_memory_messages: Optional list of in-memory messages, each with a role and message.
|
|
379
|
+
current_date_time: Optional current date and time in ISO format.
|
|
380
|
+
save_history: Whether to save the userInput and output to conversation history. Default is True.
|
|
381
|
+
additional_info: Optional additional information.
|
|
382
|
+
prompt_variables: Optional variables to be used in the prompt.
|
|
383
|
+
correlation_id: Optional correlation ID for request tracing. If not provided,
|
|
384
|
+
one will be generated automatically.
|
|
385
|
+
api_version: API version to use. Default is `v2`
|
|
386
|
+
|
|
387
|
+
Returns:
|
|
388
|
+
The API response as a dictionary.
|
|
389
|
+
|
|
390
|
+
Raises:
|
|
391
|
+
AiriaAPIError: If the API request fails with details about the error.
|
|
392
|
+
aiohttp.ClientError: For other request-related errors.
|
|
393
|
+
|
|
394
|
+
Example:
|
|
395
|
+
>>> async with AiriaAsyncClient(api_key="your_api_key") as client:
|
|
396
|
+
... response = await client.execute_pipeline(
|
|
397
|
+
... pipeline_id="pipeline_123",
|
|
398
|
+
... user_input="Tell me about quantum computing"
|
|
399
|
+
... )
|
|
400
|
+
>>> print(response.result)
|
|
401
|
+
"""
|
|
402
|
+
self._check_session()
|
|
403
|
+
|
|
404
|
+
request_data = self._pre_execute_pipeline(
|
|
405
|
+
pipeline_id=pipeline_id,
|
|
406
|
+
user_input=user_input,
|
|
407
|
+
debug=debug,
|
|
408
|
+
user_id=user_id,
|
|
409
|
+
conversation_id=conversation_id,
|
|
410
|
+
async_output=async_output,
|
|
411
|
+
include_tools_response=include_tools_response,
|
|
412
|
+
images=images,
|
|
413
|
+
files=files,
|
|
414
|
+
data_source_folders=data_source_folders,
|
|
415
|
+
data_source_files=data_source_files,
|
|
416
|
+
in_memory_messages=in_memory_messages,
|
|
417
|
+
current_date_time=current_date_time,
|
|
418
|
+
save_history=save_history,
|
|
419
|
+
additional_info=additional_info,
|
|
420
|
+
prompt_variables=prompt_variables,
|
|
421
|
+
correlation_id=correlation_id,
|
|
422
|
+
api_version=api_version,
|
|
423
|
+
)
|
|
424
|
+
stream = async_output and api_version == ApiVersion.V2.value
|
|
425
|
+
if stream:
|
|
426
|
+
resp = self._make_request_stream(method="POST", request_data=request_data)
|
|
427
|
+
else:
|
|
428
|
+
resp = await self._make_request("POST", request_data)
|
|
429
|
+
|
|
430
|
+
if not async_output:
|
|
431
|
+
if not debug:
|
|
432
|
+
return PipelineExecutionResponse(**resp)
|
|
433
|
+
return PipelineExecutionDebugResponse(**resp)
|
|
434
|
+
|
|
435
|
+
if api_version == ApiVersion.V1.value:
|
|
436
|
+
url = urljoin(
|
|
437
|
+
self.base_url, f"{api_version}/StreamSocketConfig/GenerateUrl"
|
|
438
|
+
)
|
|
439
|
+
request_data = self._prepare_request(
|
|
440
|
+
url,
|
|
441
|
+
{"socketIdentifier": resp},
|
|
442
|
+
request_data.headers["X-Correlation-ID"],
|
|
443
|
+
)
|
|
444
|
+
resp = await self._make_request("POST", request_data)
|
|
445
|
+
|
|
446
|
+
return PipelineExecutionV1StreamedResponse(**resp)
|
|
447
|
+
|
|
448
|
+
return PipelineExecutionV2AsyncStreamedResponse(stream=resp)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any, Dict, List, Optional
|
|
4
|
+
from urllib.parse import urljoin
|
|
5
|
+
|
|
6
|
+
import loguru
|
|
7
|
+
|
|
8
|
+
from ..logs import configure_logging, set_correlation_id
|
|
9
|
+
from ..types import ApiVersion, RequestData
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AiriaBaseClient:
|
|
13
|
+
"""Base client containing shared functionality for Airia API clients."""
|
|
14
|
+
openai = None
|
|
15
|
+
anthropic = None
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
api_key: Optional[str] = None,
|
|
20
|
+
timeout: float = 30.0,
|
|
21
|
+
log_requests: bool = False,
|
|
22
|
+
custom_logger: Optional["loguru.Logger"] = None,
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
Initialize the Airia API client base class.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
29
|
+
timeout: Request timeout in seconds.
|
|
30
|
+
log_requests: Whether to log API requests and responses. Default is False.
|
|
31
|
+
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
32
|
+
"""
|
|
33
|
+
# Resolve API key: parameter takes precedence over environment variable
|
|
34
|
+
self.api_key = self.__class__._get_api_key(api_key)
|
|
35
|
+
|
|
36
|
+
# Store configuration
|
|
37
|
+
self.base_url = "https://api.airia.ai/"
|
|
38
|
+
self.timeout = timeout
|
|
39
|
+
self.log_requests = log_requests
|
|
40
|
+
|
|
41
|
+
# Initialize logger
|
|
42
|
+
self.logger = configure_logging() if custom_logger is None else custom_logger
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def _get_api_key(api_key: Optional[str] = None):
|
|
46
|
+
"""
|
|
47
|
+
Get the API key from either the provided parameter or environment variable.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
api_key (Optional[str]): The API key provided as a parameter. Defaults to None.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
str: The resolved API key.
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValueError: If no API key is provided through either method.
|
|
57
|
+
"""
|
|
58
|
+
api_key = api_key or os.environ.get("AIRIA_API_KEY")
|
|
59
|
+
|
|
60
|
+
if not api_key:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
"API key must be provided either as a parameter or through the AIRIA_API_KEY environment variable."
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
return api_key
|
|
66
|
+
|
|
67
|
+
def _prepare_request(
|
|
68
|
+
self,
|
|
69
|
+
url: str,
|
|
70
|
+
payload: Optional[Dict[str, Any]] = None,
|
|
71
|
+
correlation_id: Optional[str] = None,
|
|
72
|
+
):
|
|
73
|
+
"""
|
|
74
|
+
Prepare the request parameters for an API call.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
url (str): The endpoint URL for the API request.
|
|
78
|
+
payload (Optional[Dict[str, Any]]): The request payload/body to be sent.
|
|
79
|
+
correlation_id (Optional[str]): A unique identifier for tracing the request. If None, one will be generated.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
dict: A dictionary containing the prepared request parameters with the following keys:
|
|
83
|
+
- url: The request URL
|
|
84
|
+
- payload: The request payload
|
|
85
|
+
- headers: Request headers including API key and correlation ID
|
|
86
|
+
- correlation_id: The correlation ID used for the request
|
|
87
|
+
|
|
88
|
+
Note:
|
|
89
|
+
This method handles:
|
|
90
|
+
- Setting/generating correlation IDs
|
|
91
|
+
- Adding authentication headers
|
|
92
|
+
- Logging requests (if enabled) with sensitive information redacted
|
|
93
|
+
"""
|
|
94
|
+
# Set correlation ID if provided or generate a new one
|
|
95
|
+
correlation_id = set_correlation_id(correlation_id)
|
|
96
|
+
|
|
97
|
+
# Add the X-API-KEY header and correlation ID
|
|
98
|
+
headers = {
|
|
99
|
+
"X-API-KEY": self.api_key,
|
|
100
|
+
"X-Correlation-ID": correlation_id,
|
|
101
|
+
"Content-Type": "application/json",
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
# Log the request if enabled
|
|
105
|
+
if self.log_requests:
|
|
106
|
+
# Create a sanitized copy of headers for logging
|
|
107
|
+
log_headers = headers.copy()
|
|
108
|
+
|
|
109
|
+
# Filter out sensitive headers
|
|
110
|
+
if "X-API-KEY" in log_headers:
|
|
111
|
+
log_headers["X-API-KEY"] = "[REDACTED]"
|
|
112
|
+
|
|
113
|
+
# Process payload for logging
|
|
114
|
+
log_payload = None
|
|
115
|
+
if payload is not None:
|
|
116
|
+
log_payload = payload.copy()
|
|
117
|
+
|
|
118
|
+
if "images" in log_payload and log_payload["images"] is not None:
|
|
119
|
+
log_payload["images"] = f"{len(log_payload['images'])} images"
|
|
120
|
+
if "files" in log_payload and log_payload["files"] is not None:
|
|
121
|
+
log_payload["files"] = f"{len(log_payload['files'])} files"
|
|
122
|
+
|
|
123
|
+
log_payload = json.dumps(log_payload)
|
|
124
|
+
|
|
125
|
+
self.logger.info(
|
|
126
|
+
f"API Request: POST {url}\n"
|
|
127
|
+
f"Headers: {json.dumps(log_headers)}\n"
|
|
128
|
+
f"Payload: {log_payload}"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return RequestData(
|
|
132
|
+
**{
|
|
133
|
+
"url": url,
|
|
134
|
+
"payload": payload,
|
|
135
|
+
"headers": headers,
|
|
136
|
+
"correlation_id": correlation_id,
|
|
137
|
+
}
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def _pre_execute_pipeline(
|
|
141
|
+
self,
|
|
142
|
+
pipeline_id: str,
|
|
143
|
+
user_input: str,
|
|
144
|
+
debug: bool = False,
|
|
145
|
+
user_id: Optional[str] = None,
|
|
146
|
+
conversation_id: Optional[str] = None,
|
|
147
|
+
async_output: bool = False,
|
|
148
|
+
include_tools_response: bool = False,
|
|
149
|
+
images: Optional[List[str]] = None,
|
|
150
|
+
files: Optional[List[str]] = None,
|
|
151
|
+
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
152
|
+
data_source_files: Optional[Dict[str, Any]] = None,
|
|
153
|
+
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
154
|
+
current_date_time: Optional[str] = None,
|
|
155
|
+
save_history: bool = True,
|
|
156
|
+
additional_info: Optional[List[Any]] = None,
|
|
157
|
+
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
158
|
+
correlation_id: Optional[str] = None,
|
|
159
|
+
api_version: str = ApiVersion.V2.value,
|
|
160
|
+
):
|
|
161
|
+
if api_version not in ApiVersion.as_list():
|
|
162
|
+
raise ValueError(
|
|
163
|
+
f"Invalid API version: {api_version}. Valid versions are: {', '.join(ApiVersion.as_list())}"
|
|
164
|
+
)
|
|
165
|
+
url = urljoin(self.base_url, f"{api_version}/PipelineExecution/{pipeline_id}")
|
|
166
|
+
|
|
167
|
+
payload = {
|
|
168
|
+
"userInput": user_input,
|
|
169
|
+
"debug": debug,
|
|
170
|
+
"userId": user_id,
|
|
171
|
+
"conversationId": conversation_id,
|
|
172
|
+
"asyncOutput": async_output,
|
|
173
|
+
"includeToolsResponse": include_tools_response,
|
|
174
|
+
"images": images,
|
|
175
|
+
"files": files,
|
|
176
|
+
"dataSourceFolders": data_source_folders,
|
|
177
|
+
"dataSourceFiles": data_source_files,
|
|
178
|
+
"inMemoryMessages": in_memory_messages,
|
|
179
|
+
"currentDateTime": current_date_time,
|
|
180
|
+
"saveHistory": save_history,
|
|
181
|
+
"additionalInfo": additional_info,
|
|
182
|
+
"promptVariables": prompt_variables,
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
request_data = self._prepare_request(url, payload, correlation_id)
|
|
186
|
+
|
|
187
|
+
return request_data
|
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Literal, Optional, overload
|
|
2
|
+
from urllib.parse import urljoin
|
|
3
|
+
|
|
4
|
+
import loguru
|
|
5
|
+
import requests
|
|
6
|
+
|
|
7
|
+
from ..exceptions import AiriaAPIError
|
|
8
|
+
from ..types import (
|
|
9
|
+
ApiVersion,
|
|
10
|
+
PipelineExecutionDebugResponse,
|
|
11
|
+
PipelineExecutionResponse,
|
|
12
|
+
PipelineExecutionV1StreamedResponse,
|
|
13
|
+
PipelineExecutionV2StreamedResponse,
|
|
14
|
+
RequestData,
|
|
15
|
+
)
|
|
16
|
+
from .base_client import AiriaBaseClient
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AiriaClient(AiriaBaseClient):
|
|
20
|
+
"""Synchronous client for interacting with the Airia API."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
api_key: Optional[str] = None,
|
|
25
|
+
timeout: float = 30.0,
|
|
26
|
+
log_requests: bool = False,
|
|
27
|
+
custom_logger: Optional["loguru.Logger"] = None,
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Initialize the synchronous Airia API client.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
34
|
+
timeout: Request timeout in seconds.
|
|
35
|
+
log_requests: Whether to log API requests and responses. Default is False.
|
|
36
|
+
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(api_key, timeout, log_requests, custom_logger)
|
|
39
|
+
|
|
40
|
+
# Initialize session for synchronous requests
|
|
41
|
+
self.session = requests.Session()
|
|
42
|
+
self.session.headers.update({"Content-Type": "application/json"})
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def with_openai_gateway(
|
|
46
|
+
cls,
|
|
47
|
+
api_key: Optional[str] = None,
|
|
48
|
+
timeout: float = 30.0,
|
|
49
|
+
log_requests: bool = False,
|
|
50
|
+
custom_logger: Optional["loguru.Logger"] = None,
|
|
51
|
+
**kwargs,
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
Initialize the synchronous Airia API client with OpenAI gateway capabilities.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
58
|
+
timeout: Request timeout in seconds.
|
|
59
|
+
log_requests: Whether to log API requests and responses. Default is False.
|
|
60
|
+
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
61
|
+
**kwargs: Additional keyword arguments to pass to the OpenAI client initialization.
|
|
62
|
+
"""
|
|
63
|
+
from openai import OpenAI
|
|
64
|
+
|
|
65
|
+
api_key = cls._get_api_key(api_key)
|
|
66
|
+
cls.openai = OpenAI(
|
|
67
|
+
api_key=api_key,
|
|
68
|
+
base_url="https://gateway.airia.ai/openai/v1",
|
|
69
|
+
**kwargs,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return cls(api_key, timeout, log_requests, custom_logger)
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
def with_anthropic_gateway(
|
|
76
|
+
cls,
|
|
77
|
+
api_key: Optional[str] = None,
|
|
78
|
+
timeout: float = 30.0,
|
|
79
|
+
log_requests: bool = False,
|
|
80
|
+
custom_logger: Optional["loguru.Logger"] = None,
|
|
81
|
+
**kwargs,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Initialize the synchronous Airia API client with Anthropic gateway capabilities.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
88
|
+
timeout: Request timeout in seconds.
|
|
89
|
+
log_requests: Whether to log API requests and responses. Default is False.
|
|
90
|
+
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
91
|
+
**kwargs: Additional keyword arguments to pass to the Anthropic client initialization.
|
|
92
|
+
"""
|
|
93
|
+
from anthropic import Anthropic
|
|
94
|
+
|
|
95
|
+
api_key = cls._get_api_key(api_key)
|
|
96
|
+
cls.anthropic = Anthropic(
|
|
97
|
+
api_key=api_key,
|
|
98
|
+
base_url="https://gateway.airia.ai/anthropic",
|
|
99
|
+
**kwargs,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return cls(api_key, timeout, log_requests, custom_logger)
|
|
103
|
+
|
|
104
|
+
def _handle_exception(self, e: requests.HTTPError, url: str, correlation_id: str):
|
|
105
|
+
# Log the error response if enabled
|
|
106
|
+
if self.log_requests:
|
|
107
|
+
self.logger.error(
|
|
108
|
+
f"API Error: {e.response.status_code} {e.response.reason}\n"
|
|
109
|
+
f"URL: {url}\n"
|
|
110
|
+
f"Correlation ID: {correlation_id}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Extract error details from response if possible
|
|
114
|
+
error_message = "API request failed"
|
|
115
|
+
try:
|
|
116
|
+
error_data = e.response.json()
|
|
117
|
+
if isinstance(error_data, dict) and "message" in error_data:
|
|
118
|
+
error_message = error_data["message"]
|
|
119
|
+
elif isinstance(error_data, dict) and "error" in error_data:
|
|
120
|
+
error_message = error_data["error"]
|
|
121
|
+
except (ValueError, KeyError):
|
|
122
|
+
# If JSON parsing fails or expected keys are missing
|
|
123
|
+
error_message = f"API request failed: {str(e)}"
|
|
124
|
+
|
|
125
|
+
# Make sure API key is not included in error messages
|
|
126
|
+
sanitized_message = (
|
|
127
|
+
error_message.replace(self.api_key, "[REDACTED]")
|
|
128
|
+
if self.api_key in error_message
|
|
129
|
+
else error_message
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Raise custom exception with status code and sanitized message
|
|
133
|
+
raise AiriaAPIError(
|
|
134
|
+
status_code=e.response.status_code, message=sanitized_message
|
|
135
|
+
) from e
|
|
136
|
+
|
|
137
|
+
def _make_request(self, method: str, request_data: RequestData):
|
|
138
|
+
"""
|
|
139
|
+
Makes a synchronous HTTP request to the Airia API.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
method (str): The HTTP method (e.g., 'GET', 'POST')
|
|
143
|
+
request_data: A dictionary containing the following request information:
|
|
144
|
+
- url: The endpoint URL for the request
|
|
145
|
+
- headers: HTTP headers to include in the request
|
|
146
|
+
- payload: The JSON payload/body for the request
|
|
147
|
+
- correlation_id: Unique identifier for request tracing
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
resp (Dict[str, Any]): The JSON response from the API as a dictionary.
|
|
151
|
+
|
|
152
|
+
Raises:
|
|
153
|
+
AiriaAPIError: If the API returns an error response, with details about the error
|
|
154
|
+
requests.HTTPError: For HTTP-related errors
|
|
155
|
+
|
|
156
|
+
Note:
|
|
157
|
+
This is an internal method used by other client methods to make API requests.
|
|
158
|
+
It handles logging, error handling, and API key redaction in error messages.
|
|
159
|
+
"""
|
|
160
|
+
try:
|
|
161
|
+
# Make the request
|
|
162
|
+
response = self.session.request(
|
|
163
|
+
method=method,
|
|
164
|
+
url=request_data.url,
|
|
165
|
+
json=request_data.payload,
|
|
166
|
+
headers=request_data.headers,
|
|
167
|
+
timeout=self.timeout,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Log the response if enabled
|
|
171
|
+
if self.log_requests:
|
|
172
|
+
self.logger.info(
|
|
173
|
+
f"API Response: {response.status_code} {response.reason}\n"
|
|
174
|
+
f"URL: {request_data.url}\n"
|
|
175
|
+
f"Correlation ID: {request_data.correlation_id}\n"
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Check for HTTP errors
|
|
179
|
+
response.raise_for_status()
|
|
180
|
+
|
|
181
|
+
# Returns the JSON response
|
|
182
|
+
return response.json()
|
|
183
|
+
|
|
184
|
+
except requests.HTTPError as e:
|
|
185
|
+
self._handle_exception(e, request_data.url, request_data.correlation_id)
|
|
186
|
+
|
|
187
|
+
def _make_request_stream(self, method: str, request_data: RequestData):
|
|
188
|
+
"""
|
|
189
|
+
Makes a synchronous HTTP request to the Airia API.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
method (str): The HTTP method (e.g., 'GET', 'POST')
|
|
193
|
+
request_data: A dictionary containing the following request information:
|
|
194
|
+
- url: The endpoint URL for the request
|
|
195
|
+
- headers: HTTP headers to include in the request
|
|
196
|
+
- payload: The JSON payload/body for the request
|
|
197
|
+
- correlation_id: Unique identifier for request tracing
|
|
198
|
+
stream (bool): If True, the response will be streamed instead of downloaded all at once
|
|
199
|
+
|
|
200
|
+
Yields:
|
|
201
|
+
resp (Iterator[str]): Yields chunks of the response as they are received.
|
|
202
|
+
|
|
203
|
+
Raises:
|
|
204
|
+
AiriaAPIError: If the API returns an error response, with details about the error
|
|
205
|
+
requests.HTTPError: For HTTP-related errors
|
|
206
|
+
|
|
207
|
+
Note:
|
|
208
|
+
This is an internal method used by other client methods to make API requests.
|
|
209
|
+
It handles logging, error handling, and API key redaction in error messages.
|
|
210
|
+
"""
|
|
211
|
+
try:
|
|
212
|
+
# Make the request
|
|
213
|
+
response = self.session.request(
|
|
214
|
+
method=method,
|
|
215
|
+
url=request_data.url,
|
|
216
|
+
json=request_data.payload,
|
|
217
|
+
headers=request_data.headers,
|
|
218
|
+
timeout=self.timeout,
|
|
219
|
+
stream=True,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Log the response if enabled
|
|
223
|
+
if self.log_requests:
|
|
224
|
+
self.logger.info(
|
|
225
|
+
f"API Response: {response.status_code} {response.reason}\n"
|
|
226
|
+
f"URL: {request_data.url}\n"
|
|
227
|
+
f"Correlation ID: {request_data.correlation_id}\n"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
# Check for HTTP errors
|
|
231
|
+
response.raise_for_status()
|
|
232
|
+
|
|
233
|
+
# Yields the response content as a stream
|
|
234
|
+
for chunk in response.iter_content():
|
|
235
|
+
yield chunk.decode("utf-8")
|
|
236
|
+
|
|
237
|
+
except requests.HTTPError as e:
|
|
238
|
+
self._handle_exception(e, request_data.url, request_data.correlation_id)
|
|
239
|
+
|
|
240
|
+
@overload
|
|
241
|
+
def execute_pipeline(
|
|
242
|
+
self,
|
|
243
|
+
pipeline_id: str,
|
|
244
|
+
user_input: str,
|
|
245
|
+
debug: Literal[False] = False,
|
|
246
|
+
user_id: Optional[str] = None,
|
|
247
|
+
conversation_id: Optional[str] = None,
|
|
248
|
+
async_output: Literal[False] = False,
|
|
249
|
+
include_tools_response: bool = False,
|
|
250
|
+
images: Optional[List[str]] = None,
|
|
251
|
+
files: Optional[List[str]] = None,
|
|
252
|
+
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
253
|
+
data_source_files: Optional[Dict[str, Any]] = None,
|
|
254
|
+
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
255
|
+
current_date_time: Optional[str] = None,
|
|
256
|
+
save_history: bool = True,
|
|
257
|
+
additional_info: Optional[List[Any]] = None,
|
|
258
|
+
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
259
|
+
correlation_id: Optional[str] = None,
|
|
260
|
+
api_version: str = ApiVersion.V2.value,
|
|
261
|
+
) -> PipelineExecutionResponse: ...
|
|
262
|
+
|
|
263
|
+
@overload
|
|
264
|
+
def execute_pipeline(
|
|
265
|
+
self,
|
|
266
|
+
pipeline_id: str,
|
|
267
|
+
user_input: str,
|
|
268
|
+
debug: Literal[True] = True,
|
|
269
|
+
user_id: Optional[str] = None,
|
|
270
|
+
conversation_id: Optional[str] = None,
|
|
271
|
+
async_output: Literal[False] = False,
|
|
272
|
+
include_tools_response: bool = False,
|
|
273
|
+
images: Optional[List[str]] = None,
|
|
274
|
+
files: Optional[List[str]] = None,
|
|
275
|
+
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
276
|
+
data_source_files: Optional[Dict[str, Any]] = None,
|
|
277
|
+
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
278
|
+
current_date_time: Optional[str] = None,
|
|
279
|
+
save_history: bool = True,
|
|
280
|
+
additional_info: Optional[List[Any]] = None,
|
|
281
|
+
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
282
|
+
correlation_id: Optional[str] = None,
|
|
283
|
+
api_version: str = ApiVersion.V2.value,
|
|
284
|
+
) -> PipelineExecutionDebugResponse: ...
|
|
285
|
+
|
|
286
|
+
@overload
|
|
287
|
+
def execute_pipeline(
|
|
288
|
+
self,
|
|
289
|
+
pipeline_id: str,
|
|
290
|
+
user_input: str,
|
|
291
|
+
debug: bool = False,
|
|
292
|
+
user_id: Optional[str] = None,
|
|
293
|
+
conversation_id: Optional[str] = None,
|
|
294
|
+
async_output: Literal[True] = True,
|
|
295
|
+
include_tools_response: bool = False,
|
|
296
|
+
images: Optional[List[str]] = None,
|
|
297
|
+
files: Optional[List[str]] = None,
|
|
298
|
+
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
299
|
+
data_source_files: Optional[Dict[str, Any]] = None,
|
|
300
|
+
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
301
|
+
current_date_time: Optional[str] = None,
|
|
302
|
+
save_history: bool = True,
|
|
303
|
+
additional_info: Optional[List[Any]] = None,
|
|
304
|
+
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
305
|
+
correlation_id: Optional[str] = None,
|
|
306
|
+
api_version: Literal["v2"] = ApiVersion.V2.value,
|
|
307
|
+
) -> PipelineExecutionV2StreamedResponse: ...
|
|
308
|
+
|
|
309
|
+
@overload
|
|
310
|
+
def execute_pipeline(
|
|
311
|
+
self,
|
|
312
|
+
pipeline_id: str,
|
|
313
|
+
user_input: str,
|
|
314
|
+
debug: bool = False,
|
|
315
|
+
user_id: Optional[str] = None,
|
|
316
|
+
conversation_id: Optional[str] = None,
|
|
317
|
+
async_output: Literal[True] = True,
|
|
318
|
+
include_tools_response: bool = False,
|
|
319
|
+
images: Optional[List[str]] = None,
|
|
320
|
+
files: Optional[List[str]] = None,
|
|
321
|
+
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
322
|
+
data_source_files: Optional[Dict[str, Any]] = None,
|
|
323
|
+
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
324
|
+
current_date_time: Optional[str] = None,
|
|
325
|
+
save_history: bool = True,
|
|
326
|
+
additional_info: Optional[List[Any]] = None,
|
|
327
|
+
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
328
|
+
correlation_id: Optional[str] = None,
|
|
329
|
+
api_version: Literal["v1"] = ApiVersion.V1.value,
|
|
330
|
+
) -> PipelineExecutionV1StreamedResponse: ...
|
|
331
|
+
|
|
332
|
+
def execute_pipeline(
|
|
333
|
+
self,
|
|
334
|
+
pipeline_id: str,
|
|
335
|
+
user_input: str,
|
|
336
|
+
debug: bool = False,
|
|
337
|
+
user_id: Optional[str] = None,
|
|
338
|
+
conversation_id: Optional[str] = None,
|
|
339
|
+
async_output: bool = False,
|
|
340
|
+
include_tools_response: bool = False,
|
|
341
|
+
images: Optional[List[str]] = None,
|
|
342
|
+
files: Optional[List[str]] = None,
|
|
343
|
+
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
344
|
+
data_source_files: Optional[Dict[str, Any]] = None,
|
|
345
|
+
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
346
|
+
current_date_time: Optional[str] = None,
|
|
347
|
+
save_history: bool = True,
|
|
348
|
+
additional_info: Optional[List[Any]] = None,
|
|
349
|
+
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
350
|
+
correlation_id: Optional[str] = None,
|
|
351
|
+
api_version: str = ApiVersion.V2.value,
|
|
352
|
+
):
|
|
353
|
+
"""
|
|
354
|
+
Execute a pipeline with the provided input.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
pipeline_id: The ID of the pipeline to execute.
|
|
358
|
+
user_input: input text to process.
|
|
359
|
+
debug: Whether debug mode execution is enabled. Default is False.
|
|
360
|
+
user_id: Optional ID of the user making the request (guid).
|
|
361
|
+
conversation_id: Optional conversation ID (guid).
|
|
362
|
+
async_output: Whether to stream the response. Default is False.
|
|
363
|
+
include_tools_response: Whether to return the initial LLM tool result. Default is False.
|
|
364
|
+
images: Optional list of images formatted as base64 strings.
|
|
365
|
+
files: Optional list of files formatted as base64 strings.
|
|
366
|
+
data_source_folders: Optional data source folders information.
|
|
367
|
+
data_source_files: Optional data source files information.
|
|
368
|
+
in_memory_messages: Optional list of in-memory messages, each with a role and message.
|
|
369
|
+
current_date_time: Optional current date and time in ISO format.
|
|
370
|
+
save_history: Whether to save the userInput and output to conversation history. Default is True.
|
|
371
|
+
additional_info: Optional additional information.
|
|
372
|
+
prompt_variables: Optional variables to be used in the prompt.
|
|
373
|
+
correlation_id: Optional correlation ID for request tracing. If not provided,
|
|
374
|
+
one will be generated automatically.
|
|
375
|
+
api_version: API version to use. Default is `v2`
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
The API response as a dictionary.
|
|
379
|
+
|
|
380
|
+
Raises:
|
|
381
|
+
AiriaAPIError: If the API request fails with details about the error.
|
|
382
|
+
requests.RequestException: For other request-related errors.
|
|
383
|
+
|
|
384
|
+
Example:
|
|
385
|
+
>>> client = AiriaClient(api_key="your_api_key")
|
|
386
|
+
>>> response = client.execute_pipeline(
|
|
387
|
+
... pipeline_id="pipeline_123",
|
|
388
|
+
... user_input="Tell me about quantum computing"
|
|
389
|
+
... )
|
|
390
|
+
>>> print(response.result)
|
|
391
|
+
"""
|
|
392
|
+
request_data = self._pre_execute_pipeline(
|
|
393
|
+
pipeline_id=pipeline_id,
|
|
394
|
+
user_input=user_input,
|
|
395
|
+
debug=debug,
|
|
396
|
+
user_id=user_id,
|
|
397
|
+
conversation_id=conversation_id,
|
|
398
|
+
async_output=async_output,
|
|
399
|
+
include_tools_response=include_tools_response,
|
|
400
|
+
images=images,
|
|
401
|
+
files=files,
|
|
402
|
+
data_source_folders=data_source_folders,
|
|
403
|
+
data_source_files=data_source_files,
|
|
404
|
+
in_memory_messages=in_memory_messages,
|
|
405
|
+
current_date_time=current_date_time,
|
|
406
|
+
save_history=save_history,
|
|
407
|
+
additional_info=additional_info,
|
|
408
|
+
prompt_variables=prompt_variables,
|
|
409
|
+
correlation_id=correlation_id,
|
|
410
|
+
api_version=api_version,
|
|
411
|
+
)
|
|
412
|
+
stream = async_output and api_version == ApiVersion.V2.value
|
|
413
|
+
if stream:
|
|
414
|
+
resp = self._make_request_stream("POST", request_data)
|
|
415
|
+
else:
|
|
416
|
+
resp = self._make_request("POST", request_data)
|
|
417
|
+
|
|
418
|
+
if not async_output:
|
|
419
|
+
if not debug:
|
|
420
|
+
return PipelineExecutionResponse(**resp)
|
|
421
|
+
return PipelineExecutionDebugResponse(**resp)
|
|
422
|
+
|
|
423
|
+
if api_version == ApiVersion.V1.value:
|
|
424
|
+
url = urljoin(
|
|
425
|
+
self.base_url, f"{api_version}/StreamSocketConfig/GenerateUrl"
|
|
426
|
+
)
|
|
427
|
+
request_data = self._prepare_request(
|
|
428
|
+
url,
|
|
429
|
+
{"socketIdentifier": resp},
|
|
430
|
+
request_data.headers["X-Correlation-ID"],
|
|
431
|
+
)
|
|
432
|
+
resp = self._make_request("POST", request_data)
|
|
433
|
+
|
|
434
|
+
return PipelineExecutionV1StreamedResponse(**resp)
|
|
435
|
+
|
|
436
|
+
return PipelineExecutionV2StreamedResponse(stream=resp)
|
airia/types/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .api_version import ApiVersion
|
|
2
|
+
from .pipeline_execution import (
|
|
3
|
+
PipelineExecutionDebugResponse,
|
|
4
|
+
PipelineExecutionResponse,
|
|
5
|
+
PipelineExecutionV1StreamedResponse,
|
|
6
|
+
PipelineExecutionV2AsyncStreamedResponse,
|
|
7
|
+
PipelineExecutionV2StreamedResponse,
|
|
8
|
+
)
|
|
9
|
+
from .request_data import RequestData
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ApiVersion",
|
|
13
|
+
"PipelineExecutionDebugResponse",
|
|
14
|
+
"PipelineExecutionResponse",
|
|
15
|
+
"PipelineExecutionV1StreamedResponse",
|
|
16
|
+
"PipelineExecutionV2AsyncStreamedResponse",
|
|
17
|
+
"PipelineExecutionV2StreamedResponse",
|
|
18
|
+
"RequestData",
|
|
19
|
+
]
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Any, AsyncIterator, Dict, Iterator
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PipelineExecutionResponse(BaseModel):
|
|
7
|
+
result: str
|
|
8
|
+
report: None
|
|
9
|
+
isBackupPipeline: bool
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PipelineExecutionDebugResponse(BaseModel):
|
|
13
|
+
result: str
|
|
14
|
+
report: Dict[str, Any]
|
|
15
|
+
isBackupPipeline: bool
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PipelineExecutionV1StreamedResponse(BaseModel):
|
|
19
|
+
webSocketUrl: str
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PipelineExecutionV2StreamedResponse(BaseModel):
|
|
23
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
24
|
+
|
|
25
|
+
stream: Iterator[str]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class PipelineExecutionV2AsyncStreamedResponse(BaseModel):
|
|
29
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
30
|
+
|
|
31
|
+
stream: AsyncIterator[str]
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
airia/__init__.py,sha256=T39gO8E5T5zxlw-JP78ruxOu7-LeKOJCJzz6t40kdQo,231
|
|
2
|
+
airia/exceptions.py,sha256=4Z55n-cRJrtTa5-pZBIK2oZD4-Z99aUtKx_kfTFYY5o,1146
|
|
3
|
+
airia/logs.py,sha256=17YZ4IuzOF0m5bgofj9-QYlJ2BYR2kRZbBVQfFSLFEk,5441
|
|
4
|
+
airia/client/__init__.py,sha256=6gSQ9bl7j79q1HPE0o5py3IRdkwWWuU_7J4h05Dd2o8,127
|
|
5
|
+
airia/client/async_client.py,sha256=Xd1GAGxQvXVh0-hH8vE4HOeaNY_Fawh5iRxeS26CInY,18173
|
|
6
|
+
airia/client/base_client.py,sha256=3oUcCWWq-DmPbiJke_HHPXm9UzDvD-MzrnoiKnBjcdk,6810
|
|
7
|
+
airia/client/sync_client.py,sha256=p_auqkfyP-9YuCoppl43SyMdKek30kA6M05QI0LYunE,17689
|
|
8
|
+
airia/types/__init__.py,sha256=OOFtJ0VIbO98px89u7cq645iL7CYDOel43g85IAjRRw,562
|
|
9
|
+
airia/types/api_version.py,sha256=Uzom6O2ZG92HN_Z2h-lTydmO2XYX9RVs4Yi4DJmXytE,255
|
|
10
|
+
airia/types/pipeline_execution.py,sha256=obp8KOz-ShNxEciF1YQCSpHXPoJUyEa_9uPv-VHf6YA,699
|
|
11
|
+
airia/types/request_data.py,sha256=RbVPWPRIYuT7FaolJKn19IVzuQKbLM4VhXM5r7UXTUY,186
|
|
12
|
+
airia-0.1.4.dist-info/licenses/LICENSE,sha256=R3ClUMMKPRItIcZ0svzyj2taZZnFYw568YDNzN9KQ1Q,1066
|
|
13
|
+
airia-0.1.4.dist-info/METADATA,sha256=VOWxozh4xZXAHUXH6_ZyCUTIU95LXxly3Yfs-ayi0sk,9966
|
|
14
|
+
airia-0.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
+
airia-0.1.4.dist-info/top_level.txt,sha256=qUQEKfs_hdOYTwjKj1JZbRhS5YeXDNaKQaVTrzabS6w,6
|
|
16
|
+
airia-0.1.4.dist-info/RECORD,,
|
airia-0.1.3.dist-info/RECORD
DELETED
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
airia/__init__.py,sha256=T39gO8E5T5zxlw-JP78ruxOu7-LeKOJCJzz6t40kdQo,231
|
|
2
|
-
airia/exceptions.py,sha256=4Z55n-cRJrtTa5-pZBIK2oZD4-Z99aUtKx_kfTFYY5o,1146
|
|
3
|
-
airia/logs.py,sha256=17YZ4IuzOF0m5bgofj9-QYlJ2BYR2kRZbBVQfFSLFEk,5441
|
|
4
|
-
airia-0.1.3.dist-info/licenses/LICENSE,sha256=R3ClUMMKPRItIcZ0svzyj2taZZnFYw568YDNzN9KQ1Q,1066
|
|
5
|
-
airia-0.1.3.dist-info/METADATA,sha256=9LEQFb6HGNqoTfhn6aX8elSCa5Eaogk43cj2j_pcx4o,9966
|
|
6
|
-
airia-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
7
|
-
airia-0.1.3.dist-info/top_level.txt,sha256=qUQEKfs_hdOYTwjKj1JZbRhS5YeXDNaKQaVTrzabS6w,6
|
|
8
|
-
airia-0.1.3.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|