airia 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,19 +1,24 @@
1
1
  from typing import Any, Dict, List, Literal, Optional, overload
2
- from urllib.parse import urljoin
3
2
 
4
3
  import loguru
5
4
  import requests
6
5
 
6
+ from ..constants import (
7
+ DEFAULT_ANTHROPIC_GATEWAY_URL,
8
+ DEFAULT_BASE_URL,
9
+ DEFAULT_OPENAI_GATEWAY_URL,
10
+ DEFAULT_TIMEOUT,
11
+ )
7
12
  from ..exceptions import AiriaAPIError
8
- from ..types import (
9
- ApiVersion,
13
+ from ..types._api_version import ApiVersion
14
+ from ..types._request_data import RequestData
15
+ from ..types.api import (
16
+ CreateConversationResponse,
10
17
  GetPipelineConfigResponse,
11
18
  PipelineExecutionDebugResponse,
12
19
  PipelineExecutionResponse,
13
- PipelineExecutionV1StreamedResponse,
14
- PipelineExecutionV2StreamedResponse,
20
+ PipelineExecutionStreamedResponse,
15
21
  ProjectItem,
16
- RequestData,
17
22
  )
18
23
  from ..utils.sse_parser import parse_sse_stream_chunked
19
24
  from .base_client import AiriaBaseClient
@@ -24,9 +29,10 @@ class AiriaClient(AiriaBaseClient):
24
29
 
25
30
  def __init__(
26
31
  self,
27
- base_url: str = "https://api.airia.ai/",
32
+ base_url: str = DEFAULT_BASE_URL,
28
33
  api_key: Optional[str] = None,
29
- timeout: float = 30.0,
34
+ bearer_token: Optional[str] = None,
35
+ timeout: float = DEFAULT_TIMEOUT,
30
36
  log_requests: bool = False,
31
37
  custom_logger: Optional["loguru.Logger"] = None,
32
38
  ):
@@ -36,6 +42,7 @@ class AiriaClient(AiriaBaseClient):
36
42
  Args:
37
43
  base_url: Base URL of the Airia API.
38
44
  api_key: API key for authentication. If not provided, will attempt to use AIRIA_API_KEY environment variable.
45
+ bearer_token: Bearer token for authentication. Must be provided explicitly (no environment variable fallback).
39
46
  timeout: Request timeout in seconds.
40
47
  log_requests: Whether to log API requests and responses. Default is False.
41
48
  custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
@@ -43,6 +50,7 @@ class AiriaClient(AiriaBaseClient):
43
50
  super().__init__(
44
51
  base_url=base_url,
45
52
  api_key=api_key,
53
+ bearer_token=bearer_token,
46
54
  timeout=timeout,
47
55
  log_requests=log_requests,
48
56
  custom_logger=custom_logger,
@@ -55,10 +63,10 @@ class AiriaClient(AiriaBaseClient):
55
63
  @classmethod
56
64
  def with_openai_gateway(
57
65
  cls,
58
- base_url: str = "https://api.airia.ai/",
59
- gateway_url: str = "https://gateway.airia.ai/openai/v1",
66
+ base_url: str = DEFAULT_BASE_URL,
67
+ gateway_url: str = DEFAULT_OPENAI_GATEWAY_URL,
60
68
  api_key: Optional[str] = None,
61
- timeout: float = 30.0,
69
+ timeout: float = DEFAULT_TIMEOUT,
62
70
  log_requests: bool = False,
63
71
  custom_logger: Optional["loguru.Logger"] = None,
64
72
  **kwargs,
@@ -77,22 +85,28 @@ class AiriaClient(AiriaBaseClient):
77
85
  """
78
86
  from openai import OpenAI
79
87
 
80
- api_key = cls._get_api_key(api_key)
81
- cls.openai = OpenAI(
88
+ client = cls(
89
+ base_url=base_url,
82
90
  api_key=api_key,
91
+ timeout=timeout,
92
+ log_requests=log_requests,
93
+ custom_logger=custom_logger,
94
+ )
95
+ cls.openai = OpenAI(
96
+ api_key=client.api_key,
83
97
  base_url=gateway_url,
84
98
  **kwargs,
85
99
  )
86
100
 
87
- return cls(base_url, api_key, timeout, log_requests, custom_logger)
101
+ return client
88
102
 
89
103
  @classmethod
90
104
  def with_anthropic_gateway(
91
105
  cls,
92
- base_url: str = "https://api.airia.ai/",
93
- gateway_url: str = "https://gateway.airia.ai/anthropic",
106
+ base_url: str = DEFAULT_BASE_URL,
107
+ gateway_url: str = DEFAULT_ANTHROPIC_GATEWAY_URL,
94
108
  api_key: Optional[str] = None,
95
- timeout: float = 30.0,
109
+ timeout: float = DEFAULT_TIMEOUT,
96
110
  log_requests: bool = False,
97
111
  custom_logger: Optional["loguru.Logger"] = None,
98
112
  **kwargs,
@@ -111,14 +125,47 @@ class AiriaClient(AiriaBaseClient):
111
125
  """
112
126
  from anthropic import Anthropic
113
127
 
114
- api_key = cls._get_api_key(api_key)
115
- cls.anthropic = Anthropic(
128
+ client = cls(
129
+ base_url=base_url,
116
130
  api_key=api_key,
131
+ timeout=timeout,
132
+ log_requests=log_requests,
133
+ custom_logger=custom_logger,
134
+ )
135
+ cls.anthropic = Anthropic(
136
+ api_key=client.api_key,
117
137
  base_url=gateway_url,
118
138
  **kwargs,
119
139
  )
120
140
 
121
- return cls(base_url, api_key, timeout, log_requests, custom_logger)
141
+ return client
142
+
143
+ @classmethod
144
+ def with_bearer_token(
145
+ cls,
146
+ bearer_token: str,
147
+ base_url: str = DEFAULT_BASE_URL,
148
+ timeout: float = DEFAULT_TIMEOUT,
149
+ log_requests: bool = False,
150
+ custom_logger: Optional["loguru.Logger"] = None,
151
+ ):
152
+ """
153
+ Initialize the synchronous Airia API client with bearer token authentication.
154
+
155
+ Args:
156
+ bearer_token: Bearer token for authentication.
157
+ base_url: Base URL of the Airia API.
158
+ timeout: Request timeout in seconds.
159
+ log_requests: Whether to log API requests and responses. Default is False.
160
+ custom_logger: Optional custom logger object to use for logging. If not provided, will use a default logger when `log_requests` is True.
161
+ """
162
+ return cls(
163
+ base_url=base_url,
164
+ bearer_token=bearer_token,
165
+ timeout=timeout,
166
+ log_requests=log_requests,
167
+ custom_logger=custom_logger,
168
+ )
122
169
 
123
170
  def _handle_exception(self, e: requests.HTTPError, url: str, correlation_id: str):
124
171
  # Log the error response if enabled
@@ -141,12 +188,14 @@ class AiriaClient(AiriaBaseClient):
141
188
  # If JSON parsing fails or expected keys are missing
142
189
  error_message = f"API request failed: {str(e)}"
143
190
 
144
- # Make sure API key is not included in error messages
145
- sanitized_message = (
146
- error_message.replace(self.api_key, "[REDACTED]")
147
- if self.api_key in error_message
148
- else error_message
149
- )
191
+ # Make sure sensitive auth information is not included in error messages
192
+ sanitized_message = error_message
193
+ if self.api_key and self.api_key in sanitized_message:
194
+ sanitized_message = sanitized_message.replace(self.api_key, "[REDACTED]")
195
+ if self.bearer_token and self.bearer_token in sanitized_message:
196
+ sanitized_message = sanitized_message.replace(
197
+ self.bearer_token, "[REDACTED]"
198
+ )
150
199
 
151
200
  # Raise custom exception with status code and sanitized message
152
201
  raise AiriaAPIError(
@@ -278,7 +327,6 @@ class AiriaClient(AiriaBaseClient):
278
327
  additional_info: Optional[List[Any]] = None,
279
328
  prompt_variables: Optional[Dict[str, Any]] = None,
280
329
  correlation_id: Optional[str] = None,
281
- api_version: str = ApiVersion.V2.value,
282
330
  ) -> PipelineExecutionResponse: ...
283
331
 
284
332
  @overload
@@ -301,7 +349,6 @@ class AiriaClient(AiriaBaseClient):
301
349
  additional_info: Optional[List[Any]] = None,
302
350
  prompt_variables: Optional[Dict[str, Any]] = None,
303
351
  correlation_id: Optional[str] = None,
304
- api_version: str = ApiVersion.V2.value,
305
352
  ) -> PipelineExecutionDebugResponse: ...
306
353
 
307
354
  @overload
@@ -324,31 +371,7 @@ class AiriaClient(AiriaBaseClient):
324
371
  additional_info: Optional[List[Any]] = None,
325
372
  prompt_variables: Optional[Dict[str, Any]] = None,
326
373
  correlation_id: Optional[str] = None,
327
- api_version: Literal["v2"] = ApiVersion.V2.value,
328
- ) -> PipelineExecutionV2StreamedResponse: ...
329
-
330
- @overload
331
- def execute_pipeline(
332
- self,
333
- pipeline_id: str,
334
- user_input: str,
335
- debug: bool = False,
336
- user_id: Optional[str] = None,
337
- conversation_id: Optional[str] = None,
338
- async_output: Literal[True] = True,
339
- include_tools_response: bool = False,
340
- images: Optional[List[str]] = None,
341
- files: Optional[List[str]] = None,
342
- data_source_folders: Optional[Dict[str, Any]] = None,
343
- data_source_files: Optional[Dict[str, Any]] = None,
344
- in_memory_messages: Optional[List[Dict[str, str]]] = None,
345
- current_date_time: Optional[str] = None,
346
- save_history: bool = True,
347
- additional_info: Optional[List[Any]] = None,
348
- prompt_variables: Optional[Dict[str, Any]] = None,
349
- correlation_id: Optional[str] = None,
350
- api_version: Literal["v1"] = ApiVersion.V1.value,
351
- ) -> PipelineExecutionV1StreamedResponse: ...
374
+ ) -> PipelineExecutionStreamedResponse: ...
352
375
 
353
376
  def execute_pipeline(
354
377
  self,
@@ -369,7 +392,6 @@ class AiriaClient(AiriaBaseClient):
369
392
  additional_info: Optional[List[Any]] = None,
370
393
  prompt_variables: Optional[Dict[str, Any]] = None,
371
394
  correlation_id: Optional[str] = None,
372
- api_version: str = ApiVersion.V2.value,
373
395
  ):
374
396
  """
375
397
  Execute a pipeline with the provided input.
@@ -393,7 +415,6 @@ class AiriaClient(AiriaBaseClient):
393
415
  prompt_variables: Optional variables to be used in the prompt.
394
416
  correlation_id: Optional correlation ID for request tracing. If not provided,
395
417
  one will be generated automatically.
396
- api_version: API version to use. Default is `v2`
397
418
 
398
419
  Returns:
399
420
  The API response as a dictionary.
@@ -428,12 +449,11 @@ class AiriaClient(AiriaBaseClient):
428
449
  additional_info=additional_info,
429
450
  prompt_variables=prompt_variables,
430
451
  correlation_id=correlation_id,
431
- api_version=api_version,
452
+ api_version=ApiVersion.V2.value,
432
453
  )
433
- stream = async_output and api_version == ApiVersion.V2.value
434
454
  resp = (
435
455
  self._make_request_stream("POST", request_data)
436
- if stream
456
+ if async_output
437
457
  else self._make_request("POST", request_data)
438
458
  )
439
459
 
@@ -442,26 +462,9 @@ class AiriaClient(AiriaBaseClient):
442
462
  return PipelineExecutionResponse(**resp)
443
463
  return PipelineExecutionDebugResponse(**resp)
444
464
 
445
- if api_version == ApiVersion.V1.value:
446
- url = urljoin(
447
- self.base_url, f"{api_version}/StreamSocketConfig/GenerateUrl"
448
- )
449
- request_data = self._prepare_request(
450
- url,
451
- payload={"socketIdentifier": resp},
452
- correlation_id=request_data.headers["X-Correlation-ID"],
453
- )
454
- resp = self._make_request("POST", request_data)
455
-
456
- return PipelineExecutionV1StreamedResponse(**resp)
457
-
458
- return PipelineExecutionV2StreamedResponse(stream=resp)
465
+ return PipelineExecutionStreamedResponse(stream=resp)
459
466
 
460
- def get_projects(
461
- self,
462
- correlation_id: Optional[str] = None,
463
- api_version: str = ApiVersion.V1.value,
464
- ) -> List[ProjectItem]:
467
+ def get_projects(self, correlation_id: Optional[str] = None) -> List[ProjectItem]:
465
468
  """
466
469
  Retrieve a list of all projects accessible to the authenticated user.
467
470
 
@@ -472,8 +475,6 @@ class AiriaClient(AiriaBaseClient):
472
475
  Args:
473
476
  correlation_id (str, optional): A unique identifier for request tracing
474
477
  and logging. If not provided, one will be automatically generated.
475
- api_version (str, optional): The API version to use for the request.
476
- Defaults to "v1". Valid versions are defined in ApiVersion enum.
477
478
 
478
479
  Returns:
479
480
  List[ProjectItem]: A list of ProjectItem objects containing project
@@ -481,7 +482,6 @@ class AiriaClient(AiriaBaseClient):
481
482
  or found.
482
483
 
483
484
  Raises:
484
- ValueError: If the provided api_version is not valid.
485
485
  AiriaAPIError: If the API request fails, including cases where:
486
486
  - Authentication fails (401)
487
487
  - Access is forbidden (403)
@@ -510,7 +510,7 @@ class AiriaClient(AiriaBaseClient):
510
510
  access to.
511
511
  """
512
512
  request_data = self._pre_get_projects(
513
- correlation_id=correlation_id, api_version=api_version
513
+ correlation_id=correlation_id, api_version=ApiVersion.V1.value
514
514
  )
515
515
  resp = self._make_request("GET", request_data)
516
516
 
@@ -520,10 +520,7 @@ class AiriaClient(AiriaBaseClient):
520
520
  return [ProjectItem(**item) for item in resp["items"]]
521
521
 
522
522
  def get_active_pipelines_ids(
523
- self,
524
- project_id: Optional[str] = None,
525
- correlation_id: Optional[str] = None,
526
- api_version: str = ApiVersion.V1.value,
523
+ self, project_id: Optional[str] = None, correlation_id: Optional[str] = None
527
524
  ) -> List[str]:
528
525
  """
529
526
  Retrieve a list of active pipeline IDs.
@@ -537,15 +534,12 @@ class AiriaClient(AiriaBaseClient):
537
534
  accessible to the authenticated user.
538
535
  correlation_id (str, optional): A unique identifier for request tracing
539
536
  and logging. If not provided, one will be automatically generated.
540
- api_version (str, optional): The API version to use for the request.
541
- Defaults to "v1". Valid versions are defined in ApiVersion enum.
542
537
 
543
538
  Returns:
544
539
  List[str]: A list of pipeline IDs that are currently active. Returns an
545
540
  empty list if no active pipelines are found.
546
541
 
547
542
  Raises:
548
- ValueError: If the provided api_version is not valid.
549
543
  AiriaAPIError: If the API request fails, including cases where:
550
544
  - The project_id doesn't exist (404)
551
545
  - Authentication fails (401)
@@ -576,7 +570,7 @@ class AiriaClient(AiriaBaseClient):
576
570
  request_data = self._pre_get_active_pipelines_ids(
577
571
  project_id=project_id,
578
572
  correlation_id=correlation_id,
579
- api_version=api_version,
573
+ api_version=ApiVersion.V1.value,
580
574
  )
581
575
  resp = self._make_request("GET", request_data)
582
576
 
@@ -588,10 +582,7 @@ class AiriaClient(AiriaBaseClient):
588
582
  return pipeline_ids
589
583
 
590
584
  def get_pipeline_config(
591
- self,
592
- pipeline_id: str,
593
- correlation_id: Optional[str] = None,
594
- api_version: str = ApiVersion.V1.value,
585
+ self, pipeline_id: str, correlation_id: Optional[str] = None
595
586
  ) -> GetPipelineConfigResponse:
596
587
  """
597
588
  Retrieve configuration details for a specific pipeline.
@@ -602,8 +593,6 @@ class AiriaClient(AiriaBaseClient):
602
593
  Args:
603
594
  pipeline_id (str): The unique identifier of the pipeline to retrieve
604
595
  configuration for.
605
- api_version (str, optional): The API version to use for the request.
606
- Defaults to "v1". Valid versions are defined in ApiVersion enum.
607
596
  correlation_id (str, optional): A unique identifier for request tracing
608
597
  and logging. If not provided, one will be automatically generated.
609
598
 
@@ -612,7 +601,6 @@ class AiriaClient(AiriaBaseClient):
612
601
  configuration.
613
602
 
614
603
  Raises:
615
- ValueError: If the provided api_version is not valid.
616
604
  AiriaAPIError: If the API request fails, including cases where:
617
605
  - The pipeline_id doesn't exist (404)
618
606
  - Authentication fails (401)
@@ -643,8 +631,88 @@ class AiriaClient(AiriaBaseClient):
643
631
  request_data = self._pre_get_pipeline_config(
644
632
  pipeline_id=pipeline_id,
645
633
  correlation_id=correlation_id,
646
- api_version=api_version,
634
+ api_version=ApiVersion.V1.value,
647
635
  )
648
636
  resp = self._make_request("GET", request_data)
649
637
 
650
638
  return GetPipelineConfigResponse(**resp)
639
+
640
+ def create_conversation(
641
+ self,
642
+ user_id: str,
643
+ title: Optional[str] = None,
644
+ deployment_id: Optional[str] = None,
645
+ data_source_files: Dict[str, Any] = {},
646
+ is_bookmarked: bool = False,
647
+ correlation_id: Optional[str] = None,
648
+ ) -> CreateConversationResponse:
649
+ """
650
+ Create a new conversation.
651
+
652
+ Args:
653
+ user_id (str): The unique identifier of the user creating the conversation.
654
+ title (str, optional): The title for the conversation. If not provided,
655
+ the conversation will be created without a title.
656
+ deployment_id (str, optional): The unique identifier of the deployment
657
+ to associate with the conversation. If not provided, the conversation
658
+ will not be associated with any specific deployment.
659
+ data_source_files (dict): Configuration for data source files
660
+ to be associated with the conversation. If not provided, no data
661
+ source files will be associated.
662
+ is_bookmarked (bool): Whether the conversation should be bookmarked.
663
+ Defaults to False.
664
+ correlation_id (str, optional): A unique identifier for request tracing
665
+ and logging. If not provided, one will be automatically generated.
666
+
667
+ Returns:
668
+ CreateConversationResponse: A response object containing the created
669
+ conversation details including its ID, creation timestamp, and
670
+ all provided parameters.
671
+
672
+ Raises:
673
+ AiriaAPIError: If the API request fails, including cases where:
674
+ - The user_id doesn't exist (404)
675
+ - The deployment_id is invalid (404)
676
+ - Authentication fails (401)
677
+ - Access is forbidden (403)
678
+ - Server errors (5xx)
679
+
680
+ Example:
681
+ ```python
682
+ from airia import AiriaClient
683
+
684
+ client = AiriaClient(api_key="your_api_key")
685
+
686
+ # Create a basic conversation
687
+ conversation = client.create_conversation(
688
+ user_id="user_123"
689
+ )
690
+ print(f"Created conversation: {conversation.id}")
691
+
692
+ # Create a conversation with all options
693
+ conversation = client.create_conversation(
694
+ user_id="user_123",
695
+ title="My Research Session",
696
+ deployment_id="deployment_456",
697
+ data_source_files={"documents": ["doc1.pdf", "doc2.txt"]},
698
+ is_bookmarked=True
699
+ )
700
+ print(f"Created bookmarked conversation: {conversation.id}")
701
+ ```
702
+
703
+ Note:
704
+ The user_id is required and must correspond to a valid user in the system.
705
+ All other parameters are optional and can be set to None or their default values.
706
+ """
707
+ request_data = self._pre_create_conversation(
708
+ user_id=user_id,
709
+ title=title,
710
+ deployment_id=deployment_id,
711
+ data_source_files=data_source_files,
712
+ is_bookmarked=is_bookmarked,
713
+ correlation_id=correlation_id,
714
+ api_version=ApiVersion.V1.value,
715
+ )
716
+ resp = self._make_request("POST", request_data)
717
+
718
+ return CreateConversationResponse(**resp)
airia/constants.py ADDED
@@ -0,0 +1,9 @@
1
+ """Constants used throughout the Airia SDK."""
2
+
3
+ # Default API endpoints
4
+ DEFAULT_BASE_URL = "https://api.airia.ai/"
5
+ DEFAULT_OPENAI_GATEWAY_URL = "https://gateway.airia.ai/openai/v1"
6
+ DEFAULT_ANTHROPIC_GATEWAY_URL = "https://gateway.airia.ai/anthropic"
7
+
8
+ # Default timeouts
9
+ DEFAULT_TIMEOUT = 30.0
@@ -0,0 +1,19 @@
1
+ from .get_projects import ProjectItem
2
+ from .get_pipeline_config import GetPipelineConfigResponse
3
+ from .pipeline_execution import (
4
+ PipelineExecutionDebugResponse,
5
+ PipelineExecutionResponse,
6
+ PipelineExecutionAsyncStreamedResponse,
7
+ PipelineExecutionStreamedResponse,
8
+ )
9
+ from .conversations import CreateConversationResponse
10
+
11
+ __all__ = [
12
+ "PipelineExecutionDebugResponse",
13
+ "PipelineExecutionResponse",
14
+ "PipelineExecutionStreamedResponse",
15
+ "PipelineExecutionAsyncStreamedResponse",
16
+ "GetPipelineConfigResponse",
17
+ "ProjectItem",
18
+ "CreateConversationResponse",
19
+ ]
@@ -0,0 +1,14 @@
1
+ from typing import Optional
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class CreateConversationResponse(BaseModel):
7
+ user_id: str = Field(alias="userId")
8
+ conversation_id: str = Field(alias="conversationId")
9
+ websocket_url: str = Field(alias="websocketUrl")
10
+ deployment_id: str = Field(alias="deploymentId")
11
+ icon_id: Optional[str] = Field(None, alias="iconId")
12
+ icon_url: Optional[str] = Field(None, alias="iconUrl")
13
+ description: Optional[str] = None
14
+ space_name: Optional[str] = Field(None, alias="spaceName")
@@ -1,33 +1,29 @@
1
1
  from typing import Any, AsyncIterator, Dict, Iterator
2
2
 
3
- from pydantic import BaseModel, ConfigDict
3
+ from pydantic import BaseModel, ConfigDict, Field
4
4
 
5
- from ..sse_messages import SSEMessage
5
+ from ..sse import SSEMessage
6
6
 
7
7
 
8
8
  class PipelineExecutionResponse(BaseModel):
9
9
  result: str
10
10
  report: None
11
- isBackupPipeline: bool
11
+ is_backup_pipeline: bool = Field(alias="isBackupPipeline")
12
12
 
13
13
 
14
14
  class PipelineExecutionDebugResponse(BaseModel):
15
15
  result: str
16
16
  report: Dict[str, Any]
17
- isBackupPipeline: bool
17
+ is_backup_pipeline: bool = Field(alias="isBackupPipeline")
18
18
 
19
19
 
20
- class PipelineExecutionV1StreamedResponse(BaseModel):
21
- webSocketUrl: str
22
-
23
-
24
- class PipelineExecutionV2StreamedResponse(BaseModel):
20
+ class PipelineExecutionStreamedResponse(BaseModel):
25
21
  model_config = ConfigDict(arbitrary_types_allowed=True)
26
22
 
27
23
  stream: Iterator[SSEMessage]
28
24
 
29
25
 
30
- class PipelineExecutionV2AsyncStreamedResponse(BaseModel):
26
+ class PipelineExecutionAsyncStreamedResponse(BaseModel):
31
27
  model_config = ConfigDict(arbitrary_types_allowed=True)
32
28
 
33
29
  stream: AsyncIterator[SSEMessage]
@@ -1,15 +1,6 @@
1
- from .api.get_projects import ProjectItem
2
- from .api.get_pipeline_config import GetPipelineConfigResponse
3
- from .api.pipeline_execution import (
4
- PipelineExecutionDebugResponse,
5
- PipelineExecutionResponse,
6
- PipelineExecutionV1StreamedResponse,
7
- PipelineExecutionV2AsyncStreamedResponse,
8
- PipelineExecutionV2StreamedResponse,
9
- )
10
- from .api_version import ApiVersion
11
- from .request_data import RequestData
12
1
  from .sse_messages import (
2
+ SSEDict,
3
+ SSEMessage,
13
4
  AgentAgentCardMessage,
14
5
  AgentAgentCardStreamEndMessage,
15
6
  AgentAgentCardStreamErrorMessage,
@@ -37,15 +28,8 @@ from .sse_messages import (
37
28
  )
38
29
 
39
30
  __all__ = [
40
- "ApiVersion",
41
- "PipelineExecutionDebugResponse",
42
- "PipelineExecutionResponse",
43
- "PipelineExecutionV1StreamedResponse",
44
- "PipelineExecutionV2AsyncStreamedResponse",
45
- "PipelineExecutionV2StreamedResponse",
46
- "GetPipelineConfigResponse",
47
- "ProjectItem",
48
- "RequestData",
31
+ "SSEDict",
32
+ "SSEMessage",
49
33
  "AgentPingMessage",
50
34
  "AgentStartMessage",
51
35
  "AgentEndMessage",
airia/utils/sse_parser.py CHANGED
@@ -2,7 +2,7 @@ import json
2
2
  import re
3
3
  from typing import AsyncIterable, AsyncIterator, Iterable, Iterator
4
4
 
5
- from ..types.sse_messages import SSEDict, SSEMessage
5
+ from ..types.sse import SSEDict, SSEMessage
6
6
 
7
7
 
8
8
  def _to_snake_case(name: str):