airia 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airia/client/async_client.py +163 -29
- airia/client/base_client.py +32 -1
- airia/client/sync_client.py +140 -12
- airia/types/__init__.py +54 -2
- airia/types/api/get_pipeline_config.py +65 -0
- airia/types/{pipeline_execution.py → api/pipeline_execution.py} +4 -2
- airia/types/request_data.py +2 -2
- airia/types/sse_messages.py +269 -0
- airia/utils/sse_parser.py +91 -0
- {airia-0.1.4.dist-info → airia-0.1.6.dist-info}/METADATA +126 -19
- airia-0.1.6.dist-info/RECORD +19 -0
- airia-0.1.4.dist-info/RECORD +0 -16
- {airia-0.1.4.dist-info → airia-0.1.6.dist-info}/WHEEL +0 -0
- {airia-0.1.4.dist-info → airia-0.1.6.dist-info}/licenses/LICENSE +0 -0
- {airia-0.1.4.dist-info → airia-0.1.6.dist-info}/top_level.txt +0 -0
airia/client/async_client.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import weakref
|
|
1
3
|
from typing import Any, AsyncIterator, Dict, List, Literal, Optional, overload
|
|
2
4
|
from urllib.parse import urljoin
|
|
3
5
|
|
|
@@ -7,12 +9,14 @@ import loguru
|
|
|
7
9
|
from ..exceptions import AiriaAPIError
|
|
8
10
|
from ..types import (
|
|
9
11
|
ApiVersion,
|
|
12
|
+
GetPipelineConfigResponse,
|
|
10
13
|
PipelineExecutionDebugResponse,
|
|
11
14
|
PipelineExecutionResponse,
|
|
12
15
|
PipelineExecutionV1StreamedResponse,
|
|
13
16
|
PipelineExecutionV2AsyncStreamedResponse,
|
|
14
17
|
RequestData,
|
|
15
18
|
)
|
|
19
|
+
from ..utils.sse_parser import async_parse_sse_stream_chunked
|
|
16
20
|
from .base_client import AiriaBaseClient
|
|
17
21
|
|
|
18
22
|
|
|
@@ -21,6 +25,7 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
21
25
|
|
|
22
26
|
def __init__(
|
|
23
27
|
self,
|
|
28
|
+
base_url: str = "https://api.airia.ai/",
|
|
24
29
|
api_key: Optional[str] = None,
|
|
25
30
|
timeout: float = 30.0,
|
|
26
31
|
log_requests: bool = False,
|
|
@@ -30,20 +35,52 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
30
35
|
Initialize the asynchronous Airia API client.
|
|
31
36
|
|
|
32
37
|
Args:
|
|
38
|
+
base_url: Base URL of the Airia API.
|
|
33
39
|
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
34
40
|
timeout: Request timeout in seconds.
|
|
35
41
|
log_requests: Whether to log API requests and responses. Default is False.
|
|
36
42
|
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
37
43
|
"""
|
|
38
|
-
super().__init__(
|
|
44
|
+
super().__init__(
|
|
45
|
+
base_url=base_url,
|
|
46
|
+
api_key=api_key,
|
|
47
|
+
timeout=timeout,
|
|
48
|
+
log_requests=log_requests,
|
|
49
|
+
custom_logger=custom_logger,
|
|
50
|
+
)
|
|
39
51
|
|
|
40
52
|
# Session will be initialized in __aenter__
|
|
41
|
-
self.session = None
|
|
42
53
|
self.headers = {"Content-Type": "application/json"}
|
|
43
|
-
|
|
54
|
+
self.session = aiohttp.ClientSession(headers=self.headers)
|
|
55
|
+
|
|
56
|
+
# Register finalizer to clean up session when client is garbage collected
|
|
57
|
+
self._finalizer = weakref.finalize(self, self._cleanup_session, self.session)
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def _cleanup_session(session: aiohttp.ClientSession):
|
|
61
|
+
"""Static method to clean up session - called by finalizer"""
|
|
62
|
+
if session and not session.closed:
|
|
63
|
+
# Create a new event loop if none exists
|
|
64
|
+
try:
|
|
65
|
+
loop = asyncio.get_event_loop()
|
|
66
|
+
if loop.is_closed():
|
|
67
|
+
raise RuntimeError("Event loop is closed")
|
|
68
|
+
except RuntimeError:
|
|
69
|
+
loop = asyncio.new_event_loop()
|
|
70
|
+
asyncio.set_event_loop(loop)
|
|
71
|
+
|
|
72
|
+
# Close the session
|
|
73
|
+
if not loop.is_running():
|
|
74
|
+
loop.run_until_complete(session.close())
|
|
75
|
+
else:
|
|
76
|
+
# If loop is running, schedule the close operation
|
|
77
|
+
asyncio.create_task(session.close())
|
|
78
|
+
|
|
44
79
|
@classmethod
|
|
45
80
|
def with_openai_gateway(
|
|
46
81
|
cls,
|
|
82
|
+
base_url: str = "https://api.airia.ai/",
|
|
83
|
+
gateway_url: str = "https://gateway.airia.ai/openai/v1",
|
|
47
84
|
api_key: Optional[str] = None,
|
|
48
85
|
timeout: float = 30.0,
|
|
49
86
|
log_requests: bool = False,
|
|
@@ -54,6 +91,8 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
54
91
|
Initialize the asynchronous Airia API client with AsyncOpenAI gateway capabilities.
|
|
55
92
|
|
|
56
93
|
Args:
|
|
94
|
+
base_url: Base URL of the Airia API.
|
|
95
|
+
gateway_url: Base URL of the Airia Gateway API.
|
|
57
96
|
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
58
97
|
timeout: Request timeout in seconds.
|
|
59
98
|
log_requests: Whether to log API requests and responses. Default is False.
|
|
@@ -65,15 +104,17 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
65
104
|
api_key = cls._get_api_key(api_key)
|
|
66
105
|
cls.openai = AsyncOpenAI(
|
|
67
106
|
api_key=api_key,
|
|
68
|
-
base_url=
|
|
107
|
+
base_url=gateway_url,
|
|
69
108
|
**kwargs,
|
|
70
109
|
)
|
|
71
110
|
|
|
72
|
-
return cls(api_key, timeout, log_requests, custom_logger)
|
|
111
|
+
return cls(base_url, api_key, timeout, log_requests, custom_logger)
|
|
73
112
|
|
|
74
113
|
@classmethod
|
|
75
114
|
def with_anthropic_gateway(
|
|
76
115
|
cls,
|
|
116
|
+
base_url: str = "https://api.airia.ai/",
|
|
117
|
+
gateway_url: str = "https://gateway.airia.ai/anthropic",
|
|
77
118
|
api_key: Optional[str] = None,
|
|
78
119
|
timeout: float = 30.0,
|
|
79
120
|
log_requests: bool = False,
|
|
@@ -84,6 +125,8 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
84
125
|
Initialize the asynchronous Airia API client with AsyncAnthropic gateway capabilities.
|
|
85
126
|
|
|
86
127
|
Args:
|
|
128
|
+
base_url: Base URL of the Airia API.
|
|
129
|
+
gateway_url: Base URL of the Airia Gateway API.
|
|
87
130
|
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
88
131
|
timeout: Request timeout in seconds.
|
|
89
132
|
log_requests: Whether to log API requests and responses. Default is False.
|
|
@@ -95,29 +138,11 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
95
138
|
api_key = cls._get_api_key(api_key)
|
|
96
139
|
cls.anthropic = AsyncAnthropic(
|
|
97
140
|
api_key=api_key,
|
|
98
|
-
base_url=
|
|
141
|
+
base_url=gateway_url,
|
|
99
142
|
**kwargs,
|
|
100
143
|
)
|
|
101
144
|
|
|
102
|
-
return cls(api_key, timeout, log_requests, custom_logger)
|
|
103
|
-
|
|
104
|
-
async def __aenter__(self):
|
|
105
|
-
"""Async context manager entry point."""
|
|
106
|
-
self.session = aiohttp.ClientSession(headers=self.headers)
|
|
107
|
-
return self
|
|
108
|
-
|
|
109
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
110
|
-
"""Async context manager exit point."""
|
|
111
|
-
if self.session:
|
|
112
|
-
await self.session.close()
|
|
113
|
-
self.session = None
|
|
114
|
-
|
|
115
|
-
def _check_session(self):
|
|
116
|
-
"""Check if the client session is initialized."""
|
|
117
|
-
if not self.session:
|
|
118
|
-
raise RuntimeError(
|
|
119
|
-
"Client session not initialized. Use async with AiriaAsyncClient() as client: ..."
|
|
120
|
-
)
|
|
145
|
+
return cls(base_url, api_key, timeout, log_requests, custom_logger)
|
|
121
146
|
|
|
122
147
|
def _handle_exception(
|
|
123
148
|
self, e: aiohttp.ClientResponseError, url: str, correlation_id: str
|
|
@@ -241,8 +266,10 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
241
266
|
response.raise_for_status()
|
|
242
267
|
|
|
243
268
|
# Yields the response content as a stream if streaming
|
|
244
|
-
async for
|
|
245
|
-
|
|
269
|
+
async for message in async_parse_sse_stream_chunked(
|
|
270
|
+
response.content.iter_any()
|
|
271
|
+
):
|
|
272
|
+
yield message
|
|
246
273
|
|
|
247
274
|
except aiohttp.ClientResponseError as e:
|
|
248
275
|
self._handle_exception(e, request_data.url, request_data.correlation_id)
|
|
@@ -399,8 +426,6 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
399
426
|
... )
|
|
400
427
|
>>> print(response.result)
|
|
401
428
|
"""
|
|
402
|
-
self._check_session()
|
|
403
|
-
|
|
404
429
|
request_data = self._pre_execute_pipeline(
|
|
405
430
|
pipeline_id=pipeline_id,
|
|
406
431
|
user_input=user_input,
|
|
@@ -446,3 +471,112 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
446
471
|
return PipelineExecutionV1StreamedResponse(**resp)
|
|
447
472
|
|
|
448
473
|
return PipelineExecutionV2AsyncStreamedResponse(stream=resp)
|
|
474
|
+
|
|
475
|
+
async def get_active_pipelines_ids(
|
|
476
|
+
self,
|
|
477
|
+
correlation_id: Optional[str] = None,
|
|
478
|
+
api_version: str = ApiVersion.V1.value,
|
|
479
|
+
) -> List[str]:
|
|
480
|
+
"""
|
|
481
|
+
Retrieve a list of active pipeline IDs.
|
|
482
|
+
|
|
483
|
+
This method fetches all currently active pipeline IDs from the Airia API.
|
|
484
|
+
These IDs can be used with other methods like execute_pipeline() or
|
|
485
|
+
get_pipeline_config().
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
api_version (str, optional): API version to use for the request.
|
|
489
|
+
Must be one of the supported versions. Defaults to "v1".
|
|
490
|
+
correlation_id (str, optional): Unique identifier for request tracing
|
|
491
|
+
and logging. If not provided, a new UUID will be automatically
|
|
492
|
+
generated.
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
List[str]: A list of active pipeline ID strings. Returns an empty list
|
|
496
|
+
if no active pipelines are found.
|
|
497
|
+
|
|
498
|
+
Raises:
|
|
499
|
+
ValueError: If the provided API version is not supported.
|
|
500
|
+
AiriaAPIError: If the API request fails, including network errors,
|
|
501
|
+
authentication failures, or server errors.
|
|
502
|
+
|
|
503
|
+
Example:
|
|
504
|
+
>>> client = AiriaClient(api_key="your_api_key")
|
|
505
|
+
>>> pipeline_ids = client.get_active_pipelines_ids()
|
|
506
|
+
>>> print(f"Found {len(pipeline_ids)} active pipelines")
|
|
507
|
+
>>> for pipeline_id in pipeline_ids:
|
|
508
|
+
... print(f"Pipeline ID: {pipeline_id}")
|
|
509
|
+
"""
|
|
510
|
+
request_data = self._pre_get_active_pipelines_ids(
|
|
511
|
+
correlation_id=correlation_id, api_version=api_version
|
|
512
|
+
)
|
|
513
|
+
resp = await self._make_request("GET", request_data)
|
|
514
|
+
|
|
515
|
+
if "items" not in resp or len(resp["items"]) == 0:
|
|
516
|
+
return []
|
|
517
|
+
|
|
518
|
+
pipeline_ids = [r["activeVersion"]["pipelineId"] for r in resp["items"]]
|
|
519
|
+
|
|
520
|
+
return pipeline_ids
|
|
521
|
+
|
|
522
|
+
async def get_pipeline_config(
|
|
523
|
+
self,
|
|
524
|
+
pipeline_id: str,
|
|
525
|
+
correlation_id: Optional[str] = None,
|
|
526
|
+
api_version: str = ApiVersion.V1.value,
|
|
527
|
+
) -> GetPipelineConfigResponse:
|
|
528
|
+
"""
|
|
529
|
+
Retrieve configuration details for a specific pipeline.
|
|
530
|
+
|
|
531
|
+
This method fetches comprehensive information about a pipeline including its
|
|
532
|
+
deployment details, execution statistics, version information, and metadata.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
pipeline_id (str): The unique identifier of the pipeline to retrieve
|
|
536
|
+
configuration for.
|
|
537
|
+
api_version (str, optional): The API version to use for the request.
|
|
538
|
+
Defaults to "v1". Valid versions are defined in ApiVersion enum.
|
|
539
|
+
correlation_id (str, optional): A unique identifier for request tracing
|
|
540
|
+
and logging. If not provided, one will be automatically generated.
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
GetPipelineConfigResponse: A response object containing the pipeline
|
|
544
|
+
configuration.
|
|
545
|
+
|
|
546
|
+
Raises:
|
|
547
|
+
ValueError: If the provided api_version is not valid.
|
|
548
|
+
AiriaAPIError: If the API request fails, including cases where:
|
|
549
|
+
- The pipeline_id doesn't exist (404)
|
|
550
|
+
- Authentication fails (401)
|
|
551
|
+
- Access is forbidden (403)
|
|
552
|
+
- Server errors (5xx)
|
|
553
|
+
|
|
554
|
+
Example:
|
|
555
|
+
```python
|
|
556
|
+
from airia import AiriaClient
|
|
557
|
+
|
|
558
|
+
client = AiriaClient(api_key="your_api_key")
|
|
559
|
+
|
|
560
|
+
# Get pipeline configuration
|
|
561
|
+
config = client.get_pipeline_config(
|
|
562
|
+
pipeline_id="your_pipeline_id"
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
print(f"Pipeline: {config.deployment_name}")
|
|
566
|
+
print(f"Description: {config.deployment_description}")
|
|
567
|
+
print(f"Success rate: {config.execution_stats.success_count}")
|
|
568
|
+
print(f"Active version: {config.active_version.version_number}")
|
|
569
|
+
```
|
|
570
|
+
|
|
571
|
+
Note:
|
|
572
|
+
This method only retrieves configuration information and does not
|
|
573
|
+
execute the pipeline. Use execute_pipeline() to run the pipeline.
|
|
574
|
+
"""
|
|
575
|
+
request_data = self._pre_get_pipeline_config(
|
|
576
|
+
pipeline_id=pipeline_id,
|
|
577
|
+
correlation_id=correlation_id,
|
|
578
|
+
api_version=api_version,
|
|
579
|
+
)
|
|
580
|
+
resp = await self._make_request("GET", request_data)
|
|
581
|
+
|
|
582
|
+
return GetPipelineConfigResponse(**resp)
|
airia/client/base_client.py
CHANGED
|
@@ -11,11 +11,13 @@ from ..types import ApiVersion, RequestData
|
|
|
11
11
|
|
|
12
12
|
class AiriaBaseClient:
|
|
13
13
|
"""Base client containing shared functionality for Airia API clients."""
|
|
14
|
+
|
|
14
15
|
openai = None
|
|
15
16
|
anthropic = None
|
|
16
17
|
|
|
17
18
|
def __init__(
|
|
18
19
|
self,
|
|
20
|
+
base_url: str = "https://api.airia.ai/",
|
|
19
21
|
api_key: Optional[str] = None,
|
|
20
22
|
timeout: float = 30.0,
|
|
21
23
|
log_requests: bool = False,
|
|
@@ -34,7 +36,7 @@ class AiriaBaseClient:
|
|
|
34
36
|
self.api_key = self.__class__._get_api_key(api_key)
|
|
35
37
|
|
|
36
38
|
# Store configuration
|
|
37
|
-
self.base_url =
|
|
39
|
+
self.base_url = base_url
|
|
38
40
|
self.timeout = timeout
|
|
39
41
|
self.log_requests = log_requests
|
|
40
42
|
|
|
@@ -185,3 +187,32 @@ class AiriaBaseClient:
|
|
|
185
187
|
request_data = self._prepare_request(url, payload, correlation_id)
|
|
186
188
|
|
|
187
189
|
return request_data
|
|
190
|
+
|
|
191
|
+
def _pre_get_active_pipelines_ids(
|
|
192
|
+
self,
|
|
193
|
+
correlation_id: Optional[str] = None,
|
|
194
|
+
api_version: str = ApiVersion.V1.value,
|
|
195
|
+
):
|
|
196
|
+
if api_version not in ApiVersion.as_list():
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"Invalid API version: {api_version}. Valid versions are: {', '.join(ApiVersion.as_list())}"
|
|
199
|
+
)
|
|
200
|
+
url = urljoin(self.base_url, f"{api_version}/PipelinesConfig")
|
|
201
|
+
request_data = self._prepare_request(url, correlation_id=correlation_id)
|
|
202
|
+
|
|
203
|
+
return request_data
|
|
204
|
+
|
|
205
|
+
def _pre_get_pipeline_config(
|
|
206
|
+
self,
|
|
207
|
+
pipeline_id: str,
|
|
208
|
+
correlation_id: Optional[str] = None,
|
|
209
|
+
api_version: str = ApiVersion.V1.value,
|
|
210
|
+
):
|
|
211
|
+
if api_version not in ApiVersion.as_list():
|
|
212
|
+
raise ValueError(
|
|
213
|
+
f"Invalid API version: {api_version}. Valid versions are: {', '.join(ApiVersion.as_list())}"
|
|
214
|
+
)
|
|
215
|
+
url = urljoin(self.base_url, f"{api_version}/PipelinesConfig/{pipeline_id}")
|
|
216
|
+
request_data = self._prepare_request(url, correlation_id=correlation_id)
|
|
217
|
+
|
|
218
|
+
return request_data
|
airia/client/sync_client.py
CHANGED
|
@@ -7,12 +7,14 @@ import requests
|
|
|
7
7
|
from ..exceptions import AiriaAPIError
|
|
8
8
|
from ..types import (
|
|
9
9
|
ApiVersion,
|
|
10
|
+
GetPipelineConfigResponse,
|
|
10
11
|
PipelineExecutionDebugResponse,
|
|
11
12
|
PipelineExecutionResponse,
|
|
12
13
|
PipelineExecutionV1StreamedResponse,
|
|
13
14
|
PipelineExecutionV2StreamedResponse,
|
|
14
15
|
RequestData,
|
|
15
16
|
)
|
|
17
|
+
from ..utils.sse_parser import parse_sse_stream_chunked
|
|
16
18
|
from .base_client import AiriaBaseClient
|
|
17
19
|
|
|
18
20
|
|
|
@@ -21,6 +23,7 @@ class AiriaClient(AiriaBaseClient):
|
|
|
21
23
|
|
|
22
24
|
def __init__(
|
|
23
25
|
self,
|
|
26
|
+
base_url: str = "https://api.airia.ai/",
|
|
24
27
|
api_key: Optional[str] = None,
|
|
25
28
|
timeout: float = 30.0,
|
|
26
29
|
log_requests: bool = False,
|
|
@@ -30,12 +33,19 @@ class AiriaClient(AiriaBaseClient):
|
|
|
30
33
|
Initialize the synchronous Airia API client.
|
|
31
34
|
|
|
32
35
|
Args:
|
|
36
|
+
base_url: Base URL of the Airia API.
|
|
33
37
|
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
34
38
|
timeout: Request timeout in seconds.
|
|
35
39
|
log_requests: Whether to log API requests and responses. Default is False.
|
|
36
40
|
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
37
41
|
"""
|
|
38
|
-
super().__init__(
|
|
42
|
+
super().__init__(
|
|
43
|
+
base_url=base_url,
|
|
44
|
+
api_key=api_key,
|
|
45
|
+
timeout=timeout,
|
|
46
|
+
log_requests=log_requests,
|
|
47
|
+
custom_logger=custom_logger,
|
|
48
|
+
)
|
|
39
49
|
|
|
40
50
|
# Initialize session for synchronous requests
|
|
41
51
|
self.session = requests.Session()
|
|
@@ -44,6 +54,8 @@ class AiriaClient(AiriaBaseClient):
|
|
|
44
54
|
@classmethod
|
|
45
55
|
def with_openai_gateway(
|
|
46
56
|
cls,
|
|
57
|
+
base_url: str = "https://api.airia.ai/",
|
|
58
|
+
gateway_url: str = "https://gateway.airia.ai/openai/v1",
|
|
47
59
|
api_key: Optional[str] = None,
|
|
48
60
|
timeout: float = 30.0,
|
|
49
61
|
log_requests: bool = False,
|
|
@@ -54,6 +66,8 @@ class AiriaClient(AiriaBaseClient):
|
|
|
54
66
|
Initialize the synchronous Airia API client with OpenAI gateway capabilities.
|
|
55
67
|
|
|
56
68
|
Args:
|
|
69
|
+
base_url: Base URL of the Airia API.
|
|
70
|
+
gateway_url: Base URL of the Airia Gateway API.
|
|
57
71
|
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
58
72
|
timeout: Request timeout in seconds.
|
|
59
73
|
log_requests: Whether to log API requests and responses. Default is False.
|
|
@@ -65,15 +79,17 @@ class AiriaClient(AiriaBaseClient):
|
|
|
65
79
|
api_key = cls._get_api_key(api_key)
|
|
66
80
|
cls.openai = OpenAI(
|
|
67
81
|
api_key=api_key,
|
|
68
|
-
base_url=
|
|
82
|
+
base_url=gateway_url,
|
|
69
83
|
**kwargs,
|
|
70
84
|
)
|
|
71
85
|
|
|
72
|
-
return cls(api_key, timeout, log_requests, custom_logger)
|
|
86
|
+
return cls(base_url, api_key, timeout, log_requests, custom_logger)
|
|
73
87
|
|
|
74
88
|
@classmethod
|
|
75
89
|
def with_anthropic_gateway(
|
|
76
90
|
cls,
|
|
91
|
+
base_url: str = "https://api.airia.ai/",
|
|
92
|
+
gateway_url: str = "https://gateway.airia.ai/anthropic",
|
|
77
93
|
api_key: Optional[str] = None,
|
|
78
94
|
timeout: float = 30.0,
|
|
79
95
|
log_requests: bool = False,
|
|
@@ -84,6 +100,8 @@ class AiriaClient(AiriaBaseClient):
|
|
|
84
100
|
Initialize the synchronous Airia API client with Anthropic gateway capabilities.
|
|
85
101
|
|
|
86
102
|
Args:
|
|
103
|
+
base_url: Base URL of the Airia API.
|
|
104
|
+
gateway_url: Base URL of the Airia Gateway API.
|
|
87
105
|
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
88
106
|
timeout: Request timeout in seconds.
|
|
89
107
|
log_requests: Whether to log API requests and responses. Default is False.
|
|
@@ -95,11 +113,11 @@ class AiriaClient(AiriaBaseClient):
|
|
|
95
113
|
api_key = cls._get_api_key(api_key)
|
|
96
114
|
cls.anthropic = Anthropic(
|
|
97
115
|
api_key=api_key,
|
|
98
|
-
base_url=
|
|
116
|
+
base_url=gateway_url,
|
|
99
117
|
**kwargs,
|
|
100
118
|
)
|
|
101
119
|
|
|
102
|
-
return cls(api_key, timeout, log_requests, custom_logger)
|
|
120
|
+
return cls(base_url, api_key, timeout, log_requests, custom_logger)
|
|
103
121
|
|
|
104
122
|
def _handle_exception(self, e: requests.HTTPError, url: str, correlation_id: str):
|
|
105
123
|
# Log the error response if enabled
|
|
@@ -134,7 +152,7 @@ class AiriaClient(AiriaBaseClient):
|
|
|
134
152
|
status_code=e.response.status_code, message=sanitized_message
|
|
135
153
|
) from e
|
|
136
154
|
|
|
137
|
-
def _make_request(self, method: str, request_data: RequestData):
|
|
155
|
+
def _make_request(self, method: str, request_data: RequestData) -> Dict[str, Any]:
|
|
138
156
|
"""
|
|
139
157
|
Makes a synchronous HTTP request to the Airia API.
|
|
140
158
|
|
|
@@ -231,8 +249,8 @@ class AiriaClient(AiriaBaseClient):
|
|
|
231
249
|
response.raise_for_status()
|
|
232
250
|
|
|
233
251
|
# Yields the response content as a stream
|
|
234
|
-
for
|
|
235
|
-
yield
|
|
252
|
+
for message in parse_sse_stream_chunked(response.iter_content()):
|
|
253
|
+
yield message
|
|
236
254
|
|
|
237
255
|
except requests.HTTPError as e:
|
|
238
256
|
self._handle_exception(e, request_data.url, request_data.correlation_id)
|
|
@@ -410,10 +428,11 @@ class AiriaClient(AiriaBaseClient):
|
|
|
410
428
|
api_version=api_version,
|
|
411
429
|
)
|
|
412
430
|
stream = async_output and api_version == ApiVersion.V2.value
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
431
|
+
resp = (
|
|
432
|
+
self._make_request_stream("POST", request_data)
|
|
433
|
+
if stream
|
|
434
|
+
else self._make_request("POST", request_data)
|
|
435
|
+
)
|
|
417
436
|
|
|
418
437
|
if not async_output:
|
|
419
438
|
if not debug:
|
|
@@ -434,3 +453,112 @@ class AiriaClient(AiriaBaseClient):
|
|
|
434
453
|
return PipelineExecutionV1StreamedResponse(**resp)
|
|
435
454
|
|
|
436
455
|
return PipelineExecutionV2StreamedResponse(stream=resp)
|
|
456
|
+
|
|
457
|
+
def get_active_pipelines_ids(
|
|
458
|
+
self,
|
|
459
|
+
correlation_id: Optional[str] = None,
|
|
460
|
+
api_version: str = ApiVersion.V1.value,
|
|
461
|
+
) -> List[str]:
|
|
462
|
+
"""
|
|
463
|
+
Retrieve a list of active pipeline IDs.
|
|
464
|
+
|
|
465
|
+
This method fetches all currently active pipeline IDs from the Airia API.
|
|
466
|
+
These IDs can be used with other methods like execute_pipeline() or
|
|
467
|
+
get_pipeline_config().
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
api_version (str, optional): API version to use for the request.
|
|
471
|
+
Must be one of the supported versions. Defaults to "v1".
|
|
472
|
+
correlation_id (str, optional): Unique identifier for request tracing
|
|
473
|
+
and logging. If not provided, a new UUID will be automatically
|
|
474
|
+
generated.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
List[str]: A list of active pipeline ID strings. Returns an empty list
|
|
478
|
+
if no active pipelines are found.
|
|
479
|
+
|
|
480
|
+
Raises:
|
|
481
|
+
ValueError: If the provided API version is not supported.
|
|
482
|
+
AiriaAPIError: If the API request fails, including network errors,
|
|
483
|
+
authentication failures, or server errors.
|
|
484
|
+
|
|
485
|
+
Example:
|
|
486
|
+
>>> client = AiriaClient(api_key="your_api_key")
|
|
487
|
+
>>> pipeline_ids = client.get_active_pipelines_ids()
|
|
488
|
+
>>> print(f"Found {len(pipeline_ids)} active pipelines")
|
|
489
|
+
>>> for pipeline_id in pipeline_ids:
|
|
490
|
+
... print(f"Pipeline ID: {pipeline_id}")
|
|
491
|
+
"""
|
|
492
|
+
request_data = self._pre_get_active_pipelines_ids(
|
|
493
|
+
correlation_id=correlation_id, api_version=api_version
|
|
494
|
+
)
|
|
495
|
+
resp = self._make_request("GET", request_data)
|
|
496
|
+
|
|
497
|
+
if "items" not in resp or len(resp["items"]) == 0:
|
|
498
|
+
return []
|
|
499
|
+
|
|
500
|
+
pipeline_ids = [r["activeVersion"]["pipelineId"] for r in resp["items"]]
|
|
501
|
+
|
|
502
|
+
return pipeline_ids
|
|
503
|
+
|
|
504
|
+
def get_pipeline_config(
|
|
505
|
+
self,
|
|
506
|
+
pipeline_id: str,
|
|
507
|
+
correlation_id: Optional[str] = None,
|
|
508
|
+
api_version: str = ApiVersion.V1.value,
|
|
509
|
+
) -> GetPipelineConfigResponse:
|
|
510
|
+
"""
|
|
511
|
+
Retrieve configuration details for a specific pipeline.
|
|
512
|
+
|
|
513
|
+
This method fetches comprehensive information about a pipeline including its
|
|
514
|
+
deployment details, execution statistics, version information, and metadata.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
pipeline_id (str): The unique identifier of the pipeline to retrieve
|
|
518
|
+
configuration for.
|
|
519
|
+
api_version (str, optional): The API version to use for the request.
|
|
520
|
+
Defaults to "v1". Valid versions are defined in ApiVersion enum.
|
|
521
|
+
correlation_id (str, optional): A unique identifier for request tracing
|
|
522
|
+
and logging. If not provided, one will be automatically generated.
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
GetPipelineConfigResponse: A response object containing the pipeline
|
|
526
|
+
configuration.
|
|
527
|
+
|
|
528
|
+
Raises:
|
|
529
|
+
ValueError: If the provided api_version is not valid.
|
|
530
|
+
AiriaAPIError: If the API request fails, including cases where:
|
|
531
|
+
- The pipeline_id doesn't exist (404)
|
|
532
|
+
- Authentication fails (401)
|
|
533
|
+
- Access is forbidden (403)
|
|
534
|
+
- Server errors (5xx)
|
|
535
|
+
|
|
536
|
+
Example:
|
|
537
|
+
```python
|
|
538
|
+
from airia import AiriaClient
|
|
539
|
+
|
|
540
|
+
client = AiriaClient(api_key="your_api_key")
|
|
541
|
+
|
|
542
|
+
# Get pipeline configuration
|
|
543
|
+
config = client.get_pipeline_config(
|
|
544
|
+
pipeline_id="your_pipeline_id"
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
print(f"Pipeline: {config.deployment_name}")
|
|
548
|
+
print(f"Description: {config.deployment_description}")
|
|
549
|
+
print(f"Success rate: {config.execution_stats.success_count}")
|
|
550
|
+
print(f"Active version: {config.active_version.version_number}")
|
|
551
|
+
```
|
|
552
|
+
|
|
553
|
+
Note:
|
|
554
|
+
This method only retrieves configuration information and does not
|
|
555
|
+
execute the pipeline. Use execute_pipeline() to run the pipeline.
|
|
556
|
+
"""
|
|
557
|
+
request_data = self._pre_get_pipeline_config(
|
|
558
|
+
pipeline_id=pipeline_id,
|
|
559
|
+
correlation_id=correlation_id,
|
|
560
|
+
api_version=api_version,
|
|
561
|
+
)
|
|
562
|
+
resp = self._make_request("GET", request_data)
|
|
563
|
+
|
|
564
|
+
return GetPipelineConfigResponse(**resp)
|
airia/types/__init__.py
CHANGED
|
@@ -1,12 +1,39 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .pipeline_execution import (
|
|
1
|
+
from .api.get_pipeline_config import GetPipelineConfigResponse
|
|
2
|
+
from .api.pipeline_execution import (
|
|
3
3
|
PipelineExecutionDebugResponse,
|
|
4
4
|
PipelineExecutionResponse,
|
|
5
5
|
PipelineExecutionV1StreamedResponse,
|
|
6
6
|
PipelineExecutionV2AsyncStreamedResponse,
|
|
7
7
|
PipelineExecutionV2StreamedResponse,
|
|
8
8
|
)
|
|
9
|
+
from .api_version import ApiVersion
|
|
9
10
|
from .request_data import RequestData
|
|
11
|
+
from .sse_messages import (
|
|
12
|
+
AgentAgentCardMessage,
|
|
13
|
+
AgentAgentCardStreamEndMessage,
|
|
14
|
+
AgentAgentCardStreamErrorMessage,
|
|
15
|
+
AgentAgentCardStreamFragmentMessage,
|
|
16
|
+
AgentAgentCardStreamStartMessage,
|
|
17
|
+
AgentDatasearchMessage,
|
|
18
|
+
AgentEndMessage,
|
|
19
|
+
AgentInvocationMessage,
|
|
20
|
+
AgentModelMessage,
|
|
21
|
+
AgentModelStreamEndMessage,
|
|
22
|
+
AgentModelStreamErrorMessage,
|
|
23
|
+
AgentModelStreamFragmentMessage,
|
|
24
|
+
AgentModelStreamStartMessage,
|
|
25
|
+
AgentModelStreamUsageMessage,
|
|
26
|
+
AgentOutputMessage,
|
|
27
|
+
AgentPingMessage,
|
|
28
|
+
AgentPythonCodeMessage,
|
|
29
|
+
AgentStartMessage,
|
|
30
|
+
AgentStepEndMessage,
|
|
31
|
+
AgentStepHaltMessage,
|
|
32
|
+
AgentStepStartMessage,
|
|
33
|
+
AgentToolActionMessage,
|
|
34
|
+
AgentToolRequestMessage,
|
|
35
|
+
AgentToolResponseMessage,
|
|
36
|
+
)
|
|
10
37
|
|
|
11
38
|
__all__ = [
|
|
12
39
|
"ApiVersion",
|
|
@@ -15,5 +42,30 @@ __all__ = [
|
|
|
15
42
|
"PipelineExecutionV1StreamedResponse",
|
|
16
43
|
"PipelineExecutionV2AsyncStreamedResponse",
|
|
17
44
|
"PipelineExecutionV2StreamedResponse",
|
|
45
|
+
"GetPipelineConfigResponse",
|
|
18
46
|
"RequestData",
|
|
47
|
+
"AgentPingMessage",
|
|
48
|
+
"AgentStartMessage",
|
|
49
|
+
"AgentEndMessage",
|
|
50
|
+
"AgentStepStartMessage",
|
|
51
|
+
"AgentStepHaltMessage",
|
|
52
|
+
"AgentStepEndMessage",
|
|
53
|
+
"AgentOutputMessage",
|
|
54
|
+
"AgentAgentCardMessage",
|
|
55
|
+
"AgentDatasearchMessage",
|
|
56
|
+
"AgentInvocationMessage",
|
|
57
|
+
"AgentModelMessage",
|
|
58
|
+
"AgentPythonCodeMessage",
|
|
59
|
+
"AgentToolActionMessage",
|
|
60
|
+
"AgentModelStreamStartMessage",
|
|
61
|
+
"AgentModelStreamEndMessage",
|
|
62
|
+
"AgentModelStreamErrorMessage",
|
|
63
|
+
"AgentModelStreamUsageMessage",
|
|
64
|
+
"AgentModelStreamFragmentMessage",
|
|
65
|
+
"AgentAgentCardStreamStartMessage",
|
|
66
|
+
"AgentAgentCardStreamErrorMessage",
|
|
67
|
+
"AgentAgentCardStreamFragmentMessage",
|
|
68
|
+
"AgentAgentCardStreamEndMessage",
|
|
69
|
+
"AgentToolRequestMessage",
|
|
70
|
+
"AgentToolResponseMessage",
|
|
19
71
|
]
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Any, Dict, List, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Version(BaseModel):
|
|
8
|
+
pipeline_id: str = Field(alias="pipelineId")
|
|
9
|
+
major_version: int = Field(alias="majorVersion")
|
|
10
|
+
minor_version: int = Field(alias="minorVersion")
|
|
11
|
+
version_number: str = Field(alias="versionNumber")
|
|
12
|
+
is_draft_version: bool = Field(alias="isDraftVersion")
|
|
13
|
+
is_latest: bool = Field(alias="isLatest")
|
|
14
|
+
steps: Optional[List[Dict[str, Any]]] = None
|
|
15
|
+
alignment: str
|
|
16
|
+
id: str
|
|
17
|
+
tenant_id: str = Field(alias="tenantId")
|
|
18
|
+
project_id: str = Field(alias="projectId")
|
|
19
|
+
created_at: datetime = Field(alias="createdAt")
|
|
20
|
+
updated_at: datetime = Field(alias="updatedAt")
|
|
21
|
+
user_id: str = Field(alias="userId")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ExecutionStats(BaseModel):
|
|
25
|
+
success_count: int = Field(alias="successCount")
|
|
26
|
+
failure_count: int = Field(alias="failureCount")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class GetPipelineConfigResponse(BaseModel):
|
|
30
|
+
deployment_id: Optional[str] = Field(alias="deploymentId")
|
|
31
|
+
deployment_name: Optional[str] = Field(alias="deploymentName")
|
|
32
|
+
deployment_description: Optional[str] = Field(alias="deploymentDescription")
|
|
33
|
+
user_keys: Dict[str, Any] = Field(alias="userKeys")
|
|
34
|
+
group_keys: Dict[str, Any] = Field(alias="groupKeys")
|
|
35
|
+
agent_icon: Optional[str] = Field(alias="agentIcon")
|
|
36
|
+
external: bool
|
|
37
|
+
active_version_id: str = Field(alias="activeVersionId")
|
|
38
|
+
name: str
|
|
39
|
+
execution_name: str = Field(alias="executionName")
|
|
40
|
+
description: str
|
|
41
|
+
video_link: Optional[str] = Field(alias="videoLink")
|
|
42
|
+
agent_icon_id: Optional[str] = Field(alias="agentIconId")
|
|
43
|
+
versions: List[Version]
|
|
44
|
+
execution_stats: ExecutionStats = Field(alias="executionStats")
|
|
45
|
+
industry: Optional[str]
|
|
46
|
+
sub_industries: List[str] = Field(alias="subIndustries")
|
|
47
|
+
agent_details: Dict[str, Any] = Field(alias="agentDetails")
|
|
48
|
+
agent_details_tags: List[str] = Field(alias="agentDetailsTags")
|
|
49
|
+
active_version: Version = Field(alias="activeVersion")
|
|
50
|
+
backup_pipeline_id: Optional[str] = Field(alias="backupPipelineId")
|
|
51
|
+
deployment: Optional[Any]
|
|
52
|
+
library_agent_id: Optional[str] = Field(alias="libraryAgentId")
|
|
53
|
+
library_imported_hash: Optional[str] = Field(alias="libraryImportedHash")
|
|
54
|
+
library_imported_version: Optional[str] = Field(alias="libraryImportedVersion")
|
|
55
|
+
is_deleted: Optional[bool] = Field(alias="isDeleted")
|
|
56
|
+
agent_trigger: Optional[Any] = Field(alias="agentTrigger")
|
|
57
|
+
api_key_id: Optional[str] = Field(alias="apiKeyId")
|
|
58
|
+
is_seeded: bool = Field(alias="isSeeded")
|
|
59
|
+
behaviours: List[Any]
|
|
60
|
+
id: str
|
|
61
|
+
tenant_id: str = Field(alias="tenantId")
|
|
62
|
+
project_id: str = Field(alias="projectId")
|
|
63
|
+
created_at: datetime = Field(alias="createdAt")
|
|
64
|
+
updated_at: datetime = Field(alias="updatedAt")
|
|
65
|
+
user_id: str = Field(alias="userId")
|
|
@@ -2,6 +2,8 @@ from typing import Any, AsyncIterator, Dict, Iterator
|
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel, ConfigDict
|
|
4
4
|
|
|
5
|
+
from ..sse_messages import SSEMessage
|
|
6
|
+
|
|
5
7
|
|
|
6
8
|
class PipelineExecutionResponse(BaseModel):
|
|
7
9
|
result: str
|
|
@@ -22,10 +24,10 @@ class PipelineExecutionV1StreamedResponse(BaseModel):
|
|
|
22
24
|
class PipelineExecutionV2StreamedResponse(BaseModel):
|
|
23
25
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
24
26
|
|
|
25
|
-
stream: Iterator[
|
|
27
|
+
stream: Iterator[SSEMessage]
|
|
26
28
|
|
|
27
29
|
|
|
28
30
|
class PipelineExecutionV2AsyncStreamedResponse(BaseModel):
|
|
29
31
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
30
32
|
|
|
31
|
-
stream: AsyncIterator[
|
|
33
|
+
stream: AsyncIterator[SSEMessage]
|
airia/types/request_data.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from typing import Any, Dict
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class RequestData(BaseModel):
|
|
7
7
|
url: str
|
|
8
|
-
payload: Dict[str, Any]
|
|
8
|
+
payload: Optional[Dict[str, Any]]
|
|
9
9
|
headers: Dict[str, Any]
|
|
10
10
|
correlation_id: str
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
from datetime import datetime, time
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MessageType(str, Enum):
|
|
9
|
+
AGENT_PING = "AgentPingMessage"
|
|
10
|
+
AGENT_START = "AgentStartMessage"
|
|
11
|
+
AGENT_END = "AgentEndMessage"
|
|
12
|
+
AGENT_STEP_START = "AgentStepStartMessage"
|
|
13
|
+
AGENT_STEP_HALT = "AgentStepHaltMessage"
|
|
14
|
+
AGENT_STEP_END = "AgentStepEndMessage"
|
|
15
|
+
AGENT_OUTPUT = "AgentOutputMessage"
|
|
16
|
+
AGENT_AGENT_CARD = "AgentAgentCardMessage"
|
|
17
|
+
AGENT_DATASEARCH = "AgentDatasearchMessage"
|
|
18
|
+
AGENT_INVOCATION = "AgentInvocationMessage"
|
|
19
|
+
AGENT_MODEL = "AgentModelMessage"
|
|
20
|
+
AGENT_PYTHON_CODE = "AgentPythonCodeMessage"
|
|
21
|
+
AGENT_TOOL_ACTION = "AgentToolActionMessage"
|
|
22
|
+
AGENT_MODEL_STREAM_START = "AgentModelStreamStartMessage"
|
|
23
|
+
AGENT_MODEL_STREAM_END = "AgentModelStreamEndMessage"
|
|
24
|
+
AGENT_MODEL_STREAM_ERROR = "AgentModelStreamErrorMessage"
|
|
25
|
+
AGENT_MODEL_STREAM_USAGE = "AgentModelStreamUsageMessage"
|
|
26
|
+
AGENT_MODEL_STREAM_FRAGMENT = "AgentModelStreamFragmentMessage"
|
|
27
|
+
MODEL_STREAM_FRAGMENT = "ModelStreamFragment"
|
|
28
|
+
AGENT_AGENT_CARD_STREAM_START = "AgentAgentCardStreamStartMessage"
|
|
29
|
+
AGENT_AGENT_CARD_STREAM_ERROR = "AgentAgentCardStreamErrorMessage"
|
|
30
|
+
AGENT_AGENT_CARD_STREAM_FRAGMENT = "AgentAgentCardStreamFragmentMessage"
|
|
31
|
+
AGENT_AGENT_CARD_STREAM_END = "AgentAgentCardStreamEndMessage"
|
|
32
|
+
AGENT_TOOL_REQUEST = "AgentToolRequestMessage"
|
|
33
|
+
AGENT_TOOL_RESPONSE = "AgentToolResponseMessage"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class BaseSSEMessage(BaseModel):
|
|
37
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
38
|
+
message_type: MessageType
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class AgentPingMessage(BaseSSEMessage):
|
|
42
|
+
message_type: MessageType = MessageType.AGENT_PING
|
|
43
|
+
timestamp: datetime
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
### Agent Messages ###
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class BaseAgentMessage(BaseSSEMessage):
|
|
50
|
+
agent_id: str
|
|
51
|
+
execution_id: str
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class AgentStartMessage(BaseAgentMessage):
|
|
55
|
+
message_type: MessageType = MessageType.AGENT_START
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class AgentEndMessage(BaseAgentMessage):
|
|
59
|
+
message_type: MessageType = MessageType.AGENT_END
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
### Step Messages ###
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class BaseStepMessage(BaseAgentMessage):
|
|
66
|
+
step_id: str
|
|
67
|
+
step_type: str
|
|
68
|
+
step_title: Optional[str] = None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class AgentStepStartMessage(BaseStepMessage):
|
|
72
|
+
message_type: MessageType = MessageType.AGENT_STEP_START
|
|
73
|
+
start_time: datetime
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AgentStepHaltMessage(BaseStepMessage):
|
|
77
|
+
message_type: MessageType = MessageType.AGENT_STEP_HALT
|
|
78
|
+
approval_id: str
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class AgentStepEndMessage(BaseStepMessage):
|
|
82
|
+
message_type: MessageType = MessageType.AGENT_STEP_END
|
|
83
|
+
end_time: datetime
|
|
84
|
+
duration: time
|
|
85
|
+
status: str
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class AgentOutputMessage(BaseStepMessage):
|
|
89
|
+
message_type: MessageType = MessageType.AGENT_OUTPUT
|
|
90
|
+
step_result: str
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
### Status Messages ###
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class BaseStatusMessage(BaseStepMessage):
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class AgentAgentCardMessage(BaseStatusMessage):
|
|
101
|
+
message_type: MessageType = MessageType.AGENT_AGENT_CARD
|
|
102
|
+
step_name: str
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class AgentDatasearchMessage(BaseStatusMessage):
|
|
106
|
+
message_type: MessageType = MessageType.AGENT_DATASEARCH
|
|
107
|
+
datastore_id: str
|
|
108
|
+
datastore_type: str
|
|
109
|
+
datastore_name: str
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class AgentInvocationMessage(BaseStatusMessage):
|
|
113
|
+
message_type: MessageType = MessageType.AGENT_INVOCATION
|
|
114
|
+
agent_name: str
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class AgentModelMessage(BaseStatusMessage):
|
|
118
|
+
message_type: MessageType = MessageType.AGENT_MODEL
|
|
119
|
+
model_name: str
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class AgentPythonCodeMessage(BaseStatusMessage):
|
|
123
|
+
message_type: MessageType = MessageType.AGENT_PYTHON_CODE
|
|
124
|
+
step_name: str
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class AgentToolActionMessage(BaseStatusMessage):
|
|
128
|
+
message_type: MessageType = MessageType.AGENT_TOOL_ACTION
|
|
129
|
+
step_name: str
|
|
130
|
+
tool_name: str
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
### Model Stream Messages ###
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class BaseModelStreamMessage(BaseAgentMessage):
|
|
137
|
+
step_id: str
|
|
138
|
+
stream_id: str
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class AgentModelStreamStartMessage(BaseModelStreamMessage):
|
|
142
|
+
message_type: MessageType = MessageType.AGENT_MODEL_STREAM_START
|
|
143
|
+
model_name: str
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class AgentModelStreamErrorMessage(BaseModelStreamMessage):
|
|
147
|
+
message_type: MessageType = MessageType.AGENT_MODEL_STREAM_ERROR
|
|
148
|
+
error_message: str
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class AgentModelStreamFragmentMessage(BaseModelStreamMessage):
|
|
152
|
+
message_type: MessageType = MessageType.AGENT_MODEL_STREAM_FRAGMENT
|
|
153
|
+
index: int
|
|
154
|
+
content: Optional[str] = None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class AgentModelStreamEndMessage(BaseModelStreamMessage):
|
|
158
|
+
message_type: MessageType = MessageType.AGENT_MODEL_STREAM_END
|
|
159
|
+
content_id: str
|
|
160
|
+
duration: Optional[float] = None
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class AgentModelStreamUsageMessage(BaseModelStreamMessage):
|
|
164
|
+
message_type: MessageType = MessageType.AGENT_MODEL_STREAM_USAGE
|
|
165
|
+
token: Optional[int] = None
|
|
166
|
+
tokens_cost: Optional[float] = None
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
### Agent Card Messages ###
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class BaseAgentAgentCardStreamMessage(BaseAgentMessage):
|
|
173
|
+
step_id: str
|
|
174
|
+
stream_id: str
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class AgentAgentCardStreamStartMessage(BaseAgentAgentCardStreamMessage):
|
|
178
|
+
message_type: MessageType = MessageType.AGENT_AGENT_CARD_STREAM_START
|
|
179
|
+
content: Optional[str] = None
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class AgentAgentCardStreamErrorMessage(BaseAgentAgentCardStreamMessage):
|
|
183
|
+
message_type: MessageType = MessageType.AGENT_AGENT_CARD_STREAM_ERROR
|
|
184
|
+
error_message: str
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class AgentAgentCardStreamFragmentMessage(BaseAgentAgentCardStreamMessage):
|
|
188
|
+
message_type: MessageType = MessageType.AGENT_AGENT_CARD_STREAM_FRAGMENT
|
|
189
|
+
index: int
|
|
190
|
+
content: Optional[str]
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class AgentAgentCardStreamEndMessage(BaseAgentAgentCardStreamMessage):
|
|
194
|
+
message_type: MessageType = MessageType.AGENT_AGENT_CARD_STREAM_END
|
|
195
|
+
content: Optional[str] = None
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
### Tool Messages ###
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class BaseAgentToolMessage(BaseStepMessage):
|
|
202
|
+
id: str
|
|
203
|
+
name: str
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class AgentToolRequestMessage(BaseAgentToolMessage):
|
|
207
|
+
message_type: MessageType = MessageType.AGENT_TOOL_REQUEST
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class AgentToolResponseMessage(BaseAgentToolMessage):
|
|
211
|
+
message_type: MessageType = MessageType.AGENT_TOOL_RESPONSE
|
|
212
|
+
duration: time
|
|
213
|
+
success: bool
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# Union type for all possible messages
|
|
217
|
+
SSEMessage = Union[
|
|
218
|
+
AgentPingMessage,
|
|
219
|
+
AgentStartMessage,
|
|
220
|
+
AgentEndMessage,
|
|
221
|
+
AgentStepStartMessage,
|
|
222
|
+
AgentStepHaltMessage,
|
|
223
|
+
AgentStepEndMessage,
|
|
224
|
+
AgentOutputMessage,
|
|
225
|
+
AgentAgentCardMessage,
|
|
226
|
+
AgentDatasearchMessage,
|
|
227
|
+
AgentInvocationMessage,
|
|
228
|
+
AgentModelMessage,
|
|
229
|
+
AgentPythonCodeMessage,
|
|
230
|
+
AgentToolActionMessage,
|
|
231
|
+
AgentModelStreamStartMessage,
|
|
232
|
+
AgentModelStreamEndMessage,
|
|
233
|
+
AgentModelStreamErrorMessage,
|
|
234
|
+
AgentModelStreamUsageMessage,
|
|
235
|
+
AgentModelStreamFragmentMessage,
|
|
236
|
+
AgentAgentCardStreamStartMessage,
|
|
237
|
+
AgentAgentCardStreamErrorMessage,
|
|
238
|
+
AgentAgentCardStreamFragmentMessage,
|
|
239
|
+
AgentAgentCardStreamEndMessage,
|
|
240
|
+
AgentToolRequestMessage,
|
|
241
|
+
AgentToolResponseMessage,
|
|
242
|
+
]
|
|
243
|
+
|
|
244
|
+
SSEDict = {
|
|
245
|
+
MessageType.AGENT_PING.value: AgentPingMessage,
|
|
246
|
+
MessageType.AGENT_START.value: AgentStartMessage,
|
|
247
|
+
MessageType.AGENT_END.value: AgentEndMessage,
|
|
248
|
+
MessageType.AGENT_STEP_START.value: AgentStepStartMessage,
|
|
249
|
+
MessageType.AGENT_STEP_HALT.value: AgentStepHaltMessage,
|
|
250
|
+
MessageType.AGENT_STEP_END.value: AgentStepEndMessage,
|
|
251
|
+
MessageType.AGENT_OUTPUT.value: AgentOutputMessage,
|
|
252
|
+
MessageType.AGENT_AGENT_CARD.value: AgentAgentCardMessage,
|
|
253
|
+
MessageType.AGENT_DATASEARCH.value: AgentDatasearchMessage,
|
|
254
|
+
MessageType.AGENT_INVOCATION.value: AgentInvocationMessage,
|
|
255
|
+
MessageType.AGENT_MODEL.value: AgentModelMessage,
|
|
256
|
+
MessageType.AGENT_PYTHON_CODE.value: AgentPythonCodeMessage,
|
|
257
|
+
MessageType.AGENT_TOOL_ACTION.value: AgentToolActionMessage,
|
|
258
|
+
MessageType.AGENT_MODEL_STREAM_START.value: AgentModelStreamStartMessage,
|
|
259
|
+
MessageType.AGENT_MODEL_STREAM_END.value: AgentModelStreamEndMessage,
|
|
260
|
+
MessageType.AGENT_MODEL_STREAM_ERROR.value: AgentModelStreamErrorMessage,
|
|
261
|
+
MessageType.AGENT_MODEL_STREAM_USAGE.value: AgentModelStreamUsageMessage,
|
|
262
|
+
MessageType.AGENT_MODEL_STREAM_FRAGMENT.value: AgentModelStreamFragmentMessage,
|
|
263
|
+
MessageType.AGENT_AGENT_CARD_STREAM_START.value: AgentAgentCardStreamStartMessage,
|
|
264
|
+
MessageType.AGENT_AGENT_CARD_STREAM_ERROR.value: AgentAgentCardStreamErrorMessage,
|
|
265
|
+
MessageType.AGENT_AGENT_CARD_STREAM_FRAGMENT.value: AgentAgentCardStreamFragmentMessage,
|
|
266
|
+
MessageType.AGENT_AGENT_CARD_STREAM_END.value: AgentAgentCardStreamEndMessage,
|
|
267
|
+
MessageType.AGENT_TOOL_REQUEST.value: AgentToolRequestMessage,
|
|
268
|
+
MessageType.AGENT_TOOL_RESPONSE.value: AgentToolResponseMessage,
|
|
269
|
+
}
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
from typing import AsyncIterable, AsyncIterator, Iterable, Iterator
|
|
4
|
+
|
|
5
|
+
from ..types.sse_messages import SSEDict, SSEMessage
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _to_snake_case(name: str):
|
|
9
|
+
return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def parse_sse_stream_chunked(stream_chunks: Iterable[bytes]) -> Iterator[SSEMessage]:
|
|
13
|
+
"""
|
|
14
|
+
Parse SSE stream from an iterable of chunks (e.g., from a streaming response).
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
stream_chunks: Iterable of string chunks
|
|
18
|
+
|
|
19
|
+
Yields:
|
|
20
|
+
dict: Dictionary containing 'event' and 'data' keys
|
|
21
|
+
"""
|
|
22
|
+
buffer = ""
|
|
23
|
+
|
|
24
|
+
for chunk in stream_chunks:
|
|
25
|
+
buffer += chunk.decode("utf-8")
|
|
26
|
+
|
|
27
|
+
# Look for complete events (ending with double newline)
|
|
28
|
+
while "\n\n" in buffer:
|
|
29
|
+
event_block, buffer = buffer.split("\n\n", 1)
|
|
30
|
+
|
|
31
|
+
if not event_block.strip():
|
|
32
|
+
continue
|
|
33
|
+
|
|
34
|
+
event_name = None
|
|
35
|
+
event_data = None
|
|
36
|
+
|
|
37
|
+
# Parse each line in the event block
|
|
38
|
+
for line in event_block.strip().split("\n"):
|
|
39
|
+
line = line.strip()
|
|
40
|
+
if line.startswith("event:"):
|
|
41
|
+
event_name = line[6:].strip()
|
|
42
|
+
elif line.startswith("data:"):
|
|
43
|
+
data_json = line[5:].strip()
|
|
44
|
+
event_data = json.loads(data_json)
|
|
45
|
+
|
|
46
|
+
if event_name and event_data:
|
|
47
|
+
yield SSEDict[event_name](
|
|
48
|
+
**{_to_snake_case(k): v for k, v in event_data.items()}
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
async def async_parse_sse_stream_chunked(
|
|
53
|
+
stream_chunks: AsyncIterable[bytes],
|
|
54
|
+
) -> AsyncIterator[SSEMessage]:
|
|
55
|
+
"""
|
|
56
|
+
Parse SSE stream from an iterable of chunks (e.g., from a streaming response).
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
stream_chunks: Iterable of string chunks
|
|
60
|
+
|
|
61
|
+
Yields:
|
|
62
|
+
dict: Dictionary containing 'event' and 'data' keys
|
|
63
|
+
"""
|
|
64
|
+
buffer = ""
|
|
65
|
+
|
|
66
|
+
async for chunk in stream_chunks:
|
|
67
|
+
buffer += chunk.decode("utf-8")
|
|
68
|
+
|
|
69
|
+
# Look for complete events (ending with double newline)
|
|
70
|
+
while "\n\n" in buffer:
|
|
71
|
+
event_block, buffer = buffer.split("\n\n", 1)
|
|
72
|
+
|
|
73
|
+
if not event_block.strip():
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
event_name = None
|
|
77
|
+
event_data = None
|
|
78
|
+
|
|
79
|
+
# Parse each line in the event block
|
|
80
|
+
for line in event_block.strip().split("\n"):
|
|
81
|
+
line = line.strip()
|
|
82
|
+
if line.startswith("event:"):
|
|
83
|
+
event_name = line[6:].strip()
|
|
84
|
+
elif line.startswith("data:"):
|
|
85
|
+
data_json = line[5:].strip()
|
|
86
|
+
event_data = json.loads(data_json)
|
|
87
|
+
|
|
88
|
+
if event_name and event_data:
|
|
89
|
+
yield SSEDict[event_name](
|
|
90
|
+
**{_to_snake_case(k): v for k, v in event_data.items()}
|
|
91
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: airia
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.6
|
|
4
4
|
Summary: Python SDK for Airia API
|
|
5
5
|
Author-email: Airia LLC <support@airia.com>
|
|
6
6
|
License: MIT
|
|
@@ -176,6 +176,20 @@ This will create both wheel and source distribution in the `dist/` directory.
|
|
|
176
176
|
|
|
177
177
|
## Quick Start
|
|
178
178
|
|
|
179
|
+
### Client Instantiation
|
|
180
|
+
|
|
181
|
+
```python
|
|
182
|
+
from airia import AiriaClient
|
|
183
|
+
|
|
184
|
+
client = AiriaClient(
|
|
185
|
+
base_url="https://api.airia.ai", # Default: "https://api.airia.ai"
|
|
186
|
+
api_key=None, # Or set AIRIA_API_KEY environment variable
|
|
187
|
+
timeout=30.0, # Request timeout in seconds (default: 30.0)
|
|
188
|
+
log_requests=False, # Enable request/response logging (default: False)
|
|
189
|
+
custom_logger=None # Use custom logger (default: None - uses built-in)
|
|
190
|
+
)
|
|
191
|
+
```
|
|
192
|
+
|
|
179
193
|
### Synchronous Usage
|
|
180
194
|
|
|
181
195
|
```python
|
|
@@ -208,8 +222,8 @@ response = client.execute_pipeline(
|
|
|
208
222
|
async_output=True
|
|
209
223
|
)
|
|
210
224
|
|
|
211
|
-
for c in
|
|
212
|
-
print(c
|
|
225
|
+
for c in response.stream:
|
|
226
|
+
print(c)
|
|
213
227
|
```
|
|
214
228
|
|
|
215
229
|
### Asynchronous Usage
|
|
@@ -219,13 +233,12 @@ import asyncio
|
|
|
219
233
|
from airia import AiriaAsyncClient
|
|
220
234
|
|
|
221
235
|
async def main():
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
print(response.result)
|
|
236
|
+
client = AiriaAsyncClient(api_key="your_api_key")
|
|
237
|
+
response = await client.execute_pipeline(
|
|
238
|
+
pipeline_id="your_pipeline_id",
|
|
239
|
+
user_input="Tell me about quantum computing"
|
|
240
|
+
)
|
|
241
|
+
print(response.result)
|
|
229
242
|
|
|
230
243
|
asyncio.run(main())
|
|
231
244
|
```
|
|
@@ -237,19 +250,111 @@ import asyncio
|
|
|
237
250
|
from airia import AiriaAsyncClient
|
|
238
251
|
|
|
239
252
|
async def main():
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
print(c, end="")
|
|
253
|
+
client = AiriaAsyncClient(api_key="your_api_key")
|
|
254
|
+
response = await client.execute_pipeline(
|
|
255
|
+
pipeline_id="your_pipeline_id",
|
|
256
|
+
user_input="Tell me about quantum computing",
|
|
257
|
+
async_output=True
|
|
258
|
+
)
|
|
259
|
+
async for c in response.stream:
|
|
260
|
+
print(c)
|
|
249
261
|
|
|
250
262
|
asyncio.run(main())
|
|
251
263
|
```
|
|
252
264
|
|
|
265
|
+
## Streaming Event Parsing
|
|
266
|
+
|
|
267
|
+
When using streaming mode (`async_output=True`), the API returns Server-Sent Events (SSE) that contain different types of messages throughout the pipeline execution. You can parse and filter these events to extract specific information.
|
|
268
|
+
|
|
269
|
+
### Available Message Types
|
|
270
|
+
|
|
271
|
+
The streaming response includes various message types defined in `airia.types`. Here are the key ones:
|
|
272
|
+
|
|
273
|
+
- `AgentModelStreamFragmentMessage` - Contains actual LLM output chunks
|
|
274
|
+
- `AgentModelStreamStartMessage` - Indicates LLM streaming has started
|
|
275
|
+
- `AgentModelStreamEndMessage` - Indicates LLM streaming has ended
|
|
276
|
+
- `AgentStepStartMessage` - Indicates a pipeline step has started
|
|
277
|
+
- `AgentStepEndMessage` - Indicates a pipeline step has ended
|
|
278
|
+
- `AgentOutputMessage` - Contains step output
|
|
279
|
+
|
|
280
|
+
<details>
|
|
281
|
+
<summary>Click to expand the full list of message types</summary>
|
|
282
|
+
|
|
283
|
+
```python
|
|
284
|
+
[
|
|
285
|
+
AgentPingMessage,
|
|
286
|
+
AgentStartMessage,
|
|
287
|
+
AgentEndMessage,
|
|
288
|
+
AgentStepStartMessage,
|
|
289
|
+
AgentStepHaltMessage,
|
|
290
|
+
AgentStepEndMessage,
|
|
291
|
+
AgentOutputMessage,
|
|
292
|
+
AgentAgentCardMessage,
|
|
293
|
+
AgentDatasearchMessage,
|
|
294
|
+
AgentInvocationMessage,
|
|
295
|
+
AgentModelMessage,
|
|
296
|
+
AgentPythonCodeMessage,
|
|
297
|
+
AgentToolActionMessage,
|
|
298
|
+
AgentModelStreamStartMessage,
|
|
299
|
+
AgentModelStreamEndMessage,
|
|
300
|
+
AgentModelStreamErrorMessage,
|
|
301
|
+
AgentModelStreamUsageMessage,
|
|
302
|
+
AgentModelStreamFragmentMessage,
|
|
303
|
+
AgentAgentCardStreamStartMessage,
|
|
304
|
+
AgentAgentCardStreamErrorMessage,
|
|
305
|
+
AgentAgentCardStreamFragmentMessage,
|
|
306
|
+
AgentAgentCardStreamEndMessage,
|
|
307
|
+
AgentToolRequestMessage,
|
|
308
|
+
AgentToolResponseMessage,
|
|
309
|
+
]
|
|
310
|
+
```
|
|
311
|
+
|
|
312
|
+
</details>
|
|
313
|
+
|
|
314
|
+
### Filtering LLM Output
|
|
315
|
+
|
|
316
|
+
To extract only the actual LLM output text from the stream:
|
|
317
|
+
|
|
318
|
+
```python
|
|
319
|
+
from airia import AiriaClient
|
|
320
|
+
from airia.types import AgentModelStreamFragmentMessage
|
|
321
|
+
|
|
322
|
+
client = AiriaClient(api_key="your_api_key")
|
|
323
|
+
|
|
324
|
+
response = client.execute_pipeline(
|
|
325
|
+
pipeline_id="your_pipeline_id",
|
|
326
|
+
user_input="Tell me about quantum computing",
|
|
327
|
+
async_output=True
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Filter and display only LLM output
|
|
331
|
+
for event in response.stream:
|
|
332
|
+
if isinstance(event, AgentModelStreamFragmentMessage) and event.index != -1:
|
|
333
|
+
print(event.content, end="", flush=True)
|
|
334
|
+
```
|
|
335
|
+
|
|
336
|
+
## Pipeline Configuration Retrieval
|
|
337
|
+
|
|
338
|
+
You can retrieve detailed configuration information about a pipeline using the `get_pipeline_config` method:
|
|
339
|
+
|
|
340
|
+
> To get a list of all active pipeline ids, run the `get_active_pipelines_ids` method.
|
|
341
|
+
|
|
342
|
+
```python
|
|
343
|
+
from airia import AiriaClient
|
|
344
|
+
|
|
345
|
+
client = AiriaClient(api_key="your_api_key")
|
|
346
|
+
|
|
347
|
+
# Get pipeline configuration
|
|
348
|
+
config = client.get_pipeline_config(pipeline_id="your_pipeline_id")
|
|
349
|
+
|
|
350
|
+
# Access configuration details
|
|
351
|
+
print(f"Pipeline Name: {config.deployment_name}")
|
|
352
|
+
print(f"Description: {config.deployment_description}")
|
|
353
|
+
print(f"Active Version: {config.active_version.version_number}")
|
|
354
|
+
print(f"Success Count: {config.execution_stats.success_count}")
|
|
355
|
+
print(f"Failure Count: {config.execution_stats.failure_count}")
|
|
356
|
+
```
|
|
357
|
+
|
|
253
358
|
## Gateway Usage
|
|
254
359
|
|
|
255
360
|
Airia provides gateway capabilities for popular AI services like OpenAI and Anthropic, allowing you to use your Airia API key with these services.
|
|
@@ -292,6 +397,8 @@ response = client.anthropic.messages.create(
|
|
|
292
397
|
print(response.content[0].text)
|
|
293
398
|
```
|
|
294
399
|
|
|
400
|
+
You can set the Gateway URL by passing the `gateway_url` parameter when using the gateway constructors. The default values are `https://gateway.airia.ai/openai/v1` for OpenAI and `https://gateway.airia.ai/anthropic` for Anthropic.
|
|
401
|
+
|
|
295
402
|
### Asynchronous Gateway Usage
|
|
296
403
|
|
|
297
404
|
Both gateways also support asynchronous usage:
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
airia/__init__.py,sha256=T39gO8E5T5zxlw-JP78ruxOu7-LeKOJCJzz6t40kdQo,231
|
|
2
|
+
airia/exceptions.py,sha256=4Z55n-cRJrtTa5-pZBIK2oZD4-Z99aUtKx_kfTFYY5o,1146
|
|
3
|
+
airia/logs.py,sha256=17YZ4IuzOF0m5bgofj9-QYlJ2BYR2kRZbBVQfFSLFEk,5441
|
|
4
|
+
airia/client/__init__.py,sha256=6gSQ9bl7j79q1HPE0o5py3IRdkwWWuU_7J4h05Dd2o8,127
|
|
5
|
+
airia/client/async_client.py,sha256=QtnH1MNfF_ONx2YHsTz7Y4pFBTrggHWz0LLJiITREwc,23694
|
|
6
|
+
airia/client/base_client.py,sha256=IQYksAw6xCFmAarPUGxyp2bHm95iTzdZUv8FNUq4fHk,7962
|
|
7
|
+
airia/client/sync_client.py,sha256=iUqRn7Bv5hi-Qcg3Vgfc-Z1gx5T42f9Vq4pUTPsyG5M,22750
|
|
8
|
+
airia/types/__init__.py,sha256=s6_uMrrKzhk2wg1lpaRRIl4AQ8hUxKbCNLfz-VK7tAs,2166
|
|
9
|
+
airia/types/api_version.py,sha256=Uzom6O2ZG92HN_Z2h-lTydmO2XYX9RVs4Yi4DJmXytE,255
|
|
10
|
+
airia/types/request_data.py,sha256=m8lBwFliK2_kZU2TgeLdhiO_7v826lVRhA5_8NYB_NM,206
|
|
11
|
+
airia/types/sse_messages.py,sha256=wSdowY07AjEO8R73SJrFPJtkfIBS4satUpNytjKQq2U,8305
|
|
12
|
+
airia/types/api/get_pipeline_config.py,sha256=hHyIfzpH-miayAAUtgf5ywXfqOg53dUMiEfDHW1I-tg,2837
|
|
13
|
+
airia/types/api/pipeline_execution.py,sha256=4zGM5W4aXiAibAdgWQCIZBC_blW6QYAlAHoVUMbnCF0,752
|
|
14
|
+
airia/utils/sse_parser.py,sha256=h3TcBvXqUIngTqgY6yYxZCGLnC1eI6meQzYr13aFAb8,2791
|
|
15
|
+
airia-0.1.6.dist-info/licenses/LICENSE,sha256=R3ClUMMKPRItIcZ0svzyj2taZZnFYw568YDNzN9KQ1Q,1066
|
|
16
|
+
airia-0.1.6.dist-info/METADATA,sha256=1gvnnQKQhkJR31GkL4RG9Vw91UTi9Tclcjv5d9nxpMk,13471
|
|
17
|
+
airia-0.1.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
+
airia-0.1.6.dist-info/top_level.txt,sha256=qUQEKfs_hdOYTwjKj1JZbRhS5YeXDNaKQaVTrzabS6w,6
|
|
19
|
+
airia-0.1.6.dist-info/RECORD,,
|
airia-0.1.4.dist-info/RECORD
DELETED
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
airia/__init__.py,sha256=T39gO8E5T5zxlw-JP78ruxOu7-LeKOJCJzz6t40kdQo,231
|
|
2
|
-
airia/exceptions.py,sha256=4Z55n-cRJrtTa5-pZBIK2oZD4-Z99aUtKx_kfTFYY5o,1146
|
|
3
|
-
airia/logs.py,sha256=17YZ4IuzOF0m5bgofj9-QYlJ2BYR2kRZbBVQfFSLFEk,5441
|
|
4
|
-
airia/client/__init__.py,sha256=6gSQ9bl7j79q1HPE0o5py3IRdkwWWuU_7J4h05Dd2o8,127
|
|
5
|
-
airia/client/async_client.py,sha256=Xd1GAGxQvXVh0-hH8vE4HOeaNY_Fawh5iRxeS26CInY,18173
|
|
6
|
-
airia/client/base_client.py,sha256=3oUcCWWq-DmPbiJke_HHPXm9UzDvD-MzrnoiKnBjcdk,6810
|
|
7
|
-
airia/client/sync_client.py,sha256=p_auqkfyP-9YuCoppl43SyMdKek30kA6M05QI0LYunE,17689
|
|
8
|
-
airia/types/__init__.py,sha256=OOFtJ0VIbO98px89u7cq645iL7CYDOel43g85IAjRRw,562
|
|
9
|
-
airia/types/api_version.py,sha256=Uzom6O2ZG92HN_Z2h-lTydmO2XYX9RVs4Yi4DJmXytE,255
|
|
10
|
-
airia/types/pipeline_execution.py,sha256=obp8KOz-ShNxEciF1YQCSpHXPoJUyEa_9uPv-VHf6YA,699
|
|
11
|
-
airia/types/request_data.py,sha256=RbVPWPRIYuT7FaolJKn19IVzuQKbLM4VhXM5r7UXTUY,186
|
|
12
|
-
airia-0.1.4.dist-info/licenses/LICENSE,sha256=R3ClUMMKPRItIcZ0svzyj2taZZnFYw568YDNzN9KQ1Q,1066
|
|
13
|
-
airia-0.1.4.dist-info/METADATA,sha256=VOWxozh4xZXAHUXH6_ZyCUTIU95LXxly3Yfs-ayi0sk,9966
|
|
14
|
-
airia-0.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
-
airia-0.1.4.dist-info/top_level.txt,sha256=qUQEKfs_hdOYTwjKj1JZbRhS5YeXDNaKQaVTrzabS6w,6
|
|
16
|
-
airia-0.1.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|