celeste-ai 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of celeste-ai might be problematic. Click here for more details.

celeste/http.py ADDED
@@ -0,0 +1,201 @@
1
+ """HTTP client with persistent connection pooling for AI provider APIs."""
2
+
3
+ import json
4
+ import logging
5
+ from collections.abc import AsyncIterator
6
+ from typing import Any
7
+
8
+ import httpx
9
+ from httpx_sse import aconnect_sse
10
+
11
+ from celeste.core import Capability, Provider
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ MAX_CONNECTIONS = 20
16
+ MAX_KEEPALIVE_CONNECTIONS = 10
17
+ DEFAULT_TIMEOUT = 60.0
18
+
19
+
20
+ class HTTPClient:
21
+ """Async HTTP client with persistent connection pooling."""
22
+
23
+ def __init__(
24
+ self,
25
+ max_connections: int = MAX_CONNECTIONS,
26
+ max_keepalive_connections: int = MAX_KEEPALIVE_CONNECTIONS,
27
+ ) -> None:
28
+ """Initialize HTTP client with connection pool limits.
29
+
30
+ Args:
31
+ max_connections: Maximum total connections in pool.
32
+ max_keepalive_connections: Maximum idle keepalive connections.
33
+ """
34
+ self._client: httpx.AsyncClient | None = None
35
+ self._max_connections = max_connections
36
+ self._max_keepalive_connections = max_keepalive_connections
37
+
38
+ async def _get_client(self) -> httpx.AsyncClient:
39
+ """Get or create httpx.AsyncClient with connection pooling."""
40
+ if self._client is None:
41
+ limits = httpx.Limits(
42
+ max_connections=self._max_connections,
43
+ max_keepalive_connections=self._max_keepalive_connections,
44
+ )
45
+ self._client = httpx.AsyncClient(limits=limits) # nosec B113
46
+ return self._client
47
+
48
+ async def post(
49
+ self,
50
+ url: str,
51
+ headers: dict[str, str],
52
+ json_body: dict[str, Any],
53
+ timeout: float = DEFAULT_TIMEOUT,
54
+ ) -> httpx.Response:
55
+ """Make POST request with connection pooling.
56
+
57
+ Args:
58
+ url: Full URL to POST to.
59
+ headers: HTTP headers including authentication.
60
+ json_body: JSON request body.
61
+ timeout: Request timeout in seconds.
62
+
63
+ Returns:
64
+ HTTP response from the server.
65
+
66
+ Raises:
67
+ httpx.HTTPError: On network or timeout errors.
68
+ """
69
+ client = await self._get_client()
70
+ return await client.post(
71
+ url,
72
+ headers=headers,
73
+ json=json_body,
74
+ timeout=timeout,
75
+ )
76
+
77
+ async def get(
78
+ self,
79
+ url: str,
80
+ headers: dict[str, str] | None = None,
81
+ timeout: float = DEFAULT_TIMEOUT,
82
+ follow_redirects: bool = True,
83
+ ) -> httpx.Response:
84
+ """Make GET request with connection pooling.
85
+
86
+ Args:
87
+ url: Full URL to GET.
88
+ headers: HTTP headers including authentication (optional).
89
+ timeout: Request timeout in seconds.
90
+ follow_redirects: Whether to follow HTTP redirects (default: True).
91
+
92
+ Returns:
93
+ HTTP response from the server.
94
+
95
+ Raises:
96
+ httpx.HTTPError: On network or timeout errors.
97
+ """
98
+ client = await self._get_client()
99
+ return await client.get(
100
+ url,
101
+ headers=headers or {},
102
+ timeout=timeout,
103
+ follow_redirects=follow_redirects,
104
+ )
105
+
106
+ async def stream_post(
107
+ self,
108
+ url: str,
109
+ headers: dict[str, str],
110
+ json_body: dict[str, Any],
111
+ timeout: float = DEFAULT_TIMEOUT,
112
+ ) -> AsyncIterator[dict[str, Any]]:
113
+ """Stream POST request using Server-Sent Events.
114
+
115
+ Args:
116
+ url: API endpoint URL.
117
+ headers: HTTP headers (including authentication).
118
+ json_body: JSON request body.
119
+ timeout: Timeout in seconds (default: DEFAULT_TIMEOUT).
120
+
121
+ Yields:
122
+ Parsed JSON events from SSE stream.
123
+ """
124
+ client = await self._get_client()
125
+
126
+ async with aconnect_sse(
127
+ client,
128
+ "POST",
129
+ url,
130
+ json=json_body,
131
+ headers=headers,
132
+ timeout=timeout,
133
+ ) as event_source:
134
+ async for sse in event_source.aiter_sse():
135
+ try:
136
+ yield json.loads(sse.data)
137
+ except json.JSONDecodeError:
138
+ continue # Skip non-JSON control messages (provider-agnostic)
139
+
140
+ async def aclose(self) -> None:
141
+ """Close HTTP client and cleanup all connections."""
142
+ if self._client is not None:
143
+ await self._client.aclose()
144
+ self._client = None
145
+
146
+ async def __aenter__(self) -> "HTTPClient":
147
+ """Enter async context manager."""
148
+ return self
149
+
150
+ async def __aexit__(self, *args: Any) -> None: # noqa: ANN401
151
+ """Exit async context manager and cleanup connections."""
152
+ await self.aclose()
153
+
154
+
155
+ # Module-level registry of shared HTTPClient instances
156
+ _http_clients: dict[tuple[Provider, Capability], HTTPClient] = {}
157
+
158
+
159
+ def get_http_client(provider: Provider, capability: Capability) -> HTTPClient:
160
+ """Get or create shared HTTP client for provider and capability combination.
161
+
162
+ Args:
163
+ provider: The AI provider.
164
+ capability: The capability being used.
165
+
166
+ Returns:
167
+ Shared HTTPClient instance for this provider and capability.
168
+ """
169
+ key = (provider, capability)
170
+
171
+ if key not in _http_clients:
172
+ _http_clients[key] = HTTPClient()
173
+
174
+ return _http_clients[key]
175
+
176
+
177
+ async def close_all_http_clients() -> None:
178
+ """Close all HTTP clients gracefully and clear registry."""
179
+ for key, client in list(_http_clients.items()):
180
+ try:
181
+ await client.aclose()
182
+ except Exception as e:
183
+ logger.warning(f"Failed to close HTTP client for {key}: {e}")
184
+
185
+ _http_clients.clear()
186
+
187
+
188
+ def clear_http_clients() -> None:
189
+ """Clear HTTP client registry without closing connections."""
190
+ _http_clients.clear()
191
+
192
+
193
+ __all__ = [
194
+ "DEFAULT_TIMEOUT",
195
+ "MAX_CONNECTIONS",
196
+ "MAX_KEEPALIVE_CONNECTIONS",
197
+ "HTTPClient",
198
+ "clear_http_clients",
199
+ "close_all_http_clients",
200
+ "get_http_client",
201
+ ]
celeste/io.py ADDED
@@ -0,0 +1,43 @@
1
+ """Input and output types for generation operations."""
2
+
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class Input(BaseModel):
9
+ """Base class for capability-specific input types."""
10
+
11
+ pass
12
+
13
+
14
+ class FinishReason(BaseModel):
15
+ """Base class for capability-specific finish reasons (used in streaming chunks)."""
16
+
17
+ pass
18
+
19
+
20
+ class Usage(BaseModel):
21
+ """Base class for capability-specific usage metrics."""
22
+
23
+ pass
24
+
25
+
26
+ class Output[Content](BaseModel):
27
+ """Base output class with generic content type."""
28
+
29
+ content: Content
30
+ usage: Usage = Field(default_factory=Usage)
31
+ metadata: dict[str, Any] = Field(default_factory=dict)
32
+
33
+
34
+ class Chunk[Content](BaseModel):
35
+ """Incremental chunk from streaming response with generic content type."""
36
+
37
+ content: Content
38
+ finish_reason: FinishReason | None = None
39
+ usage: Usage | None = None
40
+ metadata: dict[str, Any] = Field(default_factory=dict)
41
+
42
+
43
+ __all__ = ["Chunk", "FinishReason", "Input", "Output", "Usage"]
celeste/mime_types.py ADDED
@@ -0,0 +1,47 @@
1
+ """MIME type enumerations for Celeste."""
2
+
3
+ from enum import StrEnum
4
+
5
+
6
+ class MimeType(StrEnum):
7
+ """Base class for all MIME types."""
8
+
9
+ pass
10
+
11
+
12
+ class ImageMimeType(MimeType):
13
+ """Standard MIME types for images."""
14
+
15
+ PNG = "image/png"
16
+ JPEG = "image/jpeg"
17
+ WEBP = "image/webp"
18
+
19
+
20
+ class VideoMimeType(MimeType):
21
+ """Standard MIME types for videos."""
22
+
23
+ MP4 = "video/mp4"
24
+ AVI = "video/x-msvideo"
25
+ MOV = "video/quicktime"
26
+
27
+
28
+ class AudioMimeType(MimeType):
29
+ """Standard MIME types for audio."""
30
+
31
+ MP3 = "audio/mpeg"
32
+ WAV = "audio/wav"
33
+ OGG = "audio/ogg"
34
+ WEBM = "audio/webm"
35
+ AAC = "audio/aac"
36
+ FLAC = "audio/flac"
37
+ AIFF = "audio/aiff"
38
+ M4A = "audio/mp4"
39
+ WMA = "audio/x-ms-wma"
40
+
41
+
42
+ __all__ = [
43
+ "AudioMimeType",
44
+ "ImageMimeType",
45
+ "MimeType",
46
+ "VideoMimeType",
47
+ ]
celeste/models.py ADDED
@@ -0,0 +1,91 @@
1
+ """Models and model registry for Celeste."""
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+ from celeste.constraints import Constraint
6
+ from celeste.core import Capability, Provider
7
+
8
+
9
+ class Model(BaseModel):
10
+ """Represents an AI model with its capabilities and metadata."""
11
+
12
+ id: str
13
+ provider: Provider
14
+ capabilities: set[Capability] = Field(default_factory=set)
15
+ display_name: str
16
+ parameter_constraints: dict[str, Constraint] = Field(default_factory=dict)
17
+
18
+ @property
19
+ def supported_parameters(self) -> set[str]:
20
+ """Compute supported parameter names from parameter_constraints."""
21
+ return set(self.parameter_constraints.keys())
22
+
23
+
24
+ # Module-level registry mapping (model_id, provider) to model
25
+ _models: dict[tuple[str, Provider], Model] = {}
26
+
27
+
28
+ def register_models(models: Model | list[Model]) -> None:
29
+ """Register one or more models in the global registry.
30
+
31
+ Args:
32
+ models: Single Model instance or list of Models to register.
33
+ Each model is indexed by (model_id, provider) tuple.
34
+
35
+ Raises:
36
+ ValueError: If a model with the same (id, provider) is already registered.
37
+ """
38
+ if isinstance(models, Model):
39
+ models = [models]
40
+
41
+ for model in models:
42
+ key = (model.id, model.provider)
43
+ if key in _models:
44
+ msg = f"Model '{model.id}' for provider {model.provider.value} is already registered"
45
+ raise ValueError(msg)
46
+ _models[key] = model
47
+
48
+
49
+ def get_model(model_id: str, provider: Provider) -> Model | None:
50
+ """Get a registered model by ID and provider.
51
+
52
+ Args:
53
+ model_id: The model identifier.
54
+ provider: The provider that owns the model.
55
+
56
+ Returns:
57
+ Model instance if found, None otherwise.
58
+ """
59
+ return _models.get((model_id, provider))
60
+
61
+
62
+ def list_models(
63
+ provider: Provider | None = None,
64
+ capability: Capability | None = None,
65
+ ) -> list[Model]:
66
+ """List all registered models, optionally filtered by provider and/or capability.
67
+
68
+ Args:
69
+ provider: Optional provider filter. If provided, only models from this provider are returned.
70
+ capability: Optional capability filter. If provided, only models supporting this capability are returned.
71
+
72
+ Returns:
73
+ List of Model instances matching the filters.
74
+ """
75
+ filtered = list(_models.values())
76
+
77
+ if provider is not None:
78
+ filtered = [m for m in filtered if m.provider == provider]
79
+
80
+ if capability is not None:
81
+ filtered = [m for m in filtered if capability in m.capabilities]
82
+
83
+ return filtered
84
+
85
+
86
+ def clear() -> None:
87
+ """Clear all registered models from the registry."""
88
+ _models.clear()
89
+
90
+
91
+ __all__ = ["Model", "clear", "get_model", "list_models", "register_models"]
celeste/parameters.py ADDED
@@ -0,0 +1,51 @@
1
+ """Parameter system for Celeste."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from enum import StrEnum
5
+ from typing import Any, TypedDict
6
+
7
+ from celeste.models import Model
8
+
9
+
10
+ class Parameters(TypedDict, total=False):
11
+ """Base parameters for all capabilities."""
12
+
13
+
14
+ class ParameterMapper(ABC):
15
+ """Base class for provider-specific parameter transformation."""
16
+
17
+ name: StrEnum
18
+ """Parameter name matching capability TypedDict key. Must be StrEnum for type safety."""
19
+
20
+ @abstractmethod
21
+ def map(self, request: dict[str, Any], value: Any, model: Model) -> dict[str, Any]: # noqa: ANN401
22
+ """Transform parameter value into provider's request structure.
23
+
24
+ Args:
25
+ request: Provider request dict.
26
+ value: Parameter value.
27
+ model: Model instance containing parameter_constraints for validation.
28
+
29
+ Returns:
30
+ Updated request dict.
31
+ """
32
+ ...
33
+
34
+ def parse_output(self, content: object, value: object | None) -> object:
35
+ """Optionally transform parsed content based on parameter value (default: return unchanged)."""
36
+ return content
37
+
38
+ def _validate_value(self, value: Any, model: Model) -> Any: # noqa: ANN401
39
+ """Validate parameter value using model constraint, raising ValueError if no constraint exists."""
40
+ if value is None:
41
+ return None
42
+
43
+ constraint = model.parameter_constraints.get(self.name)
44
+ if constraint is None:
45
+ msg = f"Parameter {self.name.value} is not supported by model {model.id}"
46
+ raise ValueError(msg)
47
+
48
+ return constraint(value)
49
+
50
+
51
+ __all__ = ["ParameterMapper", "Parameters"]
celeste/py.typed ADDED
@@ -0,0 +1 @@
1
+ # Marker file for PEP 561 - this package supports type checking
celeste/streaming.py ADDED
@@ -0,0 +1,114 @@
1
+ """Streaming support for Celeste."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from collections.abc import AsyncIterator
5
+ from types import TracebackType
6
+ from typing import Any, Self
7
+
8
+ from celeste.io import Chunk, Output
9
+
10
+
11
+ class Stream[Out: Output](ABC):
12
+ """Async iterator wrapper providing final Output access after stream exhaustion."""
13
+
14
+ def __init__(
15
+ self,
16
+ sse_iterator: AsyncIterator[dict[str, Any]],
17
+ ) -> None:
18
+ """Initialize stream with SSE iterator."""
19
+ self._sse_iterator = sse_iterator
20
+ self._chunks: list[Chunk] = []
21
+ self._closed = False
22
+ self._output: Out | None = None
23
+
24
+ @abstractmethod
25
+ def _parse_chunk(self, event: dict[str, Any]) -> Chunk | None:
26
+ """Parse SSE event into Chunk (returns None to filter lifecycle events)."""
27
+ ...
28
+
29
+ @abstractmethod
30
+ def _parse_output(self, chunks: list[Chunk]) -> Out:
31
+ """Parse final Output from accumulated chunks."""
32
+ ...
33
+
34
+ def __repr__(self) -> str:
35
+ """Developer-friendly representation showing stream state."""
36
+ if self._output:
37
+ state = "done"
38
+ elif self._closed:
39
+ state = "closed"
40
+ elif self._chunks:
41
+ state = "streaming"
42
+ else:
43
+ state = "idle"
44
+
45
+ chunks = f", {len(self._chunks)} chunks" if self._chunks else ""
46
+ return f"<{self.__class__.__name__}: {state}{chunks}>"
47
+
48
+ # AsyncIterator protocol
49
+ def __aiter__(self) -> Self:
50
+ """Return self as async iterator."""
51
+ return self
52
+
53
+ async def __anext__(self) -> Chunk:
54
+ """Yield next chunk from stream."""
55
+ if self._closed:
56
+ raise StopAsyncIteration
57
+
58
+ try:
59
+ async for event in self._sse_iterator:
60
+ chunk = self._parse_chunk(event)
61
+ if chunk is not None:
62
+ self._chunks.append(chunk)
63
+ return chunk
64
+
65
+ # Stream exhausted - validate and parse final output
66
+ if not self._chunks:
67
+ msg = "Stream completed but no chunks were produced"
68
+ raise RuntimeError(msg)
69
+
70
+ self._output = self._parse_output(self._chunks)
71
+ except Exception:
72
+ await self.aclose()
73
+ raise
74
+
75
+ # Only reached on successful exhaustion
76
+ await self.aclose()
77
+ raise StopAsyncIteration
78
+
79
+ # AsyncContextManager protocol
80
+ async def __aenter__(self) -> Self:
81
+ """Enter context - return self for iteration."""
82
+ return self
83
+
84
+ async def __aexit__(
85
+ self,
86
+ exc_type: type[BaseException] | None,
87
+ exc_val: BaseException | None,
88
+ exc_tb: TracebackType | None,
89
+ ) -> bool:
90
+ """Exit context - ensure cleanup even on exception."""
91
+ await self.aclose()
92
+ return False # Propagate exceptions
93
+
94
+ @property
95
+ def output(self) -> Out:
96
+ """Access final Output after stream exhaustion (raises RuntimeError if not ready)."""
97
+ if self._output is None:
98
+ msg = "Stream not exhausted. Consume all chunks before accessing .output"
99
+ raise RuntimeError(msg)
100
+ return self._output
101
+
102
+ async def aclose(self) -> None:
103
+ """Explicitly close stream and cleanup resources."""
104
+ if self._closed:
105
+ return
106
+
107
+ self._closed = True
108
+
109
+ # Close SSE iterator (httpx-sse connection)
110
+ if hasattr(self._sse_iterator, "aclose"):
111
+ await self._sse_iterator.aclose()
112
+
113
+
114
+ __all__ = ["Stream"]