airia 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airia/client/async_client.py +169 -98
- airia/client/base_client.py +77 -16
- airia/client/sync_client.py +165 -97
- airia/constants.py +9 -0
- airia/types/api/__init__.py +19 -0
- airia/types/api/conversations.py +14 -0
- airia/types/api/pipeline_execution.py +6 -10
- airia/types/{__init__.py → sse/__init__.py} +4 -20
- airia/utils/sse_parser.py +1 -1
- {airia-0.1.9.dist-info → airia-0.1.11.dist-info}/METADATA +182 -8
- airia-0.1.11.dist-info/RECORD +23 -0
- airia-0.1.9.dist-info/RECORD +0 -20
- /airia/types/{api_version.py → _api_version.py} +0 -0
- /airia/types/{request_data.py → _request_data.py} +0 -0
- /airia/types/{sse_messages.py → sse/sse_messages.py} +0 -0
- {airia-0.1.9.dist-info → airia-0.1.11.dist-info}/WHEEL +0 -0
- {airia-0.1.9.dist-info → airia-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {airia-0.1.9.dist-info → airia-0.1.11.dist-info}/top_level.txt +0 -0
airia/client/sync_client.py
CHANGED
|
@@ -1,19 +1,24 @@
|
|
|
1
1
|
from typing import Any, Dict, List, Literal, Optional, overload
|
|
2
|
-
from urllib.parse import urljoin
|
|
3
2
|
|
|
4
3
|
import loguru
|
|
5
4
|
import requests
|
|
6
5
|
|
|
6
|
+
from ..constants import (
|
|
7
|
+
DEFAULT_ANTHROPIC_GATEWAY_URL,
|
|
8
|
+
DEFAULT_BASE_URL,
|
|
9
|
+
DEFAULT_OPENAI_GATEWAY_URL,
|
|
10
|
+
DEFAULT_TIMEOUT,
|
|
11
|
+
)
|
|
7
12
|
from ..exceptions import AiriaAPIError
|
|
8
|
-
from ..types import
|
|
9
|
-
|
|
13
|
+
from ..types._api_version import ApiVersion
|
|
14
|
+
from ..types._request_data import RequestData
|
|
15
|
+
from ..types.api import (
|
|
16
|
+
CreateConversationResponse,
|
|
10
17
|
GetPipelineConfigResponse,
|
|
11
18
|
PipelineExecutionDebugResponse,
|
|
12
19
|
PipelineExecutionResponse,
|
|
13
|
-
|
|
14
|
-
PipelineExecutionV2StreamedResponse,
|
|
20
|
+
PipelineExecutionStreamedResponse,
|
|
15
21
|
ProjectItem,
|
|
16
|
-
RequestData,
|
|
17
22
|
)
|
|
18
23
|
from ..utils.sse_parser import parse_sse_stream_chunked
|
|
19
24
|
from .base_client import AiriaBaseClient
|
|
@@ -24,9 +29,10 @@ class AiriaClient(AiriaBaseClient):
|
|
|
24
29
|
|
|
25
30
|
def __init__(
|
|
26
31
|
self,
|
|
27
|
-
base_url: str =
|
|
32
|
+
base_url: str = DEFAULT_BASE_URL,
|
|
28
33
|
api_key: Optional[str] = None,
|
|
29
|
-
|
|
34
|
+
bearer_token: Optional[str] = None,
|
|
35
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
30
36
|
log_requests: bool = False,
|
|
31
37
|
custom_logger: Optional["loguru.Logger"] = None,
|
|
32
38
|
):
|
|
@@ -36,6 +42,7 @@ class AiriaClient(AiriaBaseClient):
|
|
|
36
42
|
Args:
|
|
37
43
|
base_url: Base URL of the Airia API.
|
|
38
44
|
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
45
|
+
bearer_token: Bearer token for authentication. Must be provided explicitly (no environment variable fallback).
|
|
39
46
|
timeout: Request timeout in seconds.
|
|
40
47
|
log_requests: Whether to log API requests and responses. Default is False.
|
|
41
48
|
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
@@ -43,6 +50,7 @@ class AiriaClient(AiriaBaseClient):
|
|
|
43
50
|
super().__init__(
|
|
44
51
|
base_url=base_url,
|
|
45
52
|
api_key=api_key,
|
|
53
|
+
bearer_token=bearer_token,
|
|
46
54
|
timeout=timeout,
|
|
47
55
|
log_requests=log_requests,
|
|
48
56
|
custom_logger=custom_logger,
|
|
@@ -55,10 +63,10 @@ class AiriaClient(AiriaBaseClient):
|
|
|
55
63
|
@classmethod
|
|
56
64
|
def with_openai_gateway(
|
|
57
65
|
cls,
|
|
58
|
-
base_url: str =
|
|
59
|
-
gateway_url: str =
|
|
66
|
+
base_url: str = DEFAULT_BASE_URL,
|
|
67
|
+
gateway_url: str = DEFAULT_OPENAI_GATEWAY_URL,
|
|
60
68
|
api_key: Optional[str] = None,
|
|
61
|
-
timeout: float =
|
|
69
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
62
70
|
log_requests: bool = False,
|
|
63
71
|
custom_logger: Optional["loguru.Logger"] = None,
|
|
64
72
|
**kwargs,
|
|
@@ -77,22 +85,28 @@ class AiriaClient(AiriaBaseClient):
|
|
|
77
85
|
"""
|
|
78
86
|
from openai import OpenAI
|
|
79
87
|
|
|
80
|
-
|
|
81
|
-
|
|
88
|
+
client = cls(
|
|
89
|
+
base_url=base_url,
|
|
82
90
|
api_key=api_key,
|
|
91
|
+
timeout=timeout,
|
|
92
|
+
log_requests=log_requests,
|
|
93
|
+
custom_logger=custom_logger,
|
|
94
|
+
)
|
|
95
|
+
cls.openai = OpenAI(
|
|
96
|
+
api_key=client.api_key,
|
|
83
97
|
base_url=gateway_url,
|
|
84
98
|
**kwargs,
|
|
85
99
|
)
|
|
86
100
|
|
|
87
|
-
return
|
|
101
|
+
return client
|
|
88
102
|
|
|
89
103
|
@classmethod
|
|
90
104
|
def with_anthropic_gateway(
|
|
91
105
|
cls,
|
|
92
|
-
base_url: str =
|
|
93
|
-
gateway_url: str =
|
|
106
|
+
base_url: str = DEFAULT_BASE_URL,
|
|
107
|
+
gateway_url: str = DEFAULT_ANTHROPIC_GATEWAY_URL,
|
|
94
108
|
api_key: Optional[str] = None,
|
|
95
|
-
timeout: float =
|
|
109
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
96
110
|
log_requests: bool = False,
|
|
97
111
|
custom_logger: Optional["loguru.Logger"] = None,
|
|
98
112
|
**kwargs,
|
|
@@ -111,14 +125,47 @@ class AiriaClient(AiriaBaseClient):
|
|
|
111
125
|
"""
|
|
112
126
|
from anthropic import Anthropic
|
|
113
127
|
|
|
114
|
-
|
|
115
|
-
|
|
128
|
+
client = cls(
|
|
129
|
+
base_url=base_url,
|
|
116
130
|
api_key=api_key,
|
|
131
|
+
timeout=timeout,
|
|
132
|
+
log_requests=log_requests,
|
|
133
|
+
custom_logger=custom_logger,
|
|
134
|
+
)
|
|
135
|
+
cls.anthropic = Anthropic(
|
|
136
|
+
api_key=client.api_key,
|
|
117
137
|
base_url=gateway_url,
|
|
118
138
|
**kwargs,
|
|
119
139
|
)
|
|
120
140
|
|
|
121
|
-
return
|
|
141
|
+
return client
|
|
142
|
+
|
|
143
|
+
@classmethod
|
|
144
|
+
def with_bearer_token(
|
|
145
|
+
cls,
|
|
146
|
+
bearer_token: str,
|
|
147
|
+
base_url: str = DEFAULT_BASE_URL,
|
|
148
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
149
|
+
log_requests: bool = False,
|
|
150
|
+
custom_logger: Optional["loguru.Logger"] = None,
|
|
151
|
+
):
|
|
152
|
+
"""
|
|
153
|
+
Initialize the synchronous Airia API client with bearer token authentication.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
bearer_token: Bearer token for authentication.
|
|
157
|
+
base_url: Base URL of the Airia API.
|
|
158
|
+
timeout: Request timeout in seconds.
|
|
159
|
+
log_requests: Whether to log API requests and responses. Default is False.
|
|
160
|
+
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
161
|
+
"""
|
|
162
|
+
return cls(
|
|
163
|
+
base_url=base_url,
|
|
164
|
+
bearer_token=bearer_token,
|
|
165
|
+
timeout=timeout,
|
|
166
|
+
log_requests=log_requests,
|
|
167
|
+
custom_logger=custom_logger,
|
|
168
|
+
)
|
|
122
169
|
|
|
123
170
|
def _handle_exception(self, e: requests.HTTPError, url: str, correlation_id: str):
|
|
124
171
|
# Log the error response if enabled
|
|
@@ -141,12 +188,14 @@ class AiriaClient(AiriaBaseClient):
|
|
|
141
188
|
# If JSON parsing fails or expected keys are missing
|
|
142
189
|
error_message = f"API request failed: {str(e)}"
|
|
143
190
|
|
|
144
|
-
# Make sure
|
|
145
|
-
sanitized_message =
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
191
|
+
# Make sure sensitive auth information is not included in error messages
|
|
192
|
+
sanitized_message = error_message
|
|
193
|
+
if self.api_key and self.api_key in sanitized_message:
|
|
194
|
+
sanitized_message = sanitized_message.replace(self.api_key, "[REDACTED]")
|
|
195
|
+
if self.bearer_token and self.bearer_token in sanitized_message:
|
|
196
|
+
sanitized_message = sanitized_message.replace(
|
|
197
|
+
self.bearer_token, "[REDACTED]"
|
|
198
|
+
)
|
|
150
199
|
|
|
151
200
|
# Raise custom exception with status code and sanitized message
|
|
152
201
|
raise AiriaAPIError(
|
|
@@ -278,7 +327,6 @@ class AiriaClient(AiriaBaseClient):
|
|
|
278
327
|
additional_info: Optional[List[Any]] = None,
|
|
279
328
|
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
280
329
|
correlation_id: Optional[str] = None,
|
|
281
|
-
api_version: str = ApiVersion.V2.value,
|
|
282
330
|
) -> PipelineExecutionResponse: ...
|
|
283
331
|
|
|
284
332
|
@overload
|
|
@@ -301,7 +349,6 @@ class AiriaClient(AiriaBaseClient):
|
|
|
301
349
|
additional_info: Optional[List[Any]] = None,
|
|
302
350
|
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
303
351
|
correlation_id: Optional[str] = None,
|
|
304
|
-
api_version: str = ApiVersion.V2.value,
|
|
305
352
|
) -> PipelineExecutionDebugResponse: ...
|
|
306
353
|
|
|
307
354
|
@overload
|
|
@@ -324,31 +371,7 @@ class AiriaClient(AiriaBaseClient):
|
|
|
324
371
|
additional_info: Optional[List[Any]] = None,
|
|
325
372
|
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
326
373
|
correlation_id: Optional[str] = None,
|
|
327
|
-
|
|
328
|
-
) -> PipelineExecutionV2StreamedResponse: ...
|
|
329
|
-
|
|
330
|
-
@overload
|
|
331
|
-
def execute_pipeline(
|
|
332
|
-
self,
|
|
333
|
-
pipeline_id: str,
|
|
334
|
-
user_input: str,
|
|
335
|
-
debug: bool = False,
|
|
336
|
-
user_id: Optional[str] = None,
|
|
337
|
-
conversation_id: Optional[str] = None,
|
|
338
|
-
async_output: Literal[True] = True,
|
|
339
|
-
include_tools_response: bool = False,
|
|
340
|
-
images: Optional[List[str]] = None,
|
|
341
|
-
files: Optional[List[str]] = None,
|
|
342
|
-
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
343
|
-
data_source_files: Optional[Dict[str, Any]] = None,
|
|
344
|
-
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
345
|
-
current_date_time: Optional[str] = None,
|
|
346
|
-
save_history: bool = True,
|
|
347
|
-
additional_info: Optional[List[Any]] = None,
|
|
348
|
-
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
349
|
-
correlation_id: Optional[str] = None,
|
|
350
|
-
api_version: Literal["v1"] = ApiVersion.V1.value,
|
|
351
|
-
) -> PipelineExecutionV1StreamedResponse: ...
|
|
374
|
+
) -> PipelineExecutionStreamedResponse: ...
|
|
352
375
|
|
|
353
376
|
def execute_pipeline(
|
|
354
377
|
self,
|
|
@@ -369,7 +392,6 @@ class AiriaClient(AiriaBaseClient):
|
|
|
369
392
|
additional_info: Optional[List[Any]] = None,
|
|
370
393
|
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
371
394
|
correlation_id: Optional[str] = None,
|
|
372
|
-
api_version: str = ApiVersion.V2.value,
|
|
373
395
|
):
|
|
374
396
|
"""
|
|
375
397
|
Execute a pipeline with the provided input.
|
|
@@ -393,7 +415,6 @@ class AiriaClient(AiriaBaseClient):
|
|
|
393
415
|
prompt_variables: Optional variables to be used in the prompt.
|
|
394
416
|
correlation_id: Optional correlation ID for request tracing. If not provided,
|
|
395
417
|
one will be generated automatically.
|
|
396
|
-
api_version: API version to use. Default is `v2`
|
|
397
418
|
|
|
398
419
|
Returns:
|
|
399
420
|
The API response as a dictionary.
|
|
@@ -428,12 +449,11 @@ class AiriaClient(AiriaBaseClient):
|
|
|
428
449
|
additional_info=additional_info,
|
|
429
450
|
prompt_variables=prompt_variables,
|
|
430
451
|
correlation_id=correlation_id,
|
|
431
|
-
api_version=
|
|
452
|
+
api_version=ApiVersion.V2.value,
|
|
432
453
|
)
|
|
433
|
-
stream = async_output and api_version == ApiVersion.V2.value
|
|
434
454
|
resp = (
|
|
435
455
|
self._make_request_stream("POST", request_data)
|
|
436
|
-
if
|
|
456
|
+
if async_output
|
|
437
457
|
else self._make_request("POST", request_data)
|
|
438
458
|
)
|
|
439
459
|
|
|
@@ -442,26 +462,9 @@ class AiriaClient(AiriaBaseClient):
|
|
|
442
462
|
return PipelineExecutionResponse(**resp)
|
|
443
463
|
return PipelineExecutionDebugResponse(**resp)
|
|
444
464
|
|
|
445
|
-
|
|
446
|
-
url = urljoin(
|
|
447
|
-
self.base_url, f"{api_version}/StreamSocketConfig/GenerateUrl"
|
|
448
|
-
)
|
|
449
|
-
request_data = self._prepare_request(
|
|
450
|
-
url,
|
|
451
|
-
payload={"socketIdentifier": resp},
|
|
452
|
-
correlation_id=request_data.headers["X-Correlation-ID"],
|
|
453
|
-
)
|
|
454
|
-
resp = self._make_request("POST", request_data)
|
|
455
|
-
|
|
456
|
-
return PipelineExecutionV1StreamedResponse(**resp)
|
|
457
|
-
|
|
458
|
-
return PipelineExecutionV2StreamedResponse(stream=resp)
|
|
465
|
+
return PipelineExecutionStreamedResponse(stream=resp)
|
|
459
466
|
|
|
460
|
-
def get_projects(
|
|
461
|
-
self,
|
|
462
|
-
correlation_id: Optional[str] = None,
|
|
463
|
-
api_version: str = ApiVersion.V1.value,
|
|
464
|
-
) -> List[ProjectItem]:
|
|
467
|
+
def get_projects(self, correlation_id: Optional[str] = None) -> List[ProjectItem]:
|
|
465
468
|
"""
|
|
466
469
|
Retrieve a list of all projects accessible to the authenticated user.
|
|
467
470
|
|
|
@@ -472,8 +475,6 @@ class AiriaClient(AiriaBaseClient):
|
|
|
472
475
|
Args:
|
|
473
476
|
correlation_id (str, optional): A unique identifier for request tracing
|
|
474
477
|
and logging. If not provided, one will be automatically generated.
|
|
475
|
-
api_version (str, optional): The API version to use for the request.
|
|
476
|
-
Defaults to "v1". Valid versions are defined in ApiVersion enum.
|
|
477
478
|
|
|
478
479
|
Returns:
|
|
479
480
|
List[ProjectItem]: A list of ProjectItem objects containing project
|
|
@@ -481,7 +482,6 @@ class AiriaClient(AiriaBaseClient):
|
|
|
481
482
|
or found.
|
|
482
483
|
|
|
483
484
|
Raises:
|
|
484
|
-
ValueError: If the provided api_version is not valid.
|
|
485
485
|
AiriaAPIError: If the API request fails, including cases where:
|
|
486
486
|
- Authentication fails (401)
|
|
487
487
|
- Access is forbidden (403)
|
|
@@ -510,7 +510,7 @@ class AiriaClient(AiriaBaseClient):
|
|
|
510
510
|
access to.
|
|
511
511
|
"""
|
|
512
512
|
request_data = self._pre_get_projects(
|
|
513
|
-
correlation_id=correlation_id, api_version=
|
|
513
|
+
correlation_id=correlation_id, api_version=ApiVersion.V1.value
|
|
514
514
|
)
|
|
515
515
|
resp = self._make_request("GET", request_data)
|
|
516
516
|
|
|
@@ -520,10 +520,7 @@ class AiriaClient(AiriaBaseClient):
|
|
|
520
520
|
return [ProjectItem(**item) for item in resp["items"]]
|
|
521
521
|
|
|
522
522
|
def get_active_pipelines_ids(
|
|
523
|
-
self,
|
|
524
|
-
project_id: Optional[str] = None,
|
|
525
|
-
correlation_id: Optional[str] = None,
|
|
526
|
-
api_version: str = ApiVersion.V1.value,
|
|
523
|
+
self, project_id: Optional[str] = None, correlation_id: Optional[str] = None
|
|
527
524
|
) -> List[str]:
|
|
528
525
|
"""
|
|
529
526
|
Retrieve a list of active pipeline IDs.
|
|
@@ -537,15 +534,12 @@ class AiriaClient(AiriaBaseClient):
|
|
|
537
534
|
accessible to the authenticated user.
|
|
538
535
|
correlation_id (str, optional): A unique identifier for request tracing
|
|
539
536
|
and logging. If not provided, one will be automatically generated.
|
|
540
|
-
api_version (str, optional): The API version to use for the request.
|
|
541
|
-
Defaults to "v1". Valid versions are defined in ApiVersion enum.
|
|
542
537
|
|
|
543
538
|
Returns:
|
|
544
539
|
List[str]: A list of pipeline IDs that are currently active. Returns an
|
|
545
540
|
empty list if no active pipelines are found.
|
|
546
541
|
|
|
547
542
|
Raises:
|
|
548
|
-
ValueError: If the provided api_version is not valid.
|
|
549
543
|
AiriaAPIError: If the API request fails, including cases where:
|
|
550
544
|
- The project_id doesn't exist (404)
|
|
551
545
|
- Authentication fails (401)
|
|
@@ -576,7 +570,7 @@ class AiriaClient(AiriaBaseClient):
|
|
|
576
570
|
request_data = self._pre_get_active_pipelines_ids(
|
|
577
571
|
project_id=project_id,
|
|
578
572
|
correlation_id=correlation_id,
|
|
579
|
-
api_version=
|
|
573
|
+
api_version=ApiVersion.V1.value,
|
|
580
574
|
)
|
|
581
575
|
resp = self._make_request("GET", request_data)
|
|
582
576
|
|
|
@@ -588,10 +582,7 @@ class AiriaClient(AiriaBaseClient):
|
|
|
588
582
|
return pipeline_ids
|
|
589
583
|
|
|
590
584
|
def get_pipeline_config(
|
|
591
|
-
self,
|
|
592
|
-
pipeline_id: str,
|
|
593
|
-
correlation_id: Optional[str] = None,
|
|
594
|
-
api_version: str = ApiVersion.V1.value,
|
|
585
|
+
self, pipeline_id: str, correlation_id: Optional[str] = None
|
|
595
586
|
) -> GetPipelineConfigResponse:
|
|
596
587
|
"""
|
|
597
588
|
Retrieve configuration details for a specific pipeline.
|
|
@@ -602,8 +593,6 @@ class AiriaClient(AiriaBaseClient):
|
|
|
602
593
|
Args:
|
|
603
594
|
pipeline_id (str): The unique identifier of the pipeline to retrieve
|
|
604
595
|
configuration for.
|
|
605
|
-
api_version (str, optional): The API version to use for the request.
|
|
606
|
-
Defaults to "v1". Valid versions are defined in ApiVersion enum.
|
|
607
596
|
correlation_id (str, optional): A unique identifier for request tracing
|
|
608
597
|
and logging. If not provided, one will be automatically generated.
|
|
609
598
|
|
|
@@ -612,7 +601,6 @@ class AiriaClient(AiriaBaseClient):
|
|
|
612
601
|
configuration.
|
|
613
602
|
|
|
614
603
|
Raises:
|
|
615
|
-
ValueError: If the provided api_version is not valid.
|
|
616
604
|
AiriaAPIError: If the API request fails, including cases where:
|
|
617
605
|
- The pipeline_id doesn't exist (404)
|
|
618
606
|
- Authentication fails (401)
|
|
@@ -643,8 +631,88 @@ class AiriaClient(AiriaBaseClient):
|
|
|
643
631
|
request_data = self._pre_get_pipeline_config(
|
|
644
632
|
pipeline_id=pipeline_id,
|
|
645
633
|
correlation_id=correlation_id,
|
|
646
|
-
api_version=
|
|
634
|
+
api_version=ApiVersion.V1.value,
|
|
647
635
|
)
|
|
648
636
|
resp = self._make_request("GET", request_data)
|
|
649
637
|
|
|
650
638
|
return GetPipelineConfigResponse(**resp)
|
|
639
|
+
|
|
640
|
+
def create_conversation(
|
|
641
|
+
self,
|
|
642
|
+
user_id: str,
|
|
643
|
+
title: Optional[str] = None,
|
|
644
|
+
deployment_id: Optional[str] = None,
|
|
645
|
+
data_source_files: Dict[str, Any] = {},
|
|
646
|
+
is_bookmarked: bool = False,
|
|
647
|
+
correlation_id: Optional[str] = None,
|
|
648
|
+
) -> CreateConversationResponse:
|
|
649
|
+
"""
|
|
650
|
+
Create a new conversation.
|
|
651
|
+
|
|
652
|
+
Args:
|
|
653
|
+
user_id (str): The unique identifier of the user creating the conversation.
|
|
654
|
+
title (str, optional): The title for the conversation. If not provided,
|
|
655
|
+
the conversation will be created without a title.
|
|
656
|
+
deployment_id (str, optional): The unique identifier of the deployment
|
|
657
|
+
to associate with the conversation. If not provided, the conversation
|
|
658
|
+
will not be associated with any specific deployment.
|
|
659
|
+
data_source_files (dict): Configuration for data source files
|
|
660
|
+
to be associated with the conversation. If not provided, no data
|
|
661
|
+
source files will be associated.
|
|
662
|
+
is_bookmarked (bool): Whether the conversation should be bookmarked.
|
|
663
|
+
Defaults to False.
|
|
664
|
+
correlation_id (str, optional): A unique identifier for request tracing
|
|
665
|
+
and logging. If not provided, one will be automatically generated.
|
|
666
|
+
|
|
667
|
+
Returns:
|
|
668
|
+
CreateConversationResponse: A response object containing the created
|
|
669
|
+
conversation details including its ID, creation timestamp, and
|
|
670
|
+
all provided parameters.
|
|
671
|
+
|
|
672
|
+
Raises:
|
|
673
|
+
AiriaAPIError: If the API request fails, including cases where:
|
|
674
|
+
- The user_id doesn't exist (404)
|
|
675
|
+
- The deployment_id is invalid (404)
|
|
676
|
+
- Authentication fails (401)
|
|
677
|
+
- Access is forbidden (403)
|
|
678
|
+
- Server errors (5xx)
|
|
679
|
+
|
|
680
|
+
Example:
|
|
681
|
+
```python
|
|
682
|
+
from airia import AiriaClient
|
|
683
|
+
|
|
684
|
+
client = AiriaClient(api_key="your_api_key")
|
|
685
|
+
|
|
686
|
+
# Create a basic conversation
|
|
687
|
+
conversation = client.create_conversation(
|
|
688
|
+
user_id="user_123"
|
|
689
|
+
)
|
|
690
|
+
print(f"Created conversation: {conversation.id}")
|
|
691
|
+
|
|
692
|
+
# Create a conversation with all options
|
|
693
|
+
conversation = client.create_conversation(
|
|
694
|
+
user_id="user_123",
|
|
695
|
+
title="My Research Session",
|
|
696
|
+
deployment_id="deployment_456",
|
|
697
|
+
data_source_files={"documents": ["doc1.pdf", "doc2.txt"]},
|
|
698
|
+
is_bookmarked=True
|
|
699
|
+
)
|
|
700
|
+
print(f"Created bookmarked conversation: {conversation.id}")
|
|
701
|
+
```
|
|
702
|
+
|
|
703
|
+
Note:
|
|
704
|
+
The user_id is required and must correspond to a valid user in the system.
|
|
705
|
+
All other parameters are optional and can be set to None or their default values.
|
|
706
|
+
"""
|
|
707
|
+
request_data = self._pre_create_conversation(
|
|
708
|
+
user_id=user_id,
|
|
709
|
+
title=title,
|
|
710
|
+
deployment_id=deployment_id,
|
|
711
|
+
data_source_files=data_source_files,
|
|
712
|
+
is_bookmarked=is_bookmarked,
|
|
713
|
+
correlation_id=correlation_id,
|
|
714
|
+
api_version=ApiVersion.V1.value,
|
|
715
|
+
)
|
|
716
|
+
resp = self._make_request("POST", request_data)
|
|
717
|
+
|
|
718
|
+
return CreateConversationResponse(**resp)
|
airia/constants.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Constants used throughout the Airia SDK."""
|
|
2
|
+
|
|
3
|
+
# Default API endpoints
|
|
4
|
+
DEFAULT_BASE_URL = "https://api.airia.ai/"
|
|
5
|
+
DEFAULT_OPENAI_GATEWAY_URL = "https://gateway.airia.ai/openai/v1"
|
|
6
|
+
DEFAULT_ANTHROPIC_GATEWAY_URL = "https://gateway.airia.ai/anthropic"
|
|
7
|
+
|
|
8
|
+
# Default timeouts
|
|
9
|
+
DEFAULT_TIMEOUT = 30.0
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .get_projects import ProjectItem
|
|
2
|
+
from .get_pipeline_config import GetPipelineConfigResponse
|
|
3
|
+
from .pipeline_execution import (
|
|
4
|
+
PipelineExecutionDebugResponse,
|
|
5
|
+
PipelineExecutionResponse,
|
|
6
|
+
PipelineExecutionAsyncStreamedResponse,
|
|
7
|
+
PipelineExecutionStreamedResponse,
|
|
8
|
+
)
|
|
9
|
+
from .conversations import CreateConversationResponse
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"PipelineExecutionDebugResponse",
|
|
13
|
+
"PipelineExecutionResponse",
|
|
14
|
+
"PipelineExecutionStreamedResponse",
|
|
15
|
+
"PipelineExecutionAsyncStreamedResponse",
|
|
16
|
+
"GetPipelineConfigResponse",
|
|
17
|
+
"ProjectItem",
|
|
18
|
+
"CreateConversationResponse",
|
|
19
|
+
]
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CreateConversationResponse(BaseModel):
|
|
7
|
+
user_id: str = Field(alias="userId")
|
|
8
|
+
conversation_id: str = Field(alias="conversationId")
|
|
9
|
+
websocket_url: str = Field(alias="websocketUrl")
|
|
10
|
+
deployment_id: str = Field(alias="deploymentId")
|
|
11
|
+
icon_id: Optional[str] = Field(None, alias="iconId")
|
|
12
|
+
icon_url: Optional[str] = Field(None, alias="iconUrl")
|
|
13
|
+
description: Optional[str] = None
|
|
14
|
+
space_name: Optional[str] = Field(None, alias="spaceName")
|
|
@@ -1,33 +1,29 @@
|
|
|
1
1
|
from typing import Any, AsyncIterator, Dict, Iterator
|
|
2
2
|
|
|
3
|
-
from pydantic import BaseModel, ConfigDict
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
4
4
|
|
|
5
|
-
from ..
|
|
5
|
+
from ..sse import SSEMessage
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class PipelineExecutionResponse(BaseModel):
|
|
9
9
|
result: str
|
|
10
10
|
report: None
|
|
11
|
-
|
|
11
|
+
is_backup_pipeline: bool = Field(alias="isBackupPipeline")
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class PipelineExecutionDebugResponse(BaseModel):
|
|
15
15
|
result: str
|
|
16
16
|
report: Dict[str, Any]
|
|
17
|
-
|
|
17
|
+
is_backup_pipeline: bool = Field(alias="isBackupPipeline")
|
|
18
18
|
|
|
19
19
|
|
|
20
|
-
class
|
|
21
|
-
webSocketUrl: str
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class PipelineExecutionV2StreamedResponse(BaseModel):
|
|
20
|
+
class PipelineExecutionStreamedResponse(BaseModel):
|
|
25
21
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
26
22
|
|
|
27
23
|
stream: Iterator[SSEMessage]
|
|
28
24
|
|
|
29
25
|
|
|
30
|
-
class
|
|
26
|
+
class PipelineExecutionAsyncStreamedResponse(BaseModel):
|
|
31
27
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
32
28
|
|
|
33
29
|
stream: AsyncIterator[SSEMessage]
|
|
@@ -1,15 +1,6 @@
|
|
|
1
|
-
from .api.get_projects import ProjectItem
|
|
2
|
-
from .api.get_pipeline_config import GetPipelineConfigResponse
|
|
3
|
-
from .api.pipeline_execution import (
|
|
4
|
-
PipelineExecutionDebugResponse,
|
|
5
|
-
PipelineExecutionResponse,
|
|
6
|
-
PipelineExecutionV1StreamedResponse,
|
|
7
|
-
PipelineExecutionV2AsyncStreamedResponse,
|
|
8
|
-
PipelineExecutionV2StreamedResponse,
|
|
9
|
-
)
|
|
10
|
-
from .api_version import ApiVersion
|
|
11
|
-
from .request_data import RequestData
|
|
12
1
|
from .sse_messages import (
|
|
2
|
+
SSEDict,
|
|
3
|
+
SSEMessage,
|
|
13
4
|
AgentAgentCardMessage,
|
|
14
5
|
AgentAgentCardStreamEndMessage,
|
|
15
6
|
AgentAgentCardStreamErrorMessage,
|
|
@@ -37,15 +28,8 @@ from .sse_messages import (
|
|
|
37
28
|
)
|
|
38
29
|
|
|
39
30
|
__all__ = [
|
|
40
|
-
"
|
|
41
|
-
"
|
|
42
|
-
"PipelineExecutionResponse",
|
|
43
|
-
"PipelineExecutionV1StreamedResponse",
|
|
44
|
-
"PipelineExecutionV2AsyncStreamedResponse",
|
|
45
|
-
"PipelineExecutionV2StreamedResponse",
|
|
46
|
-
"GetPipelineConfigResponse",
|
|
47
|
-
"ProjectItem",
|
|
48
|
-
"RequestData",
|
|
31
|
+
"SSEDict",
|
|
32
|
+
"SSEMessage",
|
|
49
33
|
"AgentPingMessage",
|
|
50
34
|
"AgentStartMessage",
|
|
51
35
|
"AgentEndMessage",
|