aix 0.0.23__py3-none-any.whl → 0.0.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aix/__init__.py +149 -58
- aix/ai_models/__init__.py +136 -0
- aix/ai_models/base.py +270 -0
- aix/ai_models/examples.py +326 -0
- aix/ai_models/manager.py +282 -0
- aix/ai_models/sources.py +310 -0
- aix/audio.py +499 -0
- aix/batches.py +443 -0
- aix/chat.py +321 -0
- aix/contexts.py +1 -1101
- aix/embeddings.py +325 -0
- aix/image.py +493 -0
- aix/models.py +435 -0
- aix/prompts.py +710 -0
- aix/util.py +17 -2
- aix/video.py +455 -0
- aix-0.0.25.dist-info/METADATA +579 -0
- aix-0.0.25.dist-info/RECORD +26 -0
- {aix-0.0.23.dist-info → aix-0.0.25.dist-info}/WHEEL +1 -1
- aix/np.py +0 -9
- aix/pd.py +0 -9
- aix/sk.py +0 -17
- aix-0.0.23.dist-info/METADATA +0 -220
- aix-0.0.23.dist-info/RECORD +0 -16
- {aix-0.0.23.dist-info → aix-0.0.25.dist-info/licenses}/LICENSE +0 -0
- {aix-0.0.23.dist-info → aix-0.0.25.dist-info}/top_level.txt +0 -0
aix/__init__.py
CHANGED
|
@@ -1,68 +1,159 @@
|
|
|
1
|
-
"""
|
|
2
|
-
Facade to key AI tools.
|
|
1
|
+
"""AIX: Artificial Intelligence eXtensions
|
|
3
2
|
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
A clean, pythonic facade for common AI operations that abstracts away
|
|
4
|
+
provider-specific details and complexities.
|
|
6
5
|
|
|
7
|
-
|
|
8
|
-
>>>
|
|
9
|
-
['gemini-1.5-flash',
|
|
10
|
-
'gpt-4',
|
|
11
|
-
'gpt-4-32k',
|
|
12
|
-
'gpt-4-turbo',
|
|
13
|
-
'gpt-3.5-turbo',
|
|
14
|
-
'o1-preview',
|
|
15
|
-
'o1-mini',
|
|
16
|
-
'gpt-4o',
|
|
17
|
-
'gpt-4o-mini']
|
|
6
|
+
Quick Start:
|
|
7
|
+
>>> from aix import chat, embeddings, prompt_func, models
|
|
18
8
|
|
|
19
|
-
|
|
20
|
-
>>>
|
|
21
|
-
|
|
22
|
-
'42'
|
|
23
|
-
>>> openai_chat = chat_funcs['gpt-3.5-turbo'] # doctest: +SKIP
|
|
24
|
-
>>> openai_chat("What is the meaning of life? Respond with a number.") # doctest: +SKIP
|
|
25
|
-
'42'
|
|
9
|
+
# Simple chat
|
|
10
|
+
>>> response = chat("What is 2+2?") # doctest: +SKIP
|
|
11
|
+
'The answer is 4.'
|
|
26
12
|
|
|
27
|
-
|
|
13
|
+
# Create prompt-based functions
|
|
14
|
+
>>> translate = prompt_func("Translate to French: {text}")
|
|
15
|
+
>>> translate(text="Hello world") # doctest: +SKIP
|
|
16
|
+
'Bonjour le monde'
|
|
28
17
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
)
|
|
18
|
+
# Get embeddings
|
|
19
|
+
>>> vecs = list(embeddings(["hello", "world"])) # doctest: +SKIP
|
|
20
|
+
>>> len(vecs) # doctest: +SKIP
|
|
21
|
+
2
|
|
22
|
+
|
|
23
|
+
# Discover models
|
|
24
|
+
>>> models.discover() # doctest: +SKIP
|
|
25
|
+
>>> list(models)[:5] # doctest: +SKIP
|
|
26
|
+
['openai/gpt-4o', 'openai/gpt-4o-mini', ...]
|
|
36
27
|
|
|
37
|
-
|
|
38
|
-
|
|
28
|
+
Main Features:
|
|
29
|
+
- chat(): Simple chat interface across providers
|
|
30
|
+
- embeddings(): Vector embeddings for text
|
|
31
|
+
- prompt_func(): Create functions from prompt templates
|
|
32
|
+
- models: Model discovery and selection
|
|
33
|
+
- generate_image(): Text-to-image generation
|
|
34
|
+
- text_to_speech(), transcribe(): Audio operations
|
|
35
|
+
- generate_video(): Text-to-video generation (provider-dependent)
|
|
36
|
+
- Batch operations for efficiency
|
|
37
|
+
- Clean, i2mint-style Mapping interfaces
|
|
38
|
+
|
|
39
|
+
Backends:
|
|
40
|
+
- Uses LiteLLM for provider interactions
|
|
41
|
+
- Supports OpenAI, Anthropic, Google, and 100+ models
|
|
42
|
+
- OpenRouter integration for multi-provider access
|
|
43
|
+
|
|
44
|
+
For detailed documentation, see: https://github.com/thorwhalen/aix
|
|
45
|
+
"""
|
|
39
46
|
|
|
40
|
-
#
|
|
41
|
-
|
|
42
|
-
|
|
47
|
+
# Core interfaces (new clean API)
|
|
48
|
+
from aix.chat import chat, ask, chat_with_history, ChatSession
|
|
49
|
+
from aix.embeddings import (
|
|
50
|
+
embeddings,
|
|
51
|
+
embed,
|
|
52
|
+
cosine_similarity,
|
|
53
|
+
find_most_similar,
|
|
54
|
+
EmbeddingCache,
|
|
55
|
+
)
|
|
56
|
+
from aix.prompts import (
|
|
57
|
+
prompt_func,
|
|
58
|
+
prompt_to_text,
|
|
59
|
+
prompt_to_json,
|
|
60
|
+
PromptFuncs,
|
|
61
|
+
common_funcs,
|
|
62
|
+
constrained_answer,
|
|
63
|
+
)
|
|
64
|
+
from aix.models import (
|
|
65
|
+
models,
|
|
66
|
+
ModelStore,
|
|
67
|
+
discover_available_models,
|
|
68
|
+
get_model_info,
|
|
69
|
+
find_models,
|
|
70
|
+
)
|
|
71
|
+
from aix.batches import (
|
|
72
|
+
batch_chat,
|
|
73
|
+
batch_embeddings,
|
|
74
|
+
batch_process,
|
|
75
|
+
BatchProcessor,
|
|
76
|
+
)
|
|
77
|
+
from aix.image import (
|
|
78
|
+
generate_image,
|
|
79
|
+
generate_images,
|
|
80
|
+
edit_image,
|
|
81
|
+
create_variation,
|
|
82
|
+
GeneratedImage,
|
|
83
|
+
)
|
|
84
|
+
from aix.audio import (
|
|
85
|
+
text_to_speech,
|
|
86
|
+
transcribe,
|
|
87
|
+
transcribe_with_timestamps,
|
|
88
|
+
translate_audio,
|
|
89
|
+
GeneratedAudio,
|
|
90
|
+
TranscriptionResult,
|
|
91
|
+
)
|
|
92
|
+
from aix.video import (
|
|
93
|
+
generate_video,
|
|
94
|
+
animate_image as animate_image_to_video,
|
|
95
|
+
extend_video,
|
|
96
|
+
GeneratedVideo,
|
|
97
|
+
get_available_providers as get_video_providers,
|
|
98
|
+
)
|
|
43
99
|
|
|
44
|
-
#
|
|
45
|
-
|
|
46
|
-
# from aix import sk
|
|
100
|
+
# Legacy interfaces (for backward compatibility)
|
|
101
|
+
from aix.gen_ai import chat_models, chat_funcs
|
|
47
102
|
|
|
103
|
+
# Version info
|
|
104
|
+
__version__ = "0.1.0"
|
|
48
105
|
|
|
49
|
-
#
|
|
50
|
-
|
|
51
|
-
#
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
#
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
#
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
106
|
+
# Public API
|
|
107
|
+
__all__ = [
|
|
108
|
+
# Core chat
|
|
109
|
+
"chat",
|
|
110
|
+
"ask",
|
|
111
|
+
"chat_with_history",
|
|
112
|
+
"ChatSession",
|
|
113
|
+
# Embeddings
|
|
114
|
+
"embeddings",
|
|
115
|
+
"embed",
|
|
116
|
+
"cosine_similarity",
|
|
117
|
+
"find_most_similar",
|
|
118
|
+
"EmbeddingCache",
|
|
119
|
+
# Prompts
|
|
120
|
+
"prompt_func",
|
|
121
|
+
"prompt_to_text",
|
|
122
|
+
"prompt_to_json",
|
|
123
|
+
"PromptFuncs",
|
|
124
|
+
"common_funcs",
|
|
125
|
+
"constrained_answer",
|
|
126
|
+
# Models
|
|
127
|
+
"models",
|
|
128
|
+
"ModelStore",
|
|
129
|
+
"discover_available_models",
|
|
130
|
+
"get_model_info",
|
|
131
|
+
"find_models",
|
|
132
|
+
# Batches
|
|
133
|
+
"batch_chat",
|
|
134
|
+
"batch_embeddings",
|
|
135
|
+
"batch_process",
|
|
136
|
+
"BatchProcessor",
|
|
137
|
+
# Image
|
|
138
|
+
"generate_image",
|
|
139
|
+
"generate_images",
|
|
140
|
+
"edit_image",
|
|
141
|
+
"create_variation",
|
|
142
|
+
"GeneratedImage",
|
|
143
|
+
# Audio
|
|
144
|
+
"text_to_speech",
|
|
145
|
+
"transcribe",
|
|
146
|
+
"transcribe_with_timestamps",
|
|
147
|
+
"translate_audio",
|
|
148
|
+
"GeneratedAudio",
|
|
149
|
+
"TranscriptionResult",
|
|
150
|
+
# Video
|
|
151
|
+
"generate_video",
|
|
152
|
+
"animate_image_to_video",
|
|
153
|
+
"extend_video",
|
|
154
|
+
"GeneratedVideo",
|
|
155
|
+
"get_video_providers",
|
|
156
|
+
# Legacy
|
|
157
|
+
"chat_models",
|
|
158
|
+
"chat_funcs",
|
|
159
|
+
]
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""AI Model Management Module.
|
|
2
|
+
|
|
3
|
+
A unified interface for managing AI models across multiple providers.
|
|
4
|
+
|
|
5
|
+
Basic usage:
|
|
6
|
+
>>> from aix.ai_models import get_manager
|
|
7
|
+
>>> manager = get_manager()
|
|
8
|
+
>>> _ = manager.discover_from_source("openrouter", auto_register=True, verbose=False)
|
|
9
|
+
>>> models = manager.list_models(provider="openai")
|
|
10
|
+
|
|
11
|
+
Custom filtering:
|
|
12
|
+
>>> cheap_models = manager.list_models(
|
|
13
|
+
... custom_filter=lambda m: m.cost_per_token.get("input", 0) < 0.001
|
|
14
|
+
... )
|
|
15
|
+
|
|
16
|
+
Get connector-specific metadata:
|
|
17
|
+
>>> openai_meta = manager.get_connector_metadata("openai/gpt-4", "openai")
|
|
18
|
+
>>> # Use with: openai.ChatCompletion.create(**openai_meta, messages=[...])
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
# Core types
|
|
23
|
+
from aix.ai_models.base import (
|
|
24
|
+
Model,
|
|
25
|
+
ModelRegistry,
|
|
26
|
+
ModelSource,
|
|
27
|
+
Connector,
|
|
28
|
+
ConnectorRegistry,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Concrete sources and connectors
|
|
32
|
+
from aix.ai_models.sources import (
|
|
33
|
+
Connector,
|
|
34
|
+
OpenRouterSource,
|
|
35
|
+
OllamaSource,
|
|
36
|
+
ProviderAPISource,
|
|
37
|
+
OpenAIConnector,
|
|
38
|
+
OpenRouterConnector,
|
|
39
|
+
LangChainConnector,
|
|
40
|
+
OllamaConnector,
|
|
41
|
+
DSPyConnector,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Main facade
|
|
45
|
+
from aix.ai_models.manager import (
|
|
46
|
+
ModelManager,
|
|
47
|
+
get_manager,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Version info
|
|
51
|
+
__version__ = "0.1.0"
|
|
52
|
+
|
|
53
|
+
# Public API
|
|
54
|
+
__all__ = [
|
|
55
|
+
# Core types
|
|
56
|
+
"Model",
|
|
57
|
+
"ModelRegistry",
|
|
58
|
+
"ModelSource",
|
|
59
|
+
"Connector",
|
|
60
|
+
"ConnectorRegistry",
|
|
61
|
+
# Sources
|
|
62
|
+
"OpenRouterSource",
|
|
63
|
+
"OllamaSource",
|
|
64
|
+
"ProviderAPISource",
|
|
65
|
+
# Connectors
|
|
66
|
+
"OpenAIConnector",
|
|
67
|
+
"OpenRouterConnector",
|
|
68
|
+
"LangChainConnector",
|
|
69
|
+
"OllamaConnector",
|
|
70
|
+
"DSPyConnector",
|
|
71
|
+
# Main API
|
|
72
|
+
"ModelManager",
|
|
73
|
+
"get_manager",
|
|
74
|
+
]
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# Convenience functions for common operations
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def list_available_models(
|
|
81
|
+
*,
|
|
82
|
+
provider: str | None = None,
|
|
83
|
+
is_local: bool | None = None,
|
|
84
|
+
storage_path: str | None = None,
|
|
85
|
+
) -> list[Model]:
|
|
86
|
+
"""Quick function to list available models without explicit manager.
|
|
87
|
+
|
|
88
|
+
>>> models = list_available_models(provider="openai")
|
|
89
|
+
>>> len(models) >= 0
|
|
90
|
+
True
|
|
91
|
+
"""
|
|
92
|
+
manager = get_manager(storage_path=storage_path)
|
|
93
|
+
return manager.list_models(provider=provider, is_local=is_local)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def discover_models(
|
|
97
|
+
source_name: str = "openrouter",
|
|
98
|
+
*,
|
|
99
|
+
storage_path: str | None = None,
|
|
100
|
+
auto_register: bool = True,
|
|
101
|
+
verbose: bool = True,
|
|
102
|
+
) -> list[Model]:
|
|
103
|
+
"""Quick function to discover models from a source.
|
|
104
|
+
|
|
105
|
+
>>> models = discover_models("openrouter", auto_register=False, verbose=False)
|
|
106
|
+
>>> len(models) > 0
|
|
107
|
+
True
|
|
108
|
+
"""
|
|
109
|
+
manager = get_manager(storage_path=storage_path)
|
|
110
|
+
return manager.discover_from_source(
|
|
111
|
+
source_name, auto_register=auto_register, verbose=verbose
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_model_metadata(
|
|
116
|
+
model_id: str, connector_name: str, *, storage_path: str | None = None
|
|
117
|
+
) -> dict:
|
|
118
|
+
"""Quick function to get formatted metadata for a model.
|
|
119
|
+
|
|
120
|
+
>>> import tempfile, os, json
|
|
121
|
+
>>> if 'OPENROUTER_API_KEY' in os.environ:
|
|
122
|
+
... from aix.ai_models.manager import get_manager
|
|
123
|
+
... with tempfile.NamedTemporaryFile(mode='w+', suffix=".json", delete=False) as temp:
|
|
124
|
+
... storage_path = temp.name
|
|
125
|
+
... json.dump({'models': []}, temp)
|
|
126
|
+
... try:
|
|
127
|
+
... manager = get_manager(storage_path=storage_path)
|
|
128
|
+
... _ = manager.discover_from_source("openrouter", auto_register=True, verbose=False)
|
|
129
|
+
... assert 'openai/gpt-3.5-turbo' in manager.models
|
|
130
|
+
... metadata = get_model_metadata("openai/gpt-3.5-turbo", "openrouter", storage_path=storage_path)
|
|
131
|
+
... assert 'model' in metadata
|
|
132
|
+
... finally:
|
|
133
|
+
... os.remove(storage_path)
|
|
134
|
+
"""
|
|
135
|
+
manager = get_manager(storage_path=storage_path)
|
|
136
|
+
return manager.get_connector_metadata(model_id, connector_name)
|
aix/ai_models/base.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""Core types for AI model management.
|
|
2
|
+
|
|
3
|
+
This module provides a unified interface for managing, discovering, and
|
|
4
|
+
connecting to AI models across multiple providers and deployment methods.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from dataclasses import dataclass, field, asdict, fields
|
|
8
|
+
from typing import Any
|
|
9
|
+
from collections.abc import Mapping, MutableMapping, Iterator, Callable, Iterable
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
import json
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class Model:
|
|
17
|
+
"""Represents an AI model with its metadata.
|
|
18
|
+
|
|
19
|
+
>>> model = Model(
|
|
20
|
+
... id="gpt-4",
|
|
21
|
+
... provider="openai",
|
|
22
|
+
... context_size=8192,
|
|
23
|
+
... is_local=False
|
|
24
|
+
... )
|
|
25
|
+
>>> model.id
|
|
26
|
+
'gpt-4'
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
id: str
|
|
30
|
+
provider: str
|
|
31
|
+
context_size: int | None = None
|
|
32
|
+
is_local: bool = False
|
|
33
|
+
capabilities: dict[str, Any] = field(default_factory=dict)
|
|
34
|
+
cost_per_token: dict[str, float] = field(default_factory=dict)
|
|
35
|
+
tags: set[str] = field(default_factory=set)
|
|
36
|
+
connector_metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
37
|
+
custom_metadata: dict[str, Any] = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
def to_dict(self) -> dict[str, Any]:
|
|
40
|
+
"""Convert model to dictionary representation."""
|
|
41
|
+
return asdict(self)
|
|
42
|
+
|
|
43
|
+
def matches_filter(self, **criteria) -> bool:
|
|
44
|
+
"""Check if model matches given criteria.
|
|
45
|
+
|
|
46
|
+
>>> model = Model(id="gpt-4", provider="openai", is_local=False)
|
|
47
|
+
>>> model.matches_filter(provider="openai")
|
|
48
|
+
True
|
|
49
|
+
>>> model.matches_filter(is_local=True)
|
|
50
|
+
False
|
|
51
|
+
"""
|
|
52
|
+
for key, value in criteria.items():
|
|
53
|
+
if not hasattr(self, key):
|
|
54
|
+
return False
|
|
55
|
+
if getattr(self, key) != value:
|
|
56
|
+
return False
|
|
57
|
+
return True
|
|
58
|
+
|
|
59
|
+
def __getitem__(self, key: str) -> Any:
|
|
60
|
+
"""Get field value by name."""
|
|
61
|
+
try:
|
|
62
|
+
return getattr(self, key)
|
|
63
|
+
except AttributeError:
|
|
64
|
+
raise KeyError(key)
|
|
65
|
+
|
|
66
|
+
def __iter__(self) -> Iterator[str]:
|
|
67
|
+
"""Iterate over field names."""
|
|
68
|
+
return iter(field.name for field in fields(self))
|
|
69
|
+
|
|
70
|
+
def __len__(self) -> int:
|
|
71
|
+
"""Return number of fields."""
|
|
72
|
+
return len(fields(self))
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class ModelRegistry(MutableMapping[str, Model]):
|
|
76
|
+
"""Registry for managing AI models using Mapping interface.
|
|
77
|
+
|
|
78
|
+
>>> registry = ModelRegistry()
|
|
79
|
+
>>> registry["gpt-4"] = Model(id="gpt-4", provider="openai")
|
|
80
|
+
>>> "gpt-4" in registry
|
|
81
|
+
True
|
|
82
|
+
>>> len(registry)
|
|
83
|
+
1
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(self, *, storage_path: Path | None = None):
|
|
87
|
+
"""Initialize registry with optional persistent storage."""
|
|
88
|
+
self._models: dict[str, Model] = {}
|
|
89
|
+
self._storage_path = storage_path
|
|
90
|
+
if storage_path and storage_path.exists():
|
|
91
|
+
self._load()
|
|
92
|
+
|
|
93
|
+
def __setitem__(self, model_id: str, model: Model) -> None:
|
|
94
|
+
"""Add or update a model in the registry."""
|
|
95
|
+
self._models[model_id] = model
|
|
96
|
+
if self._storage_path:
|
|
97
|
+
self._save()
|
|
98
|
+
|
|
99
|
+
def __getitem__(
|
|
100
|
+
self, key: str | list[str] | Callable[[Model], bool]
|
|
101
|
+
) -> Model | list[Model]:
|
|
102
|
+
"""Get model(s) by ID, list of IDs, or filter function.
|
|
103
|
+
|
|
104
|
+
Supports:
|
|
105
|
+
- Single ID: registry["gpt-4"]
|
|
106
|
+
- Multiple IDs: registry[["gpt-4", "claude-3"]]
|
|
107
|
+
- Filter function: registry[lambda m: m.is_local]
|
|
108
|
+
"""
|
|
109
|
+
if isinstance(key, str):
|
|
110
|
+
return self._models[key]
|
|
111
|
+
elif isinstance(key, list):
|
|
112
|
+
return [self._models[k] for k in key if k in self._models]
|
|
113
|
+
elif callable(key):
|
|
114
|
+
return [m for m in self._models.values() if key(m)]
|
|
115
|
+
else:
|
|
116
|
+
raise TypeError(f"Unsupported key type: {type(key)}")
|
|
117
|
+
|
|
118
|
+
def __delitem__(self, model_id: str) -> None:
|
|
119
|
+
"""Remove a model from the registry."""
|
|
120
|
+
del self._models[model_id]
|
|
121
|
+
if self._storage_path:
|
|
122
|
+
self._save()
|
|
123
|
+
|
|
124
|
+
def __iter__(self) -> Iterator[str]:
|
|
125
|
+
"""Iterate over model IDs."""
|
|
126
|
+
yield from self._models.keys()
|
|
127
|
+
|
|
128
|
+
def __len__(self) -> int:
|
|
129
|
+
"""Return number of models in registry."""
|
|
130
|
+
return len(self._models)
|
|
131
|
+
|
|
132
|
+
def filter(
|
|
133
|
+
self,
|
|
134
|
+
*,
|
|
135
|
+
provider: str | None = None,
|
|
136
|
+
is_local: bool | None = None,
|
|
137
|
+
min_context_size: int | None = None,
|
|
138
|
+
max_context_size: int | None = None,
|
|
139
|
+
has_capabilities: Iterable[str] | None = None,
|
|
140
|
+
tags: Iterable[str] | None = None,
|
|
141
|
+
custom_filter: Callable[[Model], bool] | None = None,
|
|
142
|
+
) -> list[Model]:
|
|
143
|
+
"""Filter models by multiple criteria.
|
|
144
|
+
|
|
145
|
+
>>> registry = ModelRegistry()
|
|
146
|
+
>>> registry["gpt-4"] = Model(id="gpt-4", provider="openai", context_size=8192)
|
|
147
|
+
>>> registry["llama2"] = Model(id="llama2", provider="ollama", is_local=True)
|
|
148
|
+
>>> local_models = registry.filter(is_local=True)
|
|
149
|
+
>>> len(local_models)
|
|
150
|
+
1
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
def _matches(model: Model) -> bool:
|
|
154
|
+
if provider and model.provider != provider:
|
|
155
|
+
return False
|
|
156
|
+
if is_local is not None and model.is_local != is_local:
|
|
157
|
+
return False
|
|
158
|
+
if min_context_size and (
|
|
159
|
+
not model.context_size or model.context_size < min_context_size
|
|
160
|
+
):
|
|
161
|
+
return False
|
|
162
|
+
if max_context_size and (
|
|
163
|
+
not model.context_size or model.context_size > max_context_size
|
|
164
|
+
):
|
|
165
|
+
return False
|
|
166
|
+
if has_capabilities:
|
|
167
|
+
for cap in has_capabilities:
|
|
168
|
+
if not model.capabilities.get(cap):
|
|
169
|
+
return False
|
|
170
|
+
if tags and not model.tags.issuperset(tags):
|
|
171
|
+
return False
|
|
172
|
+
if custom_filter and not custom_filter(model):
|
|
173
|
+
return False
|
|
174
|
+
return True
|
|
175
|
+
|
|
176
|
+
return [m for m in self._models.values() if _matches(m)]
|
|
177
|
+
|
|
178
|
+
def _load(self) -> None:
|
|
179
|
+
"""Load models from persistent storage."""
|
|
180
|
+
if not self._storage_path:
|
|
181
|
+
return
|
|
182
|
+
|
|
183
|
+
with open(self._storage_path) as f:
|
|
184
|
+
data = json.load(f)
|
|
185
|
+
for model_data in data["models"]:
|
|
186
|
+
# Reconstruct set for tags
|
|
187
|
+
model_data["tags"] = set(model_data.get("tags", []))
|
|
188
|
+
model = Model(**model_data)
|
|
189
|
+
self._models[model.id] = model
|
|
190
|
+
|
|
191
|
+
def _save(self) -> None:
|
|
192
|
+
"""Save models to persistent storage."""
|
|
193
|
+
if not self._storage_path:
|
|
194
|
+
return
|
|
195
|
+
|
|
196
|
+
models_data = []
|
|
197
|
+
for model in self._models.values():
|
|
198
|
+
model_dict = model.to_dict()
|
|
199
|
+
# Convert set to list for JSON serialization
|
|
200
|
+
model_dict["tags"] = list(model_dict["tags"])
|
|
201
|
+
models_data.append(model_dict)
|
|
202
|
+
|
|
203
|
+
self._storage_path.parent.mkdir(parents=True, exist_ok=True)
|
|
204
|
+
with open(self._storage_path, "w") as f:
|
|
205
|
+
json.dump({"models": models_data}, f, indent=2)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class ModelSource(ABC):
|
|
209
|
+
"""Abstract base for model discovery sources."""
|
|
210
|
+
|
|
211
|
+
@abstractmethod
|
|
212
|
+
def discover_models(self) -> Iterable[Model]:
|
|
213
|
+
"""Discover available models from this source.
|
|
214
|
+
|
|
215
|
+
Yields Model instances for each discovered model.
|
|
216
|
+
"""
|
|
217
|
+
pass
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
class Connector(ABC):
|
|
221
|
+
"""Abstract base for model connectors/clients."""
|
|
222
|
+
|
|
223
|
+
@abstractmethod
|
|
224
|
+
def format_metadata(self, model: Model) -> dict[str, Any]:
|
|
225
|
+
"""Format model metadata for this connector.
|
|
226
|
+
|
|
227
|
+
Returns a dict that can be used to instantiate/connect via this connector.
|
|
228
|
+
"""
|
|
229
|
+
pass
|
|
230
|
+
|
|
231
|
+
@property
|
|
232
|
+
@abstractmethod
|
|
233
|
+
def name(self) -> str:
|
|
234
|
+
"""Unique identifier for this connector."""
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class ConnectorRegistry(MutableMapping[str, Connector]):
|
|
239
|
+
"""Registry for managing model connectors.
|
|
240
|
+
|
|
241
|
+
>>> registry = ConnectorRegistry()
|
|
242
|
+
>>> class MyConnector(Connector):
|
|
243
|
+
... @property
|
|
244
|
+
... def name(self) -> str:
|
|
245
|
+
... return "my_connector"
|
|
246
|
+
... def format_metadata(self, model: Model) -> dict[str, Any]:
|
|
247
|
+
... return {"model": model.id}
|
|
248
|
+
>>> connector = MyConnector()
|
|
249
|
+
>>> registry[connector.name] = connector
|
|
250
|
+
>>> "my_connector" in registry
|
|
251
|
+
True
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
def __init__(self):
|
|
255
|
+
self._connectors: dict[str, Connector] = {}
|
|
256
|
+
|
|
257
|
+
def __setitem__(self, name: str, connector: Connector) -> None:
|
|
258
|
+
self._connectors[name] = connector
|
|
259
|
+
|
|
260
|
+
def __getitem__(self, name: str) -> Connector:
|
|
261
|
+
return self._connectors[name]
|
|
262
|
+
|
|
263
|
+
def __delitem__(self, name: str) -> None:
|
|
264
|
+
del self._connectors[name]
|
|
265
|
+
|
|
266
|
+
def __iter__(self) -> Iterator[str]:
|
|
267
|
+
yield from self._connectors.keys()
|
|
268
|
+
|
|
269
|
+
def __len__(self) -> int:
|
|
270
|
+
return len(self._connectors)
|