airia 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,21 +1,26 @@
1
1
  import asyncio
2
2
  import weakref
3
3
  from typing import Any, AsyncIterator, Dict, List, Literal, Optional, overload
4
- from urllib.parse import urljoin
5
4
 
6
5
  import aiohttp
7
6
  import loguru
8
7
 
8
+ from ..constants import (
9
+ DEFAULT_ANTHROPIC_GATEWAY_URL,
10
+ DEFAULT_BASE_URL,
11
+ DEFAULT_OPENAI_GATEWAY_URL,
12
+ DEFAULT_TIMEOUT,
13
+ )
9
14
  from ..exceptions import AiriaAPIError
10
- from ..types import (
11
- ApiVersion,
15
+ from ..types._api_version import ApiVersion
16
+ from ..types._request_data import RequestData
17
+ from ..types.api import (
18
+ CreateConversationResponse,
12
19
  GetPipelineConfigResponse,
20
+ PipelineExecutionAsyncStreamedResponse,
13
21
  PipelineExecutionDebugResponse,
14
22
  PipelineExecutionResponse,
15
- PipelineExecutionV1StreamedResponse,
16
- PipelineExecutionV2AsyncStreamedResponse,
17
23
  ProjectItem,
18
- RequestData,
19
24
  )
20
25
  from ..utils.sse_parser import async_parse_sse_stream_chunked
21
26
  from .base_client import AiriaBaseClient
@@ -26,9 +31,10 @@ class AiriaAsyncClient(AiriaBaseClient):
26
31
 
27
32
  def __init__(
28
33
  self,
29
- base_url: str = "https://api.airia.ai/",
34
+ base_url: str = DEFAULT_BASE_URL,
30
35
  api_key: Optional[str] = None,
31
- timeout: float = 30.0,
36
+ bearer_token: Optional[str] = None,
37
+ timeout: float = DEFAULT_TIMEOUT,
32
38
  log_requests: bool = False,
33
39
  custom_logger: Optional["loguru.Logger"] = None,
34
40
  ):
@@ -38,6 +44,7 @@ class AiriaAsyncClient(AiriaBaseClient):
38
44
  Args:
39
45
  base_url: Base URL of the Airia API.
40
46
  api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
47
+ bearer_token: Bearer token for authentication. Must be provided explicitly (no environment variable fallback).
41
48
  timeout: Request timeout in seconds.
42
49
  log_requests: Whether to log API requests and responses. Default is False.
43
50
  custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
@@ -45,6 +52,7 @@ class AiriaAsyncClient(AiriaBaseClient):
45
52
  super().__init__(
46
53
  base_url=base_url,
47
54
  api_key=api_key,
55
+ bearer_token=bearer_token,
48
56
  timeout=timeout,
49
57
  log_requests=log_requests,
50
58
  custom_logger=custom_logger,
@@ -80,10 +88,10 @@ class AiriaAsyncClient(AiriaBaseClient):
80
88
  @classmethod
81
89
  def with_openai_gateway(
82
90
  cls,
83
- base_url: str = "https://api.airia.ai/",
84
- gateway_url: str = "https://gateway.airia.ai/openai/v1",
91
+ base_url: str = DEFAULT_BASE_URL,
92
+ gateway_url: str = DEFAULT_OPENAI_GATEWAY_URL,
85
93
  api_key: Optional[str] = None,
86
- timeout: float = 30.0,
94
+ timeout: float = DEFAULT_TIMEOUT,
87
95
  log_requests: bool = False,
88
96
  custom_logger: Optional["loguru.Logger"] = None,
89
97
  **kwargs,
@@ -102,22 +110,28 @@ class AiriaAsyncClient(AiriaBaseClient):
102
110
  """
103
111
  from openai import AsyncOpenAI
104
112
 
105
- api_key = cls._get_api_key(api_key)
106
- cls.openai = AsyncOpenAI(
113
+ client = cls(
114
+ base_url=base_url,
107
115
  api_key=api_key,
116
+ timeout=timeout,
117
+ log_requests=log_requests,
118
+ custom_logger=custom_logger,
119
+ )
120
+ cls.openai = AsyncOpenAI(
121
+ api_key=client.api_key,
108
122
  base_url=gateway_url,
109
123
  **kwargs,
110
124
  )
111
125
 
112
- return cls(base_url, api_key, timeout, log_requests, custom_logger)
126
+ return client
113
127
 
114
128
  @classmethod
115
129
  def with_anthropic_gateway(
116
130
  cls,
117
- base_url: str = "https://api.airia.ai/",
118
- gateway_url: str = "https://gateway.airia.ai/anthropic",
131
+ base_url: str = DEFAULT_BASE_URL,
132
+ gateway_url: str = DEFAULT_ANTHROPIC_GATEWAY_URL,
119
133
  api_key: Optional[str] = None,
120
- timeout: float = 30.0,
134
+ timeout: float = DEFAULT_TIMEOUT,
121
135
  log_requests: bool = False,
122
136
  custom_logger: Optional["loguru.Logger"] = None,
123
137
  **kwargs,
@@ -136,14 +150,47 @@ class AiriaAsyncClient(AiriaBaseClient):
136
150
  """
137
151
  from anthropic import AsyncAnthropic
138
152
 
139
- api_key = cls._get_api_key(api_key)
140
- cls.anthropic = AsyncAnthropic(
153
+ client = cls(
154
+ base_url=base_url,
141
155
  api_key=api_key,
156
+ timeout=timeout,
157
+ log_requests=log_requests,
158
+ custom_logger=custom_logger,
159
+ )
160
+ cls.anthropic = AsyncAnthropic(
161
+ api_key=client.api_key,
142
162
  base_url=gateway_url,
143
163
  **kwargs,
144
164
  )
145
165
 
146
- return cls(base_url, api_key, timeout, log_requests, custom_logger)
166
+ return client
167
+
168
+ @classmethod
169
+ def with_bearer_token(
170
+ cls,
171
+ bearer_token: str,
172
+ base_url: str = DEFAULT_BASE_URL,
173
+ timeout: float = DEFAULT_TIMEOUT,
174
+ log_requests: bool = False,
175
+ custom_logger: Optional["loguru.Logger"] = None,
176
+ ):
177
+ """
178
+ Initialize the asynchronous Airia API client with bearer token authentication.
179
+
180
+ Args:
181
+ bearer_token: Bearer token for authentication.
182
+ base_url: Base URL of the Airia API.
183
+ timeout: Request timeout in seconds.
184
+ log_requests: Whether to log API requests and responses. Default is False.
185
+ custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
186
+ """
187
+ return cls(
188
+ base_url=base_url,
189
+ bearer_token=bearer_token,
190
+ timeout=timeout,
191
+ log_requests=log_requests,
192
+ custom_logger=custom_logger,
193
+ )
147
194
 
148
195
  def _handle_exception(
149
196
  self, e: aiohttp.ClientResponseError, url: str, correlation_id: str
@@ -159,12 +206,14 @@ class AiriaAsyncClient(AiriaBaseClient):
159
206
  # Extract error details from response
160
207
  error_message = e.message
161
208
 
162
- # Make sure API key is not included in error messages
163
- sanitized_message = (
164
- error_message.replace(self.api_key, "[REDACTED]")
165
- if self.api_key in error_message
166
- else error_message
167
- )
209
+ # Make sure sensitive auth information is not included in error messages
210
+ sanitized_message = error_message
211
+ if self.api_key and self.api_key in sanitized_message:
212
+ sanitized_message = sanitized_message.replace(self.api_key, "[REDACTED]")
213
+ if self.bearer_token and self.bearer_token in sanitized_message:
214
+ sanitized_message = sanitized_message.replace(
215
+ self.bearer_token, "[REDACTED]"
216
+ )
168
217
 
169
218
  # Raise custom exception with status code and sanitized message
170
219
  raise AiriaAPIError(status_code=e.status, message=sanitized_message) from e
@@ -297,7 +346,6 @@ class AiriaAsyncClient(AiriaBaseClient):
297
346
  additional_info: Optional[List[Any]] = None,
298
347
  prompt_variables: Optional[Dict[str, Any]] = None,
299
348
  correlation_id: Optional[str] = None,
300
- api_version: str = ApiVersion.V2.value,
301
349
  ) -> PipelineExecutionResponse: ...
302
350
 
303
351
  @overload
@@ -320,7 +368,6 @@ class AiriaAsyncClient(AiriaBaseClient):
320
368
  additional_info: Optional[List[Any]] = None,
321
369
  prompt_variables: Optional[Dict[str, Any]] = None,
322
370
  correlation_id: Optional[str] = None,
323
- api_version: str = ApiVersion.V2.value,
324
371
  ) -> PipelineExecutionDebugResponse: ...
325
372
 
326
373
  @overload
@@ -343,31 +390,7 @@ class AiriaAsyncClient(AiriaBaseClient):
343
390
  additional_info: Optional[List[Any]] = None,
344
391
  prompt_variables: Optional[Dict[str, Any]] = None,
345
392
  correlation_id: Optional[str] = None,
346
- api_version: Literal["v2"] = ApiVersion.V2.value,
347
- ) -> PipelineExecutionV2AsyncStreamedResponse: ...
348
-
349
- @overload
350
- async def execute_pipeline(
351
- self,
352
- pipeline_id: str,
353
- user_input: str,
354
- debug: bool = False,
355
- user_id: Optional[str] = None,
356
- conversation_id: Optional[str] = None,
357
- async_output: Literal[True] = True,
358
- include_tools_response: bool = False,
359
- images: Optional[List[str]] = None,
360
- files: Optional[List[str]] = None,
361
- data_source_folders: Optional[Dict[str, Any]] = None,
362
- data_source_files: Optional[Dict[str, Any]] = None,
363
- in_memory_messages: Optional[List[Dict[str, str]]] = None,
364
- current_date_time: Optional[str] = None,
365
- save_history: bool = True,
366
- additional_info: Optional[List[Any]] = None,
367
- prompt_variables: Optional[Dict[str, Any]] = None,
368
- correlation_id: Optional[str] = None,
369
- api_version: Literal["v1"] = ApiVersion.V1.value,
370
- ) -> PipelineExecutionV1StreamedResponse: ...
393
+ ) -> PipelineExecutionAsyncStreamedResponse: ...
371
394
 
372
395
  async def execute_pipeline(
373
396
  self,
@@ -388,7 +411,6 @@ class AiriaAsyncClient(AiriaBaseClient):
388
411
  additional_info: Optional[List[Any]] = None,
389
412
  prompt_variables: Optional[Dict[str, Any]] = None,
390
413
  correlation_id: Optional[str] = None,
391
- api_version: str = ApiVersion.V2.value,
392
414
  ) -> Dict[str, Any]:
393
415
  """
394
416
  Execute a pipeline with the provided input asynchronously.
@@ -412,7 +434,6 @@ class AiriaAsyncClient(AiriaBaseClient):
412
434
  prompt_variables: Optional variables to be used in the prompt.
413
435
  correlation_id: Optional correlation ID for request tracing. If not provided,
414
436
  one will be generated automatically.
415
- api_version: API version to use. Default is `v2`
416
437
 
417
438
  Returns:
418
439
  The API response as a dictionary.
@@ -447,38 +468,23 @@ class AiriaAsyncClient(AiriaBaseClient):
447
468
  additional_info=additional_info,
448
469
  prompt_variables=prompt_variables,
449
470
  correlation_id=correlation_id,
450
- api_version=api_version,
471
+ api_version=ApiVersion.V2.value,
472
+ )
473
+ resp = (
474
+ self._make_request_stream(method="POST", request_data=request_data)
475
+ if async_output
476
+ else await self._make_request("POST", request_data=request_data)
451
477
  )
452
- stream = async_output and api_version == ApiVersion.V2.value
453
- if stream:
454
- resp = self._make_request_stream(method="POST", request_data=request_data)
455
- else:
456
- resp = await self._make_request("POST", request_data)
457
478
 
458
479
  if not async_output:
459
480
  if not debug:
460
481
  return PipelineExecutionResponse(**resp)
461
482
  return PipelineExecutionDebugResponse(**resp)
462
483
 
463
- if api_version == ApiVersion.V1.value:
464
- url = urljoin(
465
- self.base_url, f"{api_version}/StreamSocketConfig/GenerateUrl"
466
- )
467
- request_data = self._prepare_request(
468
- url=url,
469
- payload={"socketIdentifier": resp},
470
- correlation_id=request_data.headers["X-Correlation-ID"],
471
- )
472
- resp = await self._make_request("POST", request_data)
473
-
474
- return PipelineExecutionV1StreamedResponse(**resp)
475
-
476
- return PipelineExecutionV2AsyncStreamedResponse(stream=resp)
484
+ return PipelineExecutionAsyncStreamedResponse(stream=resp)
477
485
 
478
486
  async def get_projects(
479
- self,
480
- correlation_id: Optional[str] = None,
481
- api_version: str = ApiVersion.V1.value,
487
+ self, correlation_id: Optional[str] = None
482
488
  ) -> List[ProjectItem]:
483
489
  """
484
490
  Retrieve a list of all projects accessible to the authenticated user.
@@ -490,8 +496,6 @@ class AiriaAsyncClient(AiriaBaseClient):
490
496
  Args:
491
497
  correlation_id (str, optional): A unique identifier for request tracing
492
498
  and logging. If not provided, one will be automatically generated.
493
- api_version (str, optional): The API version to use for the request.
494
- Defaults to "v1". Valid versions are defined in ApiVersion enum.
495
499
 
496
500
  Returns:
497
501
  List[ProjectItem]: A list of ProjectItem objects containing project
@@ -499,7 +503,6 @@ class AiriaAsyncClient(AiriaBaseClient):
499
503
  or found.
500
504
 
501
505
  Raises:
502
- ValueError: If the provided api_version is not valid.
503
506
  AiriaAPIError: If the API request fails, including cases where:
504
507
  - Authentication fails (401)
505
508
  - Access is forbidden (403)
@@ -528,7 +531,7 @@ class AiriaAsyncClient(AiriaBaseClient):
528
531
  access to.
529
532
  """
530
533
  request_data = self._pre_get_projects(
531
- correlation_id=correlation_id, api_version=api_version
534
+ correlation_id=correlation_id, api_version=ApiVersion.V1.value
532
535
  )
533
536
  resp = await self._make_request("GET", request_data)
534
537
 
@@ -538,10 +541,7 @@ class AiriaAsyncClient(AiriaBaseClient):
538
541
  return [ProjectItem(**item) for item in resp["items"]]
539
542
 
540
543
  async def get_active_pipelines_ids(
541
- self,
542
- project_id: Optional[str] = None,
543
- correlation_id: Optional[str] = None,
544
- api_version: str = ApiVersion.V1.value,
544
+ self, project_id: Optional[str] = None, correlation_id: Optional[str] = None
545
545
  ) -> List[str]:
546
546
  """
547
547
  Retrieve a list of active pipeline IDs.
@@ -555,15 +555,12 @@ class AiriaAsyncClient(AiriaBaseClient):
555
555
  accessible to the authenticated user.
556
556
  correlation_id (str, optional): A unique identifier for request tracing
557
557
  and logging. If not provided, one will be automatically generated.
558
- api_version (str, optional): The API version to use for the request.
559
- Defaults to "v1". Valid versions are defined in ApiVersion enum.
560
558
 
561
559
  Returns:
562
560
  List[str]: A list of pipeline IDs that are currently active. Returns an
563
561
  empty list if no active pipelines are found.
564
562
 
565
563
  Raises:
566
- ValueError: If the provided api_version is not valid.
567
564
  AiriaAPIError: If the API request fails, including cases where:
568
565
  - The project_id doesn't exist (404)
569
566
  - Authentication fails (401)
@@ -594,7 +591,7 @@ class AiriaAsyncClient(AiriaBaseClient):
594
591
  request_data = self._pre_get_active_pipelines_ids(
595
592
  project_id=project_id,
596
593
  correlation_id=correlation_id,
597
- api_version=api_version,
594
+ api_version=ApiVersion.V1.value,
598
595
  )
599
596
  resp = await self._make_request("GET", request_data)
600
597
 
@@ -606,10 +603,7 @@ class AiriaAsyncClient(AiriaBaseClient):
606
603
  return pipeline_ids
607
604
 
608
605
  async def get_pipeline_config(
609
- self,
610
- pipeline_id: str,
611
- correlation_id: Optional[str] = None,
612
- api_version: str = ApiVersion.V1.value,
606
+ self, pipeline_id: str, correlation_id: Optional[str] = None
613
607
  ) -> GetPipelineConfigResponse:
614
608
  """
615
609
  Retrieve configuration details for a specific pipeline.
@@ -620,8 +614,6 @@ class AiriaAsyncClient(AiriaBaseClient):
620
614
  Args:
621
615
  pipeline_id (str): The unique identifier of the pipeline to retrieve
622
616
  configuration for.
623
- api_version (str, optional): The API version to use for the request.
624
- Defaults to "v1". Valid versions are defined in ApiVersion enum.
625
617
  correlation_id (str, optional): A unique identifier for request tracing
626
618
  and logging. If not provided, one will be automatically generated.
627
619
 
@@ -630,7 +622,6 @@ class AiriaAsyncClient(AiriaBaseClient):
630
622
  configuration.
631
623
 
632
624
  Raises:
633
- ValueError: If the provided api_version is not valid.
634
625
  AiriaAPIError: If the API request fails, including cases where:
635
626
  - The pipeline_id doesn't exist (404)
636
627
  - Authentication fails (401)
@@ -661,8 +652,88 @@ class AiriaAsyncClient(AiriaBaseClient):
661
652
  request_data = self._pre_get_pipeline_config(
662
653
  pipeline_id=pipeline_id,
663
654
  correlation_id=correlation_id,
664
- api_version=api_version,
655
+ api_version=ApiVersion.V1.value,
665
656
  )
666
657
  resp = await self._make_request("GET", request_data)
667
658
 
668
659
  return GetPipelineConfigResponse(**resp)
660
+
661
+ async def create_conversation(
662
+ self,
663
+ user_id: str,
664
+ title: Optional[str] = None,
665
+ deployment_id: Optional[str] = None,
666
+ data_source_files: Dict[str, Any] = {},
667
+ is_bookmarked: bool = False,
668
+ correlation_id: Optional[str] = None,
669
+ ) -> CreateConversationResponse:
670
+ """
671
+ Create a new conversation.
672
+
673
+ Args:
674
+ user_id (str): The unique identifier of the user creating the conversation.
675
+ title (str, optional): The title for the conversation. If not provided,
676
+ the conversation will be created without a title.
677
+ deployment_id (str, optional): The unique identifier of the deployment
678
+ to associate with the conversation. If not provided, the conversation
679
+ will not be associated with any specific deployment.
680
+ data_source_files (dict): Configuration for data source files
681
+ to be associated with the conversation. If not provided, no data
682
+ source files will be associated.
683
+ is_bookmarked (bool): Whether the conversation should be bookmarked.
684
+ Defaults to False.
685
+ correlation_id (str, optional): A unique identifier for request tracing
686
+ and logging. If not provided, one will be automatically generated.
687
+
688
+ Returns:
689
+ CreateConversationResponse: A response object containing the created
690
+ conversation details including its ID, creation timestamp, and
691
+ all provided parameters.
692
+
693
+ Raises:
694
+ AiriaAPIError: If the API request fails, including cases where:
695
+ - The user_id doesn't exist (404)
696
+ - The deployment_id is invalid (404)
697
+ - Authentication fails (401)
698
+ - Access is forbidden (403)
699
+ - Server errors (5xx)
700
+
701
+ Example:
702
+ ```python
703
+ from airia import AiriaAsyncClient
704
+
705
+ client = AiriaAsyncClient(api_key="your_api_key")
706
+
707
+ # Create a basic conversation
708
+ conversation = await client.create_conversation(
709
+ user_id="user_123"
710
+ )
711
+ print(f"Created conversation: {conversation.conversation_id}")
712
+
713
+ # Create a conversation with all options
714
+ conversation = await client.create_conversation(
715
+ user_id="user_123",
716
+ title="My Research Session",
717
+ deployment_id="deployment_456",
718
+ data_source_files={"documents": ["doc1.pdf", "doc2.txt"]},
719
+ is_bookmarked=True
720
+ )
721
+ print(f"Created bookmarked conversation: {conversation.conversation_id}")
722
+ ```
723
+
724
+ Note:
725
+ The user_id is required and must correspond to a valid user in the system.
726
+ All other parameters are optional and can be set to None or their default values.
727
+ """
728
+ request_data = self._pre_create_conversation(
729
+ user_id=user_id,
730
+ title=title,
731
+ deployment_id=deployment_id,
732
+ data_source_files=data_source_files,
733
+ is_bookmarked=is_bookmarked,
734
+ correlation_id=correlation_id,
735
+ api_version=ApiVersion.V1.value,
736
+ )
737
+ resp = await self._make_request("POST", request_data)
738
+
739
+ return CreateConversationResponse(**resp)
@@ -5,8 +5,10 @@ from urllib.parse import urljoin
5
5
 
6
6
  import loguru
7
7
 
8
+ from ..constants import DEFAULT_BASE_URL, DEFAULT_TIMEOUT
8
9
  from ..logs import configure_logging, set_correlation_id
9
- from ..types import ApiVersion, RequestData
10
+ from ..types._api_version import ApiVersion
11
+ from ..types._request_data import RequestData
10
12
 
11
13
 
12
14
  class AiriaBaseClient:
@@ -17,9 +19,10 @@ class AiriaBaseClient:
17
19
 
18
20
  def __init__(
19
21
  self,
20
- base_url: str = "https://api.airia.ai/",
22
+ base_url: str = DEFAULT_BASE_URL,
21
23
  api_key: Optional[str] = None,
22
- timeout: float = 30.0,
24
+ bearer_token: Optional[str] = None,
25
+ timeout: float = DEFAULT_TIMEOUT,
23
26
  log_requests: bool = False,
24
27
  custom_logger: Optional["loguru.Logger"] = None,
25
28
  ):
@@ -28,12 +31,15 @@ class AiriaBaseClient:
28
31
 
29
32
  Args:
30
33
  api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
34
+ bearer_token: Bearer token for authentication. Must be provided explicitly (no environment variable fallback).
31
35
  timeout: Request timeout in seconds.
32
36
  log_requests: Whether to log API requests and responses. Default is False.
33
37
  custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
34
38
  """
35
- # Resolve API key: parameter takes precedence over environment variable
36
- self.api_key = self.__class__._get_api_key(api_key)
39
+ # Resolve authentication credentials
40
+ self.api_key, self.bearer_token = self.__class__._resolve_auth_credentials(
41
+ api_key, bearer_token
42
+ )
37
43
 
38
44
  # Store configuration
39
45
  self.base_url = base_url
@@ -44,27 +50,45 @@ class AiriaBaseClient:
44
50
  self.logger = configure_logging() if custom_logger is None else custom_logger
45
51
 
46
52
  @staticmethod
47
- def _get_api_key(api_key: Optional[str] = None):
53
+ def _resolve_auth_credentials(
54
+ api_key: Optional[str] = None, bearer_token: Optional[str] = None
55
+ ):
48
56
  """
49
- Get the API key from either the provided parameter or environment variable.
57
+ Resolve authentication credentials from parameters and environment variables.
50
58
 
51
59
  Args:
52
60
  api_key (Optional[str]): The API key provided as a parameter. Defaults to None.
61
+ bearer_token (Optional[str]): The bearer token provided as a parameter. Defaults to None.
53
62
 
54
63
  Returns:
55
- str: The resolved API key.
64
+ tuple: (api_key, bearer_token) - exactly one will be non-None
56
65
 
57
66
  Raises:
58
- ValueError: If no API key is provided through either method.
67
+ ValueError: If no authentication method is provided or if both are provided.
59
68
  """
60
- api_key = api_key or os.environ.get("AIRIA_API_KEY")
61
-
62
- if not api_key:
69
+ # Check for explicit conflict first
70
+ if api_key and bearer_token:
63
71
  raise ValueError(
64
- "API key must be provided either as a parameter or through the AIRIA_API_KEY environment variable."
72
+ "Cannot provide both api_key and bearer_token. Please use only one authentication method."
65
73
  )
66
74
 
67
- return api_key
75
+ # If bearer token is explicitly provided, use it exclusively
76
+ if bearer_token:
77
+ return None, bearer_token
78
+
79
+ # If API key is explicitly provided, use it exclusively
80
+ if api_key:
81
+ return api_key, None
82
+
83
+ # If neither is provided explicitly, fall back to environment variable
84
+ resolved_api_key = os.environ.get("AIRIA_API_KEY")
85
+ if resolved_api_key:
86
+ return resolved_api_key, None
87
+
88
+ # No authentication method found
89
+ raise ValueError(
90
+ "Authentication required. Provide either api_key (or set AIRIA_API_KEY environment variable) or bearer_token."
91
+ )
68
92
 
69
93
  def _prepare_request(
70
94
  self,
@@ -76,13 +100,18 @@ class AiriaBaseClient:
76
100
  # Set correlation ID if provided or generate a new one
77
101
  correlation_id = set_correlation_id(correlation_id)
78
102
 
79
- # Add the X-API-KEY header and correlation ID
103
+ # Set up base headers
80
104
  headers = {
81
- "X-API-KEY": self.api_key,
82
105
  "X-Correlation-ID": correlation_id,
83
106
  "Content-Type": "application/json",
84
107
  }
85
108
 
109
+ # Add authentication header based on the method used
110
+ if self.api_key:
111
+ headers["X-API-KEY"] = self.api_key
112
+ elif self.bearer_token:
113
+ headers["Authorization"] = f"Bearer {self.bearer_token}"
114
+
86
115
  # Log the request if enabled
87
116
  if self.log_requests:
88
117
  # Create a sanitized copy of headers and params for logging
@@ -92,6 +121,8 @@ class AiriaBaseClient:
92
121
  # Filter out sensitive headers
93
122
  if "X-API-KEY" in log_headers:
94
123
  log_headers["X-API-KEY"] = "[REDACTED]"
124
+ if "Authorization" in log_headers:
125
+ log_headers["Authorization"] = "[REDACTED]"
95
126
 
96
127
  # Process payload for logging
97
128
  log_payload = payload.copy() if payload is not None else {}
@@ -218,6 +249,36 @@ class AiriaBaseClient:
218
249
 
219
250
  return request_data
220
251
 
252
+ def _pre_create_conversation(
253
+ self,
254
+ user_id: str,
255
+ title: Optional[str] = None,
256
+ deployment_id: Optional[str] = None,
257
+ data_source_files: Dict[str, Any] = {},
258
+ is_bookmarked: bool = False,
259
+ correlation_id: Optional[str] = None,
260
+ api_version: str = ApiVersion.V1.value,
261
+ ):
262
+ if api_version not in ApiVersion.as_list():
263
+ raise ValueError(
264
+ f"Invalid API version: {api_version}. Valid versions are: {', '.join(ApiVersion.as_list())}"
265
+ )
266
+ url = urljoin(self.base_url, f"{api_version}/Conversations")
267
+
268
+ payload = {
269
+ "userId": user_id,
270
+ "title": title,
271
+ "deploymentId": deployment_id,
272
+ "dataSourceFiles": data_source_files,
273
+ "isBookmarked": is_bookmarked,
274
+ }
275
+
276
+ request_data = self._prepare_request(
277
+ url=url, payload=payload, correlation_id=correlation_id
278
+ )
279
+
280
+ return request_data
281
+
221
282
  def _pre_get_projects(
222
283
  self,
223
284
  correlation_id: Optional[str] = None,