celeste-ai 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of celeste-ai might be problematic. Click here for more details.
- celeste/__init__.py +105 -16
- celeste/artifacts.py +16 -20
- celeste/client.py +150 -0
- celeste/constraints.py +180 -0
- celeste/core.py +27 -10
- celeste/credentials.py +59 -30
- celeste/http.py +201 -0
- celeste/io.py +43 -0
- celeste/mime_types.py +13 -12
- celeste/models.py +91 -0
- celeste/parameters.py +51 -0
- celeste/py.typed +1 -0
- celeste/streaming.py +114 -0
- celeste_ai-0.0.3.dist-info/METADATA +163 -0
- celeste_ai-0.0.3.dist-info/RECORD +17 -0
- celeste_ai-0.0.3.dist-info/licenses/LICENSE +201 -0
- celeste_ai-0.0.2.dist-info/METADATA +0 -73
- celeste_ai-0.0.2.dist-info/RECORD +0 -9
- {celeste_ai-0.0.2.dist-info → celeste_ai-0.0.3.dist-info}/WHEEL +0 -0
celeste/__init__.py
CHANGED
|
@@ -1,27 +1,116 @@
|
|
|
1
|
-
|
|
1
|
+
import importlib.metadata
|
|
2
|
+
import logging
|
|
2
3
|
|
|
3
|
-
|
|
4
|
-
__author__ = "agent-kai"
|
|
4
|
+
from pydantic import SecretStr
|
|
5
5
|
|
|
6
|
-
from celeste.
|
|
6
|
+
from celeste.client import Client, get_client_class, register_client
|
|
7
|
+
from celeste.core import Capability, Parameter, Provider
|
|
8
|
+
from celeste.credentials import credentials
|
|
9
|
+
from celeste.http import HTTPClient, close_all_http_clients
|
|
10
|
+
from celeste.io import Input, Output, Usage
|
|
11
|
+
from celeste.models import Model, get_model, list_models, register_models
|
|
12
|
+
from celeste.parameters import Parameters
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _resolve_model(
|
|
18
|
+
capability: Capability,
|
|
19
|
+
provider: Provider | None,
|
|
20
|
+
model: Model | str | None,
|
|
21
|
+
) -> Model:
|
|
22
|
+
"""Resolve model parameter to Model object (auto-select if None, lookup if string)."""
|
|
23
|
+
if model is None:
|
|
24
|
+
# Auto-select first available model
|
|
25
|
+
models = list_models(provider=provider, capability=capability)
|
|
26
|
+
if not models:
|
|
27
|
+
msg = f"No model found for {capability}"
|
|
28
|
+
raise ValueError(msg)
|
|
29
|
+
return models[0]
|
|
30
|
+
|
|
31
|
+
if isinstance(model, str):
|
|
32
|
+
# String ID requires provider
|
|
33
|
+
if not provider:
|
|
34
|
+
msg = "provider required when model is a string ID"
|
|
35
|
+
raise ValueError(msg)
|
|
36
|
+
found = get_model(model, provider)
|
|
37
|
+
if not found:
|
|
38
|
+
msg = f"Model '{model}' not found for provider {provider}"
|
|
39
|
+
raise ValueError(msg)
|
|
40
|
+
return found
|
|
41
|
+
|
|
42
|
+
return model
|
|
7
43
|
|
|
8
44
|
|
|
9
45
|
def create_client(
|
|
10
46
|
capability: Capability,
|
|
11
|
-
provider: Provider,
|
|
12
|
-
model: str
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
47
|
+
provider: Provider | None = None,
|
|
48
|
+
model: Model | str | None = None,
|
|
49
|
+
api_key: SecretStr | None = None,
|
|
50
|
+
) -> Client:
|
|
51
|
+
"""Create an async client for the specified AI capability.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
capability: The AI capability to use (e.g., TEXT_GENERATION).
|
|
55
|
+
provider: Optional provider. Required if model is a string ID.
|
|
56
|
+
model: Model object, string model ID, or None for auto-selection.
|
|
57
|
+
api_key: Optional SecretStr override. If not specified, loaded from environment.
|
|
16
58
|
|
|
17
|
-
|
|
59
|
+
Returns:
|
|
60
|
+
Configured client instance ready for generation operations.
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
ValueError: If no model found or resolution fails.
|
|
64
|
+
NotImplementedError: If no client registered for capability/provider.
|
|
18
65
|
"""
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
66
|
+
# Resolve model
|
|
67
|
+
resolved_model = _resolve_model(capability, provider, model)
|
|
68
|
+
|
|
69
|
+
# Get client class and credentials
|
|
70
|
+
client_class = get_client_class(capability, resolved_model.provider)
|
|
71
|
+
resolved_key = credentials.get_credentials(
|
|
72
|
+
resolved_model.provider, override_key=api_key
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Create and return client
|
|
76
|
+
return client_class(
|
|
77
|
+
model=resolved_model,
|
|
78
|
+
provider=resolved_model.provider,
|
|
79
|
+
capability=capability,
|
|
80
|
+
api_key=resolved_key,
|
|
23
81
|
)
|
|
24
82
|
|
|
25
83
|
|
|
26
|
-
|
|
27
|
-
|
|
84
|
+
def _load_from_entry_points() -> None:
|
|
85
|
+
"""Load models and clients from installed packages via entry points."""
|
|
86
|
+
entry_points = importlib.metadata.entry_points(group="celeste.packages")
|
|
87
|
+
|
|
88
|
+
for ep in entry_points:
|
|
89
|
+
register_func = ep.load()
|
|
90
|
+
# The function should register models and clients when called
|
|
91
|
+
register_func()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Load from entry points on import
|
|
95
|
+
_load_from_entry_points()
|
|
96
|
+
|
|
97
|
+
# Exports
|
|
98
|
+
__all__ = [
|
|
99
|
+
"Capability",
|
|
100
|
+
"Client",
|
|
101
|
+
"HTTPClient",
|
|
102
|
+
"Input",
|
|
103
|
+
"Model",
|
|
104
|
+
"Output",
|
|
105
|
+
"Parameter",
|
|
106
|
+
"Parameters",
|
|
107
|
+
"Provider",
|
|
108
|
+
"Usage",
|
|
109
|
+
"close_all_http_clients",
|
|
110
|
+
"create_client",
|
|
111
|
+
"get_client_class",
|
|
112
|
+
"get_model",
|
|
113
|
+
"list_models",
|
|
114
|
+
"register_client",
|
|
115
|
+
"register_models",
|
|
116
|
+
]
|
celeste/artifacts.py
CHANGED
|
@@ -1,51 +1,47 @@
|
|
|
1
|
-
"""Unified artifact types for
|
|
1
|
+
"""Unified artifact types for Celeste."""
|
|
2
2
|
|
|
3
|
-
from typing import Any
|
|
3
|
+
from typing import Any
|
|
4
4
|
|
|
5
5
|
from pydantic import BaseModel, Field
|
|
6
6
|
|
|
7
|
-
from celeste.mime_types import AudioMimeType, ImageMimeType, VideoMimeType
|
|
7
|
+
from celeste.mime_types import AudioMimeType, ImageMimeType, MimeType, VideoMimeType
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
class Artifact(BaseModel):
|
|
11
|
-
"""Base class for all media artifacts.
|
|
11
|
+
"""Base class for all media artifacts."""
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
Providers typically populate only one of these fields.
|
|
19
|
-
"""
|
|
20
|
-
|
|
21
|
-
url: Optional[str] = None
|
|
22
|
-
data: Optional[bytes] = None
|
|
23
|
-
path: Optional[str] = None
|
|
24
|
-
mime_type: Optional[str] = None # Standard MIME type for the artifact
|
|
13
|
+
url: str | None = None
|
|
14
|
+
data: bytes | None = None
|
|
15
|
+
path: str | None = None
|
|
16
|
+
mime_type: MimeType | None = None
|
|
25
17
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
26
18
|
|
|
27
19
|
@property
|
|
28
20
|
def has_content(self) -> bool:
|
|
29
21
|
"""Check if artifact has any content."""
|
|
30
|
-
return bool(
|
|
22
|
+
return bool(
|
|
23
|
+
(self.url and self.url.strip())
|
|
24
|
+
or self.data
|
|
25
|
+
or (self.path and self.path.strip())
|
|
26
|
+
)
|
|
31
27
|
|
|
32
28
|
|
|
33
29
|
class ImageArtifact(Artifact):
|
|
34
30
|
"""Image artifact from generation/edit operations."""
|
|
35
31
|
|
|
36
|
-
mime_type:
|
|
32
|
+
mime_type: ImageMimeType | None = None
|
|
37
33
|
|
|
38
34
|
|
|
39
35
|
class VideoArtifact(Artifact):
|
|
40
36
|
"""Video artifact from generation operations."""
|
|
41
37
|
|
|
42
|
-
mime_type:
|
|
38
|
+
mime_type: VideoMimeType | None = None
|
|
43
39
|
|
|
44
40
|
|
|
45
41
|
class AudioArtifact(Artifact):
|
|
46
42
|
"""Audio artifact from TTS/transcription operations."""
|
|
47
43
|
|
|
48
|
-
mime_type:
|
|
44
|
+
mime_type: AudioMimeType | None = None
|
|
49
45
|
|
|
50
46
|
|
|
51
47
|
__all__ = [
|
celeste/client.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Base client and client registry for AI capabilities."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, Unpack
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
|
7
|
+
|
|
8
|
+
from celeste.core import Capability, Provider
|
|
9
|
+
from celeste.http import HTTPClient, get_http_client
|
|
10
|
+
from celeste.io import Input, Output, Usage
|
|
11
|
+
from celeste.models import Model
|
|
12
|
+
from celeste.parameters import ParameterMapper, Parameters
|
|
13
|
+
from celeste.streaming import Stream
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Client[In: Input, Out: Output](ABC, BaseModel):
|
|
17
|
+
"""Base class for all capability-specific clients."""
|
|
18
|
+
|
|
19
|
+
model_config = ConfigDict(from_attributes=True)
|
|
20
|
+
|
|
21
|
+
model: Model
|
|
22
|
+
provider: Provider
|
|
23
|
+
capability: Capability
|
|
24
|
+
api_key: SecretStr = Field(exclude=True)
|
|
25
|
+
|
|
26
|
+
def model_post_init(self, __context: object) -> None:
|
|
27
|
+
"""Validate capability compatibility."""
|
|
28
|
+
if self.capability not in self.model.capabilities:
|
|
29
|
+
raise ValueError(
|
|
30
|
+
f"Model '{self.model.id}' does not support capability {self.capability.value}"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def http_client(self) -> HTTPClient:
|
|
35
|
+
"""Shared HTTP client with connection pooling for this provider."""
|
|
36
|
+
return get_http_client(self.provider, self.capability)
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def parameter_mappers(cls) -> list[ParameterMapper]:
|
|
41
|
+
"""Provider-specific parameter mappers."""
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def _init_request(self, inputs: In) -> dict[str, Any]:
|
|
46
|
+
"""Initialize provider-specific base request structure."""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def _parse_usage(self, response_data: dict[str, Any]) -> Usage:
|
|
51
|
+
"""Parse usage information from provider response."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def _parse_content(
|
|
56
|
+
self, response_data: dict[str, Any], **parameters: Unpack[Parameters]
|
|
57
|
+
) -> object:
|
|
58
|
+
"""Parse content from provider response."""
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
def _transform_output(
|
|
62
|
+
self, content: object, **parameters: Unpack[Parameters]
|
|
63
|
+
) -> object:
|
|
64
|
+
"""Transform content using parameter mapper output transformations."""
|
|
65
|
+
for mapper in self.parameter_mappers():
|
|
66
|
+
value = parameters.get(mapper.name)
|
|
67
|
+
if value is not None:
|
|
68
|
+
content = mapper.parse_output(content, value)
|
|
69
|
+
return content
|
|
70
|
+
|
|
71
|
+
def _build_request(
|
|
72
|
+
self, inputs: In, **parameters: Unpack[Parameters]
|
|
73
|
+
) -> dict[str, Any]:
|
|
74
|
+
"""Build complete request by combining base request with parameters."""
|
|
75
|
+
request = self._init_request(inputs)
|
|
76
|
+
|
|
77
|
+
# Apply parameter mappers from registry
|
|
78
|
+
for mapper in self.parameter_mappers():
|
|
79
|
+
value = parameters.get(mapper.name)
|
|
80
|
+
request = mapper.map(request, value, self.model)
|
|
81
|
+
|
|
82
|
+
return request
|
|
83
|
+
|
|
84
|
+
def stream(self, *args: Any, **parameters: Unpack[Parameters]) -> Stream[Out]: # noqa: ANN401
|
|
85
|
+
"""Stream content - signature varies by capability.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
*args: Capability-specific positional arguments (same as generate).
|
|
89
|
+
**parameters: Capability-specific keyword arguments (same as generate).
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Stream yielding chunks and providing final Output.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
NotImplementedError: If capability doesn't support streaming.
|
|
96
|
+
"""
|
|
97
|
+
msg = f"Streaming not supported for {self.capability.value} with provider {self.provider.value}"
|
|
98
|
+
raise NotImplementedError(msg)
|
|
99
|
+
|
|
100
|
+
@abstractmethod
|
|
101
|
+
async def generate(self, *args: Any, **parameters: Unpack[Parameters]) -> Out: # noqa: ANN401
|
|
102
|
+
"""Generate content - signature varies by capability.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
*args: Capability-specific positional arguments (prompt, text, image_url, etc.).
|
|
106
|
+
**parameters: Capability-specific keyword arguments (temperature, max_tokens, etc.).
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Output of the parameterized type (e.g., TextGenerationOutput).
|
|
110
|
+
"""
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
_clients: dict[tuple[Capability, Provider], type[Client]] = {}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def register_client(
|
|
118
|
+
capability: Capability, provider: Provider, client_class: type[Client]
|
|
119
|
+
) -> None:
|
|
120
|
+
"""Register a provider-specific client class for a capability.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
capability: The capability this client implements.
|
|
124
|
+
provider: The provider this client uses.
|
|
125
|
+
client_class: The client class to register.
|
|
126
|
+
"""
|
|
127
|
+
_clients[(capability, provider)] = client_class
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_client_class(capability: Capability, provider: Provider) -> type[Client]:
|
|
131
|
+
"""Get the registered client class for a capability and provider.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
capability: The capability to get a client for.
|
|
135
|
+
provider: The provider to use.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
The registered client class.
|
|
139
|
+
|
|
140
|
+
Raises:
|
|
141
|
+
NotImplementedError: If no client is registered for this capability/provider.
|
|
142
|
+
"""
|
|
143
|
+
if (capability, provider) not in _clients:
|
|
144
|
+
raise NotImplementedError(
|
|
145
|
+
f"No client registered for {capability.value} with provider {provider.value}"
|
|
146
|
+
)
|
|
147
|
+
return _clients[(capability, provider)]
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
__all__ = ["Client", "get_client_class", "register_client"]
|
celeste/constraints.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""Constraint models for parameter validation."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import re
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any, get_args, get_origin
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Constraint(BaseModel, ABC):
|
|
12
|
+
"""Base constraint for parameter validation."""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def __call__(self, value: Any) -> Any: # noqa: ANN401
|
|
16
|
+
"""Validate value against constraint and return validated value."""
|
|
17
|
+
...
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Choice[T](Constraint):
|
|
21
|
+
"""Choice constraint - value must be one of the provided options."""
|
|
22
|
+
|
|
23
|
+
options: list[T] = Field(min_length=1)
|
|
24
|
+
|
|
25
|
+
def __call__(self, value: T) -> T:
|
|
26
|
+
"""Validate value is in options."""
|
|
27
|
+
if value not in self.options:
|
|
28
|
+
msg = f"Must be one of {self.options}, got {value!r}"
|
|
29
|
+
raise ValueError(msg)
|
|
30
|
+
return value
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Range(Constraint):
|
|
34
|
+
"""Range constraint - value must be within min/max bounds.
|
|
35
|
+
|
|
36
|
+
If step is provided, value must be at min + (n * step) for some integer n.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
min: float | int
|
|
40
|
+
max: float | int
|
|
41
|
+
step: float | None = None
|
|
42
|
+
|
|
43
|
+
def __call__(self, value: float | int) -> float | int:
|
|
44
|
+
"""Validate value is within range and matches step increment."""
|
|
45
|
+
if not isinstance(value, (int, float)):
|
|
46
|
+
msg = f"Must be numeric, got {type(value).__name__}"
|
|
47
|
+
raise TypeError(msg)
|
|
48
|
+
|
|
49
|
+
if not self.min <= value <= self.max:
|
|
50
|
+
msg = f"Must be between {self.min} and {self.max}, got {value}"
|
|
51
|
+
raise ValueError(msg)
|
|
52
|
+
|
|
53
|
+
if self.step is not None:
|
|
54
|
+
remainder = (value - self.min) % self.step
|
|
55
|
+
# Use epsilon for floating-point comparison tolerance
|
|
56
|
+
epsilon = 1e-9
|
|
57
|
+
if not math.isclose(remainder, 0, abs_tol=epsilon) and not math.isclose(
|
|
58
|
+
remainder, self.step, abs_tol=epsilon
|
|
59
|
+
):
|
|
60
|
+
# Calculate nearest valid values for actionable error message
|
|
61
|
+
closest_below = self.min + (
|
|
62
|
+
int((value - self.min) / self.step) * self.step
|
|
63
|
+
)
|
|
64
|
+
closest_above = closest_below + self.step
|
|
65
|
+
msg = f"Value must match step {self.step}. Nearest valid: {closest_below} or {closest_above}, got {value}"
|
|
66
|
+
raise ValueError(msg)
|
|
67
|
+
|
|
68
|
+
return value
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Pattern(Constraint):
|
|
72
|
+
"""Pattern constraint - value must match regex pattern."""
|
|
73
|
+
|
|
74
|
+
pattern: str
|
|
75
|
+
|
|
76
|
+
def __call__(self, value: str) -> str:
|
|
77
|
+
"""Validate value matches pattern."""
|
|
78
|
+
if not isinstance(value, str):
|
|
79
|
+
msg = f"Must be string, got {type(value).__name__}"
|
|
80
|
+
raise TypeError(msg)
|
|
81
|
+
|
|
82
|
+
if not re.fullmatch(self.pattern, value):
|
|
83
|
+
msg = f"Must match pattern {self.pattern!r}, got {value!r}"
|
|
84
|
+
raise ValueError(msg)
|
|
85
|
+
|
|
86
|
+
return value
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class Str(Constraint):
|
|
90
|
+
"""String type constraint with optional length validation."""
|
|
91
|
+
|
|
92
|
+
max_length: int | None = None
|
|
93
|
+
min_length: int | None = None
|
|
94
|
+
|
|
95
|
+
def __call__(self, value: str) -> str:
|
|
96
|
+
"""Validate value is a string."""
|
|
97
|
+
if not isinstance(value, str):
|
|
98
|
+
msg = f"Must be string, got {type(value).__name__}"
|
|
99
|
+
raise TypeError(msg)
|
|
100
|
+
|
|
101
|
+
if self.min_length is not None and len(value) < self.min_length:
|
|
102
|
+
msg = f"String too short (min {self.min_length}), got {len(value)}"
|
|
103
|
+
raise ValueError(msg)
|
|
104
|
+
|
|
105
|
+
if self.max_length is not None and len(value) > self.max_length:
|
|
106
|
+
msg = f"String too long (max {self.max_length}), got {len(value)}"
|
|
107
|
+
raise ValueError(msg)
|
|
108
|
+
|
|
109
|
+
return value
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class Int(Constraint):
|
|
113
|
+
"""Integer type constraint."""
|
|
114
|
+
|
|
115
|
+
def __call__(self, value: int) -> int:
|
|
116
|
+
"""Validate value is an integer."""
|
|
117
|
+
# isinstance(True, int) is True, so exclude bools explicitly
|
|
118
|
+
if not isinstance(value, int) or isinstance(value, bool):
|
|
119
|
+
msg = f"Must be int, got {type(value).__name__}"
|
|
120
|
+
raise TypeError(msg)
|
|
121
|
+
|
|
122
|
+
return value
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class Float(Constraint):
|
|
126
|
+
"""Float type constraint (accepts int as well)."""
|
|
127
|
+
|
|
128
|
+
def __call__(self, value: float) -> float:
|
|
129
|
+
"""Validate value is numeric."""
|
|
130
|
+
if not isinstance(value, (int, float)) or isinstance(value, bool):
|
|
131
|
+
msg = f"Must be float or int, got {type(value).__name__}"
|
|
132
|
+
raise TypeError(msg)
|
|
133
|
+
|
|
134
|
+
return float(value)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class Bool(Constraint):
|
|
138
|
+
"""Boolean type constraint."""
|
|
139
|
+
|
|
140
|
+
def __call__(self, value: bool) -> bool:
|
|
141
|
+
"""Validate value is a boolean."""
|
|
142
|
+
if not isinstance(value, bool):
|
|
143
|
+
msg = f"Must be bool, got {type(value).__name__}"
|
|
144
|
+
raise TypeError(msg)
|
|
145
|
+
|
|
146
|
+
return value
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class Schema(Constraint):
|
|
150
|
+
"""Schema constraint - value must be a Pydantic BaseModel subclass or list[BaseModel]."""
|
|
151
|
+
|
|
152
|
+
def __call__(self, value: type[BaseModel]) -> type[BaseModel]:
|
|
153
|
+
"""Validate value is BaseModel or list[BaseModel]."""
|
|
154
|
+
# For list[T], validate inner type T
|
|
155
|
+
if get_origin(value) is list:
|
|
156
|
+
inner = get_args(value)[0]
|
|
157
|
+
if not (isinstance(inner, type) and issubclass(inner, BaseModel)):
|
|
158
|
+
msg = f"List type must be BaseModel, got {inner}"
|
|
159
|
+
raise TypeError(msg)
|
|
160
|
+
return value
|
|
161
|
+
|
|
162
|
+
# For plain type, validate directly
|
|
163
|
+
if not (isinstance(value, type) and issubclass(value, BaseModel)):
|
|
164
|
+
msg = f"Must be BaseModel, got {value}"
|
|
165
|
+
raise TypeError(msg)
|
|
166
|
+
|
|
167
|
+
return value
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
__all__ = [
|
|
171
|
+
"Bool",
|
|
172
|
+
"Choice",
|
|
173
|
+
"Constraint",
|
|
174
|
+
"Float",
|
|
175
|
+
"Int",
|
|
176
|
+
"Pattern",
|
|
177
|
+
"Range",
|
|
178
|
+
"Schema",
|
|
179
|
+
"Str",
|
|
180
|
+
]
|
celeste/core.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
"""Core enumerations for
|
|
1
|
+
"""Core enumerations for Celeste."""
|
|
2
2
|
|
|
3
|
-
from enum import
|
|
3
|
+
from enum import StrEnum
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
class Provider(
|
|
6
|
+
class Provider(StrEnum):
|
|
7
7
|
"""Supported AI providers."""
|
|
8
8
|
|
|
9
9
|
OPENAI = "openai"
|
|
@@ -17,19 +17,36 @@ class Provider(str, Enum):
|
|
|
17
17
|
STABILITYAI = "stabilityai"
|
|
18
18
|
LUMA = "luma"
|
|
19
19
|
TOPAZLABS = "topazlabs"
|
|
20
|
-
|
|
21
|
-
TRANSFORMERS = "transformers"
|
|
22
|
-
LOCAL = "local"
|
|
20
|
+
PERPLEXITY = "perplexity"
|
|
23
21
|
|
|
24
22
|
|
|
25
|
-
class Capability(
|
|
23
|
+
class Capability(StrEnum):
|
|
26
24
|
"""Supported AI capabilities."""
|
|
27
25
|
|
|
28
26
|
# Text
|
|
29
|
-
TEXT_GENERATION =
|
|
27
|
+
TEXT_GENERATION = "text_generation"
|
|
28
|
+
EMBEDDINGS = "embeddings"
|
|
30
29
|
|
|
31
30
|
# Image
|
|
32
|
-
IMAGE_GENERATION =
|
|
31
|
+
IMAGE_GENERATION = "image_generation"
|
|
33
32
|
|
|
33
|
+
# Video
|
|
34
|
+
VIDEO_INTELLIGENCE = "video_intelligence"
|
|
35
|
+
VIDEO_GENERATION = "video_generation"
|
|
34
36
|
|
|
35
|
-
|
|
37
|
+
# Audio
|
|
38
|
+
AUDIO_INTELLIGENCE = "audio_intelligence"
|
|
39
|
+
|
|
40
|
+
# Search
|
|
41
|
+
SEARCH = "search"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Parameter(StrEnum):
|
|
45
|
+
"""Universal parameters across most capabilities."""
|
|
46
|
+
|
|
47
|
+
TEMPERATURE = "temperature"
|
|
48
|
+
SEED = "seed"
|
|
49
|
+
MAX_TOKENS = "max_tokens"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
__all__ = ["Capability", "Parameter", "Provider"]
|