celeste-ai 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of celeste-ai might be problematic. Click here for more details.
- celeste/__init__.py +108 -17
- celeste/artifacts.py +52 -0
- celeste/client.py +150 -0
- celeste/constraints.py +180 -0
- celeste/core.py +52 -0
- celeste/credentials.py +108 -0
- celeste/http.py +201 -0
- celeste/io.py +43 -0
- celeste/mime_types.py +47 -0
- celeste/models.py +91 -0
- celeste/parameters.py +51 -0
- celeste/py.typed +1 -0
- celeste/streaming.py +114 -0
- celeste_ai-0.0.3.dist-info/METADATA +163 -0
- celeste_ai-0.0.3.dist-info/RECORD +17 -0
- celeste_ai-0.0.3.dist-info/licenses/LICENSE +201 -0
- celeste_ai-0.0.1.dist-info/METADATA +0 -59
- celeste_ai-0.0.1.dist-info/RECORD +0 -4
- {celeste_ai-0.0.1.dist-info → celeste_ai-0.0.3.dist-info}/WHEEL +0 -0
celeste/http.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
"""HTTP client with persistent connection pooling for AI provider APIs."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from collections.abc import AsyncIterator
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
from httpx_sse import aconnect_sse
|
|
10
|
+
|
|
11
|
+
from celeste.core import Capability, Provider
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
MAX_CONNECTIONS = 20
|
|
16
|
+
MAX_KEEPALIVE_CONNECTIONS = 10
|
|
17
|
+
DEFAULT_TIMEOUT = 60.0
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class HTTPClient:
|
|
21
|
+
"""Async HTTP client with persistent connection pooling."""
|
|
22
|
+
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
max_connections: int = MAX_CONNECTIONS,
|
|
26
|
+
max_keepalive_connections: int = MAX_KEEPALIVE_CONNECTIONS,
|
|
27
|
+
) -> None:
|
|
28
|
+
"""Initialize HTTP client with connection pool limits.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
max_connections: Maximum total connections in pool.
|
|
32
|
+
max_keepalive_connections: Maximum idle keepalive connections.
|
|
33
|
+
"""
|
|
34
|
+
self._client: httpx.AsyncClient | None = None
|
|
35
|
+
self._max_connections = max_connections
|
|
36
|
+
self._max_keepalive_connections = max_keepalive_connections
|
|
37
|
+
|
|
38
|
+
async def _get_client(self) -> httpx.AsyncClient:
|
|
39
|
+
"""Get or create httpx.AsyncClient with connection pooling."""
|
|
40
|
+
if self._client is None:
|
|
41
|
+
limits = httpx.Limits(
|
|
42
|
+
max_connections=self._max_connections,
|
|
43
|
+
max_keepalive_connections=self._max_keepalive_connections,
|
|
44
|
+
)
|
|
45
|
+
self._client = httpx.AsyncClient(limits=limits) # nosec B113
|
|
46
|
+
return self._client
|
|
47
|
+
|
|
48
|
+
async def post(
|
|
49
|
+
self,
|
|
50
|
+
url: str,
|
|
51
|
+
headers: dict[str, str],
|
|
52
|
+
json_body: dict[str, Any],
|
|
53
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
54
|
+
) -> httpx.Response:
|
|
55
|
+
"""Make POST request with connection pooling.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
url: Full URL to POST to.
|
|
59
|
+
headers: HTTP headers including authentication.
|
|
60
|
+
json_body: JSON request body.
|
|
61
|
+
timeout: Request timeout in seconds.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
HTTP response from the server.
|
|
65
|
+
|
|
66
|
+
Raises:
|
|
67
|
+
httpx.HTTPError: On network or timeout errors.
|
|
68
|
+
"""
|
|
69
|
+
client = await self._get_client()
|
|
70
|
+
return await client.post(
|
|
71
|
+
url,
|
|
72
|
+
headers=headers,
|
|
73
|
+
json=json_body,
|
|
74
|
+
timeout=timeout,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
async def get(
|
|
78
|
+
self,
|
|
79
|
+
url: str,
|
|
80
|
+
headers: dict[str, str] | None = None,
|
|
81
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
82
|
+
follow_redirects: bool = True,
|
|
83
|
+
) -> httpx.Response:
|
|
84
|
+
"""Make GET request with connection pooling.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
url: Full URL to GET.
|
|
88
|
+
headers: HTTP headers including authentication (optional).
|
|
89
|
+
timeout: Request timeout in seconds.
|
|
90
|
+
follow_redirects: Whether to follow HTTP redirects (default: True).
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
HTTP response from the server.
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
httpx.HTTPError: On network or timeout errors.
|
|
97
|
+
"""
|
|
98
|
+
client = await self._get_client()
|
|
99
|
+
return await client.get(
|
|
100
|
+
url,
|
|
101
|
+
headers=headers or {},
|
|
102
|
+
timeout=timeout,
|
|
103
|
+
follow_redirects=follow_redirects,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
async def stream_post(
|
|
107
|
+
self,
|
|
108
|
+
url: str,
|
|
109
|
+
headers: dict[str, str],
|
|
110
|
+
json_body: dict[str, Any],
|
|
111
|
+
timeout: float = DEFAULT_TIMEOUT,
|
|
112
|
+
) -> AsyncIterator[dict[str, Any]]:
|
|
113
|
+
"""Stream POST request using Server-Sent Events.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
url: API endpoint URL.
|
|
117
|
+
headers: HTTP headers (including authentication).
|
|
118
|
+
json_body: JSON request body.
|
|
119
|
+
timeout: Timeout in seconds (default: DEFAULT_TIMEOUT).
|
|
120
|
+
|
|
121
|
+
Yields:
|
|
122
|
+
Parsed JSON events from SSE stream.
|
|
123
|
+
"""
|
|
124
|
+
client = await self._get_client()
|
|
125
|
+
|
|
126
|
+
async with aconnect_sse(
|
|
127
|
+
client,
|
|
128
|
+
"POST",
|
|
129
|
+
url,
|
|
130
|
+
json=json_body,
|
|
131
|
+
headers=headers,
|
|
132
|
+
timeout=timeout,
|
|
133
|
+
) as event_source:
|
|
134
|
+
async for sse in event_source.aiter_sse():
|
|
135
|
+
try:
|
|
136
|
+
yield json.loads(sse.data)
|
|
137
|
+
except json.JSONDecodeError:
|
|
138
|
+
continue # Skip non-JSON control messages (provider-agnostic)
|
|
139
|
+
|
|
140
|
+
async def aclose(self) -> None:
|
|
141
|
+
"""Close HTTP client and cleanup all connections."""
|
|
142
|
+
if self._client is not None:
|
|
143
|
+
await self._client.aclose()
|
|
144
|
+
self._client = None
|
|
145
|
+
|
|
146
|
+
async def __aenter__(self) -> "HTTPClient":
|
|
147
|
+
"""Enter async context manager."""
|
|
148
|
+
return self
|
|
149
|
+
|
|
150
|
+
async def __aexit__(self, *args: Any) -> None: # noqa: ANN401
|
|
151
|
+
"""Exit async context manager and cleanup connections."""
|
|
152
|
+
await self.aclose()
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# Module-level registry of shared HTTPClient instances
|
|
156
|
+
_http_clients: dict[tuple[Provider, Capability], HTTPClient] = {}
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def get_http_client(provider: Provider, capability: Capability) -> HTTPClient:
|
|
160
|
+
"""Get or create shared HTTP client for provider and capability combination.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
provider: The AI provider.
|
|
164
|
+
capability: The capability being used.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
Shared HTTPClient instance for this provider and capability.
|
|
168
|
+
"""
|
|
169
|
+
key = (provider, capability)
|
|
170
|
+
|
|
171
|
+
if key not in _http_clients:
|
|
172
|
+
_http_clients[key] = HTTPClient()
|
|
173
|
+
|
|
174
|
+
return _http_clients[key]
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
async def close_all_http_clients() -> None:
|
|
178
|
+
"""Close all HTTP clients gracefully and clear registry."""
|
|
179
|
+
for key, client in list(_http_clients.items()):
|
|
180
|
+
try:
|
|
181
|
+
await client.aclose()
|
|
182
|
+
except Exception as e:
|
|
183
|
+
logger.warning(f"Failed to close HTTP client for {key}: {e}")
|
|
184
|
+
|
|
185
|
+
_http_clients.clear()
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def clear_http_clients() -> None:
|
|
189
|
+
"""Clear HTTP client registry without closing connections."""
|
|
190
|
+
_http_clients.clear()
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
__all__ = [
|
|
194
|
+
"DEFAULT_TIMEOUT",
|
|
195
|
+
"MAX_CONNECTIONS",
|
|
196
|
+
"MAX_KEEPALIVE_CONNECTIONS",
|
|
197
|
+
"HTTPClient",
|
|
198
|
+
"clear_http_clients",
|
|
199
|
+
"close_all_http_clients",
|
|
200
|
+
"get_http_client",
|
|
201
|
+
]
|
celeste/io.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
"""Input and output types for generation operations."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Input(BaseModel):
|
|
9
|
+
"""Base class for capability-specific input types."""
|
|
10
|
+
|
|
11
|
+
pass
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FinishReason(BaseModel):
|
|
15
|
+
"""Base class for capability-specific finish reasons (used in streaming chunks)."""
|
|
16
|
+
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Usage(BaseModel):
|
|
21
|
+
"""Base class for capability-specific usage metrics."""
|
|
22
|
+
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Output[Content](BaseModel):
|
|
27
|
+
"""Base output class with generic content type."""
|
|
28
|
+
|
|
29
|
+
content: Content
|
|
30
|
+
usage: Usage = Field(default_factory=Usage)
|
|
31
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Chunk[Content](BaseModel):
|
|
35
|
+
"""Incremental chunk from streaming response with generic content type."""
|
|
36
|
+
|
|
37
|
+
content: Content
|
|
38
|
+
finish_reason: FinishReason | None = None
|
|
39
|
+
usage: Usage | None = None
|
|
40
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
__all__ = ["Chunk", "FinishReason", "Input", "Output", "Usage"]
|
celeste/mime_types.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""MIME type enumerations for Celeste."""
|
|
2
|
+
|
|
3
|
+
from enum import StrEnum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MimeType(StrEnum):
|
|
7
|
+
"""Base class for all MIME types."""
|
|
8
|
+
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ImageMimeType(MimeType):
|
|
13
|
+
"""Standard MIME types for images."""
|
|
14
|
+
|
|
15
|
+
PNG = "image/png"
|
|
16
|
+
JPEG = "image/jpeg"
|
|
17
|
+
WEBP = "image/webp"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class VideoMimeType(MimeType):
|
|
21
|
+
"""Standard MIME types for videos."""
|
|
22
|
+
|
|
23
|
+
MP4 = "video/mp4"
|
|
24
|
+
AVI = "video/x-msvideo"
|
|
25
|
+
MOV = "video/quicktime"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AudioMimeType(MimeType):
|
|
29
|
+
"""Standard MIME types for audio."""
|
|
30
|
+
|
|
31
|
+
MP3 = "audio/mpeg"
|
|
32
|
+
WAV = "audio/wav"
|
|
33
|
+
OGG = "audio/ogg"
|
|
34
|
+
WEBM = "audio/webm"
|
|
35
|
+
AAC = "audio/aac"
|
|
36
|
+
FLAC = "audio/flac"
|
|
37
|
+
AIFF = "audio/aiff"
|
|
38
|
+
M4A = "audio/mp4"
|
|
39
|
+
WMA = "audio/x-ms-wma"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
__all__ = [
|
|
43
|
+
"AudioMimeType",
|
|
44
|
+
"ImageMimeType",
|
|
45
|
+
"MimeType",
|
|
46
|
+
"VideoMimeType",
|
|
47
|
+
]
|
celeste/models.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""Models and model registry for Celeste."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, Field
|
|
4
|
+
|
|
5
|
+
from celeste.constraints import Constraint
|
|
6
|
+
from celeste.core import Capability, Provider
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Model(BaseModel):
|
|
10
|
+
"""Represents an AI model with its capabilities and metadata."""
|
|
11
|
+
|
|
12
|
+
id: str
|
|
13
|
+
provider: Provider
|
|
14
|
+
capabilities: set[Capability] = Field(default_factory=set)
|
|
15
|
+
display_name: str
|
|
16
|
+
parameter_constraints: dict[str, Constraint] = Field(default_factory=dict)
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def supported_parameters(self) -> set[str]:
|
|
20
|
+
"""Compute supported parameter names from parameter_constraints."""
|
|
21
|
+
return set(self.parameter_constraints.keys())
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# Module-level registry mapping (model_id, provider) to model
|
|
25
|
+
_models: dict[tuple[str, Provider], Model] = {}
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def register_models(models: Model | list[Model]) -> None:
|
|
29
|
+
"""Register one or more models in the global registry.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
models: Single Model instance or list of Models to register.
|
|
33
|
+
Each model is indexed by (model_id, provider) tuple.
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
ValueError: If a model with the same (id, provider) is already registered.
|
|
37
|
+
"""
|
|
38
|
+
if isinstance(models, Model):
|
|
39
|
+
models = [models]
|
|
40
|
+
|
|
41
|
+
for model in models:
|
|
42
|
+
key = (model.id, model.provider)
|
|
43
|
+
if key in _models:
|
|
44
|
+
msg = f"Model '{model.id}' for provider {model.provider.value} is already registered"
|
|
45
|
+
raise ValueError(msg)
|
|
46
|
+
_models[key] = model
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_model(model_id: str, provider: Provider) -> Model | None:
|
|
50
|
+
"""Get a registered model by ID and provider.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
model_id: The model identifier.
|
|
54
|
+
provider: The provider that owns the model.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Model instance if found, None otherwise.
|
|
58
|
+
"""
|
|
59
|
+
return _models.get((model_id, provider))
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def list_models(
|
|
63
|
+
provider: Provider | None = None,
|
|
64
|
+
capability: Capability | None = None,
|
|
65
|
+
) -> list[Model]:
|
|
66
|
+
"""List all registered models, optionally filtered by provider and/or capability.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
provider: Optional provider filter. If provided, only models from this provider are returned.
|
|
70
|
+
capability: Optional capability filter. If provided, only models supporting this capability are returned.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
List of Model instances matching the filters.
|
|
74
|
+
"""
|
|
75
|
+
filtered = list(_models.values())
|
|
76
|
+
|
|
77
|
+
if provider is not None:
|
|
78
|
+
filtered = [m for m in filtered if m.provider == provider]
|
|
79
|
+
|
|
80
|
+
if capability is not None:
|
|
81
|
+
filtered = [m for m in filtered if capability in m.capabilities]
|
|
82
|
+
|
|
83
|
+
return filtered
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def clear() -> None:
|
|
87
|
+
"""Clear all registered models from the registry."""
|
|
88
|
+
_models.clear()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
__all__ = ["Model", "clear", "get_model", "list_models", "register_models"]
|
celeste/parameters.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Parameter system for Celeste."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from enum import StrEnum
|
|
5
|
+
from typing import Any, TypedDict
|
|
6
|
+
|
|
7
|
+
from celeste.models import Model
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Parameters(TypedDict, total=False):
|
|
11
|
+
"""Base parameters for all capabilities."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ParameterMapper(ABC):
|
|
15
|
+
"""Base class for provider-specific parameter transformation."""
|
|
16
|
+
|
|
17
|
+
name: StrEnum
|
|
18
|
+
"""Parameter name matching capability TypedDict key. Must be StrEnum for type safety."""
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def map(self, request: dict[str, Any], value: Any, model: Model) -> dict[str, Any]: # noqa: ANN401
|
|
22
|
+
"""Transform parameter value into provider's request structure.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
request: Provider request dict.
|
|
26
|
+
value: Parameter value.
|
|
27
|
+
model: Model instance containing parameter_constraints for validation.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Updated request dict.
|
|
31
|
+
"""
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
def parse_output(self, content: object, value: object | None) -> object:
|
|
35
|
+
"""Optionally transform parsed content based on parameter value (default: return unchanged)."""
|
|
36
|
+
return content
|
|
37
|
+
|
|
38
|
+
def _validate_value(self, value: Any, model: Model) -> Any: # noqa: ANN401
|
|
39
|
+
"""Validate parameter value using model constraint, raising ValueError if no constraint exists."""
|
|
40
|
+
if value is None:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
constraint = model.parameter_constraints.get(self.name)
|
|
44
|
+
if constraint is None:
|
|
45
|
+
msg = f"Parameter {self.name.value} is not supported by model {model.id}"
|
|
46
|
+
raise ValueError(msg)
|
|
47
|
+
|
|
48
|
+
return constraint(value)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
__all__ = ["ParameterMapper", "Parameters"]
|
celeste/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Marker file for PEP 561 - this package supports type checking
|
celeste/streaming.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
"""Streaming support for Celeste."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
from types import TracebackType
|
|
6
|
+
from typing import Any, Self
|
|
7
|
+
|
|
8
|
+
from celeste.io import Chunk, Output
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Stream[Out: Output](ABC):
|
|
12
|
+
"""Async iterator wrapper providing final Output access after stream exhaustion."""
|
|
13
|
+
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
sse_iterator: AsyncIterator[dict[str, Any]],
|
|
17
|
+
) -> None:
|
|
18
|
+
"""Initialize stream with SSE iterator."""
|
|
19
|
+
self._sse_iterator = sse_iterator
|
|
20
|
+
self._chunks: list[Chunk] = []
|
|
21
|
+
self._closed = False
|
|
22
|
+
self._output: Out | None = None
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None:
|
|
26
|
+
"""Parse SSE event into Chunk (returns None to filter lifecycle events)."""
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def _parse_output(self, chunks: list[Chunk]) -> Out:
|
|
31
|
+
"""Parse final Output from accumulated chunks."""
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
def __repr__(self) -> str:
|
|
35
|
+
"""Developer-friendly representation showing stream state."""
|
|
36
|
+
if self._output:
|
|
37
|
+
state = "done"
|
|
38
|
+
elif self._closed:
|
|
39
|
+
state = "closed"
|
|
40
|
+
elif self._chunks:
|
|
41
|
+
state = "streaming"
|
|
42
|
+
else:
|
|
43
|
+
state = "idle"
|
|
44
|
+
|
|
45
|
+
chunks = f", {len(self._chunks)} chunks" if self._chunks else ""
|
|
46
|
+
return f"<{self.__class__.__name__}: {state}{chunks}>"
|
|
47
|
+
|
|
48
|
+
# AsyncIterator protocol
|
|
49
|
+
def __aiter__(self) -> Self:
|
|
50
|
+
"""Return self as async iterator."""
|
|
51
|
+
return self
|
|
52
|
+
|
|
53
|
+
async def __anext__(self) -> Chunk:
|
|
54
|
+
"""Yield next chunk from stream."""
|
|
55
|
+
if self._closed:
|
|
56
|
+
raise StopAsyncIteration
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
async for event in self._sse_iterator:
|
|
60
|
+
chunk = self._parse_chunk(event)
|
|
61
|
+
if chunk is not None:
|
|
62
|
+
self._chunks.append(chunk)
|
|
63
|
+
return chunk
|
|
64
|
+
|
|
65
|
+
# Stream exhausted - validate and parse final output
|
|
66
|
+
if not self._chunks:
|
|
67
|
+
msg = "Stream completed but no chunks were produced"
|
|
68
|
+
raise RuntimeError(msg)
|
|
69
|
+
|
|
70
|
+
self._output = self._parse_output(self._chunks)
|
|
71
|
+
except Exception:
|
|
72
|
+
await self.aclose()
|
|
73
|
+
raise
|
|
74
|
+
|
|
75
|
+
# Only reached on successful exhaustion
|
|
76
|
+
await self.aclose()
|
|
77
|
+
raise StopAsyncIteration
|
|
78
|
+
|
|
79
|
+
# AsyncContextManager protocol
|
|
80
|
+
async def __aenter__(self) -> Self:
|
|
81
|
+
"""Enter context - return self for iteration."""
|
|
82
|
+
return self
|
|
83
|
+
|
|
84
|
+
async def __aexit__(
|
|
85
|
+
self,
|
|
86
|
+
exc_type: type[BaseException] | None,
|
|
87
|
+
exc_val: BaseException | None,
|
|
88
|
+
exc_tb: TracebackType | None,
|
|
89
|
+
) -> bool:
|
|
90
|
+
"""Exit context - ensure cleanup even on exception."""
|
|
91
|
+
await self.aclose()
|
|
92
|
+
return False # Propagate exceptions
|
|
93
|
+
|
|
94
|
+
@property
|
|
95
|
+
def output(self) -> Out:
|
|
96
|
+
"""Access final Output after stream exhaustion (raises RuntimeError if not ready)."""
|
|
97
|
+
if self._output is None:
|
|
98
|
+
msg = "Stream not exhausted. Consume all chunks before accessing .output"
|
|
99
|
+
raise RuntimeError(msg)
|
|
100
|
+
return self._output
|
|
101
|
+
|
|
102
|
+
async def aclose(self) -> None:
|
|
103
|
+
"""Explicitly close stream and cleanup resources."""
|
|
104
|
+
if self._closed:
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
self._closed = True
|
|
108
|
+
|
|
109
|
+
# Close SSE iterator (httpx-sse connection)
|
|
110
|
+
if hasattr(self._sse_iterator, "aclose"):
|
|
111
|
+
await self._sse_iterator.aclose()
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
__all__ = ["Stream"]
|