celeste-ai 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of celeste-ai might be problematic. Click here for more details.
- celeste/__init__.py +108 -17
- celeste/artifacts.py +52 -0
- celeste/client.py +150 -0
- celeste/constraints.py +180 -0
- celeste/core.py +52 -0
- celeste/credentials.py +108 -0
- celeste/http.py +201 -0
- celeste/io.py +43 -0
- celeste/mime_types.py +47 -0
- celeste/models.py +91 -0
- celeste/parameters.py +51 -0
- celeste/py.typed +1 -0
- celeste/streaming.py +114 -0
- celeste_ai-0.0.3.dist-info/METADATA +163 -0
- celeste_ai-0.0.3.dist-info/RECORD +17 -0
- celeste_ai-0.0.3.dist-info/licenses/LICENSE +201 -0
- celeste_ai-0.0.1.dist-info/METADATA +0 -59
- celeste_ai-0.0.1.dist-info/RECORD +0 -4
- {celeste_ai-0.0.1.dist-info → celeste_ai-0.0.3.dist-info}/WHEEL +0 -0
celeste/__init__.py
CHANGED
|
@@ -1,25 +1,116 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
import importlib.metadata
|
|
2
|
+
import logging
|
|
3
3
|
|
|
4
|
-
|
|
4
|
+
from pydantic import SecretStr
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
|
|
6
|
+
from celeste.client import Client, get_client_class, register_client
|
|
7
|
+
from celeste.core import Capability, Parameter, Provider
|
|
8
|
+
from celeste.credentials import credentials
|
|
9
|
+
from celeste.http import HTTPClient, close_all_http_clients
|
|
10
|
+
from celeste.io import Input, Output, Usage
|
|
11
|
+
from celeste.models import Model, get_model, list_models, register_models
|
|
12
|
+
from celeste.parameters import Parameters
|
|
8
13
|
|
|
9
|
-
|
|
10
|
-
__author__ = "agent-kai"
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
11
15
|
|
|
12
|
-
def create_client(*args, **kwargs):
|
|
13
|
-
"""
|
|
14
|
-
Placeholder for the universal client factory.
|
|
15
16
|
|
|
16
|
-
|
|
17
|
+
def _resolve_model(
|
|
18
|
+
capability: Capability,
|
|
19
|
+
provider: Provider | None,
|
|
20
|
+
model: Model | str | None,
|
|
21
|
+
) -> Model:
|
|
22
|
+
"""Resolve model parameter to Model object (auto-select if None, lookup if string)."""
|
|
23
|
+
if model is None:
|
|
24
|
+
# Auto-select first available model
|
|
25
|
+
models = list_models(provider=provider, capability=capability)
|
|
26
|
+
if not models:
|
|
27
|
+
msg = f"No model found for {capability}"
|
|
28
|
+
raise ValueError(msg)
|
|
29
|
+
return models[0]
|
|
30
|
+
|
|
31
|
+
if isinstance(model, str):
|
|
32
|
+
# String ID requires provider
|
|
33
|
+
if not provider:
|
|
34
|
+
msg = "provider required when model is a string ID"
|
|
35
|
+
raise ValueError(msg)
|
|
36
|
+
found = get_model(model, provider)
|
|
37
|
+
if not found:
|
|
38
|
+
msg = f"Model '{model}' not found for provider {provider}"
|
|
39
|
+
raise ValueError(msg)
|
|
40
|
+
return found
|
|
41
|
+
|
|
42
|
+
return model
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def create_client(
|
|
46
|
+
capability: Capability,
|
|
47
|
+
provider: Provider | None = None,
|
|
48
|
+
model: Model | str | None = None,
|
|
49
|
+
api_key: SecretStr | None = None,
|
|
50
|
+
) -> Client:
|
|
51
|
+
"""Create an async client for the specified AI capability.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
capability: The AI capability to use (e.g., TEXT_GENERATION).
|
|
55
|
+
provider: Optional provider. Required if model is a string ID.
|
|
56
|
+
model: Model object, string model ID, or None for auto-selection.
|
|
57
|
+
api_key: Optional SecretStr override. If not specified, loaded from environment.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Configured client instance ready for generation operations.
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
ValueError: If no model found or resolution fails.
|
|
64
|
+
NotImplementedError: If no client registered for capability/provider.
|
|
17
65
|
"""
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
66
|
+
# Resolve model
|
|
67
|
+
resolved_model = _resolve_model(capability, provider, model)
|
|
68
|
+
|
|
69
|
+
# Get client class and credentials
|
|
70
|
+
client_class = get_client_class(capability, resolved_model.provider)
|
|
71
|
+
resolved_key = credentials.get_credentials(
|
|
72
|
+
resolved_model.provider, override_key=api_key
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Create and return client
|
|
76
|
+
return client_class(
|
|
77
|
+
model=resolved_model,
|
|
78
|
+
provider=resolved_model.provider,
|
|
79
|
+
capability=capability,
|
|
80
|
+
api_key=resolved_key,
|
|
22
81
|
)
|
|
23
82
|
|
|
24
|
-
|
|
25
|
-
|
|
83
|
+
|
|
84
|
+
def _load_from_entry_points() -> None:
|
|
85
|
+
"""Load models and clients from installed packages via entry points."""
|
|
86
|
+
entry_points = importlib.metadata.entry_points(group="celeste.packages")
|
|
87
|
+
|
|
88
|
+
for ep in entry_points:
|
|
89
|
+
register_func = ep.load()
|
|
90
|
+
# The function should register models and clients when called
|
|
91
|
+
register_func()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
# Load from entry points on import
|
|
95
|
+
_load_from_entry_points()
|
|
96
|
+
|
|
97
|
+
# Exports
|
|
98
|
+
__all__ = [
|
|
99
|
+
"Capability",
|
|
100
|
+
"Client",
|
|
101
|
+
"HTTPClient",
|
|
102
|
+
"Input",
|
|
103
|
+
"Model",
|
|
104
|
+
"Output",
|
|
105
|
+
"Parameter",
|
|
106
|
+
"Parameters",
|
|
107
|
+
"Provider",
|
|
108
|
+
"Usage",
|
|
109
|
+
"close_all_http_clients",
|
|
110
|
+
"create_client",
|
|
111
|
+
"get_client_class",
|
|
112
|
+
"get_model",
|
|
113
|
+
"list_models",
|
|
114
|
+
"register_client",
|
|
115
|
+
"register_models",
|
|
116
|
+
]
|
celeste/artifacts.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Unified artifact types for Celeste."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
from celeste.mime_types import AudioMimeType, ImageMimeType, MimeType, VideoMimeType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Artifact(BaseModel):
|
|
11
|
+
"""Base class for all media artifacts."""
|
|
12
|
+
|
|
13
|
+
url: str | None = None
|
|
14
|
+
data: bytes | None = None
|
|
15
|
+
path: str | None = None
|
|
16
|
+
mime_type: MimeType | None = None
|
|
17
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def has_content(self) -> bool:
|
|
21
|
+
"""Check if artifact has any content."""
|
|
22
|
+
return bool(
|
|
23
|
+
(self.url and self.url.strip())
|
|
24
|
+
or self.data
|
|
25
|
+
or (self.path and self.path.strip())
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class ImageArtifact(Artifact):
|
|
30
|
+
"""Image artifact from generation/edit operations."""
|
|
31
|
+
|
|
32
|
+
mime_type: ImageMimeType | None = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class VideoArtifact(Artifact):
|
|
36
|
+
"""Video artifact from generation operations."""
|
|
37
|
+
|
|
38
|
+
mime_type: VideoMimeType | None = None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class AudioArtifact(Artifact):
|
|
42
|
+
"""Audio artifact from TTS/transcription operations."""
|
|
43
|
+
|
|
44
|
+
mime_type: AudioMimeType | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
__all__ = [
|
|
48
|
+
"Artifact",
|
|
49
|
+
"AudioArtifact",
|
|
50
|
+
"ImageArtifact",
|
|
51
|
+
"VideoArtifact",
|
|
52
|
+
]
|
celeste/client.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Base client and client registry for AI capabilities."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any, Unpack
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
|
7
|
+
|
|
8
|
+
from celeste.core import Capability, Provider
|
|
9
|
+
from celeste.http import HTTPClient, get_http_client
|
|
10
|
+
from celeste.io import Input, Output, Usage
|
|
11
|
+
from celeste.models import Model
|
|
12
|
+
from celeste.parameters import ParameterMapper, Parameters
|
|
13
|
+
from celeste.streaming import Stream
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Client[In: Input, Out: Output](ABC, BaseModel):
|
|
17
|
+
"""Base class for all capability-specific clients."""
|
|
18
|
+
|
|
19
|
+
model_config = ConfigDict(from_attributes=True)
|
|
20
|
+
|
|
21
|
+
model: Model
|
|
22
|
+
provider: Provider
|
|
23
|
+
capability: Capability
|
|
24
|
+
api_key: SecretStr = Field(exclude=True)
|
|
25
|
+
|
|
26
|
+
def model_post_init(self, __context: object) -> None:
|
|
27
|
+
"""Validate capability compatibility."""
|
|
28
|
+
if self.capability not in self.model.capabilities:
|
|
29
|
+
raise ValueError(
|
|
30
|
+
f"Model '{self.model.id}' does not support capability {self.capability.value}"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def http_client(self) -> HTTPClient:
|
|
35
|
+
"""Shared HTTP client with connection pooling for this provider."""
|
|
36
|
+
return get_http_client(self.provider, self.capability)
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def parameter_mappers(cls) -> list[ParameterMapper]:
|
|
41
|
+
"""Provider-specific parameter mappers."""
|
|
42
|
+
...
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def _init_request(self, inputs: In) -> dict[str, Any]:
|
|
46
|
+
"""Initialize provider-specific base request structure."""
|
|
47
|
+
pass
|
|
48
|
+
|
|
49
|
+
@abstractmethod
|
|
50
|
+
def _parse_usage(self, response_data: dict[str, Any]) -> Usage:
|
|
51
|
+
"""Parse usage information from provider response."""
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def _parse_content(
|
|
56
|
+
self, response_data: dict[str, Any], **parameters: Unpack[Parameters]
|
|
57
|
+
) -> object:
|
|
58
|
+
"""Parse content from provider response."""
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
def _transform_output(
|
|
62
|
+
self, content: object, **parameters: Unpack[Parameters]
|
|
63
|
+
) -> object:
|
|
64
|
+
"""Transform content using parameter mapper output transformations."""
|
|
65
|
+
for mapper in self.parameter_mappers():
|
|
66
|
+
value = parameters.get(mapper.name)
|
|
67
|
+
if value is not None:
|
|
68
|
+
content = mapper.parse_output(content, value)
|
|
69
|
+
return content
|
|
70
|
+
|
|
71
|
+
def _build_request(
|
|
72
|
+
self, inputs: In, **parameters: Unpack[Parameters]
|
|
73
|
+
) -> dict[str, Any]:
|
|
74
|
+
"""Build complete request by combining base request with parameters."""
|
|
75
|
+
request = self._init_request(inputs)
|
|
76
|
+
|
|
77
|
+
# Apply parameter mappers from registry
|
|
78
|
+
for mapper in self.parameter_mappers():
|
|
79
|
+
value = parameters.get(mapper.name)
|
|
80
|
+
request = mapper.map(request, value, self.model)
|
|
81
|
+
|
|
82
|
+
return request
|
|
83
|
+
|
|
84
|
+
def stream(self, *args: Any, **parameters: Unpack[Parameters]) -> Stream[Out]: # noqa: ANN401
|
|
85
|
+
"""Stream content - signature varies by capability.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
*args: Capability-specific positional arguments (same as generate).
|
|
89
|
+
**parameters: Capability-specific keyword arguments (same as generate).
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Stream yielding chunks and providing final Output.
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
NotImplementedError: If capability doesn't support streaming.
|
|
96
|
+
"""
|
|
97
|
+
msg = f"Streaming not supported for {self.capability.value} with provider {self.provider.value}"
|
|
98
|
+
raise NotImplementedError(msg)
|
|
99
|
+
|
|
100
|
+
@abstractmethod
|
|
101
|
+
async def generate(self, *args: Any, **parameters: Unpack[Parameters]) -> Out: # noqa: ANN401
|
|
102
|
+
"""Generate content - signature varies by capability.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
*args: Capability-specific positional arguments (prompt, text, image_url, etc.).
|
|
106
|
+
**parameters: Capability-specific keyword arguments (temperature, max_tokens, etc.).
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Output of the parameterized type (e.g., TextGenerationOutput).
|
|
110
|
+
"""
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
_clients: dict[tuple[Capability, Provider], type[Client]] = {}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def register_client(
|
|
118
|
+
capability: Capability, provider: Provider, client_class: type[Client]
|
|
119
|
+
) -> None:
|
|
120
|
+
"""Register a provider-specific client class for a capability.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
capability: The capability this client implements.
|
|
124
|
+
provider: The provider this client uses.
|
|
125
|
+
client_class: The client class to register.
|
|
126
|
+
"""
|
|
127
|
+
_clients[(capability, provider)] = client_class
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_client_class(capability: Capability, provider: Provider) -> type[Client]:
|
|
131
|
+
"""Get the registered client class for a capability and provider.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
capability: The capability to get a client for.
|
|
135
|
+
provider: The provider to use.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
The registered client class.
|
|
139
|
+
|
|
140
|
+
Raises:
|
|
141
|
+
NotImplementedError: If no client is registered for this capability/provider.
|
|
142
|
+
"""
|
|
143
|
+
if (capability, provider) not in _clients:
|
|
144
|
+
raise NotImplementedError(
|
|
145
|
+
f"No client registered for {capability.value} with provider {provider.value}"
|
|
146
|
+
)
|
|
147
|
+
return _clients[(capability, provider)]
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
__all__ = ["Client", "get_client_class", "register_client"]
|
celeste/constraints.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
"""Constraint models for parameter validation."""
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
import re
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any, get_args, get_origin
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Constraint(BaseModel, ABC):
|
|
12
|
+
"""Base constraint for parameter validation."""
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def __call__(self, value: Any) -> Any: # noqa: ANN401
|
|
16
|
+
"""Validate value against constraint and return validated value."""
|
|
17
|
+
...
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Choice[T](Constraint):
|
|
21
|
+
"""Choice constraint - value must be one of the provided options."""
|
|
22
|
+
|
|
23
|
+
options: list[T] = Field(min_length=1)
|
|
24
|
+
|
|
25
|
+
def __call__(self, value: T) -> T:
|
|
26
|
+
"""Validate value is in options."""
|
|
27
|
+
if value not in self.options:
|
|
28
|
+
msg = f"Must be one of {self.options}, got {value!r}"
|
|
29
|
+
raise ValueError(msg)
|
|
30
|
+
return value
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Range(Constraint):
|
|
34
|
+
"""Range constraint - value must be within min/max bounds.
|
|
35
|
+
|
|
36
|
+
If step is provided, value must be at min + (n * step) for some integer n.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
min: float | int
|
|
40
|
+
max: float | int
|
|
41
|
+
step: float | None = None
|
|
42
|
+
|
|
43
|
+
def __call__(self, value: float | int) -> float | int:
|
|
44
|
+
"""Validate value is within range and matches step increment."""
|
|
45
|
+
if not isinstance(value, (int, float)):
|
|
46
|
+
msg = f"Must be numeric, got {type(value).__name__}"
|
|
47
|
+
raise TypeError(msg)
|
|
48
|
+
|
|
49
|
+
if not self.min <= value <= self.max:
|
|
50
|
+
msg = f"Must be between {self.min} and {self.max}, got {value}"
|
|
51
|
+
raise ValueError(msg)
|
|
52
|
+
|
|
53
|
+
if self.step is not None:
|
|
54
|
+
remainder = (value - self.min) % self.step
|
|
55
|
+
# Use epsilon for floating-point comparison tolerance
|
|
56
|
+
epsilon = 1e-9
|
|
57
|
+
if not math.isclose(remainder, 0, abs_tol=epsilon) and not math.isclose(
|
|
58
|
+
remainder, self.step, abs_tol=epsilon
|
|
59
|
+
):
|
|
60
|
+
# Calculate nearest valid values for actionable error message
|
|
61
|
+
closest_below = self.min + (
|
|
62
|
+
int((value - self.min) / self.step) * self.step
|
|
63
|
+
)
|
|
64
|
+
closest_above = closest_below + self.step
|
|
65
|
+
msg = f"Value must match step {self.step}. Nearest valid: {closest_below} or {closest_above}, got {value}"
|
|
66
|
+
raise ValueError(msg)
|
|
67
|
+
|
|
68
|
+
return value
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Pattern(Constraint):
|
|
72
|
+
"""Pattern constraint - value must match regex pattern."""
|
|
73
|
+
|
|
74
|
+
pattern: str
|
|
75
|
+
|
|
76
|
+
def __call__(self, value: str) -> str:
|
|
77
|
+
"""Validate value matches pattern."""
|
|
78
|
+
if not isinstance(value, str):
|
|
79
|
+
msg = f"Must be string, got {type(value).__name__}"
|
|
80
|
+
raise TypeError(msg)
|
|
81
|
+
|
|
82
|
+
if not re.fullmatch(self.pattern, value):
|
|
83
|
+
msg = f"Must match pattern {self.pattern!r}, got {value!r}"
|
|
84
|
+
raise ValueError(msg)
|
|
85
|
+
|
|
86
|
+
return value
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class Str(Constraint):
|
|
90
|
+
"""String type constraint with optional length validation."""
|
|
91
|
+
|
|
92
|
+
max_length: int | None = None
|
|
93
|
+
min_length: int | None = None
|
|
94
|
+
|
|
95
|
+
def __call__(self, value: str) -> str:
|
|
96
|
+
"""Validate value is a string."""
|
|
97
|
+
if not isinstance(value, str):
|
|
98
|
+
msg = f"Must be string, got {type(value).__name__}"
|
|
99
|
+
raise TypeError(msg)
|
|
100
|
+
|
|
101
|
+
if self.min_length is not None and len(value) < self.min_length:
|
|
102
|
+
msg = f"String too short (min {self.min_length}), got {len(value)}"
|
|
103
|
+
raise ValueError(msg)
|
|
104
|
+
|
|
105
|
+
if self.max_length is not None and len(value) > self.max_length:
|
|
106
|
+
msg = f"String too long (max {self.max_length}), got {len(value)}"
|
|
107
|
+
raise ValueError(msg)
|
|
108
|
+
|
|
109
|
+
return value
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class Int(Constraint):
|
|
113
|
+
"""Integer type constraint."""
|
|
114
|
+
|
|
115
|
+
def __call__(self, value: int) -> int:
|
|
116
|
+
"""Validate value is an integer."""
|
|
117
|
+
# isinstance(True, int) is True, so exclude bools explicitly
|
|
118
|
+
if not isinstance(value, int) or isinstance(value, bool):
|
|
119
|
+
msg = f"Must be int, got {type(value).__name__}"
|
|
120
|
+
raise TypeError(msg)
|
|
121
|
+
|
|
122
|
+
return value
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class Float(Constraint):
|
|
126
|
+
"""Float type constraint (accepts int as well)."""
|
|
127
|
+
|
|
128
|
+
def __call__(self, value: float) -> float:
|
|
129
|
+
"""Validate value is numeric."""
|
|
130
|
+
if not isinstance(value, (int, float)) or isinstance(value, bool):
|
|
131
|
+
msg = f"Must be float or int, got {type(value).__name__}"
|
|
132
|
+
raise TypeError(msg)
|
|
133
|
+
|
|
134
|
+
return float(value)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class Bool(Constraint):
|
|
138
|
+
"""Boolean type constraint."""
|
|
139
|
+
|
|
140
|
+
def __call__(self, value: bool) -> bool:
|
|
141
|
+
"""Validate value is a boolean."""
|
|
142
|
+
if not isinstance(value, bool):
|
|
143
|
+
msg = f"Must be bool, got {type(value).__name__}"
|
|
144
|
+
raise TypeError(msg)
|
|
145
|
+
|
|
146
|
+
return value
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class Schema(Constraint):
|
|
150
|
+
"""Schema constraint - value must be a Pydantic BaseModel subclass or list[BaseModel]."""
|
|
151
|
+
|
|
152
|
+
def __call__(self, value: type[BaseModel]) -> type[BaseModel]:
|
|
153
|
+
"""Validate value is BaseModel or list[BaseModel]."""
|
|
154
|
+
# For list[T], validate inner type T
|
|
155
|
+
if get_origin(value) is list:
|
|
156
|
+
inner = get_args(value)[0]
|
|
157
|
+
if not (isinstance(inner, type) and issubclass(inner, BaseModel)):
|
|
158
|
+
msg = f"List type must be BaseModel, got {inner}"
|
|
159
|
+
raise TypeError(msg)
|
|
160
|
+
return value
|
|
161
|
+
|
|
162
|
+
# For plain type, validate directly
|
|
163
|
+
if not (isinstance(value, type) and issubclass(value, BaseModel)):
|
|
164
|
+
msg = f"Must be BaseModel, got {value}"
|
|
165
|
+
raise TypeError(msg)
|
|
166
|
+
|
|
167
|
+
return value
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
__all__ = [
|
|
171
|
+
"Bool",
|
|
172
|
+
"Choice",
|
|
173
|
+
"Constraint",
|
|
174
|
+
"Float",
|
|
175
|
+
"Int",
|
|
176
|
+
"Pattern",
|
|
177
|
+
"Range",
|
|
178
|
+
"Schema",
|
|
179
|
+
"Str",
|
|
180
|
+
]
|
celeste/core.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""Core enumerations for Celeste."""
|
|
2
|
+
|
|
3
|
+
from enum import StrEnum
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Provider(StrEnum):
|
|
7
|
+
"""Supported AI providers."""
|
|
8
|
+
|
|
9
|
+
OPENAI = "openai"
|
|
10
|
+
ANTHROPIC = "anthropic"
|
|
11
|
+
GOOGLE = "google"
|
|
12
|
+
MISTRAL = "mistral"
|
|
13
|
+
COHERE = "cohere"
|
|
14
|
+
XAI = "xai"
|
|
15
|
+
HUGGINGFACE = "huggingface"
|
|
16
|
+
REPLICATE = "replicate"
|
|
17
|
+
STABILITYAI = "stabilityai"
|
|
18
|
+
LUMA = "luma"
|
|
19
|
+
TOPAZLABS = "topazlabs"
|
|
20
|
+
PERPLEXITY = "perplexity"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Capability(StrEnum):
|
|
24
|
+
"""Supported AI capabilities."""
|
|
25
|
+
|
|
26
|
+
# Text
|
|
27
|
+
TEXT_GENERATION = "text_generation"
|
|
28
|
+
EMBEDDINGS = "embeddings"
|
|
29
|
+
|
|
30
|
+
# Image
|
|
31
|
+
IMAGE_GENERATION = "image_generation"
|
|
32
|
+
|
|
33
|
+
# Video
|
|
34
|
+
VIDEO_INTELLIGENCE = "video_intelligence"
|
|
35
|
+
VIDEO_GENERATION = "video_generation"
|
|
36
|
+
|
|
37
|
+
# Audio
|
|
38
|
+
AUDIO_INTELLIGENCE = "audio_intelligence"
|
|
39
|
+
|
|
40
|
+
# Search
|
|
41
|
+
SEARCH = "search"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Parameter(StrEnum):
|
|
45
|
+
"""Universal parameters across most capabilities."""
|
|
46
|
+
|
|
47
|
+
TEMPERATURE = "temperature"
|
|
48
|
+
SEED = "seed"
|
|
49
|
+
MAX_TOKENS = "max_tokens"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
__all__ = ["Capability", "Parameter", "Provider"]
|
celeste/credentials.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
"""Provider API credentials management for Celeste."""
|
|
2
|
+
|
|
3
|
+
from dotenv import find_dotenv
|
|
4
|
+
from pydantic import Field, SecretStr
|
|
5
|
+
from pydantic_settings import BaseSettings
|
|
6
|
+
|
|
7
|
+
from celeste.core import Provider
|
|
8
|
+
|
|
9
|
+
# Provider to credential field mapping
|
|
10
|
+
PROVIDER_CREDENTIAL_MAP = {
|
|
11
|
+
Provider.OPENAI: "openai_api_key",
|
|
12
|
+
Provider.ANTHROPIC: "anthropic_api_key",
|
|
13
|
+
Provider.GOOGLE: "google_api_key",
|
|
14
|
+
Provider.MISTRAL: "mistral_api_key",
|
|
15
|
+
Provider.HUGGINGFACE: "huggingface_token",
|
|
16
|
+
Provider.STABILITYAI: "stabilityai_api_key",
|
|
17
|
+
Provider.REPLICATE: "replicate_api_token",
|
|
18
|
+
Provider.COHERE: "cohere_api_key",
|
|
19
|
+
Provider.XAI: "xai_api_key",
|
|
20
|
+
Provider.LUMA: "luma_api_key",
|
|
21
|
+
Provider.TOPAZLABS: "topazlabs_api_key",
|
|
22
|
+
Provider.PERPLEXITY: "perplexity_api_key",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Credentials(BaseSettings):
|
|
27
|
+
"""API credentials for all supported providers."""
|
|
28
|
+
|
|
29
|
+
openai_api_key: SecretStr | None = Field(None, alias="OPENAI_API_KEY")
|
|
30
|
+
anthropic_api_key: SecretStr | None = Field(None, alias="ANTHROPIC_API_KEY")
|
|
31
|
+
google_api_key: SecretStr | None = Field(None, alias="GOOGLE_API_KEY")
|
|
32
|
+
mistral_api_key: SecretStr | None = Field(None, alias="MISTRAL_API_KEY")
|
|
33
|
+
huggingface_token: SecretStr | None = Field(None, alias="HUGGINGFACE_TOKEN")
|
|
34
|
+
stabilityai_api_key: SecretStr | None = Field(None, alias="STABILITYAI_API_KEY")
|
|
35
|
+
replicate_api_token: SecretStr | None = Field(None, alias="REPLICATE_API_TOKEN")
|
|
36
|
+
cohere_api_key: SecretStr | None = Field(None, alias="COHERE_API_KEY")
|
|
37
|
+
xai_api_key: SecretStr | None = Field(None, alias="XAI_API_KEY")
|
|
38
|
+
luma_api_key: SecretStr | None = Field(None, alias="LUMA_API_KEY")
|
|
39
|
+
topazlabs_api_key: SecretStr | None = Field(None, alias="TOPAZLABS_API_KEY")
|
|
40
|
+
perplexity_api_key: SecretStr | None = Field(None, alias="PERPLEXITY_API_KEY")
|
|
41
|
+
|
|
42
|
+
model_config = {
|
|
43
|
+
"env_file": find_dotenv(),
|
|
44
|
+
"env_file_encoding": "utf-8",
|
|
45
|
+
"case_sensitive": False,
|
|
46
|
+
"extra": "ignore", # Ignore unknown env vars like context7_api_key
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
def get_credentials(
|
|
50
|
+
self, provider: Provider, override_key: SecretStr | None = None
|
|
51
|
+
) -> SecretStr:
|
|
52
|
+
"""Get credentials for a specific provider with optional override.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
provider: The AI provider to get credentials for.
|
|
56
|
+
override_key: Optional SecretStr to use instead of environment variable.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
SecretStr containing the API key for the provider.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
ValueError: If provider requires credentials but none are configured,
|
|
63
|
+
or if provider is not supported (no credential mapping).
|
|
64
|
+
"""
|
|
65
|
+
if override_key:
|
|
66
|
+
return override_key
|
|
67
|
+
|
|
68
|
+
if not self.has_credential(provider):
|
|
69
|
+
msg = f"Provider {provider} has no credentials configured."
|
|
70
|
+
raise ValueError(msg)
|
|
71
|
+
|
|
72
|
+
credential: SecretStr = getattr(self, PROVIDER_CREDENTIAL_MAP[provider])
|
|
73
|
+
return credential
|
|
74
|
+
|
|
75
|
+
def list_available_providers(self) -> list[Provider]:
|
|
76
|
+
"""List all providers that have credentials configured.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
List of Provider enums that have credentials configured via environment variables.
|
|
80
|
+
"""
|
|
81
|
+
return [
|
|
82
|
+
provider
|
|
83
|
+
for provider in PROVIDER_CREDENTIAL_MAP
|
|
84
|
+
if self.has_credential(provider)
|
|
85
|
+
]
|
|
86
|
+
|
|
87
|
+
def has_credential(self, provider: Provider) -> bool:
|
|
88
|
+
"""Check if a specific provider has credentials configured.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
provider: The AI provider to check.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
True if provider has credentials configured, False if credentials not set.
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
ValueError: If provider has no credential mapping.
|
|
98
|
+
"""
|
|
99
|
+
credential_field = PROVIDER_CREDENTIAL_MAP.get(provider)
|
|
100
|
+
if not credential_field:
|
|
101
|
+
msg = f"Provider {provider} has no credential mapping"
|
|
102
|
+
raise ValueError(msg)
|
|
103
|
+
return getattr(self, credential_field, None) is not None
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
credentials = Credentials.model_validate({})
|
|
107
|
+
|
|
108
|
+
__all__ = ["Credentials", "credentials"]
|