celeste-ai 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of celeste-ai might be problematic. Click here for more details.

celeste/__init__.py CHANGED
@@ -1,25 +1,116 @@
1
- """
2
- Celeste AI Framework
1
+ import importlib.metadata
2
+ import logging
3
3
 
4
- A unified multi-modal AI framework for text, image, video, and audio generation.
4
+ from pydantic import SecretStr
5
5
 
6
- This is a placeholder package to reserve the name. The full framework is coming soon!
7
- """
6
+ from celeste.client import Client, get_client_class, register_client
7
+ from celeste.core import Capability, Parameter, Provider
8
+ from celeste.credentials import credentials
9
+ from celeste.http import HTTPClient, close_all_http_clients
10
+ from celeste.io import Input, Output, Usage
11
+ from celeste.models import Model, get_model, list_models, register_models
12
+ from celeste.parameters import Parameters
8
13
 
9
- __version__ = "0.0.1"
10
- __author__ = "agent-kai"
14
+ logger = logging.getLogger(__name__)
11
15
 
12
- def create_client(*args, **kwargs):
13
- """
14
- Placeholder for the universal client factory.
15
16
 
16
- The full implementation will provide a unified interface for all AI capabilities.
17
+ def _resolve_model(
18
+ capability: Capability,
19
+ provider: Provider | None,
20
+ model: Model | str | None,
21
+ ) -> Model:
22
+ """Resolve model parameter to Model object (auto-select if None, lookup if string)."""
23
+ if model is None:
24
+ # Auto-select first available model
25
+ models = list_models(provider=provider, capability=capability)
26
+ if not models:
27
+ msg = f"No model found for {capability}"
28
+ raise ValueError(msg)
29
+ return models[0]
30
+
31
+ if isinstance(model, str):
32
+ # String ID requires provider
33
+ if not provider:
34
+ msg = "provider required when model is a string ID"
35
+ raise ValueError(msg)
36
+ found = get_model(model, provider)
37
+ if not found:
38
+ msg = f"Model '{model}' not found for provider {provider}"
39
+ raise ValueError(msg)
40
+ return found
41
+
42
+ return model
43
+
44
+
45
+ def create_client(
46
+ capability: Capability,
47
+ provider: Provider | None = None,
48
+ model: Model | str | None = None,
49
+ api_key: SecretStr | None = None,
50
+ ) -> Client:
51
+ """Create an async client for the specified AI capability.
52
+
53
+ Args:
54
+ capability: The AI capability to use (e.g., TEXT_GENERATION).
55
+ provider: Optional provider. Required if model is a string ID.
56
+ model: Model object, string model ID, or None for auto-selection.
57
+ api_key: Optional SecretStr override. If not specified, loaded from environment.
58
+
59
+ Returns:
60
+ Configured client instance ready for generation operations.
61
+
62
+ Raises:
63
+ ValueError: If no model found or resolution fails.
64
+ NotImplementedError: If no client registered for capability/provider.
17
65
  """
18
- raise NotImplementedError(
19
- "Celeste AI is in development. "
20
- "This is a placeholder package to reserve the name. "
21
- "Follow updates at: https://github.com/agent-kai/celeste-ai"
66
+ # Resolve model
67
+ resolved_model = _resolve_model(capability, provider, model)
68
+
69
+ # Get client class and credentials
70
+ client_class = get_client_class(capability, resolved_model.provider)
71
+ resolved_key = credentials.get_credentials(
72
+ resolved_model.provider, override_key=api_key
73
+ )
74
+
75
+ # Create and return client
76
+ return client_class(
77
+ model=resolved_model,
78
+ provider=resolved_model.provider,
79
+ capability=capability,
80
+ api_key=resolved_key,
22
81
  )
23
82
 
24
- # Placeholder exports
25
- __all__ = ["create_client"]
83
+
84
+ def _load_from_entry_points() -> None:
85
+ """Load models and clients from installed packages via entry points."""
86
+ entry_points = importlib.metadata.entry_points(group="celeste.packages")
87
+
88
+ for ep in entry_points:
89
+ register_func = ep.load()
90
+ # The function should register models and clients when called
91
+ register_func()
92
+
93
+
94
+ # Load from entry points on import
95
+ _load_from_entry_points()
96
+
97
+ # Exports
98
+ __all__ = [
99
+ "Capability",
100
+ "Client",
101
+ "HTTPClient",
102
+ "Input",
103
+ "Model",
104
+ "Output",
105
+ "Parameter",
106
+ "Parameters",
107
+ "Provider",
108
+ "Usage",
109
+ "close_all_http_clients",
110
+ "create_client",
111
+ "get_client_class",
112
+ "get_model",
113
+ "list_models",
114
+ "register_client",
115
+ "register_models",
116
+ ]
celeste/artifacts.py ADDED
@@ -0,0 +1,52 @@
1
+ """Unified artifact types for Celeste."""
2
+
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+ from celeste.mime_types import AudioMimeType, ImageMimeType, MimeType, VideoMimeType
8
+
9
+
10
+ class Artifact(BaseModel):
11
+ """Base class for all media artifacts."""
12
+
13
+ url: str | None = None
14
+ data: bytes | None = None
15
+ path: str | None = None
16
+ mime_type: MimeType | None = None
17
+ metadata: dict[str, Any] = Field(default_factory=dict)
18
+
19
+ @property
20
+ def has_content(self) -> bool:
21
+ """Check if artifact has any content."""
22
+ return bool(
23
+ (self.url and self.url.strip())
24
+ or self.data
25
+ or (self.path and self.path.strip())
26
+ )
27
+
28
+
29
+ class ImageArtifact(Artifact):
30
+ """Image artifact from generation/edit operations."""
31
+
32
+ mime_type: ImageMimeType | None = None
33
+
34
+
35
+ class VideoArtifact(Artifact):
36
+ """Video artifact from generation operations."""
37
+
38
+ mime_type: VideoMimeType | None = None
39
+
40
+
41
+ class AudioArtifact(Artifact):
42
+ """Audio artifact from TTS/transcription operations."""
43
+
44
+ mime_type: AudioMimeType | None = None
45
+
46
+
47
+ __all__ = [
48
+ "Artifact",
49
+ "AudioArtifact",
50
+ "ImageArtifact",
51
+ "VideoArtifact",
52
+ ]
celeste/client.py ADDED
@@ -0,0 +1,150 @@
1
+ """Base client and client registry for AI capabilities."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Unpack
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field, SecretStr
7
+
8
+ from celeste.core import Capability, Provider
9
+ from celeste.http import HTTPClient, get_http_client
10
+ from celeste.io import Input, Output, Usage
11
+ from celeste.models import Model
12
+ from celeste.parameters import ParameterMapper, Parameters
13
+ from celeste.streaming import Stream
14
+
15
+
16
+ class Client[In: Input, Out: Output](ABC, BaseModel):
17
+ """Base class for all capability-specific clients."""
18
+
19
+ model_config = ConfigDict(from_attributes=True)
20
+
21
+ model: Model
22
+ provider: Provider
23
+ capability: Capability
24
+ api_key: SecretStr = Field(exclude=True)
25
+
26
+ def model_post_init(self, __context: object) -> None:
27
+ """Validate capability compatibility."""
28
+ if self.capability not in self.model.capabilities:
29
+ raise ValueError(
30
+ f"Model '{self.model.id}' does not support capability {self.capability.value}"
31
+ )
32
+
33
+ @property
34
+ def http_client(self) -> HTTPClient:
35
+ """Shared HTTP client with connection pooling for this provider."""
36
+ return get_http_client(self.provider, self.capability)
37
+
38
+ @classmethod
39
+ @abstractmethod
40
+ def parameter_mappers(cls) -> list[ParameterMapper]:
41
+ """Provider-specific parameter mappers."""
42
+ ...
43
+
44
+ @abstractmethod
45
+ def _init_request(self, inputs: In) -> dict[str, Any]:
46
+ """Initialize provider-specific base request structure."""
47
+ pass
48
+
49
+ @abstractmethod
50
+ def _parse_usage(self, response_data: dict[str, Any]) -> Usage:
51
+ """Parse usage information from provider response."""
52
+ pass
53
+
54
+ @abstractmethod
55
+ def _parse_content(
56
+ self, response_data: dict[str, Any], **parameters: Unpack[Parameters]
57
+ ) -> object:
58
+ """Parse content from provider response."""
59
+ pass
60
+
61
+ def _transform_output(
62
+ self, content: object, **parameters: Unpack[Parameters]
63
+ ) -> object:
64
+ """Transform content using parameter mapper output transformations."""
65
+ for mapper in self.parameter_mappers():
66
+ value = parameters.get(mapper.name)
67
+ if value is not None:
68
+ content = mapper.parse_output(content, value)
69
+ return content
70
+
71
+ def _build_request(
72
+ self, inputs: In, **parameters: Unpack[Parameters]
73
+ ) -> dict[str, Any]:
74
+ """Build complete request by combining base request with parameters."""
75
+ request = self._init_request(inputs)
76
+
77
+ # Apply parameter mappers from registry
78
+ for mapper in self.parameter_mappers():
79
+ value = parameters.get(mapper.name)
80
+ request = mapper.map(request, value, self.model)
81
+
82
+ return request
83
+
84
+ def stream(self, *args: Any, **parameters: Unpack[Parameters]) -> Stream[Out]: # noqa: ANN401
85
+ """Stream content - signature varies by capability.
86
+
87
+ Args:
88
+ *args: Capability-specific positional arguments (same as generate).
89
+ **parameters: Capability-specific keyword arguments (same as generate).
90
+
91
+ Returns:
92
+ Stream yielding chunks and providing final Output.
93
+
94
+ Raises:
95
+ NotImplementedError: If capability doesn't support streaming.
96
+ """
97
+ msg = f"Streaming not supported for {self.capability.value} with provider {self.provider.value}"
98
+ raise NotImplementedError(msg)
99
+
100
+ @abstractmethod
101
+ async def generate(self, *args: Any, **parameters: Unpack[Parameters]) -> Out: # noqa: ANN401
102
+ """Generate content - signature varies by capability.
103
+
104
+ Args:
105
+ *args: Capability-specific positional arguments (prompt, text, image_url, etc.).
106
+ **parameters: Capability-specific keyword arguments (temperature, max_tokens, etc.).
107
+
108
+ Returns:
109
+ Output of the parameterized type (e.g., TextGenerationOutput).
110
+ """
111
+ pass
112
+
113
+
114
+ _clients: dict[tuple[Capability, Provider], type[Client]] = {}
115
+
116
+
117
+ def register_client(
118
+ capability: Capability, provider: Provider, client_class: type[Client]
119
+ ) -> None:
120
+ """Register a provider-specific client class for a capability.
121
+
122
+ Args:
123
+ capability: The capability this client implements.
124
+ provider: The provider this client uses.
125
+ client_class: The client class to register.
126
+ """
127
+ _clients[(capability, provider)] = client_class
128
+
129
+
130
+ def get_client_class(capability: Capability, provider: Provider) -> type[Client]:
131
+ """Get the registered client class for a capability and provider.
132
+
133
+ Args:
134
+ capability: The capability to get a client for.
135
+ provider: The provider to use.
136
+
137
+ Returns:
138
+ The registered client class.
139
+
140
+ Raises:
141
+ NotImplementedError: If no client is registered for this capability/provider.
142
+ """
143
+ if (capability, provider) not in _clients:
144
+ raise NotImplementedError(
145
+ f"No client registered for {capability.value} with provider {provider.value}"
146
+ )
147
+ return _clients[(capability, provider)]
148
+
149
+
150
+ __all__ = ["Client", "get_client_class", "register_client"]
celeste/constraints.py ADDED
@@ -0,0 +1,180 @@
1
+ """Constraint models for parameter validation."""
2
+
3
+ import math
4
+ import re
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any, get_args, get_origin
7
+
8
+ from pydantic import BaseModel, Field
9
+
10
+
11
+ class Constraint(BaseModel, ABC):
12
+ """Base constraint for parameter validation."""
13
+
14
+ @abstractmethod
15
+ def __call__(self, value: Any) -> Any: # noqa: ANN401
16
+ """Validate value against constraint and return validated value."""
17
+ ...
18
+
19
+
20
+ class Choice[T](Constraint):
21
+ """Choice constraint - value must be one of the provided options."""
22
+
23
+ options: list[T] = Field(min_length=1)
24
+
25
+ def __call__(self, value: T) -> T:
26
+ """Validate value is in options."""
27
+ if value not in self.options:
28
+ msg = f"Must be one of {self.options}, got {value!r}"
29
+ raise ValueError(msg)
30
+ return value
31
+
32
+
33
+ class Range(Constraint):
34
+ """Range constraint - value must be within min/max bounds.
35
+
36
+ If step is provided, value must be at min + (n * step) for some integer n.
37
+ """
38
+
39
+ min: float | int
40
+ max: float | int
41
+ step: float | None = None
42
+
43
+ def __call__(self, value: float | int) -> float | int:
44
+ """Validate value is within range and matches step increment."""
45
+ if not isinstance(value, (int, float)):
46
+ msg = f"Must be numeric, got {type(value).__name__}"
47
+ raise TypeError(msg)
48
+
49
+ if not self.min <= value <= self.max:
50
+ msg = f"Must be between {self.min} and {self.max}, got {value}"
51
+ raise ValueError(msg)
52
+
53
+ if self.step is not None:
54
+ remainder = (value - self.min) % self.step
55
+ # Use epsilon for floating-point comparison tolerance
56
+ epsilon = 1e-9
57
+ if not math.isclose(remainder, 0, abs_tol=epsilon) and not math.isclose(
58
+ remainder, self.step, abs_tol=epsilon
59
+ ):
60
+ # Calculate nearest valid values for actionable error message
61
+ closest_below = self.min + (
62
+ int((value - self.min) / self.step) * self.step
63
+ )
64
+ closest_above = closest_below + self.step
65
+ msg = f"Value must match step {self.step}. Nearest valid: {closest_below} or {closest_above}, got {value}"
66
+ raise ValueError(msg)
67
+
68
+ return value
69
+
70
+
71
+ class Pattern(Constraint):
72
+ """Pattern constraint - value must match regex pattern."""
73
+
74
+ pattern: str
75
+
76
+ def __call__(self, value: str) -> str:
77
+ """Validate value matches pattern."""
78
+ if not isinstance(value, str):
79
+ msg = f"Must be string, got {type(value).__name__}"
80
+ raise TypeError(msg)
81
+
82
+ if not re.fullmatch(self.pattern, value):
83
+ msg = f"Must match pattern {self.pattern!r}, got {value!r}"
84
+ raise ValueError(msg)
85
+
86
+ return value
87
+
88
+
89
+ class Str(Constraint):
90
+ """String type constraint with optional length validation."""
91
+
92
+ max_length: int | None = None
93
+ min_length: int | None = None
94
+
95
+ def __call__(self, value: str) -> str:
96
+ """Validate value is a string."""
97
+ if not isinstance(value, str):
98
+ msg = f"Must be string, got {type(value).__name__}"
99
+ raise TypeError(msg)
100
+
101
+ if self.min_length is not None and len(value) < self.min_length:
102
+ msg = f"String too short (min {self.min_length}), got {len(value)}"
103
+ raise ValueError(msg)
104
+
105
+ if self.max_length is not None and len(value) > self.max_length:
106
+ msg = f"String too long (max {self.max_length}), got {len(value)}"
107
+ raise ValueError(msg)
108
+
109
+ return value
110
+
111
+
112
+ class Int(Constraint):
113
+ """Integer type constraint."""
114
+
115
+ def __call__(self, value: int) -> int:
116
+ """Validate value is an integer."""
117
+ # isinstance(True, int) is True, so exclude bools explicitly
118
+ if not isinstance(value, int) or isinstance(value, bool):
119
+ msg = f"Must be int, got {type(value).__name__}"
120
+ raise TypeError(msg)
121
+
122
+ return value
123
+
124
+
125
+ class Float(Constraint):
126
+ """Float type constraint (accepts int as well)."""
127
+
128
+ def __call__(self, value: float) -> float:
129
+ """Validate value is numeric."""
130
+ if not isinstance(value, (int, float)) or isinstance(value, bool):
131
+ msg = f"Must be float or int, got {type(value).__name__}"
132
+ raise TypeError(msg)
133
+
134
+ return float(value)
135
+
136
+
137
+ class Bool(Constraint):
138
+ """Boolean type constraint."""
139
+
140
+ def __call__(self, value: bool) -> bool:
141
+ """Validate value is a boolean."""
142
+ if not isinstance(value, bool):
143
+ msg = f"Must be bool, got {type(value).__name__}"
144
+ raise TypeError(msg)
145
+
146
+ return value
147
+
148
+
149
+ class Schema(Constraint):
150
+ """Schema constraint - value must be a Pydantic BaseModel subclass or list[BaseModel]."""
151
+
152
+ def __call__(self, value: type[BaseModel]) -> type[BaseModel]:
153
+ """Validate value is BaseModel or list[BaseModel]."""
154
+ # For list[T], validate inner type T
155
+ if get_origin(value) is list:
156
+ inner = get_args(value)[0]
157
+ if not (isinstance(inner, type) and issubclass(inner, BaseModel)):
158
+ msg = f"List type must be BaseModel, got {inner}"
159
+ raise TypeError(msg)
160
+ return value
161
+
162
+ # For plain type, validate directly
163
+ if not (isinstance(value, type) and issubclass(value, BaseModel)):
164
+ msg = f"Must be BaseModel, got {value}"
165
+ raise TypeError(msg)
166
+
167
+ return value
168
+
169
+
170
+ __all__ = [
171
+ "Bool",
172
+ "Choice",
173
+ "Constraint",
174
+ "Float",
175
+ "Int",
176
+ "Pattern",
177
+ "Range",
178
+ "Schema",
179
+ "Str",
180
+ ]
celeste/core.py ADDED
@@ -0,0 +1,52 @@
1
+ """Core enumerations for Celeste."""
2
+
3
+ from enum import StrEnum
4
+
5
+
6
+ class Provider(StrEnum):
7
+ """Supported AI providers."""
8
+
9
+ OPENAI = "openai"
10
+ ANTHROPIC = "anthropic"
11
+ GOOGLE = "google"
12
+ MISTRAL = "mistral"
13
+ COHERE = "cohere"
14
+ XAI = "xai"
15
+ HUGGINGFACE = "huggingface"
16
+ REPLICATE = "replicate"
17
+ STABILITYAI = "stabilityai"
18
+ LUMA = "luma"
19
+ TOPAZLABS = "topazlabs"
20
+ PERPLEXITY = "perplexity"
21
+
22
+
23
+ class Capability(StrEnum):
24
+ """Supported AI capabilities."""
25
+
26
+ # Text
27
+ TEXT_GENERATION = "text_generation"
28
+ EMBEDDINGS = "embeddings"
29
+
30
+ # Image
31
+ IMAGE_GENERATION = "image_generation"
32
+
33
+ # Video
34
+ VIDEO_INTELLIGENCE = "video_intelligence"
35
+ VIDEO_GENERATION = "video_generation"
36
+
37
+ # Audio
38
+ AUDIO_INTELLIGENCE = "audio_intelligence"
39
+
40
+ # Search
41
+ SEARCH = "search"
42
+
43
+
44
+ class Parameter(StrEnum):
45
+ """Universal parameters across most capabilities."""
46
+
47
+ TEMPERATURE = "temperature"
48
+ SEED = "seed"
49
+ MAX_TOKENS = "max_tokens"
50
+
51
+
52
+ __all__ = ["Capability", "Parameter", "Provider"]
celeste/credentials.py ADDED
@@ -0,0 +1,108 @@
1
+ """Provider API credentials management for Celeste."""
2
+
3
+ from dotenv import find_dotenv
4
+ from pydantic import Field, SecretStr
5
+ from pydantic_settings import BaseSettings
6
+
7
+ from celeste.core import Provider
8
+
9
+ # Provider to credential field mapping
10
+ PROVIDER_CREDENTIAL_MAP = {
11
+ Provider.OPENAI: "openai_api_key",
12
+ Provider.ANTHROPIC: "anthropic_api_key",
13
+ Provider.GOOGLE: "google_api_key",
14
+ Provider.MISTRAL: "mistral_api_key",
15
+ Provider.HUGGINGFACE: "huggingface_token",
16
+ Provider.STABILITYAI: "stabilityai_api_key",
17
+ Provider.REPLICATE: "replicate_api_token",
18
+ Provider.COHERE: "cohere_api_key",
19
+ Provider.XAI: "xai_api_key",
20
+ Provider.LUMA: "luma_api_key",
21
+ Provider.TOPAZLABS: "topazlabs_api_key",
22
+ Provider.PERPLEXITY: "perplexity_api_key",
23
+ }
24
+
25
+
26
+ class Credentials(BaseSettings):
27
+ """API credentials for all supported providers."""
28
+
29
+ openai_api_key: SecretStr | None = Field(None, alias="OPENAI_API_KEY")
30
+ anthropic_api_key: SecretStr | None = Field(None, alias="ANTHROPIC_API_KEY")
31
+ google_api_key: SecretStr | None = Field(None, alias="GOOGLE_API_KEY")
32
+ mistral_api_key: SecretStr | None = Field(None, alias="MISTRAL_API_KEY")
33
+ huggingface_token: SecretStr | None = Field(None, alias="HUGGINGFACE_TOKEN")
34
+ stabilityai_api_key: SecretStr | None = Field(None, alias="STABILITYAI_API_KEY")
35
+ replicate_api_token: SecretStr | None = Field(None, alias="REPLICATE_API_TOKEN")
36
+ cohere_api_key: SecretStr | None = Field(None, alias="COHERE_API_KEY")
37
+ xai_api_key: SecretStr | None = Field(None, alias="XAI_API_KEY")
38
+ luma_api_key: SecretStr | None = Field(None, alias="LUMA_API_KEY")
39
+ topazlabs_api_key: SecretStr | None = Field(None, alias="TOPAZLABS_API_KEY")
40
+ perplexity_api_key: SecretStr | None = Field(None, alias="PERPLEXITY_API_KEY")
41
+
42
+ model_config = {
43
+ "env_file": find_dotenv(),
44
+ "env_file_encoding": "utf-8",
45
+ "case_sensitive": False,
46
+ "extra": "ignore", # Ignore unknown env vars like context7_api_key
47
+ }
48
+
49
+ def get_credentials(
50
+ self, provider: Provider, override_key: SecretStr | None = None
51
+ ) -> SecretStr:
52
+ """Get credentials for a specific provider with optional override.
53
+
54
+ Args:
55
+ provider: The AI provider to get credentials for.
56
+ override_key: Optional SecretStr to use instead of environment variable.
57
+
58
+ Returns:
59
+ SecretStr containing the API key for the provider.
60
+
61
+ Raises:
62
+ ValueError: If provider requires credentials but none are configured,
63
+ or if provider is not supported (no credential mapping).
64
+ """
65
+ if override_key:
66
+ return override_key
67
+
68
+ if not self.has_credential(provider):
69
+ msg = f"Provider {provider} has no credentials configured."
70
+ raise ValueError(msg)
71
+
72
+ credential: SecretStr = getattr(self, PROVIDER_CREDENTIAL_MAP[provider])
73
+ return credential
74
+
75
+ def list_available_providers(self) -> list[Provider]:
76
+ """List all providers that have credentials configured.
77
+
78
+ Returns:
79
+ List of Provider enums that have credentials configured via environment variables.
80
+ """
81
+ return [
82
+ provider
83
+ for provider in PROVIDER_CREDENTIAL_MAP
84
+ if self.has_credential(provider)
85
+ ]
86
+
87
+ def has_credential(self, provider: Provider) -> bool:
88
+ """Check if a specific provider has credentials configured.
89
+
90
+ Args:
91
+ provider: The AI provider to check.
92
+
93
+ Returns:
94
+ True if provider has credentials configured, False if credentials not set.
95
+
96
+ Raises:
97
+ ValueError: If provider has no credential mapping.
98
+ """
99
+ credential_field = PROVIDER_CREDENTIAL_MAP.get(provider)
100
+ if not credential_field:
101
+ msg = f"Provider {provider} has no credential mapping"
102
+ raise ValueError(msg)
103
+ return getattr(self, credential_field, None) is not None
104
+
105
+
106
+ credentials = Credentials.model_validate({})
107
+
108
+ __all__ = ["Credentials", "credentials"]