aix 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aix/__init__.py CHANGED
@@ -1,68 +1,159 @@
1
- """
2
- Facade to key AI tools.
1
+ """AIX: Artificial Intelligence eXtensions
3
2
 
4
- Get a list of available chat functions (this will depend on the AI packages
5
- you have installed locally):
3
+ A clean, pythonic facade for common AI operations that abstracts away
4
+ provider-specific details and complexities.
6
5
 
7
- >>> from aix import chat_funcs
8
- >>> list(chat_funcs) # doctest: +SKIP
9
- ['gemini-1.5-flash',
10
- 'gpt-4',
11
- 'gpt-4-32k',
12
- 'gpt-4-turbo',
13
- 'gpt-3.5-turbo',
14
- 'o1-preview',
15
- 'o1-mini',
16
- 'gpt-4o',
17
- 'gpt-4o-mini']
6
+ Quick Start:
7
+ >>> from aix import chat, embeddings, prompt_func, models
18
8
 
19
- Choose a chat function and chat with it:
20
- >>> google_ai_chat = chat_funcs['gemini-1.5-flash'] # doctest: +SKIP
21
- >>> google_ai_chat("What is the meaning of life? Respond with a number.") # doctest: +SKIP
22
- '42'
23
- >>> openai_chat = chat_funcs['gpt-3.5-turbo'] # doctest: +SKIP
24
- >>> openai_chat("What is the meaning of life? Respond with a number.") # doctest: +SKIP
25
- '42'
9
+ # Simple chat
10
+ >>> response = chat("What is 2+2?") # doctest: +SKIP
11
+ 'The answer is 4.'
26
12
 
27
- """
13
+ # Create prompt-based functions
14
+ >>> translate = prompt_func("Translate to French: {text}")
15
+ >>> translate(text="Hello world") # doctest: +SKIP
16
+ 'Bonjour le monde'
28
17
 
29
- from aix.gen_ai import chat, chat_models, chat_funcs
30
- from aix import contexts
31
- from aix.contexts import (
32
- bytes_to_markdown, # Convert bytes to markdown (with plugin support for different types of files)
33
- bytes_store_to_markdown_store, # Convert a bytes store to a markdown store based on extensions
34
- aggregate_store, # Aggregate a store into a single value (say, a string)
35
- )
18
+ # Get embeddings
19
+ >>> vecs = list(embeddings(["hello", "world"])) # doctest: +SKIP
20
+ >>> len(vecs) # doctest: +SKIP
21
+ 2
22
+
23
+ # Discover models
24
+ >>> models.discover() # doctest: +SKIP
25
+ >>> list(models)[:5] # doctest: +SKIP
26
+ ['openai/gpt-4o', 'openai/gpt-4o-mini', ...]
36
27
 
37
- # TODO: Change this so that there's a load_pkg function that loads the packages dynamically
38
- # if and when use wants.
28
+ Main Features:
29
+ - chat(): Simple chat interface across providers
30
+ - embeddings(): Vector embeddings for text
31
+ - prompt_func(): Create functions from prompt templates
32
+ - models: Model discovery and selection
33
+ - generate_image(): Text-to-image generation
34
+ - text_to_speech(), transcribe(): Audio operations
35
+ - generate_video(): Text-to-video generation (provider-dependent)
36
+ - Batch operations for efficiency
37
+ - Clean, i2mint-style Mapping interfaces
38
+
39
+ Backends:
40
+ - Uses LiteLLM for provider interactions
41
+ - Supports OpenAI, Anthropic, Google, and 100+ models
42
+ - OpenRouter integration for multi-provider access
43
+
44
+ For detailed documentation, see: https://github.com/thorwhalen/aix
45
+ """
39
46
 
40
- # from aix.pd import *
41
- # from aix.np import *
42
- # from aix.sk import *
47
+ # Core interfaces (new clean API)
48
+ from aix.chat import chat, ask, chat_with_history, ChatSession
49
+ from aix.embeddings import (
50
+ embeddings,
51
+ embed,
52
+ cosine_similarity,
53
+ find_most_similar,
54
+ EmbeddingCache,
55
+ )
56
+ from aix.prompts import (
57
+ prompt_func,
58
+ prompt_to_text,
59
+ prompt_to_json,
60
+ PromptFuncs,
61
+ common_funcs,
62
+ constrained_answer,
63
+ )
64
+ from aix.models import (
65
+ models,
66
+ ModelStore,
67
+ discover_available_models,
68
+ get_model_info,
69
+ find_models,
70
+ )
71
+ from aix.batches import (
72
+ batch_chat,
73
+ batch_embeddings,
74
+ batch_process,
75
+ BatchProcessor,
76
+ )
77
+ from aix.image import (
78
+ generate_image,
79
+ generate_images,
80
+ edit_image,
81
+ create_variation,
82
+ GeneratedImage,
83
+ )
84
+ from aix.audio import (
85
+ text_to_speech,
86
+ transcribe,
87
+ transcribe_with_timestamps,
88
+ translate_audio,
89
+ GeneratedAudio,
90
+ TranscriptionResult,
91
+ )
92
+ from aix.video import (
93
+ generate_video,
94
+ animate_image as animate_image_to_video,
95
+ extend_video,
96
+ GeneratedVideo,
97
+ get_available_providers as get_video_providers,
98
+ )
43
99
 
44
- # from aix import pd
45
- # from aix import np
46
- # from aix import sk
100
+ # Legacy interfaces (for backward compatibility)
101
+ from aix.gen_ai import chat_models, chat_funcs
47
102
 
103
+ # Version info
104
+ __version__ = "0.1.0"
48
105
 
49
- #
50
- # from contextlib import suppress
51
- #
52
- # preferred_order = ['sk', 'np', 'pd']
53
- #
54
- # with suppress(ModuleNotFoundError):
55
- # from aix import sk
56
- #
57
- # with suppress(ModuleNotFoundError):
58
- # from aix import np
59
- #
60
- # with suppress(ModuleNotFoundError):
61
- # from aix import pd
62
- #
63
- # for _module_name in preferred_order[::-1]:
64
- # print(f"------ {_module_name}")
65
- # _module = __import__(f'aix.{_module_name}')
66
- # for _name in filter(lambda x: not x.startswith('__'), dir(_module)):
67
- # print(_name, _module)
68
- # locals()[_name] = getattr(_module, _name)
106
+ # Public API
107
+ __all__ = [
108
+ # Core chat
109
+ "chat",
110
+ "ask",
111
+ "chat_with_history",
112
+ "ChatSession",
113
+ # Embeddings
114
+ "embeddings",
115
+ "embed",
116
+ "cosine_similarity",
117
+ "find_most_similar",
118
+ "EmbeddingCache",
119
+ # Prompts
120
+ "prompt_func",
121
+ "prompt_to_text",
122
+ "prompt_to_json",
123
+ "PromptFuncs",
124
+ "common_funcs",
125
+ "constrained_answer",
126
+ # Models
127
+ "models",
128
+ "ModelStore",
129
+ "discover_available_models",
130
+ "get_model_info",
131
+ "find_models",
132
+ # Batches
133
+ "batch_chat",
134
+ "batch_embeddings",
135
+ "batch_process",
136
+ "BatchProcessor",
137
+ # Image
138
+ "generate_image",
139
+ "generate_images",
140
+ "edit_image",
141
+ "create_variation",
142
+ "GeneratedImage",
143
+ # Audio
144
+ "text_to_speech",
145
+ "transcribe",
146
+ "transcribe_with_timestamps",
147
+ "translate_audio",
148
+ "GeneratedAudio",
149
+ "TranscriptionResult",
150
+ # Video
151
+ "generate_video",
152
+ "animate_image_to_video",
153
+ "extend_video",
154
+ "GeneratedVideo",
155
+ "get_video_providers",
156
+ # Legacy
157
+ "chat_models",
158
+ "chat_funcs",
159
+ ]
@@ -0,0 +1,136 @@
1
+ """AI Model Management Module.
2
+
3
+ A unified interface for managing AI models across multiple providers.
4
+
5
+ Basic usage:
6
+ >>> from aix.ai_models import get_manager
7
+ >>> manager = get_manager()
8
+ >>> _ = manager.discover_from_source("openrouter", auto_register=True, verbose=False)
9
+ >>> models = manager.list_models(provider="openai")
10
+
11
+ Custom filtering:
12
+ >>> cheap_models = manager.list_models(
13
+ ... custom_filter=lambda m: m.cost_per_token.get("input", 0) < 0.001
14
+ ... )
15
+
16
+ Get connector-specific metadata:
17
+ >>> openai_meta = manager.get_connector_metadata("openai/gpt-4", "openai")
18
+ >>> # Use with: openai.ChatCompletion.create(**openai_meta, messages=[...])
19
+
20
+ """
21
+
22
+ # Core types
23
+ from aix.ai_models.base import (
24
+ Model,
25
+ ModelRegistry,
26
+ ModelSource,
27
+ Connector,
28
+ ConnectorRegistry,
29
+ )
30
+
31
+ # Concrete sources and connectors
32
+ from aix.ai_models.sources import (
33
+ Connector,
34
+ OpenRouterSource,
35
+ OllamaSource,
36
+ ProviderAPISource,
37
+ OpenAIConnector,
38
+ OpenRouterConnector,
39
+ LangChainConnector,
40
+ OllamaConnector,
41
+ DSPyConnector,
42
+ )
43
+
44
+ # Main facade
45
+ from aix.ai_models.manager import (
46
+ ModelManager,
47
+ get_manager,
48
+ )
49
+
50
+ # Version info
51
+ __version__ = "0.1.0"
52
+
53
+ # Public API
54
+ __all__ = [
55
+ # Core types
56
+ "Model",
57
+ "ModelRegistry",
58
+ "ModelSource",
59
+ "Connector",
60
+ "ConnectorRegistry",
61
+ # Sources
62
+ "OpenRouterSource",
63
+ "OllamaSource",
64
+ "ProviderAPISource",
65
+ # Connectors
66
+ "OpenAIConnector",
67
+ "OpenRouterConnector",
68
+ "LangChainConnector",
69
+ "OllamaConnector",
70
+ "DSPyConnector",
71
+ # Main API
72
+ "ModelManager",
73
+ "get_manager",
74
+ ]
75
+
76
+
77
+ # Convenience functions for common operations
78
+
79
+
80
+ def list_available_models(
81
+ *,
82
+ provider: str | None = None,
83
+ is_local: bool | None = None,
84
+ storage_path: str | None = None,
85
+ ) -> list[Model]:
86
+ """Quick function to list available models without explicit manager.
87
+
88
+ >>> models = list_available_models(provider="openai")
89
+ >>> len(models) >= 0
90
+ True
91
+ """
92
+ manager = get_manager(storage_path=storage_path)
93
+ return manager.list_models(provider=provider, is_local=is_local)
94
+
95
+
96
+ def discover_models(
97
+ source_name: str = "openrouter",
98
+ *,
99
+ storage_path: str | None = None,
100
+ auto_register: bool = True,
101
+ verbose: bool = True,
102
+ ) -> list[Model]:
103
+ """Quick function to discover models from a source.
104
+
105
+ >>> models = discover_models("openrouter", auto_register=False, verbose=False)
106
+ >>> len(models) > 0
107
+ True
108
+ """
109
+ manager = get_manager(storage_path=storage_path)
110
+ return manager.discover_from_source(
111
+ source_name, auto_register=auto_register, verbose=verbose
112
+ )
113
+
114
+
115
+ def get_model_metadata(
116
+ model_id: str, connector_name: str, *, storage_path: str | None = None
117
+ ) -> dict:
118
+ """Quick function to get formatted metadata for a model.
119
+
120
+ >>> import tempfile, os, json
121
+ >>> if 'OPENROUTER_API_KEY' in os.environ:
122
+ ... from aix.ai_models.manager import get_manager
123
+ ... with tempfile.NamedTemporaryFile(mode='w+', suffix=".json", delete=False) as temp:
124
+ ... storage_path = temp.name
125
+ ... json.dump({'models': []}, temp)
126
+ ... try:
127
+ ... manager = get_manager(storage_path=storage_path)
128
+ ... _ = manager.discover_from_source("openrouter", auto_register=True, verbose=False)
129
+ ... assert 'openai/gpt-3.5-turbo' in manager.models
130
+ ... metadata = get_model_metadata("openai/gpt-3.5-turbo", "openrouter", storage_path=storage_path)
131
+ ... assert 'model' in metadata
132
+ ... finally:
133
+ ... os.remove(storage_path)
134
+ """
135
+ manager = get_manager(storage_path=storage_path)
136
+ return manager.get_connector_metadata(model_id, connector_name)
aix/ai_models/base.py ADDED
@@ -0,0 +1,270 @@
1
+ """Core types for AI model management.
2
+
3
+ This module provides a unified interface for managing, discovering, and
4
+ connecting to AI models across multiple providers and deployment methods.
5
+ """
6
+
7
+ from dataclasses import dataclass, field, asdict, fields
8
+ from typing import Any
9
+ from collections.abc import Mapping, MutableMapping, Iterator, Callable, Iterable
10
+ from abc import ABC, abstractmethod
11
+ import json
12
+ from pathlib import Path
13
+
14
+
15
+ @dataclass
16
+ class Model:
17
+ """Represents an AI model with its metadata.
18
+
19
+ >>> model = Model(
20
+ ... id="gpt-4",
21
+ ... provider="openai",
22
+ ... context_size=8192,
23
+ ... is_local=False
24
+ ... )
25
+ >>> model.id
26
+ 'gpt-4'
27
+ """
28
+
29
+ id: str
30
+ provider: str
31
+ context_size: int | None = None
32
+ is_local: bool = False
33
+ capabilities: dict[str, Any] = field(default_factory=dict)
34
+ cost_per_token: dict[str, float] = field(default_factory=dict)
35
+ tags: set[str] = field(default_factory=set)
36
+ connector_metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
37
+ custom_metadata: dict[str, Any] = field(default_factory=dict)
38
+
39
+ def to_dict(self) -> dict[str, Any]:
40
+ """Convert model to dictionary representation."""
41
+ return asdict(self)
42
+
43
+ def matches_filter(self, **criteria) -> bool:
44
+ """Check if model matches given criteria.
45
+
46
+ >>> model = Model(id="gpt-4", provider="openai", is_local=False)
47
+ >>> model.matches_filter(provider="openai")
48
+ True
49
+ >>> model.matches_filter(is_local=True)
50
+ False
51
+ """
52
+ for key, value in criteria.items():
53
+ if not hasattr(self, key):
54
+ return False
55
+ if getattr(self, key) != value:
56
+ return False
57
+ return True
58
+
59
+ def __getitem__(self, key: str) -> Any:
60
+ """Get field value by name."""
61
+ try:
62
+ return getattr(self, key)
63
+ except AttributeError:
64
+ raise KeyError(key)
65
+
66
+ def __iter__(self) -> Iterator[str]:
67
+ """Iterate over field names."""
68
+ return iter(field.name for field in fields(self))
69
+
70
+ def __len__(self) -> int:
71
+ """Return number of fields."""
72
+ return len(fields(self))
73
+
74
+
75
+ class ModelRegistry(MutableMapping[str, Model]):
76
+ """Registry for managing AI models using Mapping interface.
77
+
78
+ >>> registry = ModelRegistry()
79
+ >>> registry["gpt-4"] = Model(id="gpt-4", provider="openai")
80
+ >>> "gpt-4" in registry
81
+ True
82
+ >>> len(registry)
83
+ 1
84
+ """
85
+
86
+ def __init__(self, *, storage_path: Path | None = None):
87
+ """Initialize registry with optional persistent storage."""
88
+ self._models: dict[str, Model] = {}
89
+ self._storage_path = storage_path
90
+ if storage_path and storage_path.exists():
91
+ self._load()
92
+
93
+ def __setitem__(self, model_id: str, model: Model) -> None:
94
+ """Add or update a model in the registry."""
95
+ self._models[model_id] = model
96
+ if self._storage_path:
97
+ self._save()
98
+
99
+ def __getitem__(
100
+ self, key: str | list[str] | Callable[[Model], bool]
101
+ ) -> Model | list[Model]:
102
+ """Get model(s) by ID, list of IDs, or filter function.
103
+
104
+ Supports:
105
+ - Single ID: registry["gpt-4"]
106
+ - Multiple IDs: registry[["gpt-4", "claude-3"]]
107
+ - Filter function: registry[lambda m: m.is_local]
108
+ """
109
+ if isinstance(key, str):
110
+ return self._models[key]
111
+ elif isinstance(key, list):
112
+ return [self._models[k] for k in key if k in self._models]
113
+ elif callable(key):
114
+ return [m for m in self._models.values() if key(m)]
115
+ else:
116
+ raise TypeError(f"Unsupported key type: {type(key)}")
117
+
118
+ def __delitem__(self, model_id: str) -> None:
119
+ """Remove a model from the registry."""
120
+ del self._models[model_id]
121
+ if self._storage_path:
122
+ self._save()
123
+
124
+ def __iter__(self) -> Iterator[str]:
125
+ """Iterate over model IDs."""
126
+ yield from self._models.keys()
127
+
128
+ def __len__(self) -> int:
129
+ """Return number of models in registry."""
130
+ return len(self._models)
131
+
132
+ def filter(
133
+ self,
134
+ *,
135
+ provider: str | None = None,
136
+ is_local: bool | None = None,
137
+ min_context_size: int | None = None,
138
+ max_context_size: int | None = None,
139
+ has_capabilities: Iterable[str] | None = None,
140
+ tags: Iterable[str] | None = None,
141
+ custom_filter: Callable[[Model], bool] | None = None,
142
+ ) -> list[Model]:
143
+ """Filter models by multiple criteria.
144
+
145
+ >>> registry = ModelRegistry()
146
+ >>> registry["gpt-4"] = Model(id="gpt-4", provider="openai", context_size=8192)
147
+ >>> registry["llama2"] = Model(id="llama2", provider="ollama", is_local=True)
148
+ >>> local_models = registry.filter(is_local=True)
149
+ >>> len(local_models)
150
+ 1
151
+ """
152
+
153
+ def _matches(model: Model) -> bool:
154
+ if provider and model.provider != provider:
155
+ return False
156
+ if is_local is not None and model.is_local != is_local:
157
+ return False
158
+ if min_context_size and (
159
+ not model.context_size or model.context_size < min_context_size
160
+ ):
161
+ return False
162
+ if max_context_size and (
163
+ not model.context_size or model.context_size > max_context_size
164
+ ):
165
+ return False
166
+ if has_capabilities:
167
+ for cap in has_capabilities:
168
+ if not model.capabilities.get(cap):
169
+ return False
170
+ if tags and not model.tags.issuperset(tags):
171
+ return False
172
+ if custom_filter and not custom_filter(model):
173
+ return False
174
+ return True
175
+
176
+ return [m for m in self._models.values() if _matches(m)]
177
+
178
+ def _load(self) -> None:
179
+ """Load models from persistent storage."""
180
+ if not self._storage_path:
181
+ return
182
+
183
+ with open(self._storage_path) as f:
184
+ data = json.load(f)
185
+ for model_data in data["models"]:
186
+ # Reconstruct set for tags
187
+ model_data["tags"] = set(model_data.get("tags", []))
188
+ model = Model(**model_data)
189
+ self._models[model.id] = model
190
+
191
+ def _save(self) -> None:
192
+ """Save models to persistent storage."""
193
+ if not self._storage_path:
194
+ return
195
+
196
+ models_data = []
197
+ for model in self._models.values():
198
+ model_dict = model.to_dict()
199
+ # Convert set to list for JSON serialization
200
+ model_dict["tags"] = list(model_dict["tags"])
201
+ models_data.append(model_dict)
202
+
203
+ self._storage_path.parent.mkdir(parents=True, exist_ok=True)
204
+ with open(self._storage_path, "w") as f:
205
+ json.dump({"models": models_data}, f, indent=2)
206
+
207
+
208
+ class ModelSource(ABC):
209
+ """Abstract base for model discovery sources."""
210
+
211
+ @abstractmethod
212
+ def discover_models(self) -> Iterable[Model]:
213
+ """Discover available models from this source.
214
+
215
+ Yields Model instances for each discovered model.
216
+ """
217
+ pass
218
+
219
+
220
+ class Connector(ABC):
221
+ """Abstract base for model connectors/clients."""
222
+
223
+ @abstractmethod
224
+ def format_metadata(self, model: Model) -> dict[str, Any]:
225
+ """Format model metadata for this connector.
226
+
227
+ Returns a dict that can be used to instantiate/connect via this connector.
228
+ """
229
+ pass
230
+
231
+ @property
232
+ @abstractmethod
233
+ def name(self) -> str:
234
+ """Unique identifier for this connector."""
235
+ pass
236
+
237
+
238
+ class ConnectorRegistry(MutableMapping[str, Connector]):
239
+ """Registry for managing model connectors.
240
+
241
+ >>> registry = ConnectorRegistry()
242
+ >>> class MyConnector(Connector):
243
+ ... @property
244
+ ... def name(self) -> str:
245
+ ... return "my_connector"
246
+ ... def format_metadata(self, model: Model) -> dict[str, Any]:
247
+ ... return {"model": model.id}
248
+ >>> connector = MyConnector()
249
+ >>> registry[connector.name] = connector
250
+ >>> "my_connector" in registry
251
+ True
252
+ """
253
+
254
+ def __init__(self):
255
+ self._connectors: dict[str, Connector] = {}
256
+
257
+ def __setitem__(self, name: str, connector: Connector) -> None:
258
+ self._connectors[name] = connector
259
+
260
+ def __getitem__(self, name: str) -> Connector:
261
+ return self._connectors[name]
262
+
263
+ def __delitem__(self, name: str) -> None:
264
+ del self._connectors[name]
265
+
266
+ def __iter__(self) -> Iterator[str]:
267
+ yield from self._connectors.keys()
268
+
269
+ def __len__(self) -> int:
270
+ return len(self._connectors)