airia 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- airia/client/async_client.py +169 -98
- airia/client/base_client.py +77 -16
- airia/client/sync_client.py +165 -97
- airia/constants.py +9 -0
- airia/types/api/__init__.py +19 -0
- airia/types/api/conversations.py +14 -0
- airia/types/api/pipeline_execution.py +6 -10
- airia/types/{__init__.py → sse/__init__.py} +4 -20
- airia/utils/sse_parser.py +1 -1
- {airia-0.1.9.dist-info → airia-0.1.11.dist-info}/METADATA +182 -8
- airia-0.1.11.dist-info/RECORD +23 -0
- airia-0.1.9.dist-info/RECORD +0 -20
- /airia/types/{api_version.py → _api_version.py} +0 -0
- /airia/types/{request_data.py → _request_data.py} +0 -0
- /airia/types/{sse_messages.py → sse/sse_messages.py} +0 -0
- {airia-0.1.9.dist-info → airia-0.1.11.dist-info}/WHEEL +0 -0
- {airia-0.1.9.dist-info → airia-0.1.11.dist-info}/licenses/LICENSE +0 -0
- {airia-0.1.9.dist-info → airia-0.1.11.dist-info}/top_level.txt +0 -0
airia/client/async_client.py
CHANGED
|
@@ -1,21 +1,26 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import weakref
|
|
3
3
|
from typing import Any, AsyncIterator, Dict, List, Literal, Optional, overload
|
|
4
|
-
from urllib.parse import urljoin
|
|
5
4
|
|
|
6
5
|
import aiohttp
|
|
7
6
|
import loguru
|
|
8
7
|
|
|
8
|
+
from ..constants import (
|
|
9
|
+
DEFAULT_ANTHROPIC_GATEWAY_URL,
|
|
10
|
+
DEFAULT_BASE_URL,
|
|
11
|
+
DEFAULT_OPENAI_GATEWAY_URL,
|
|
12
|
+
DEFAULT_TIMEOUT,
|
|
13
|
+
)
|
|
9
14
|
from ..exceptions import AiriaAPIError
|
|
10
|
-
from ..types import
|
|
11
|
-
|
|
15
|
+
from ..types._api_version import ApiVersion
|
|
16
|
+
from ..types._request_data import RequestData
|
|
17
|
+
from ..types.api import (
|
|
18
|
+
CreateConversationResponse,
|
|
12
19
|
GetPipelineConfigResponse,
|
|
20
|
+
PipelineExecutionAsyncStreamedResponse,
|
|
13
21
|
PipelineExecutionDebugResponse,
|
|
14
22
|
PipelineExecutionResponse,
|
|
15
|
-
PipelineExecutionV1StreamedResponse,
|
|
16
|
-
PipelineExecutionV2AsyncStreamedResponse,
|
|
17
23
|
ProjectItem,
|
|
18
|
-
RequestData,
|
|
19
24
|
)
|
|
20
25
|
from ..utils.sse_parser import async_parse_sse_stream_chunked
|
|
21
26
|
from .base_client import AiriaBaseClient
|
|
@@ -26,9 +31,10 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
26
31
|
|
|
27
32
|
def __init__(
|
|
28
33
|
self,
|
|
29
|
-
base_url: str =
|
|
34
|
+
base_url: str = DEFAULT_BASE_URL,
|
|
30
35
|
api_key: Optional[str] = None,
|
|
31
|
-
|
|
36
|
+
bearer_token: Optional[str] = None,
|
|
37
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
32
38
|
log_requests: bool = False,
|
|
33
39
|
custom_logger: Optional["loguru.Logger"] = None,
|
|
34
40
|
):
|
|
@@ -38,6 +44,7 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
38
44
|
Args:
|
|
39
45
|
base_url: Base URL of the Airia API.
|
|
40
46
|
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
47
|
+
bearer_token: Bearer token for authentication. Must be provided explicitly (no environment variable fallback).
|
|
41
48
|
timeout: Request timeout in seconds.
|
|
42
49
|
log_requests: Whether to log API requests and responses. Default is False.
|
|
43
50
|
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
@@ -45,6 +52,7 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
45
52
|
super().__init__(
|
|
46
53
|
base_url=base_url,
|
|
47
54
|
api_key=api_key,
|
|
55
|
+
bearer_token=bearer_token,
|
|
48
56
|
timeout=timeout,
|
|
49
57
|
log_requests=log_requests,
|
|
50
58
|
custom_logger=custom_logger,
|
|
@@ -80,10 +88,10 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
80
88
|
@classmethod
|
|
81
89
|
def with_openai_gateway(
|
|
82
90
|
cls,
|
|
83
|
-
base_url: str =
|
|
84
|
-
gateway_url: str =
|
|
91
|
+
base_url: str = DEFAULT_BASE_URL,
|
|
92
|
+
gateway_url: str = DEFAULT_OPENAI_GATEWAY_URL,
|
|
85
93
|
api_key: Optional[str] = None,
|
|
86
|
-
timeout: float =
|
|
94
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
87
95
|
log_requests: bool = False,
|
|
88
96
|
custom_logger: Optional["loguru.Logger"] = None,
|
|
89
97
|
**kwargs,
|
|
@@ -102,22 +110,28 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
102
110
|
"""
|
|
103
111
|
from openai import AsyncOpenAI
|
|
104
112
|
|
|
105
|
-
|
|
106
|
-
|
|
113
|
+
client = cls(
|
|
114
|
+
base_url=base_url,
|
|
107
115
|
api_key=api_key,
|
|
116
|
+
timeout=timeout,
|
|
117
|
+
log_requests=log_requests,
|
|
118
|
+
custom_logger=custom_logger,
|
|
119
|
+
)
|
|
120
|
+
cls.openai = AsyncOpenAI(
|
|
121
|
+
api_key=client.api_key,
|
|
108
122
|
base_url=gateway_url,
|
|
109
123
|
**kwargs,
|
|
110
124
|
)
|
|
111
125
|
|
|
112
|
-
return
|
|
126
|
+
return client
|
|
113
127
|
|
|
114
128
|
@classmethod
|
|
115
129
|
def with_anthropic_gateway(
|
|
116
130
|
cls,
|
|
117
|
-
base_url: str =
|
|
118
|
-
gateway_url: str =
|
|
131
|
+
base_url: str = DEFAULT_BASE_URL,
|
|
132
|
+
gateway_url: str = DEFAULT_ANTHROPIC_GATEWAY_URL,
|
|
119
133
|
api_key: Optional[str] = None,
|
|
120
|
-
timeout: float =
|
|
134
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
121
135
|
log_requests: bool = False,
|
|
122
136
|
custom_logger: Optional["loguru.Logger"] = None,
|
|
123
137
|
**kwargs,
|
|
@@ -136,14 +150,47 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
136
150
|
"""
|
|
137
151
|
from anthropic import AsyncAnthropic
|
|
138
152
|
|
|
139
|
-
|
|
140
|
-
|
|
153
|
+
client = cls(
|
|
154
|
+
base_url=base_url,
|
|
141
155
|
api_key=api_key,
|
|
156
|
+
timeout=timeout,
|
|
157
|
+
log_requests=log_requests,
|
|
158
|
+
custom_logger=custom_logger,
|
|
159
|
+
)
|
|
160
|
+
cls.anthropic = AsyncAnthropic(
|
|
161
|
+
api_key=client.api_key,
|
|
142
162
|
base_url=gateway_url,
|
|
143
163
|
**kwargs,
|
|
144
164
|
)
|
|
145
165
|
|
|
146
|
-
return
|
|
166
|
+
return client
|
|
167
|
+
|
|
168
|
+
@classmethod
|
|
169
|
+
def with_bearer_token(
|
|
170
|
+
cls,
|
|
171
|
+
bearer_token: str,
|
|
172
|
+
base_url: str = DEFAULT_BASE_URL,
|
|
173
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
174
|
+
log_requests: bool = False,
|
|
175
|
+
custom_logger: Optional["loguru.Logger"] = None,
|
|
176
|
+
):
|
|
177
|
+
"""
|
|
178
|
+
Initialize the asynchronous Airia API client with bearer token authentication.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
bearer_token: Bearer token for authentication.
|
|
182
|
+
base_url: Base URL of the Airia API.
|
|
183
|
+
timeout: Request timeout in seconds.
|
|
184
|
+
log_requests: Whether to log API requests and responses. Default is False.
|
|
185
|
+
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
186
|
+
"""
|
|
187
|
+
return cls(
|
|
188
|
+
base_url=base_url,
|
|
189
|
+
bearer_token=bearer_token,
|
|
190
|
+
timeout=timeout,
|
|
191
|
+
log_requests=log_requests,
|
|
192
|
+
custom_logger=custom_logger,
|
|
193
|
+
)
|
|
147
194
|
|
|
148
195
|
def _handle_exception(
|
|
149
196
|
self, e: aiohttp.ClientResponseError, url: str, correlation_id: str
|
|
@@ -159,12 +206,14 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
159
206
|
# Extract error details from response
|
|
160
207
|
error_message = e.message
|
|
161
208
|
|
|
162
|
-
# Make sure
|
|
163
|
-
sanitized_message =
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
209
|
+
# Make sure sensitive auth information is not included in error messages
|
|
210
|
+
sanitized_message = error_message
|
|
211
|
+
if self.api_key and self.api_key in sanitized_message:
|
|
212
|
+
sanitized_message = sanitized_message.replace(self.api_key, "[REDACTED]")
|
|
213
|
+
if self.bearer_token and self.bearer_token in sanitized_message:
|
|
214
|
+
sanitized_message = sanitized_message.replace(
|
|
215
|
+
self.bearer_token, "[REDACTED]"
|
|
216
|
+
)
|
|
168
217
|
|
|
169
218
|
# Raise custom exception with status code and sanitized message
|
|
170
219
|
raise AiriaAPIError(status_code=e.status, message=sanitized_message) from e
|
|
@@ -297,7 +346,6 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
297
346
|
additional_info: Optional[List[Any]] = None,
|
|
298
347
|
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
299
348
|
correlation_id: Optional[str] = None,
|
|
300
|
-
api_version: str = ApiVersion.V2.value,
|
|
301
349
|
) -> PipelineExecutionResponse: ...
|
|
302
350
|
|
|
303
351
|
@overload
|
|
@@ -320,7 +368,6 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
320
368
|
additional_info: Optional[List[Any]] = None,
|
|
321
369
|
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
322
370
|
correlation_id: Optional[str] = None,
|
|
323
|
-
api_version: str = ApiVersion.V2.value,
|
|
324
371
|
) -> PipelineExecutionDebugResponse: ...
|
|
325
372
|
|
|
326
373
|
@overload
|
|
@@ -343,31 +390,7 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
343
390
|
additional_info: Optional[List[Any]] = None,
|
|
344
391
|
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
345
392
|
correlation_id: Optional[str] = None,
|
|
346
|
-
|
|
347
|
-
) -> PipelineExecutionV2AsyncStreamedResponse: ...
|
|
348
|
-
|
|
349
|
-
@overload
|
|
350
|
-
async def execute_pipeline(
|
|
351
|
-
self,
|
|
352
|
-
pipeline_id: str,
|
|
353
|
-
user_input: str,
|
|
354
|
-
debug: bool = False,
|
|
355
|
-
user_id: Optional[str] = None,
|
|
356
|
-
conversation_id: Optional[str] = None,
|
|
357
|
-
async_output: Literal[True] = True,
|
|
358
|
-
include_tools_response: bool = False,
|
|
359
|
-
images: Optional[List[str]] = None,
|
|
360
|
-
files: Optional[List[str]] = None,
|
|
361
|
-
data_source_folders: Optional[Dict[str, Any]] = None,
|
|
362
|
-
data_source_files: Optional[Dict[str, Any]] = None,
|
|
363
|
-
in_memory_messages: Optional[List[Dict[str, str]]] = None,
|
|
364
|
-
current_date_time: Optional[str] = None,
|
|
365
|
-
save_history: bool = True,
|
|
366
|
-
additional_info: Optional[List[Any]] = None,
|
|
367
|
-
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
368
|
-
correlation_id: Optional[str] = None,
|
|
369
|
-
api_version: Literal["v1"] = ApiVersion.V1.value,
|
|
370
|
-
) -> PipelineExecutionV1StreamedResponse: ...
|
|
393
|
+
) -> PipelineExecutionAsyncStreamedResponse: ...
|
|
371
394
|
|
|
372
395
|
async def execute_pipeline(
|
|
373
396
|
self,
|
|
@@ -388,7 +411,6 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
388
411
|
additional_info: Optional[List[Any]] = None,
|
|
389
412
|
prompt_variables: Optional[Dict[str, Any]] = None,
|
|
390
413
|
correlation_id: Optional[str] = None,
|
|
391
|
-
api_version: str = ApiVersion.V2.value,
|
|
392
414
|
) -> Dict[str, Any]:
|
|
393
415
|
"""
|
|
394
416
|
Execute a pipeline with the provided input asynchronously.
|
|
@@ -412,7 +434,6 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
412
434
|
prompt_variables: Optional variables to be used in the prompt.
|
|
413
435
|
correlation_id: Optional correlation ID for request tracing. If not provided,
|
|
414
436
|
one will be generated automatically.
|
|
415
|
-
api_version: API version to use. Default is `v2`
|
|
416
437
|
|
|
417
438
|
Returns:
|
|
418
439
|
The API response as a dictionary.
|
|
@@ -447,38 +468,23 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
447
468
|
additional_info=additional_info,
|
|
448
469
|
prompt_variables=prompt_variables,
|
|
449
470
|
correlation_id=correlation_id,
|
|
450
|
-
api_version=
|
|
471
|
+
api_version=ApiVersion.V2.value,
|
|
472
|
+
)
|
|
473
|
+
resp = (
|
|
474
|
+
self._make_request_stream(method="POST", request_data=request_data)
|
|
475
|
+
if async_output
|
|
476
|
+
else await self._make_request("POST", request_data=request_data)
|
|
451
477
|
)
|
|
452
|
-
stream = async_output and api_version == ApiVersion.V2.value
|
|
453
|
-
if stream:
|
|
454
|
-
resp = self._make_request_stream(method="POST", request_data=request_data)
|
|
455
|
-
else:
|
|
456
|
-
resp = await self._make_request("POST", request_data)
|
|
457
478
|
|
|
458
479
|
if not async_output:
|
|
459
480
|
if not debug:
|
|
460
481
|
return PipelineExecutionResponse(**resp)
|
|
461
482
|
return PipelineExecutionDebugResponse(**resp)
|
|
462
483
|
|
|
463
|
-
|
|
464
|
-
url = urljoin(
|
|
465
|
-
self.base_url, f"{api_version}/StreamSocketConfig/GenerateUrl"
|
|
466
|
-
)
|
|
467
|
-
request_data = self._prepare_request(
|
|
468
|
-
url=url,
|
|
469
|
-
payload={"socketIdentifier": resp},
|
|
470
|
-
correlation_id=request_data.headers["X-Correlation-ID"],
|
|
471
|
-
)
|
|
472
|
-
resp = await self._make_request("POST", request_data)
|
|
473
|
-
|
|
474
|
-
return PipelineExecutionV1StreamedResponse(**resp)
|
|
475
|
-
|
|
476
|
-
return PipelineExecutionV2AsyncStreamedResponse(stream=resp)
|
|
484
|
+
return PipelineExecutionAsyncStreamedResponse(stream=resp)
|
|
477
485
|
|
|
478
486
|
async def get_projects(
|
|
479
|
-
self,
|
|
480
|
-
correlation_id: Optional[str] = None,
|
|
481
|
-
api_version: str = ApiVersion.V1.value,
|
|
487
|
+
self, correlation_id: Optional[str] = None
|
|
482
488
|
) -> List[ProjectItem]:
|
|
483
489
|
"""
|
|
484
490
|
Retrieve a list of all projects accessible to the authenticated user.
|
|
@@ -490,8 +496,6 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
490
496
|
Args:
|
|
491
497
|
correlation_id (str, optional): A unique identifier for request tracing
|
|
492
498
|
and logging. If not provided, one will be automatically generated.
|
|
493
|
-
api_version (str, optional): The API version to use for the request.
|
|
494
|
-
Defaults to "v1". Valid versions are defined in ApiVersion enum.
|
|
495
499
|
|
|
496
500
|
Returns:
|
|
497
501
|
List[ProjectItem]: A list of ProjectItem objects containing project
|
|
@@ -499,7 +503,6 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
499
503
|
or found.
|
|
500
504
|
|
|
501
505
|
Raises:
|
|
502
|
-
ValueError: If the provided api_version is not valid.
|
|
503
506
|
AiriaAPIError: If the API request fails, including cases where:
|
|
504
507
|
- Authentication fails (401)
|
|
505
508
|
- Access is forbidden (403)
|
|
@@ -528,7 +531,7 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
528
531
|
access to.
|
|
529
532
|
"""
|
|
530
533
|
request_data = self._pre_get_projects(
|
|
531
|
-
correlation_id=correlation_id, api_version=
|
|
534
|
+
correlation_id=correlation_id, api_version=ApiVersion.V1.value
|
|
532
535
|
)
|
|
533
536
|
resp = await self._make_request("GET", request_data)
|
|
534
537
|
|
|
@@ -538,10 +541,7 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
538
541
|
return [ProjectItem(**item) for item in resp["items"]]
|
|
539
542
|
|
|
540
543
|
async def get_active_pipelines_ids(
|
|
541
|
-
self,
|
|
542
|
-
project_id: Optional[str] = None,
|
|
543
|
-
correlation_id: Optional[str] = None,
|
|
544
|
-
api_version: str = ApiVersion.V1.value,
|
|
544
|
+
self, project_id: Optional[str] = None, correlation_id: Optional[str] = None
|
|
545
545
|
) -> List[str]:
|
|
546
546
|
"""
|
|
547
547
|
Retrieve a list of active pipeline IDs.
|
|
@@ -555,15 +555,12 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
555
555
|
accessible to the authenticated user.
|
|
556
556
|
correlation_id (str, optional): A unique identifier for request tracing
|
|
557
557
|
and logging. If not provided, one will be automatically generated.
|
|
558
|
-
api_version (str, optional): The API version to use for the request.
|
|
559
|
-
Defaults to "v1". Valid versions are defined in ApiVersion enum.
|
|
560
558
|
|
|
561
559
|
Returns:
|
|
562
560
|
List[str]: A list of pipeline IDs that are currently active. Returns an
|
|
563
561
|
empty list if no active pipelines are found.
|
|
564
562
|
|
|
565
563
|
Raises:
|
|
566
|
-
ValueError: If the provided api_version is not valid.
|
|
567
564
|
AiriaAPIError: If the API request fails, including cases where:
|
|
568
565
|
- The project_id doesn't exist (404)
|
|
569
566
|
- Authentication fails (401)
|
|
@@ -594,7 +591,7 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
594
591
|
request_data = self._pre_get_active_pipelines_ids(
|
|
595
592
|
project_id=project_id,
|
|
596
593
|
correlation_id=correlation_id,
|
|
597
|
-
api_version=
|
|
594
|
+
api_version=ApiVersion.V1.value,
|
|
598
595
|
)
|
|
599
596
|
resp = await self._make_request("GET", request_data)
|
|
600
597
|
|
|
@@ -606,10 +603,7 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
606
603
|
return pipeline_ids
|
|
607
604
|
|
|
608
605
|
async def get_pipeline_config(
|
|
609
|
-
self,
|
|
610
|
-
pipeline_id: str,
|
|
611
|
-
correlation_id: Optional[str] = None,
|
|
612
|
-
api_version: str = ApiVersion.V1.value,
|
|
606
|
+
self, pipeline_id: str, correlation_id: Optional[str] = None
|
|
613
607
|
) -> GetPipelineConfigResponse:
|
|
614
608
|
"""
|
|
615
609
|
Retrieve configuration details for a specific pipeline.
|
|
@@ -620,8 +614,6 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
620
614
|
Args:
|
|
621
615
|
pipeline_id (str): The unique identifier of the pipeline to retrieve
|
|
622
616
|
configuration for.
|
|
623
|
-
api_version (str, optional): The API version to use for the request.
|
|
624
|
-
Defaults to "v1". Valid versions are defined in ApiVersion enum.
|
|
625
617
|
correlation_id (str, optional): A unique identifier for request tracing
|
|
626
618
|
and logging. If not provided, one will be automatically generated.
|
|
627
619
|
|
|
@@ -630,7 +622,6 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
630
622
|
configuration.
|
|
631
623
|
|
|
632
624
|
Raises:
|
|
633
|
-
ValueError: If the provided api_version is not valid.
|
|
634
625
|
AiriaAPIError: If the API request fails, including cases where:
|
|
635
626
|
- The pipeline_id doesn't exist (404)
|
|
636
627
|
- Authentication fails (401)
|
|
@@ -661,8 +652,88 @@ class AiriaAsyncClient(AiriaBaseClient):
|
|
|
661
652
|
request_data = self._pre_get_pipeline_config(
|
|
662
653
|
pipeline_id=pipeline_id,
|
|
663
654
|
correlation_id=correlation_id,
|
|
664
|
-
api_version=
|
|
655
|
+
api_version=ApiVersion.V1.value,
|
|
665
656
|
)
|
|
666
657
|
resp = await self._make_request("GET", request_data)
|
|
667
658
|
|
|
668
659
|
return GetPipelineConfigResponse(**resp)
|
|
660
|
+
|
|
661
|
+
async def create_conversation(
|
|
662
|
+
self,
|
|
663
|
+
user_id: str,
|
|
664
|
+
title: Optional[str] = None,
|
|
665
|
+
deployment_id: Optional[str] = None,
|
|
666
|
+
data_source_files: Dict[str, Any] = {},
|
|
667
|
+
is_bookmarked: bool = False,
|
|
668
|
+
correlation_id: Optional[str] = None,
|
|
669
|
+
) -> CreateConversationResponse:
|
|
670
|
+
"""
|
|
671
|
+
Create a new conversation.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
user_id (str): The unique identifier of the user creating the conversation.
|
|
675
|
+
title (str, optional): The title for the conversation. If not provided,
|
|
676
|
+
the conversation will be created without a title.
|
|
677
|
+
deployment_id (str, optional): The unique identifier of the deployment
|
|
678
|
+
to associate with the conversation. If not provided, the conversation
|
|
679
|
+
will not be associated with any specific deployment.
|
|
680
|
+
data_source_files (dict): Configuration for data source files
|
|
681
|
+
to be associated with the conversation. If not provided, no data
|
|
682
|
+
source files will be associated.
|
|
683
|
+
is_bookmarked (bool): Whether the conversation should be bookmarked.
|
|
684
|
+
Defaults to False.
|
|
685
|
+
correlation_id (str, optional): A unique identifier for request tracing
|
|
686
|
+
and logging. If not provided, one will be automatically generated.
|
|
687
|
+
|
|
688
|
+
Returns:
|
|
689
|
+
CreateConversationResponse: A response object containing the created
|
|
690
|
+
conversation details including its ID, creation timestamp, and
|
|
691
|
+
all provided parameters.
|
|
692
|
+
|
|
693
|
+
Raises:
|
|
694
|
+
AiriaAPIError: If the API request fails, including cases where:
|
|
695
|
+
- The user_id doesn't exist (404)
|
|
696
|
+
- The deployment_id is invalid (404)
|
|
697
|
+
- Authentication fails (401)
|
|
698
|
+
- Access is forbidden (403)
|
|
699
|
+
- Server errors (5xx)
|
|
700
|
+
|
|
701
|
+
Example:
|
|
702
|
+
```python
|
|
703
|
+
from airia import AiriaAsyncClient
|
|
704
|
+
|
|
705
|
+
client = AiriaAsyncClient(api_key="your_api_key")
|
|
706
|
+
|
|
707
|
+
# Create a basic conversation
|
|
708
|
+
conversation = await client.create_conversation(
|
|
709
|
+
user_id="user_123"
|
|
710
|
+
)
|
|
711
|
+
print(f"Created conversation: {conversation.conversation_id}")
|
|
712
|
+
|
|
713
|
+
# Create a conversation with all options
|
|
714
|
+
conversation = await client.create_conversation(
|
|
715
|
+
user_id="user_123",
|
|
716
|
+
title="My Research Session",
|
|
717
|
+
deployment_id="deployment_456",
|
|
718
|
+
data_source_files={"documents": ["doc1.pdf", "doc2.txt"]},
|
|
719
|
+
is_bookmarked=True
|
|
720
|
+
)
|
|
721
|
+
print(f"Created bookmarked conversation: {conversation.conversation_id}")
|
|
722
|
+
```
|
|
723
|
+
|
|
724
|
+
Note:
|
|
725
|
+
The user_id is required and must correspond to a valid user in the system.
|
|
726
|
+
All other parameters are optional and can be set to None or their default values.
|
|
727
|
+
"""
|
|
728
|
+
request_data = self._pre_create_conversation(
|
|
729
|
+
user_id=user_id,
|
|
730
|
+
title=title,
|
|
731
|
+
deployment_id=deployment_id,
|
|
732
|
+
data_source_files=data_source_files,
|
|
733
|
+
is_bookmarked=is_bookmarked,
|
|
734
|
+
correlation_id=correlation_id,
|
|
735
|
+
api_version=ApiVersion.V1.value,
|
|
736
|
+
)
|
|
737
|
+
resp = await self._make_request("POST", request_data)
|
|
738
|
+
|
|
739
|
+
return CreateConversationResponse(**resp)
|
airia/client/base_client.py
CHANGED
|
@@ -5,8 +5,10 @@ from urllib.parse import urljoin
|
|
|
5
5
|
|
|
6
6
|
import loguru
|
|
7
7
|
|
|
8
|
+
from ..constants import DEFAULT_BASE_URL, DEFAULT_TIMEOUT
|
|
8
9
|
from ..logs import configure_logging, set_correlation_id
|
|
9
|
-
from ..types import ApiVersion
|
|
10
|
+
from ..types._api_version import ApiVersion
|
|
11
|
+
from ..types._request_data import RequestData
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
class AiriaBaseClient:
|
|
@@ -17,9 +19,10 @@ class AiriaBaseClient:
|
|
|
17
19
|
|
|
18
20
|
def __init__(
|
|
19
21
|
self,
|
|
20
|
-
base_url: str =
|
|
22
|
+
base_url: str = DEFAULT_BASE_URL,
|
|
21
23
|
api_key: Optional[str] = None,
|
|
22
|
-
|
|
24
|
+
bearer_token: Optional[str] = None,
|
|
25
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
23
26
|
log_requests: bool = False,
|
|
24
27
|
custom_logger: Optional["loguru.Logger"] = None,
|
|
25
28
|
):
|
|
@@ -28,12 +31,15 @@ class AiriaBaseClient:
|
|
|
28
31
|
|
|
29
32
|
Args:
|
|
30
33
|
api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
|
|
34
|
+
bearer_token: Bearer token for authentication. Must be provided explicitly (no environment variable fallback).
|
|
31
35
|
timeout: Request timeout in seconds.
|
|
32
36
|
log_requests: Whether to log API requests and responses. Default is False.
|
|
33
37
|
custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
|
|
34
38
|
"""
|
|
35
|
-
# Resolve
|
|
36
|
-
self.api_key = self.__class__.
|
|
39
|
+
# Resolve authentication credentials
|
|
40
|
+
self.api_key, self.bearer_token = self.__class__._resolve_auth_credentials(
|
|
41
|
+
api_key, bearer_token
|
|
42
|
+
)
|
|
37
43
|
|
|
38
44
|
# Store configuration
|
|
39
45
|
self.base_url = base_url
|
|
@@ -44,27 +50,45 @@ class AiriaBaseClient:
|
|
|
44
50
|
self.logger = configure_logging() if custom_logger is None else custom_logger
|
|
45
51
|
|
|
46
52
|
@staticmethod
|
|
47
|
-
def
|
|
53
|
+
def _resolve_auth_credentials(
|
|
54
|
+
api_key: Optional[str] = None, bearer_token: Optional[str] = None
|
|
55
|
+
):
|
|
48
56
|
"""
|
|
49
|
-
|
|
57
|
+
Resolve authentication credentials from parameters and environment variables.
|
|
50
58
|
|
|
51
59
|
Args:
|
|
52
60
|
api_key (Optional[str]): The API key provided as a parameter. Defaults to None.
|
|
61
|
+
bearer_token (Optional[str]): The bearer token provided as a parameter. Defaults to None.
|
|
53
62
|
|
|
54
63
|
Returns:
|
|
55
|
-
|
|
64
|
+
tuple: (api_key, bearer_token) - exactly one will be non-None
|
|
56
65
|
|
|
57
66
|
Raises:
|
|
58
|
-
ValueError: If no
|
|
67
|
+
ValueError: If no authentication method is provided or if both are provided.
|
|
59
68
|
"""
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
if not api_key:
|
|
69
|
+
# Check for explicit conflict first
|
|
70
|
+
if api_key and bearer_token:
|
|
63
71
|
raise ValueError(
|
|
64
|
-
"
|
|
72
|
+
"Cannot provide both api_key and bearer_token. Please use only one authentication method."
|
|
65
73
|
)
|
|
66
74
|
|
|
67
|
-
|
|
75
|
+
# If bearer token is explicitly provided, use it exclusively
|
|
76
|
+
if bearer_token:
|
|
77
|
+
return None, bearer_token
|
|
78
|
+
|
|
79
|
+
# If API key is explicitly provided, use it exclusively
|
|
80
|
+
if api_key:
|
|
81
|
+
return api_key, None
|
|
82
|
+
|
|
83
|
+
# If neither is provided explicitly, fall back to environment variable
|
|
84
|
+
resolved_api_key = os.environ.get("AIRIA_API_KEY")
|
|
85
|
+
if resolved_api_key:
|
|
86
|
+
return resolved_api_key, None
|
|
87
|
+
|
|
88
|
+
# No authentication method found
|
|
89
|
+
raise ValueError(
|
|
90
|
+
"Authentication required. Provide either api_key (or set AIRIA_API_KEY environment variable) or bearer_token."
|
|
91
|
+
)
|
|
68
92
|
|
|
69
93
|
def _prepare_request(
|
|
70
94
|
self,
|
|
@@ -76,13 +100,18 @@ class AiriaBaseClient:
|
|
|
76
100
|
# Set correlation ID if provided or generate a new one
|
|
77
101
|
correlation_id = set_correlation_id(correlation_id)
|
|
78
102
|
|
|
79
|
-
#
|
|
103
|
+
# Set up base headers
|
|
80
104
|
headers = {
|
|
81
|
-
"X-API-KEY": self.api_key,
|
|
82
105
|
"X-Correlation-ID": correlation_id,
|
|
83
106
|
"Content-Type": "application/json",
|
|
84
107
|
}
|
|
85
108
|
|
|
109
|
+
# Add authentication header based on the method used
|
|
110
|
+
if self.api_key:
|
|
111
|
+
headers["X-API-KEY"] = self.api_key
|
|
112
|
+
elif self.bearer_token:
|
|
113
|
+
headers["Authorization"] = f"Bearer {self.bearer_token}"
|
|
114
|
+
|
|
86
115
|
# Log the request if enabled
|
|
87
116
|
if self.log_requests:
|
|
88
117
|
# Create a sanitized copy of headers and params for logging
|
|
@@ -92,6 +121,8 @@ class AiriaBaseClient:
|
|
|
92
121
|
# Filter out sensitive headers
|
|
93
122
|
if "X-API-KEY" in log_headers:
|
|
94
123
|
log_headers["X-API-KEY"] = "[REDACTED]"
|
|
124
|
+
if "Authorization" in log_headers:
|
|
125
|
+
log_headers["Authorization"] = "[REDACTED]"
|
|
95
126
|
|
|
96
127
|
# Process payload for logging
|
|
97
128
|
log_payload = payload.copy() if payload is not None else {}
|
|
@@ -218,6 +249,36 @@ class AiriaBaseClient:
|
|
|
218
249
|
|
|
219
250
|
return request_data
|
|
220
251
|
|
|
252
|
+
def _pre_create_conversation(
|
|
253
|
+
self,
|
|
254
|
+
user_id: str,
|
|
255
|
+
title: Optional[str] = None,
|
|
256
|
+
deployment_id: Optional[str] = None,
|
|
257
|
+
data_source_files: Dict[str, Any] = {},
|
|
258
|
+
is_bookmarked: bool = False,
|
|
259
|
+
correlation_id: Optional[str] = None,
|
|
260
|
+
api_version: str = ApiVersion.V1.value,
|
|
261
|
+
):
|
|
262
|
+
if api_version not in ApiVersion.as_list():
|
|
263
|
+
raise ValueError(
|
|
264
|
+
f"Invalid API version: {api_version}. Valid versions are: {', '.join(ApiVersion.as_list())}"
|
|
265
|
+
)
|
|
266
|
+
url = urljoin(self.base_url, f"{api_version}/Conversations")
|
|
267
|
+
|
|
268
|
+
payload = {
|
|
269
|
+
"userId": user_id,
|
|
270
|
+
"title": title,
|
|
271
|
+
"deploymentId": deployment_id,
|
|
272
|
+
"dataSourceFiles": data_source_files,
|
|
273
|
+
"isBookmarked": is_bookmarked,
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
request_data = self._prepare_request(
|
|
277
|
+
url=url, payload=payload, correlation_id=correlation_id
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
return request_data
|
|
281
|
+
|
|
221
282
|
def _pre_get_projects(
|
|
222
283
|
self,
|
|
223
284
|
correlation_id: Optional[str] = None,
|