airia 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airia/client/async_client.py +141 -23
- airia/client/base_client.py +30 -0
- airia/client/sync_client.py +119 -7
- airia/types/__init__.py +54 -2
- airia/types/api/get_pipeline_config.py +148 -0
- airia/types/{pipeline_execution.py → api/pipeline_execution.py} +4 -2
- airia/types/request_data.py +2 -2
- airia/types/sse_messages.py +269 -0
- airia/utils/sse_parser.py +91 -0
- {airia-0.1.5.dist-info → airia-0.1.7.dist-info}/METADATA +106 -19
- airia-0.1.7.dist-info/RECORD +19 -0
- airia-0.1.5.dist-info/RECORD +0 -16
- {airia-0.1.5.dist-info → airia-0.1.7.dist-info}/WHEEL +0 -0
- {airia-0.1.5.dist-info → airia-0.1.7.dist-info}/licenses/LICENSE +0 -0
- {airia-0.1.5.dist-info → airia-0.1.7.dist-info}/top_level.txt +0 -0
airia/client/async_client.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import weakref
|
|
1
3
|
from typing import Any, AsyncIterator, Dict, List, Literal, Optional, overload
|
|
2
4
|
from urllib.parse import urljoin
|
|
3
5
|
|
|
@@ -7,12 +9,14 @@ import loguru
|
|
|
7
9
|
from ..exceptions import AiriaAPIError
|
|
8
10
|
from ..types import (
|
|
9
11
|
ApiVersion,
|
|
12
|
+
GetPipelineConfigResponse,
|
|
10
13
|
PipelineExecutionDebugResponse,
|
|
11
14
|
PipelineExecutionResponse,
|
|
12
15
|
PipelineExecutionV1StreamedResponse,
|
|
13
16
|
PipelineExecutionV2AsyncStreamedResponse,
|
|
14
17
|
RequestData,
|
|
15
18
|
)
|
|
19
|
+
from ..utils.sse_parser import async_parse_sse_stream_chunked
|
|
16
20
|
from .base_client import AiriaBaseClient
|
|
17
21
|
|
|
18
22
|
|
|
@@ -46,8 +50,31 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
46
50
|
)
|
|
47
51
|
|
|
48
52
|
# Session will be initialized in __aenter__
|
|
49
|
-
self.session = None
|
|
50
53
|
self.headers = {"Content-Type": "application/json"}
|
|
54
|
+
self.session = aiohttp.ClientSession(headers=self.headers)
|
|
55
|
+
|
|
56
|
+
# Register finalizer to clean up session when client is garbage collected
|
|
57
|
+
self._finalizer = weakref.finalize(self, self._cleanup_session, self.session)
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def _cleanup_session(session: aiohttp.ClientSession):
|
|
61
|
+
"""Static method to clean up session - called by finalizer"""
|
|
62
|
+
if session and not session.closed:
|
|
63
|
+
# Create a new event loop if none exists
|
|
64
|
+
try:
|
|
65
|
+
loop = asyncio.get_event_loop()
|
|
66
|
+
if loop.is_closed():
|
|
67
|
+
raise RuntimeError("Event loop is closed")
|
|
68
|
+
except RuntimeError:
|
|
69
|
+
loop = asyncio.new_event_loop()
|
|
70
|
+
asyncio.set_event_loop(loop)
|
|
71
|
+
|
|
72
|
+
# Close the session
|
|
73
|
+
if not loop.is_running():
|
|
74
|
+
loop.run_until_complete(session.close())
|
|
75
|
+
else:
|
|
76
|
+
# If loop is running, schedule the close operation
|
|
77
|
+
asyncio.create_task(session.close())
|
|
51
78
|
|
|
52
79
|
@classmethod
|
|
53
80
|
def with_openai_gateway(
|
|
@@ -117,24 +144,6 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
117
144
|
|
|
118
145
|
return cls(base_url, api_key, timeout, log_requests, custom_logger)
|
|
119
146
|
|
|
120
|
-
async def __aenter__(self):
|
|
121
|
-
"""Async context manager entry point."""
|
|
122
|
-
self.session = aiohttp.ClientSession(headers=self.headers)
|
|
123
|
-
return self
|
|
124
|
-
|
|
125
|
-
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
126
|
-
"""Async context manager exit point."""
|
|
127
|
-
if self.session:
|
|
128
|
-
await self.session.close()
|
|
129
|
-
self.session = None
|
|
130
|
-
|
|
131
|
-
def _check_session(self):
|
|
132
|
-
"""Check if the client session is initialized."""
|
|
133
|
-
if not self.session:
|
|
134
|
-
raise RuntimeError(
|
|
135
|
-
"Client session not initialized. Use async with AiriaAsyncClient() as client: ..."
|
|
136
|
-
)
|
|
137
|
-
|
|
138
147
|
def _handle_exception(
|
|
139
148
|
self, e: aiohttp.ClientResponseError, url: str, correlation_id: str
|
|
140
149
|
):
|
|
@@ -257,8 +266,10 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
257
266
|
response.raise_for_status()
|
|
258
267
|
|
|
259
268
|
# Yields the response content as a stream if streaming
|
|
260
|
-
async for
|
|
261
|
-
|
|
269
|
+
async for message in async_parse_sse_stream_chunked(
|
|
270
|
+
response.content.iter_any()
|
|
271
|
+
):
|
|
272
|
+
yield message
|
|
262
273
|
|
|
263
274
|
except aiohttp.ClientResponseError as e:
|
|
264
275
|
self._handle_exception(e, request_data.url, request_data.correlation_id)
|
|
@@ -415,8 +426,6 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
415
426
|
... )
|
|
416
427
|
>>> print(response.result)
|
|
417
428
|
"""
|
|
418
|
-
self._check_session()
|
|
419
|
-
|
|
420
429
|
request_data = self._pre_execute_pipeline(
|
|
421
430
|
pipeline_id=pipeline_id,
|
|
422
431
|
user_input=user_input,
|
|
@@ -462,3 +471,112 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
462
471
|
return PipelineExecutionV1StreamedResponse(**resp)
|
|
463
472
|
|
|
464
473
|
return PipelineExecutionV2AsyncStreamedResponse(stream=resp)
|
|
474
|
+
|
|
475
|
+
async def get_active_pipelines_ids(
|
|
476
|
+
self,
|
|
477
|
+
correlation_id: Optional[str] = None,
|
|
478
|
+
api_version: str = ApiVersion.V1.value,
|
|
479
|
+
) -> List[str]:
|
|
480
|
+
"""
|
|
481
|
+
Retrieve a list of active pipeline IDs.
|
|
482
|
+
|
|
483
|
+
This method fetches all currently active pipeline IDs from the Airia API.
|
|
484
|
+
These IDs can be used with other methods like execute_pipeline() or
|
|
485
|
+
get_pipeline_config().
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
api_version (str, optional): API version to use for the request.
|
|
489
|
+
Must be one of the supported versions. Defaults to "v1".
|
|
490
|
+
correlation_id (str, optional): Unique identifier for request tracing
|
|
491
|
+
and logging. If not provided, a new UUID will be automatically
|
|
492
|
+
generated.
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
List[str]: A list of active pipeline ID strings. Returns an empty list
|
|
496
|
+
if no active pipelines are found.
|
|
497
|
+
|
|
498
|
+
Raises:
|
|
499
|
+
ValueError: If the provided API version is not supported.
|
|
500
|
+
AiriaAPIError: If the API request fails, including network errors,
|
|
501
|
+
authentication failures, or server errors.
|
|
502
|
+
|
|
503
|
+
Example:
|
|
504
|
+
>>> client = AiriaClient(api_key="your_api_key")
|
|
505
|
+
>>> pipeline_ids = client.get_active_pipelines_ids()
|
|
506
|
+
>>> print(f"Found {len(pipeline_ids)} active pipelines")
|
|
507
|
+
>>> for pipeline_id in pipeline_ids:
|
|
508
|
+
... print(f"Pipeline ID: {pipeline_id}")
|
|
509
|
+
"""
|
|
510
|
+
request_data = self._pre_get_active_pipelines_ids(
|
|
511
|
+
correlation_id=correlation_id, api_version=api_version
|
|
512
|
+
)
|
|
513
|
+
resp = await self._make_request("GET", request_data)
|
|
514
|
+
|
|
515
|
+
if "items" not in resp or len(resp["items"]) == 0:
|
|
516
|
+
return []
|
|
517
|
+
|
|
518
|
+
pipeline_ids = [r["activeVersion"]["pipelineId"] for r in resp["items"]]
|
|
519
|
+
|
|
520
|
+
return pipeline_ids
|
|
521
|
+
|
|
522
|
+
async def get_pipeline_config(
|
|
523
|
+
self,
|
|
524
|
+
pipeline_id: str,
|
|
525
|
+
correlation_id: Optional[str] = None,
|
|
526
|
+
api_version: str = ApiVersion.V1.value,
|
|
527
|
+
) -> GetPipelineConfigResponse:
|
|
528
|
+
"""
|
|
529
|
+
Retrieve configuration details for a specific pipeline.
|
|
530
|
+
|
|
531
|
+
This method fetches comprehensive information about a pipeline including its
|
|
532
|
+
deployment details, execution statistics, version information, and metadata.
|
|
533
|
+
|
|
534
|
+
Args:
|
|
535
|
+
pipeline_id (str): The unique identifier of the pipeline to retrieve
|
|
536
|
+
configuration for.
|
|
537
|
+
api_version (str, optional): The API version to use for the request.
|
|
538
|
+
Defaults to "v1". Valid versions are defined in ApiVersion enum.
|
|
539
|
+
correlation_id (str, optional): A unique identifier for request tracing
|
|
540
|
+
and logging. If not provided, one will be automatically generated.
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
GetPipelineConfigResponse: A response object containing the pipeline
|
|
544
|
+
configuration.
|
|
545
|
+
|
|
546
|
+
Raises:
|
|
547
|
+
ValueError: If the provided api_version is not valid.
|
|
548
|
+
AiriaAPIError: If the API request fails, including cases where:
|
|
549
|
+
- The pipeline_id doesn't exist (404)
|
|
550
|
+
- Authentication fails (401)
|
|
551
|
+
- Access is forbidden (403)
|
|
552
|
+
- Server errors (5xx)
|
|
553
|
+
|
|
554
|
+
Example:
|
|
555
|
+
```python
|
|
556
|
+
from airia import AiriaClient
|
|
557
|
+
|
|
558
|
+
client = AiriaClient(api_key="your_api_key")
|
|
559
|
+
|
|
560
|
+
# Get pipeline configuration
|
|
561
|
+
config = client.get_pipeline_config(
|
|
562
|
+
pipeline_id="your_pipeline_id"
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
print(f"Pipeline: {config.deployment_name}")
|
|
566
|
+
print(f"Description: {config.deployment_description}")
|
|
567
|
+
print(f"Success rate: {config.execution_stats.success_count}")
|
|
568
|
+
print(f"Active version: {config.active_version.version_number}")
|
|
569
|
+
```
|
|
570
|
+
|
|
571
|
+
Note:
|
|
572
|
+
This method only retrieves configuration information and does not
|
|
573
|
+
execute the pipeline. Use execute_pipeline() to run the pipeline.
|
|
574
|
+
"""
|
|
575
|
+
request_data = self._pre_get_pipeline_config(
|
|
576
|
+
pipeline_id=pipeline_id,
|
|
577
|
+
correlation_id=correlation_id,
|
|
578
|
+
api_version=api_version,
|
|
579
|
+
)
|
|
580
|
+
resp = await self._make_request("GET", request_data)
|
|
581
|
+
|
|
582
|
+
return GetPipelineConfigResponse(**resp)
|
airia/client/base_client.py
CHANGED
|
@@ -11,6 +11,7 @@ from ..types import ApiVersion, RequestData
|
|
|
11
11
|
|
|
12
12
|
class AiriaBaseClient:
|
|
13
13
|
"""Base client containing shared functionality for Airia API clients."""
|
|
14
|
+
|
|
14
15
|
openai = None
|
|
15
16
|
anthropic = None
|
|
16
17
|
|
|
@@ -186,3 +187,32 @@ class AiriaBaseClient:
|
|
|
186
187
|
request_data = self._prepare_request(url, payload, correlation_id)
|
|
187
188
|
|
|
188
189
|
return request_data
|
|
190
|
+
|
|
191
|
+
def _pre_get_active_pipelines_ids(
|
|
192
|
+
self,
|
|
193
|
+
correlation_id: Optional[str] = None,
|
|
194
|
+
api_version: str = ApiVersion.V1.value,
|
|
195
|
+
):
|
|
196
|
+
if api_version not in ApiVersion.as_list():
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"Invalid API version: {api_version}. Valid versions are: {', '.join(ApiVersion.as_list())}"
|
|
199
|
+
)
|
|
200
|
+
url = urljoin(self.base_url, f"{api_version}/PipelinesConfig")
|
|
201
|
+
request_data = self._prepare_request(url, correlation_id=correlation_id)
|
|
202
|
+
|
|
203
|
+
return request_data
|
|
204
|
+
|
|
205
|
+
def _pre_get_pipeline_config(
|
|
206
|
+
self,
|
|
207
|
+
pipeline_id: str,
|
|
208
|
+
correlation_id: Optional[str] = None,
|
|
209
|
+
api_version: str = ApiVersion.V1.value,
|
|
210
|
+
):
|
|
211
|
+
if api_version not in ApiVersion.as_list():
|
|
212
|
+
raise ValueError(
|
|
213
|
+
f"Invalid API version: {api_version}. Valid versions are: {', '.join(ApiVersion.as_list())}"
|
|
214
|
+
)
|
|
215
|
+
url = urljoin(self.base_url, f"{api_version}/PipelinesConfig/export/{pipeline_id}")
|
|
216
|
+
request_data = self._prepare_request(url, correlation_id=correlation_id)
|
|
217
|
+
|
|
218
|
+
return request_data
|
airia/client/sync_client.py
CHANGED
|
@@ -7,12 +7,14 @@ import requests
|
|
|
7
7
|
from ..exceptions import AiriaAPIError
|
|
8
8
|
from ..types import (
|
|
9
9
|
ApiVersion,
|
|
10
|
+
GetPipelineConfigResponse,
|
|
10
11
|
PipelineExecutionDebugResponse,
|
|
11
12
|
PipelineExecutionResponse,
|
|
12
13
|
PipelineExecutionV1StreamedResponse,
|
|
13
14
|
PipelineExecutionV2StreamedResponse,
|
|
14
15
|
RequestData,
|
|
15
16
|
)
|
|
17
|
+
from ..utils.sse_parser import parse_sse_stream_chunked
|
|
16
18
|
from .base_client import AiriaBaseClient
|
|
17
19
|
|
|
18
20
|
|
|
@@ -150,7 +152,7 @@ class AiriaClient(AiriaBaseClient):
|
|
|
150
152
|
status_code=e.response.status_code, message=sanitized_message
|
|
151
153
|
) from e
|
|
152
154
|
|
|
153
|
-
def _make_request(self, method: str, request_data: RequestData):
|
|
155
|
+
def _make_request(self, method: str, request_data: RequestData) -> Dict[str, Any]:
|
|
154
156
|
"""
|
|
155
157
|
Makes a synchronous HTTP request to the Airia API.
|
|
156
158
|
|
|
@@ -247,8 +249,8 @@ class AiriaClient(AiriaBaseClient):
|
|
|
247
249
|
response.raise_for_status()
|
|
248
250
|
|
|
249
251
|
# Yields the response content as a stream
|
|
250
|
-
for
|
|
251
|
-
yield
|
|
252
|
+
for message in parse_sse_stream_chunked(response.iter_content()):
|
|
253
|
+
yield message
|
|
252
254
|
|
|
253
255
|
except requests.HTTPError as e:
|
|
254
256
|
self._handle_exception(e, request_data.url, request_data.correlation_id)
|
|
@@ -426,10 +428,11 @@ class AiriaClient(AiriaBaseClient):
|
|
|
426
428
|
api_version=api_version,
|
|
427
429
|
)
|
|
428
430
|
stream = async_output and api_version == ApiVersion.V2.value
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
431
|
+
resp = (
|
|
432
|
+
self._make_request_stream("POST", request_data)
|
|
433
|
+
if stream
|
|
434
|
+
else self._make_request("POST", request_data)
|
|
435
|
+
)
|
|
433
436
|
|
|
434
437
|
if not async_output:
|
|
435
438
|
if not debug:
|
|
@@ -450,3 +453,112 @@ class AiriaClient(AiriaBaseClient):
|
|
|
450
453
|
return PipelineExecutionV1StreamedResponse(**resp)
|
|
451
454
|
|
|
452
455
|
return PipelineExecutionV2StreamedResponse(stream=resp)
|
|
456
|
+
|
|
457
|
+
def get_active_pipelines_ids(
|
|
458
|
+
self,
|
|
459
|
+
correlation_id: Optional[str] = None,
|
|
460
|
+
api_version: str = ApiVersion.V1.value,
|
|
461
|
+
) -> List[str]:
|
|
462
|
+
"""
|
|
463
|
+
Retrieve a list of active pipeline IDs.
|
|
464
|
+
|
|
465
|
+
This method fetches all currently active pipeline IDs from the Airia API.
|
|
466
|
+
These IDs can be used with other methods like execute_pipeline() or
|
|
467
|
+
get_pipeline_config().
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
api_version (str, optional): API version to use for the request.
|
|
471
|
+
Must be one of the supported versions. Defaults to "v1".
|
|
472
|
+
correlation_id (str, optional): Unique identifier for request tracing
|
|
473
|
+
and logging. If not provided, a new UUID will be automatically
|
|
474
|
+
generated.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
List[str]: A list of active pipeline ID strings. Returns an empty list
|
|
478
|
+
if no active pipelines are found.
|
|
479
|
+
|
|
480
|
+
Raises:
|
|
481
|
+
ValueError: If the provided API version is not supported.
|
|
482
|
+
AiriaAPIError: If the API request fails, including network errors,
|
|
483
|
+
authentication failures, or server errors.
|
|
484
|
+
|
|
485
|
+
Example:
|
|
486
|
+
>>> client = AiriaClient(api_key="your_api_key")
|
|
487
|
+
>>> pipeline_ids = client.get_active_pipelines_ids()
|
|
488
|
+
>>> print(f"Found {len(pipeline_ids)} active pipelines")
|
|
489
|
+
>>> for pipeline_id in pipeline_ids:
|
|
490
|
+
... print(f"Pipeline ID: {pipeline_id}")
|
|
491
|
+
"""
|
|
492
|
+
request_data = self._pre_get_active_pipelines_ids(
|
|
493
|
+
correlation_id=correlation_id, api_version=api_version
|
|
494
|
+
)
|
|
495
|
+
resp = self._make_request("GET", request_data)
|
|
496
|
+
|
|
497
|
+
if "items" not in resp or len(resp["items"]) == 0:
|
|
498
|
+
return []
|
|
499
|
+
|
|
500
|
+
pipeline_ids = [r["activeVersion"]["pipelineId"] for r in resp["items"]]
|
|
501
|
+
|
|
502
|
+
return pipeline_ids
|
|
503
|
+
|
|
504
|
+
def get_pipeline_config(
|
|
505
|
+
self,
|
|
506
|
+
pipeline_id: str,
|
|
507
|
+
correlation_id: Optional[str] = None,
|
|
508
|
+
api_version: str = ApiVersion.V1.value,
|
|
509
|
+
) -> GetPipelineConfigResponse:
|
|
510
|
+
"""
|
|
511
|
+
Retrieve configuration details for a specific pipeline.
|
|
512
|
+
|
|
513
|
+
This method fetches comprehensive information about a pipeline including its
|
|
514
|
+
deployment details, execution statistics, version information, and metadata.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
pipeline_id (str): The unique identifier of the pipeline to retrieve
|
|
518
|
+
configuration for.
|
|
519
|
+
api_version (str, optional): The API version to use for the request.
|
|
520
|
+
Defaults to "v1". Valid versions are defined in ApiVersion enum.
|
|
521
|
+
correlation_id (str, optional): A unique identifier for request tracing
|
|
522
|
+
and logging. If not provided, one will be automatically generated.
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
GetPipelineConfigResponse: A response object containing the pipeline
|
|
526
|
+
configuration.
|
|
527
|
+
|
|
528
|
+
Raises:
|
|
529
|
+
ValueError: If the provided api_version is not valid.
|
|
530
|
+
AiriaAPIError: If the API request fails, including cases where:
|
|
531
|
+
- The pipeline_id doesn't exist (404)
|
|
532
|
+
- Authentication fails (401)
|
|
533
|
+
- Access is forbidden (403)
|
|
534
|
+
- Server errors (5xx)
|
|
535
|
+
|
|
536
|
+
Example:
|
|
537
|
+
```python
|
|
538
|
+
from airia import AiriaClient
|
|
539
|
+
|
|
540
|
+
client = AiriaClient(api_key="your_api_key")
|
|
541
|
+
|
|
542
|
+
# Get pipeline configuration
|
|
543
|
+
config = client.get_pipeline_config(
|
|
544
|
+
pipeline_id="your_pipeline_id"
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
print(f"Pipeline: {config.deployment_name}")
|
|
548
|
+
print(f"Description: {config.deployment_description}")
|
|
549
|
+
print(f"Success rate: {config.execution_stats.success_count}")
|
|
550
|
+
print(f"Active version: {config.active_version.version_number}")
|
|
551
|
+
```
|
|
552
|
+
|
|
553
|
+
Note:
|
|
554
|
+
This method only retrieves configuration information and does not
|
|
555
|
+
execute the pipeline. Use execute_pipeline() to run the pipeline.
|
|
556
|
+
"""
|
|
557
|
+
request_data = self._pre_get_pipeline_config(
|
|
558
|
+
pipeline_id=pipeline_id,
|
|
559
|
+
correlation_id=correlation_id,
|
|
560
|
+
api_version=api_version,
|
|
561
|
+
)
|
|
562
|
+
resp = self._make_request("GET", request_data)
|
|
563
|
+
|
|
564
|
+
return GetPipelineConfigResponse(**resp)
|
airia/types/__init__.py
CHANGED
|
@@ -1,12 +1,39 @@
|
|
|
1
|
-
from .
|
|
2
|
-
from .pipeline_execution import (
|
|
1
|
+
from .api.get_pipeline_config import GetPipelineConfigResponse
|
|
2
|
+
from .api.pipeline_execution import (
|
|
3
3
|
PipelineExecutionDebugResponse,
|
|
4
4
|
PipelineExecutionResponse,
|
|
5
5
|
PipelineExecutionV1StreamedResponse,
|
|
6
6
|
PipelineExecutionV2AsyncStreamedResponse,
|
|
7
7
|
PipelineExecutionV2StreamedResponse,
|
|
8
8
|
)
|
|
9
|
+
from .api_version import ApiVersion
|
|
9
10
|
from .request_data import RequestData
|
|
11
|
+
from .sse_messages import (
|
|
12
|
+
AgentAgentCardMessage,
|
|
13
|
+
AgentAgentCardStreamEndMessage,
|
|
14
|
+
AgentAgentCardStreamErrorMessage,
|
|
15
|
+
AgentAgentCardStreamFragmentMessage,
|
|
16
|
+
AgentAgentCardStreamStartMessage,
|
|
17
|
+
AgentDatasearchMessage,
|
|
18
|
+
AgentEndMessage,
|
|
19
|
+
AgentInvocationMessage,
|
|
20
|
+
AgentModelMessage,
|
|
21
|
+
AgentModelStreamEndMessage,
|
|
22
|
+
AgentModelStreamErrorMessage,
|
|
23
|
+
AgentModelStreamFragmentMessage,
|
|
24
|
+
AgentModelStreamStartMessage,
|
|
25
|
+
AgentModelStreamUsageMessage,
|
|
26
|
+
AgentOutputMessage,
|
|
27
|
+
AgentPingMessage,
|
|
28
|
+
AgentPythonCodeMessage,
|
|
29
|
+
AgentStartMessage,
|
|
30
|
+
AgentStepEndMessage,
|
|
31
|
+
AgentStepHaltMessage,
|
|
32
|
+
AgentStepStartMessage,
|
|
33
|
+
AgentToolActionMessage,
|
|
34
|
+
AgentToolRequestMessage,
|
|
35
|
+
AgentToolResponseMessage,
|
|
36
|
+
)
|
|
10
37
|
|
|
11
38
|
__all__ = [
|
|
12
39
|
"ApiVersion",
|
|
@@ -15,5 +42,30 @@ __all__ = [
|
|
|
15
42
|
"PipelineExecutionV1StreamedResponse",
|
|
16
43
|
"PipelineExecutionV2AsyncStreamedResponse",
|
|
17
44
|
"PipelineExecutionV2StreamedResponse",
|
|
45
|
+
"GetPipelineConfigResponse",
|
|
18
46
|
"RequestData",
|
|
47
|
+
"AgentPingMessage",
|
|
48
|
+
"AgentStartMessage",
|
|
49
|
+
"AgentEndMessage",
|
|
50
|
+
"AgentStepStartMessage",
|
|
51
|
+
"AgentStepHaltMessage",
|
|
52
|
+
"AgentStepEndMessage",
|
|
53
|
+
"AgentOutputMessage",
|
|
54
|
+
"AgentAgentCardMessage",
|
|
55
|
+
"AgentDatasearchMessage",
|
|
56
|
+
"AgentInvocationMessage",
|
|
57
|
+
"AgentModelMessage",
|
|
58
|
+
"AgentPythonCodeMessage",
|
|
59
|
+
"AgentToolActionMessage",
|
|
60
|
+
"AgentModelStreamStartMessage",
|
|
61
|
+
"AgentModelStreamEndMessage",
|
|
62
|
+
"AgentModelStreamErrorMessage",
|
|
63
|
+
"AgentModelStreamUsageMessage",
|
|
64
|
+
"AgentModelStreamFragmentMessage",
|
|
65
|
+
"AgentAgentCardStreamStartMessage",
|
|
66
|
+
"AgentAgentCardStreamErrorMessage",
|
|
67
|
+
"AgentAgentCardStreamFragmentMessage",
|
|
68
|
+
"AgentAgentCardStreamEndMessage",
|
|
69
|
+
"AgentToolRequestMessage",
|
|
70
|
+
"AgentToolResponseMessage",
|
|
19
71
|
]
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Metadata(BaseModel):
|
|
7
|
+
id: str
|
|
8
|
+
export_version: str = Field(alias="exportVersion")
|
|
9
|
+
tagline: Optional[str] = None
|
|
10
|
+
agent_description: Optional[str] = Field(alias="agentDescription", default=None)
|
|
11
|
+
industry: Optional[str] = None
|
|
12
|
+
tasks: Optional[str] = None
|
|
13
|
+
credential_export_option: str = Field(alias="credentialExportOption")
|
|
14
|
+
data_source_export_option: str = Field(alias="dataSourceExportOption")
|
|
15
|
+
version_information: str = Field(alias="versionInformation")
|
|
16
|
+
state: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Agent(BaseModel):
|
|
20
|
+
name: str
|
|
21
|
+
execution_name: str = Field(alias="executionName")
|
|
22
|
+
agent_description: Optional[str] = Field(alias="agentDescription", default=None)
|
|
23
|
+
video_link: Optional[str] = Field(alias="videoLink", default=None)
|
|
24
|
+
industry: Optional[str] = None
|
|
25
|
+
sub_industries: List[str] = Field(alias="subIndustries", default_factory=list)
|
|
26
|
+
agent_details: Dict[str, Any] = Field(alias="agentDetails", default_factory=dict)
|
|
27
|
+
id: str
|
|
28
|
+
agent_icon: Optional[str] = Field(alias="agentIcon", default=None)
|
|
29
|
+
steps: List[Dict[str, Any]]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class PromptMessage(BaseModel):
|
|
33
|
+
text: str
|
|
34
|
+
order: int
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Prompt(BaseModel):
|
|
38
|
+
name: str
|
|
39
|
+
version_change_description: str = Field(alias="versionChangeDescription")
|
|
40
|
+
prompt_message_list: List[PromptMessage] = Field(alias="promptMessageList")
|
|
41
|
+
id: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CredentialData(BaseModel):
|
|
45
|
+
key: str
|
|
46
|
+
value: str
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class CredentialsDefinition(BaseModel):
|
|
50
|
+
name: str
|
|
51
|
+
credential_type: str = Field(alias="credentialType")
|
|
52
|
+
source_type: str = Field(alias="sourceType")
|
|
53
|
+
credential_data_list: List[CredentialData] = Field(alias="credentialDataList")
|
|
54
|
+
id: str
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class HeaderDefinition(BaseModel):
|
|
58
|
+
key: str
|
|
59
|
+
value: str
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ParameterDefinition(BaseModel):
|
|
63
|
+
name: str
|
|
64
|
+
parameter_type: str = Field(alias="parameterType")
|
|
65
|
+
parameter_description: str = Field(alias="parameterDescription")
|
|
66
|
+
default: str
|
|
67
|
+
valid_options: List[str] = Field(alias="validOptions", default_factory=list)
|
|
68
|
+
id: str
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Tool(BaseModel):
|
|
72
|
+
tool_type: str = Field(alias="toolType")
|
|
73
|
+
name: str
|
|
74
|
+
standardized_name: str = Field(alias="standardizedName")
|
|
75
|
+
tool_description: str = Field(alias="toolDescription")
|
|
76
|
+
purpose: str
|
|
77
|
+
api_endpoint: str = Field(alias="apiEndpoint")
|
|
78
|
+
credentials_definition: CredentialsDefinition = Field(alias="credentialsDefinition")
|
|
79
|
+
headers_definition: List[HeaderDefinition] = Field(alias="headersDefinition")
|
|
80
|
+
body: str
|
|
81
|
+
parameters_definition: List[ParameterDefinition] = Field(
|
|
82
|
+
alias="parametersDefinition"
|
|
83
|
+
)
|
|
84
|
+
method_type: str = Field(alias="methodType")
|
|
85
|
+
route_through_acc: bool = Field(alias="routeThroughACC")
|
|
86
|
+
use_user_credentials: bool = Field(alias="useUserCredentials")
|
|
87
|
+
use_user_credentials_type: str = Field(alias="useUserCredentialsType")
|
|
88
|
+
id: str
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class Model(BaseModel):
|
|
92
|
+
id: str
|
|
93
|
+
display_name: str = Field(alias="displayName")
|
|
94
|
+
model_name: str = Field(alias="modelName")
|
|
95
|
+
prompt_id: Optional[str] = Field(alias="promptId", default=None)
|
|
96
|
+
system_prompt_definition: Optional[Any] = Field(
|
|
97
|
+
alias="systemPromptDefinition", default=None
|
|
98
|
+
)
|
|
99
|
+
url: str
|
|
100
|
+
input_type: str = Field(alias="inputType")
|
|
101
|
+
provider: str
|
|
102
|
+
credentials_definition: CredentialsDefinition = Field(alias="credentialsDefinition")
|
|
103
|
+
deployment_type: str = Field(alias="deploymentType")
|
|
104
|
+
source_type: str = Field(alias="sourceType")
|
|
105
|
+
connection_string: Optional[str] = Field(alias="connectionString", default=None)
|
|
106
|
+
container_name: Optional[str] = Field(alias="containerName", default=None)
|
|
107
|
+
deployed_key: Optional[str] = Field(alias="deployedKey", default=None)
|
|
108
|
+
deployed_url: Optional[str] = Field(alias="deployedUrl", default=None)
|
|
109
|
+
state: Optional[str] = None
|
|
110
|
+
uploaded_container_id: Optional[str] = Field(
|
|
111
|
+
alias="uploadedContainerId", default=None
|
|
112
|
+
)
|
|
113
|
+
library_model_id: str = Field(alias="libraryModelId")
|
|
114
|
+
input_token_price: str = Field(alias="inputTokenPrice")
|
|
115
|
+
output_token_price: str = Field(alias="outputTokenPrice")
|
|
116
|
+
token_units: int = Field(alias="tokenUnits")
|
|
117
|
+
has_tool_support: bool = Field(alias="hasToolSupport")
|
|
118
|
+
allow_airia_credentials: bool = Field(alias="allowAiriaCredentials")
|
|
119
|
+
allow_byok_credentials: bool = Field(alias="allowBYOKCredentials")
|
|
120
|
+
author: str
|
|
121
|
+
price_type: str = Field(alias="priceType")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class PythonCodeBlock(BaseModel):
|
|
125
|
+
id: str
|
|
126
|
+
code: str
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class Router(BaseModel):
|
|
130
|
+
id: str
|
|
131
|
+
model_id: str = Field(alias="modelId")
|
|
132
|
+
model: Optional[Any] = None
|
|
133
|
+
router_config: Dict[str, Dict[str, Any]] = Field(alias="routerConfig")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class GetPipelineConfigResponse(BaseModel):
|
|
137
|
+
metadata: Metadata
|
|
138
|
+
agent: Agent
|
|
139
|
+
data_sources: Optional[List[Any]] = Field(alias="dataSources", default_factory=list)
|
|
140
|
+
prompts: Optional[List[Prompt]] = Field(default_factory=list)
|
|
141
|
+
tools: Optional[List[Tool]] = Field(default_factory=list)
|
|
142
|
+
models: Optional[List[Model]] = Field(default_factory=list)
|
|
143
|
+
memories: Optional[Any] = None
|
|
144
|
+
python_code_blocks: Optional[List[PythonCodeBlock]] = Field(
|
|
145
|
+
alias="pythonCodeBlocks", default_factory=list
|
|
146
|
+
)
|
|
147
|
+
routers: Optional[List[Router]] = Field(default_factory=list)
|
|
148
|
+
deployment: Optional[Any] = None
|
|
@@ -2,6 +2,8 @@ from typing import Any, AsyncIterator, Dict, Iterator
|
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel, ConfigDict
|
|
4
4
|
|
|
5
|
+
from ..sse_messages import SSEMessage
|
|
6
|
+
|
|
5
7
|
|
|
6
8
|
class PipelineExecutionResponse(BaseModel):
|
|
7
9
|
result: str
|
|
@@ -22,10 +24,10 @@ class PipelineExecutionV1StreamedResponse(BaseModel):
|
|
|
22
24
|
class PipelineExecutionV2StreamedResponse(BaseModel):
|
|
23
25
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
24
26
|
|
|
25
|
-
stream: Iterator[
|
|
27
|
+
stream: Iterator[SSEMessage]
|
|
26
28
|
|
|
27
29
|
|
|
28
30
|
class PipelineExecutionV2AsyncStreamedResponse(BaseModel):
|
|
29
31
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
30
32
|
|
|
31
|
-
stream: AsyncIterator[
|
|
33
|
+
stream: AsyncIterator[SSEMessage]
|
airia/types/request_data.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
|
-
from typing import Any, Dict
|
|
1
|
+
from typing import Any, Dict, Optional
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class RequestData(BaseModel):
|
|
7
7
|
url: str
|
|
8
|
-
payload: Dict[str, Any]
|
|
8
|
+
payload: Optional[Dict[str, Any]]
|
|
9
9
|
headers: Dict[str, Any]
|
|
10
10
|
correlation_id: str
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
from datetime import datetime, time
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MessageType(str, Enum):
|
|
9
|
+
AGENT_PING = "AgentPingMessage"
|
|
10
|
+
AGENT_START = "AgentStartMessage"
|
|
11
|
+
AGENT_END = "AgentEndMessage"
|
|
12
|
+
AGENT_STEP_START = "AgentStepStartMessage"
|
|
13
|
+
AGENT_STEP_HALT = "AgentStepHaltMessage"
|
|
14
|
+
AGENT_STEP_END = "AgentStepEndMessage"
|
|
15
|
+
AGENT_OUTPUT = "AgentOutputMessage"
|
|
16
|
+
AGENT_AGENT_CARD = "AgentAgentCardMessage"
|
|
17
|
+
AGENT_DATASEARCH = "AgentDatasearchMessage"
|
|
18
|
+
AGENT_INVOCATION = "AgentInvocationMessage"
|
|
19
|
+
AGENT_MODEL = "AgentModelMessage"
|
|
20
|
+
AGENT_PYTHON_CODE = "AgentPythonCodeMessage"
|
|
21
|
+
AGENT_TOOL_ACTION = "AgentToolActionMessage"
|
|
22
|
+
AGENT_MODEL_STREAM_START = "AgentModelStreamStartMessage"
|
|
23
|
+
AGENT_MODEL_STREAM_END = "AgentModelStreamEndMessage"
|
|
24
|
+
AGENT_MODEL_STREAM_ERROR = "AgentModelStreamErrorMessage"
|
|
25
|
+
AGENT_MODEL_STREAM_USAGE = "AgentModelStreamUsageMessage"
|
|
26
|
+
AGENT_MODEL_STREAM_FRAGMENT = "AgentModelStreamFragmentMessage"
|
|
27
|
+
MODEL_STREAM_FRAGMENT = "ModelStreamFragment"
|
|
28
|
+
AGENT_AGENT_CARD_STREAM_START = "AgentAgentCardStreamStartMessage"
|
|
29
|
+
AGENT_AGENT_CARD_STREAM_ERROR = "AgentAgentCardStreamErrorMessage"
|
|
30
|
+
AGENT_AGENT_CARD_STREAM_FRAGMENT = "AgentAgentCardStreamFragmentMessage"
|
|
31
|
+
AGENT_AGENT_CARD_STREAM_END = "AgentAgentCardStreamEndMessage"
|
|
32
|
+
AGENT_TOOL_REQUEST = "AgentToolRequestMessage"
|
|
33
|
+
AGENT_TOOL_RESPONSE = "AgentToolResponseMessage"
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class BaseSSEMessage(BaseModel):
|
|
37
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
38
|
+
message_type: MessageType
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class AgentPingMessage(BaseSSEMessage):
|
|
42
|
+
message_type: MessageType = MessageType.AGENT_PING
|
|
43
|
+
timestamp: datetime
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
### Agent Messages ###
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class BaseAgentMessage(BaseSSEMessage):
|
|
50
|
+
agent_id: str
|
|
51
|
+
execution_id: str
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class AgentStartMessage(BaseAgentMessage):
|
|
55
|
+
message_type: MessageType = MessageType.AGENT_START
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class AgentEndMessage(BaseAgentMessage):
|
|
59
|
+
message_type: MessageType = MessageType.AGENT_END
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
### Step Messages ###
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class BaseStepMessage(BaseAgentMessage):
|
|
66
|
+
step_id: str
|
|
67
|
+
step_type: str
|
|
68
|
+
step_title: Optional[str] = None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class AgentStepStartMessage(BaseStepMessage):
|
|
72
|
+
message_type: MessageType = MessageType.AGENT_STEP_START
|
|
73
|
+
start_time: datetime
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AgentStepHaltMessage(BaseStepMessage):
|
|
77
|
+
message_type: MessageType = MessageType.AGENT_STEP_HALT
|
|
78
|
+
approval_id: str
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class AgentStepEndMessage(BaseStepMessage):
|
|
82
|
+
message_type: MessageType = MessageType.AGENT_STEP_END
|
|
83
|
+
end_time: datetime
|
|
84
|
+
duration: time
|
|
85
|
+
status: str
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class AgentOutputMessage(BaseStepMessage):
|
|
89
|
+
message_type: MessageType = MessageType.AGENT_OUTPUT
|
|
90
|
+
step_result: str
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
### Status Messages ###
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class BaseStatusMessage(BaseStepMessage):
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class AgentAgentCardMessage(BaseStatusMessage):
|
|
101
|
+
message_type: MessageType = MessageType.AGENT_AGENT_CARD
|
|
102
|
+
step_name: str
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class AgentDatasearchMessage(BaseStatusMessage):
|
|
106
|
+
message_type: MessageType = MessageType.AGENT_DATASEARCH
|
|
107
|
+
datastore_id: str
|
|
108
|
+
datastore_type: str
|
|
109
|
+
datastore_name: str
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class AgentInvocationMessage(BaseStatusMessage):
|
|
113
|
+
message_type: MessageType = MessageType.AGENT_INVOCATION
|
|
114
|
+
agent_name: str
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class AgentModelMessage(BaseStatusMessage):
|
|
118
|
+
message_type: MessageType = MessageType.AGENT_MODEL
|
|
119
|
+
model_name: str
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class AgentPythonCodeMessage(BaseStatusMessage):
|
|
123
|
+
message_type: MessageType = MessageType.AGENT_PYTHON_CODE
|
|
124
|
+
step_name: str
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class AgentToolActionMessage(BaseStatusMessage):
|
|
128
|
+
message_type: MessageType = MessageType.AGENT_TOOL_ACTION
|
|
129
|
+
step_name: str
|
|
130
|
+
tool_name: str
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
### Model Stream Messages ###
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class BaseModelStreamMessage(BaseAgentMessage):
|
|
137
|
+
step_id: str
|
|
138
|
+
stream_id: str
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class AgentModelStreamStartMessage(BaseModelStreamMessage):
|
|
142
|
+
message_type: MessageType = MessageType.AGENT_MODEL_STREAM_START
|
|
143
|
+
model_name: str
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class AgentModelStreamErrorMessage(BaseModelStreamMessage):
|
|
147
|
+
message_type: MessageType = MessageType.AGENT_MODEL_STREAM_ERROR
|
|
148
|
+
error_message: str
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class AgentModelStreamFragmentMessage(BaseModelStreamMessage):
|
|
152
|
+
message_type: MessageType = MessageType.AGENT_MODEL_STREAM_FRAGMENT
|
|
153
|
+
index: int
|
|
154
|
+
content: Optional[str] = None
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class AgentModelStreamEndMessage(BaseModelStreamMessage):
|
|
158
|
+
message_type: MessageType = MessageType.AGENT_MODEL_STREAM_END
|
|
159
|
+
content_id: str
|
|
160
|
+
duration: Optional[float] = None
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class AgentModelStreamUsageMessage(BaseModelStreamMessage):
|
|
164
|
+
message_type: MessageType = MessageType.AGENT_MODEL_STREAM_USAGE
|
|
165
|
+
token: Optional[int] = None
|
|
166
|
+
tokens_cost: Optional[float] = None
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
### Agent Card Messages ###
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class BaseAgentAgentCardStreamMessage(BaseAgentMessage):
|
|
173
|
+
step_id: str
|
|
174
|
+
stream_id: str
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class AgentAgentCardStreamStartMessage(BaseAgentAgentCardStreamMessage):
|
|
178
|
+
message_type: MessageType = MessageType.AGENT_AGENT_CARD_STREAM_START
|
|
179
|
+
content: Optional[str] = None
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class AgentAgentCardStreamErrorMessage(BaseAgentAgentCardStreamMessage):
|
|
183
|
+
message_type: MessageType = MessageType.AGENT_AGENT_CARD_STREAM_ERROR
|
|
184
|
+
error_message: str
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class AgentAgentCardStreamFragmentMessage(BaseAgentAgentCardStreamMessage):
|
|
188
|
+
message_type: MessageType = MessageType.AGENT_AGENT_CARD_STREAM_FRAGMENT
|
|
189
|
+
index: int
|
|
190
|
+
content: Optional[str]
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class AgentAgentCardStreamEndMessage(BaseAgentAgentCardStreamMessage):
|
|
194
|
+
message_type: MessageType = MessageType.AGENT_AGENT_CARD_STREAM_END
|
|
195
|
+
content: Optional[str] = None
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
### Tool Messages ###
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class BaseAgentToolMessage(BaseStepMessage):
|
|
202
|
+
id: str
|
|
203
|
+
name: str
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class AgentToolRequestMessage(BaseAgentToolMessage):
|
|
207
|
+
message_type: MessageType = MessageType.AGENT_TOOL_REQUEST
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class AgentToolResponseMessage(BaseAgentToolMessage):
|
|
211
|
+
message_type: MessageType = MessageType.AGENT_TOOL_RESPONSE
|
|
212
|
+
duration: time
|
|
213
|
+
success: bool
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
# Union type for all possible messages
|
|
217
|
+
SSEMessage = Union[
|
|
218
|
+
AgentPingMessage,
|
|
219
|
+
AgentStartMessage,
|
|
220
|
+
AgentEndMessage,
|
|
221
|
+
AgentStepStartMessage,
|
|
222
|
+
AgentStepHaltMessage,
|
|
223
|
+
AgentStepEndMessage,
|
|
224
|
+
AgentOutputMessage,
|
|
225
|
+
AgentAgentCardMessage,
|
|
226
|
+
AgentDatasearchMessage,
|
|
227
|
+
AgentInvocationMessage,
|
|
228
|
+
AgentModelMessage,
|
|
229
|
+
AgentPythonCodeMessage,
|
|
230
|
+
AgentToolActionMessage,
|
|
231
|
+
AgentModelStreamStartMessage,
|
|
232
|
+
AgentModelStreamEndMessage,
|
|
233
|
+
AgentModelStreamErrorMessage,
|
|
234
|
+
AgentModelStreamUsageMessage,
|
|
235
|
+
AgentModelStreamFragmentMessage,
|
|
236
|
+
AgentAgentCardStreamStartMessage,
|
|
237
|
+
AgentAgentCardStreamErrorMessage,
|
|
238
|
+
AgentAgentCardStreamFragmentMessage,
|
|
239
|
+
AgentAgentCardStreamEndMessage,
|
|
240
|
+
AgentToolRequestMessage,
|
|
241
|
+
AgentToolResponseMessage,
|
|
242
|
+
]
|
|
243
|
+
|
|
244
|
+
SSEDict = {
|
|
245
|
+
MessageType.AGENT_PING.value: AgentPingMessage,
|
|
246
|
+
MessageType.AGENT_START.value: AgentStartMessage,
|
|
247
|
+
MessageType.AGENT_END.value: AgentEndMessage,
|
|
248
|
+
MessageType.AGENT_STEP_START.value: AgentStepStartMessage,
|
|
249
|
+
MessageType.AGENT_STEP_HALT.value: AgentStepHaltMessage,
|
|
250
|
+
MessageType.AGENT_STEP_END.value: AgentStepEndMessage,
|
|
251
|
+
MessageType.AGENT_OUTPUT.value: AgentOutputMessage,
|
|
252
|
+
MessageType.AGENT_AGENT_CARD.value: AgentAgentCardMessage,
|
|
253
|
+
MessageType.AGENT_DATASEARCH.value: AgentDatasearchMessage,
|
|
254
|
+
MessageType.AGENT_INVOCATION.value: AgentInvocationMessage,
|
|
255
|
+
MessageType.AGENT_MODEL.value: AgentModelMessage,
|
|
256
|
+
MessageType.AGENT_PYTHON_CODE.value: AgentPythonCodeMessage,
|
|
257
|
+
MessageType.AGENT_TOOL_ACTION.value: AgentToolActionMessage,
|
|
258
|
+
MessageType.AGENT_MODEL_STREAM_START.value: AgentModelStreamStartMessage,
|
|
259
|
+
MessageType.AGENT_MODEL_STREAM_END.value: AgentModelStreamEndMessage,
|
|
260
|
+
MessageType.AGENT_MODEL_STREAM_ERROR.value: AgentModelStreamErrorMessage,
|
|
261
|
+
MessageType.AGENT_MODEL_STREAM_USAGE.value: AgentModelStreamUsageMessage,
|
|
262
|
+
MessageType.AGENT_MODEL_STREAM_FRAGMENT.value: AgentModelStreamFragmentMessage,
|
|
263
|
+
MessageType.AGENT_AGENT_CARD_STREAM_START.value: AgentAgentCardStreamStartMessage,
|
|
264
|
+
MessageType.AGENT_AGENT_CARD_STREAM_ERROR.value: AgentAgentCardStreamErrorMessage,
|
|
265
|
+
MessageType.AGENT_AGENT_CARD_STREAM_FRAGMENT.value: AgentAgentCardStreamFragmentMessage,
|
|
266
|
+
MessageType.AGENT_AGENT_CARD_STREAM_END.value: AgentAgentCardStreamEndMessage,
|
|
267
|
+
MessageType.AGENT_TOOL_REQUEST.value: AgentToolRequestMessage,
|
|
268
|
+
MessageType.AGENT_TOOL_RESPONSE.value: AgentToolResponseMessage,
|
|
269
|
+
}
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import re
|
|
3
|
+
from typing import AsyncIterable, AsyncIterator, Iterable, Iterator
|
|
4
|
+
|
|
5
|
+
from ..types.sse_messages import SSEDict, SSEMessage
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _to_snake_case(name: str):
|
|
9
|
+
return re.sub(r"(?<!^)(?=[A-Z])", "_", name).lower()
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def parse_sse_stream_chunked(stream_chunks: Iterable[bytes]) -> Iterator[SSEMessage]:
|
|
13
|
+
"""
|
|
14
|
+
Parse SSE stream from an iterable of chunks (e.g., from a streaming response).
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
stream_chunks: Iterable of string chunks
|
|
18
|
+
|
|
19
|
+
Yields:
|
|
20
|
+
dict: Dictionary containing 'event' and 'data' keys
|
|
21
|
+
"""
|
|
22
|
+
buffer = ""
|
|
23
|
+
|
|
24
|
+
for chunk in stream_chunks:
|
|
25
|
+
buffer += chunk.decode("utf-8")
|
|
26
|
+
|
|
27
|
+
# Look for complete events (ending with double newline)
|
|
28
|
+
while "\n\n" in buffer:
|
|
29
|
+
event_block, buffer = buffer.split("\n\n", 1)
|
|
30
|
+
|
|
31
|
+
if not event_block.strip():
|
|
32
|
+
continue
|
|
33
|
+
|
|
34
|
+
event_name = None
|
|
35
|
+
event_data = None
|
|
36
|
+
|
|
37
|
+
# Parse each line in the event block
|
|
38
|
+
for line in event_block.strip().split("\n"):
|
|
39
|
+
line = line.strip()
|
|
40
|
+
if line.startswith("event:"):
|
|
41
|
+
event_name = line[6:].strip()
|
|
42
|
+
elif line.startswith("data:"):
|
|
43
|
+
data_json = line[5:].strip()
|
|
44
|
+
event_data = json.loads(data_json)
|
|
45
|
+
|
|
46
|
+
if event_name and event_data:
|
|
47
|
+
yield SSEDict[event_name](
|
|
48
|
+
**{_to_snake_case(k): v for k, v in event_data.items()}
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
async def async_parse_sse_stream_chunked(
|
|
53
|
+
stream_chunks: AsyncIterable[bytes],
|
|
54
|
+
) -> AsyncIterator[SSEMessage]:
|
|
55
|
+
"""
|
|
56
|
+
Parse SSE stream from an iterable of chunks (e.g., from a streaming response).
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
stream_chunks: Iterable of string chunks
|
|
60
|
+
|
|
61
|
+
Yields:
|
|
62
|
+
dict: Dictionary containing 'event' and 'data' keys
|
|
63
|
+
"""
|
|
64
|
+
buffer = ""
|
|
65
|
+
|
|
66
|
+
async for chunk in stream_chunks:
|
|
67
|
+
buffer += chunk.decode("utf-8")
|
|
68
|
+
|
|
69
|
+
# Look for complete events (ending with double newline)
|
|
70
|
+
while "\n\n" in buffer:
|
|
71
|
+
event_block, buffer = buffer.split("\n\n", 1)
|
|
72
|
+
|
|
73
|
+
if not event_block.strip():
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
event_name = None
|
|
77
|
+
event_data = None
|
|
78
|
+
|
|
79
|
+
# Parse each line in the event block
|
|
80
|
+
for line in event_block.strip().split("\n"):
|
|
81
|
+
line = line.strip()
|
|
82
|
+
if line.startswith("event:"):
|
|
83
|
+
event_name = line[6:].strip()
|
|
84
|
+
elif line.startswith("data:"):
|
|
85
|
+
data_json = line[5:].strip()
|
|
86
|
+
event_data = json.loads(data_json)
|
|
87
|
+
|
|
88
|
+
if event_name and event_data:
|
|
89
|
+
yield SSEDict[event_name](
|
|
90
|
+
**{_to_snake_case(k): v for k, v in event_data.items()}
|
|
91
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: airia
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.7
|
|
4
4
|
Summary: Python SDK for Airia API
|
|
5
5
|
Author-email: Airia LLC <support@airia.com>
|
|
6
6
|
License: MIT
|
|
@@ -182,7 +182,7 @@ This will create both wheel and source distribution in the `dist/` directory.
|
|
|
182
182
|
from airia import AiriaClient
|
|
183
183
|
|
|
184
184
|
client = AiriaClient(
|
|
185
|
-
base_url="https://api.airia.
|
|
185
|
+
base_url="https://api.airia.ai", # Default: "https://api.airia.ai"
|
|
186
186
|
api_key=None, # Or set AIRIA_API_KEY environment variable
|
|
187
187
|
timeout=30.0, # Request timeout in seconds (default: 30.0)
|
|
188
188
|
log_requests=False, # Enable request/response logging (default: False)
|
|
@@ -223,7 +223,7 @@ response = client.execute_pipeline(
|
|
|
223
223
|
)
|
|
224
224
|
|
|
225
225
|
for c in response.stream:
|
|
226
|
-
print(c
|
|
226
|
+
print(c)
|
|
227
227
|
```
|
|
228
228
|
|
|
229
229
|
### Asynchronous Usage
|
|
@@ -233,13 +233,12 @@ import asyncio
|
|
|
233
233
|
from airia import AiriaAsyncClient
|
|
234
234
|
|
|
235
235
|
async def main():
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
print(response.result)
|
|
236
|
+
client = AiriaAsyncClient(api_key="your_api_key")
|
|
237
|
+
response = await client.execute_pipeline(
|
|
238
|
+
pipeline_id="your_pipeline_id",
|
|
239
|
+
user_input="Tell me about quantum computing"
|
|
240
|
+
)
|
|
241
|
+
print(response.result)
|
|
243
242
|
|
|
244
243
|
asyncio.run(main())
|
|
245
244
|
```
|
|
@@ -251,19 +250,107 @@ import asyncio
|
|
|
251
250
|
from airia import AiriaAsyncClient
|
|
252
251
|
|
|
253
252
|
async def main():
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
print(c, end="")
|
|
253
|
+
client = AiriaAsyncClient(api_key="your_api_key")
|
|
254
|
+
response = await client.execute_pipeline(
|
|
255
|
+
pipeline_id="your_pipeline_id",
|
|
256
|
+
user_input="Tell me about quantum computing",
|
|
257
|
+
async_output=True
|
|
258
|
+
)
|
|
259
|
+
async for c in response.stream:
|
|
260
|
+
print(c)
|
|
263
261
|
|
|
264
262
|
asyncio.run(main())
|
|
265
263
|
```
|
|
266
264
|
|
|
265
|
+
## Streaming Event Parsing
|
|
266
|
+
|
|
267
|
+
When using streaming mode (`async_output=True`), the API returns Server-Sent Events (SSE) that contain different types of messages throughout the pipeline execution. You can parse and filter these events to extract specific information.
|
|
268
|
+
|
|
269
|
+
### Available Message Types
|
|
270
|
+
|
|
271
|
+
The streaming response includes various message types defined in `airia.types`. Here are the key ones:
|
|
272
|
+
|
|
273
|
+
- `AgentModelStreamFragmentMessage` - Contains actual LLM output chunks
|
|
274
|
+
- `AgentModelStreamStartMessage` - Indicates LLM streaming has started
|
|
275
|
+
- `AgentModelStreamEndMessage` - Indicates LLM streaming has ended
|
|
276
|
+
- `AgentStepStartMessage` - Indicates a pipeline step has started
|
|
277
|
+
- `AgentStepEndMessage` - Indicates a pipeline step has ended
|
|
278
|
+
- `AgentOutputMessage` - Contains step output
|
|
279
|
+
|
|
280
|
+
<details>
|
|
281
|
+
<summary>Click to expand the full list of message types</summary>
|
|
282
|
+
|
|
283
|
+
```python
|
|
284
|
+
[
|
|
285
|
+
AgentPingMessage,
|
|
286
|
+
AgentStartMessage,
|
|
287
|
+
AgentEndMessage,
|
|
288
|
+
AgentStepStartMessage,
|
|
289
|
+
AgentStepHaltMessage,
|
|
290
|
+
AgentStepEndMessage,
|
|
291
|
+
AgentOutputMessage,
|
|
292
|
+
AgentAgentCardMessage,
|
|
293
|
+
AgentDatasearchMessage,
|
|
294
|
+
AgentInvocationMessage,
|
|
295
|
+
AgentModelMessage,
|
|
296
|
+
AgentPythonCodeMessage,
|
|
297
|
+
AgentToolActionMessage,
|
|
298
|
+
AgentModelStreamStartMessage,
|
|
299
|
+
AgentModelStreamEndMessage,
|
|
300
|
+
AgentModelStreamErrorMessage,
|
|
301
|
+
AgentModelStreamUsageMessage,
|
|
302
|
+
AgentModelStreamFragmentMessage,
|
|
303
|
+
AgentAgentCardStreamStartMessage,
|
|
304
|
+
AgentAgentCardStreamErrorMessage,
|
|
305
|
+
AgentAgentCardStreamFragmentMessage,
|
|
306
|
+
AgentAgentCardStreamEndMessage,
|
|
307
|
+
AgentToolRequestMessage,
|
|
308
|
+
AgentToolResponseMessage,
|
|
309
|
+
]
|
|
310
|
+
```
|
|
311
|
+
|
|
312
|
+
</details>
|
|
313
|
+
|
|
314
|
+
### Filtering LLM Output
|
|
315
|
+
|
|
316
|
+
To extract only the actual LLM output text from the stream:
|
|
317
|
+
|
|
318
|
+
```python
|
|
319
|
+
from airia import AiriaClient
|
|
320
|
+
from airia.types import AgentModelStreamFragmentMessage
|
|
321
|
+
|
|
322
|
+
client = AiriaClient(api_key="your_api_key")
|
|
323
|
+
|
|
324
|
+
response = client.execute_pipeline(
|
|
325
|
+
pipeline_id="your_pipeline_id",
|
|
326
|
+
user_input="Tell me about quantum computing",
|
|
327
|
+
async_output=True
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Filter and display only LLM output
|
|
331
|
+
for event in response.stream:
|
|
332
|
+
if isinstance(event, AgentModelStreamFragmentMessage) and event.index != -1:
|
|
333
|
+
print(event.content, end="", flush=True)
|
|
334
|
+
```
|
|
335
|
+
|
|
336
|
+
## Pipeline Configuration Retrieval
|
|
337
|
+
|
|
338
|
+
You can retrieve detailed configuration information about a pipeline using the `get_pipeline_config` method:
|
|
339
|
+
|
|
340
|
+
> To get a list of all active pipeline ids, run the `get_active_pipelines_ids` method.
|
|
341
|
+
|
|
342
|
+
```python
|
|
343
|
+
from airia import AiriaClient
|
|
344
|
+
|
|
345
|
+
client = AiriaClient(api_key="your_api_key")
|
|
346
|
+
|
|
347
|
+
# Get pipeline configuration
|
|
348
|
+
config = client.get_pipeline_config(pipeline_id="your_pipeline_id")
|
|
349
|
+
|
|
350
|
+
# Access configuration details
|
|
351
|
+
print(f"Pipeline Name: {config.agent.name}")
|
|
352
|
+
```
|
|
353
|
+
|
|
267
354
|
## Gateway Usage
|
|
268
355
|
|
|
269
356
|
Airia provides gateway capabilities for popular AI services like OpenAI and Anthropic, allowing you to use your Airia API key with these services.
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
airia/__init__.py,sha256=T39gO8E5T5zxlw-JP78ruxOu7-LeKOJCJzz6t40kdQo,231
|
|
2
|
+
airia/exceptions.py,sha256=4Z55n-cRJrtTa5-pZBIK2oZD4-Z99aUtKx_kfTFYY5o,1146
|
|
3
|
+
airia/logs.py,sha256=17YZ4IuzOF0m5bgofj9-QYlJ2BYR2kRZbBVQfFSLFEk,5441
|
|
4
|
+
airia/client/__init__.py,sha256=6gSQ9bl7j79q1HPE0o5py3IRdkwWWuU_7J4h05Dd2o8,127
|
|
5
|
+
airia/client/async_client.py,sha256=QtnH1MNfF_ONx2YHsTz7Y4pFBTrggHWz0LLJiITREwc,23694
|
|
6
|
+
airia/client/base_client.py,sha256=YWhQle7cmtiuc5BwiavW-NYjTC6kXyclpse4irwkMuI,7969
|
|
7
|
+
airia/client/sync_client.py,sha256=iUqRn7Bv5hi-Qcg3Vgfc-Z1gx5T42f9Vq4pUTPsyG5M,22750
|
|
8
|
+
airia/types/__init__.py,sha256=s6_uMrrKzhk2wg1lpaRRIl4AQ8hUxKbCNLfz-VK7tAs,2166
|
|
9
|
+
airia/types/api_version.py,sha256=Uzom6O2ZG92HN_Z2h-lTydmO2XYX9RVs4Yi4DJmXytE,255
|
|
10
|
+
airia/types/request_data.py,sha256=m8lBwFliK2_kZU2TgeLdhiO_7v826lVRhA5_8NYB_NM,206
|
|
11
|
+
airia/types/sse_messages.py,sha256=wSdowY07AjEO8R73SJrFPJtkfIBS4satUpNytjKQq2U,8305
|
|
12
|
+
airia/types/api/get_pipeline_config.py,sha256=TKfqB705dYFdCPMNNgun35WcUx_XfYYOLp5sNhk0Pu8,5357
|
|
13
|
+
airia/types/api/pipeline_execution.py,sha256=4zGM5W4aXiAibAdgWQCIZBC_blW6QYAlAHoVUMbnCF0,752
|
|
14
|
+
airia/utils/sse_parser.py,sha256=h3TcBvXqUIngTqgY6yYxZCGLnC1eI6meQzYr13aFAb8,2791
|
|
15
|
+
airia-0.1.7.dist-info/licenses/LICENSE,sha256=R3ClUMMKPRItIcZ0svzyj2taZZnFYw568YDNzN9KQ1Q,1066
|
|
16
|
+
airia-0.1.7.dist-info/METADATA,sha256=c3-ME4h5ftSfH0as3VFsE5rkX6jBx6pbK2vXZx6R4eA,13218
|
|
17
|
+
airia-0.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
18
|
+
airia-0.1.7.dist-info/top_level.txt,sha256=qUQEKfs_hdOYTwjKj1JZbRhS5YeXDNaKQaVTrzabS6w,6
|
|
19
|
+
airia-0.1.7.dist-info/RECORD,,
|
airia-0.1.5.dist-info/RECORD
DELETED
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
airia/__init__.py,sha256=T39gO8E5T5zxlw-JP78ruxOu7-LeKOJCJzz6t40kdQo,231
|
|
2
|
-
airia/exceptions.py,sha256=4Z55n-cRJrtTa5-pZBIK2oZD4-Z99aUtKx_kfTFYY5o,1146
|
|
3
|
-
airia/logs.py,sha256=17YZ4IuzOF0m5bgofj9-QYlJ2BYR2kRZbBVQfFSLFEk,5441
|
|
4
|
-
airia/client/__init__.py,sha256=6gSQ9bl7j79q1HPE0o5py3IRdkwWWuU_7J4h05Dd2o8,127
|
|
5
|
-
airia/client/async_client.py,sha256=4JU6RePXjznulnQXkfo9ZVX3zLeUyclOgxwlaM7-5Js,18816
|
|
6
|
-
airia/client/base_client.py,sha256=IEVsskGpiEL3jw7YLc96YWUZ5rpUZula52IZHe2D8mE,6844
|
|
7
|
-
airia/client/sync_client.py,sha256=mCLViWkW-z8QhCy8CmlEsT8dssBb1g6V2B3fAYgYU1A,18336
|
|
8
|
-
airia/types/__init__.py,sha256=OOFtJ0VIbO98px89u7cq645iL7CYDOel43g85IAjRRw,562
|
|
9
|
-
airia/types/api_version.py,sha256=Uzom6O2ZG92HN_Z2h-lTydmO2XYX9RVs4Yi4DJmXytE,255
|
|
10
|
-
airia/types/pipeline_execution.py,sha256=obp8KOz-ShNxEciF1YQCSpHXPoJUyEa_9uPv-VHf6YA,699
|
|
11
|
-
airia/types/request_data.py,sha256=RbVPWPRIYuT7FaolJKn19IVzuQKbLM4VhXM5r7UXTUY,186
|
|
12
|
-
airia-0.1.5.dist-info/licenses/LICENSE,sha256=R3ClUMMKPRItIcZ0svzyj2taZZnFYw568YDNzN9KQ1Q,1066
|
|
13
|
-
airia-0.1.5.dist-info/METADATA,sha256=WIPLholUrToVlvO4K269qOSYHT4fD7XHpVSrjowOF6A,10745
|
|
14
|
-
airia-0.1.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
-
airia-0.1.5.dist-info/top_level.txt,sha256=qUQEKfs_hdOYTwjKj1JZbRhS5YeXDNaKQaVTrzabS6w,6
|
|
16
|
-
airia-0.1.5.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|