vector-inspector 0.2.7__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vector_inspector/config/__init__.py +4 -0
- vector_inspector/config/known_embedding_models.json +432 -0
- vector_inspector/core/connections/__init__.py +2 -1
- vector_inspector/core/connections/base_connection.py +42 -1
- vector_inspector/core/connections/chroma_connection.py +47 -11
- vector_inspector/core/connections/pinecone_connection.py +768 -0
- vector_inspector/core/embedding_providers/__init__.py +14 -0
- vector_inspector/core/embedding_providers/base_provider.py +128 -0
- vector_inspector/core/embedding_providers/clip_provider.py +260 -0
- vector_inspector/core/embedding_providers/provider_factory.py +176 -0
- vector_inspector/core/embedding_providers/sentence_transformer_provider.py +203 -0
- vector_inspector/core/embedding_utils.py +69 -42
- vector_inspector/core/model_registry.py +205 -0
- vector_inspector/services/backup_restore_service.py +16 -0
- vector_inspector/services/settings_service.py +117 -1
- vector_inspector/ui/components/connection_manager_panel.py +7 -0
- vector_inspector/ui/components/profile_manager_panel.py +61 -14
- vector_inspector/ui/dialogs/__init__.py +2 -1
- vector_inspector/ui/dialogs/cross_db_migration.py +20 -1
- vector_inspector/ui/dialogs/embedding_config_dialog.py +166 -27
- vector_inspector/ui/dialogs/provider_type_dialog.py +189 -0
- vector_inspector/ui/main_window.py +33 -2
- vector_inspector/ui/views/connection_view.py +55 -10
- vector_inspector/ui/views/info_panel.py +83 -36
- vector_inspector/ui/views/search_view.py +1 -1
- vector_inspector/ui/views/visualization_view.py +20 -6
- {vector_inspector-0.2.7.dist-info → vector_inspector-0.3.2.dist-info}/METADATA +7 -2
- vector_inspector-0.3.2.dist-info/RECORD +55 -0
- vector_inspector-0.2.7.dist-info/RECORD +0 -45
- {vector_inspector-0.2.7.dist-info → vector_inspector-0.3.2.dist-info}/WHEEL +0 -0
- {vector_inspector-0.2.7.dist-info → vector_inspector-0.3.2.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
"""Embedding provider system for loading and managing embedding models."""
|
|
2
|
+
|
|
3
|
+
from .base_provider import EmbeddingProvider, EmbeddingMetadata
|
|
4
|
+
from .sentence_transformer_provider import SentenceTransformerProvider
|
|
5
|
+
from .clip_provider import CLIPProvider
|
|
6
|
+
from .provider_factory import ProviderFactory
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
'EmbeddingProvider',
|
|
10
|
+
'EmbeddingMetadata',
|
|
11
|
+
'SentenceTransformerProvider',
|
|
12
|
+
'CLIPProvider',
|
|
13
|
+
'ProviderFactory',
|
|
14
|
+
]
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""Base interface for embedding providers."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import List, Union, Optional, Any
|
|
6
|
+
from enum import Enum
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Modality(Enum):
|
|
11
|
+
"""Embedding modality types."""
|
|
12
|
+
TEXT = "text"
|
|
13
|
+
IMAGE = "image"
|
|
14
|
+
MULTIMODAL = "multimodal"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Normalization(Enum):
|
|
18
|
+
"""Embedding normalization types."""
|
|
19
|
+
NONE = "none"
|
|
20
|
+
L2 = "l2"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class EmbeddingMetadata:
|
|
25
|
+
"""Metadata about an embedding model."""
|
|
26
|
+
name: str # Model identifier (e.g., "all-MiniLM-L6-v2")
|
|
27
|
+
dimension: int # Vector dimension
|
|
28
|
+
modality: Modality # text, image, or multimodal
|
|
29
|
+
normalization: Normalization # none or l2
|
|
30
|
+
model_type: str # sentence-transformer, clip, openai, etc.
|
|
31
|
+
source: str = "unknown" # hf, local, custom, cloud
|
|
32
|
+
version: Optional[str] = None # Model version if available
|
|
33
|
+
max_sequence_length: Optional[int] = None # Maximum input length
|
|
34
|
+
description: Optional[str] = None # Human-readable description
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class EmbeddingProvider(ABC):
|
|
38
|
+
"""Abstract base class for embedding providers.
|
|
39
|
+
|
|
40
|
+
Providers handle loading, encoding, and metadata extraction for embedding models.
|
|
41
|
+
They implement lazy-loading to avoid UI freezes when working with large models.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, model_name: str):
|
|
45
|
+
"""Initialize provider with a model name.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
model_name: Model identifier (HuggingFace ID, path, or API model name)
|
|
49
|
+
"""
|
|
50
|
+
self.model_name = model_name
|
|
51
|
+
self._model = None # Lazy-loaded model instance
|
|
52
|
+
self._is_loaded = False
|
|
53
|
+
|
|
54
|
+
@abstractmethod
|
|
55
|
+
def get_metadata(self) -> EmbeddingMetadata:
|
|
56
|
+
"""Get metadata about the embedding model.
|
|
57
|
+
|
|
58
|
+
This should be fast and not require loading the full model if possible.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
EmbeddingMetadata with model information
|
|
62
|
+
"""
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
@abstractmethod
|
|
66
|
+
def encode(
|
|
67
|
+
self,
|
|
68
|
+
inputs: Union[str, List[str], Any],
|
|
69
|
+
normalize: bool = True,
|
|
70
|
+
show_progress: bool = False
|
|
71
|
+
) -> np.ndarray:
|
|
72
|
+
"""Encode inputs into embeddings.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
inputs: Text strings, images, or other inputs depending on modality
|
|
76
|
+
normalize: Whether to L2-normalize the embeddings
|
|
77
|
+
show_progress: Whether to show progress bar for batch encoding
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
numpy array of embeddings, shape (n_inputs, dimension)
|
|
81
|
+
"""
|
|
82
|
+
pass
|
|
83
|
+
|
|
84
|
+
def warmup(self, progress_callback=None):
|
|
85
|
+
"""Load and initialize the model (warm up for faster subsequent calls).
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
progress_callback: Optional callback(message: str, progress: float) for UI updates
|
|
89
|
+
"""
|
|
90
|
+
if self._is_loaded:
|
|
91
|
+
return
|
|
92
|
+
|
|
93
|
+
if progress_callback:
|
|
94
|
+
progress_callback(f"Loading {self.model_name}...", 0.0)
|
|
95
|
+
|
|
96
|
+
self._load_model()
|
|
97
|
+
self._is_loaded = True
|
|
98
|
+
|
|
99
|
+
if progress_callback:
|
|
100
|
+
progress_callback(f"Model {self.model_name} loaded", 1.0)
|
|
101
|
+
|
|
102
|
+
@abstractmethod
|
|
103
|
+
def _load_model(self):
|
|
104
|
+
"""Internal method to load the actual model. Override in subclasses."""
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
def close(self):
|
|
108
|
+
"""Release model resources and cleanup."""
|
|
109
|
+
if self._model is not None:
|
|
110
|
+
# Try to free memory
|
|
111
|
+
del self._model
|
|
112
|
+
self._model = None
|
|
113
|
+
self._is_loaded = False
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def is_loaded(self) -> bool:
|
|
117
|
+
"""Check if model is currently loaded in memory."""
|
|
118
|
+
return self._is_loaded
|
|
119
|
+
|
|
120
|
+
def __enter__(self):
|
|
121
|
+
"""Context manager support."""
|
|
122
|
+
self.warmup()
|
|
123
|
+
return self
|
|
124
|
+
|
|
125
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
126
|
+
"""Context manager cleanup."""
|
|
127
|
+
self.close()
|
|
128
|
+
return False
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""CLIP embedding provider for multimodal (text + image) embeddings."""
|
|
2
|
+
|
|
3
|
+
from typing import List, Union, Optional, Any
|
|
4
|
+
import numpy as np
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from .base_provider import (
|
|
8
|
+
EmbeddingProvider,
|
|
9
|
+
EmbeddingMetadata,
|
|
10
|
+
Modality,
|
|
11
|
+
Normalization
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class CLIPProvider(EmbeddingProvider):
|
|
16
|
+
"""Provider for CLIP models supporting text and image embeddings.
|
|
17
|
+
|
|
18
|
+
Lazy-loads the transformers library and CLIP model on first use.
|
|
19
|
+
Supports OpenAI CLIP and LAION CLIP variants:
|
|
20
|
+
- openai/clip-vit-base-patch32
|
|
21
|
+
- openai/clip-vit-large-patch14
|
|
22
|
+
- laion/CLIP-ViT-B-32-laion2B-s34B-b79K
|
|
23
|
+
- And other CLIP-compatible models
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, model_name: str):
|
|
27
|
+
"""Initialize CLIP provider.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
model_name: HuggingFace model ID (e.g., "openai/clip-vit-base-patch32")
|
|
31
|
+
"""
|
|
32
|
+
super().__init__(model_name)
|
|
33
|
+
self._processor = None
|
|
34
|
+
self._metadata = None
|
|
35
|
+
|
|
36
|
+
def get_metadata(self) -> EmbeddingMetadata:
|
|
37
|
+
"""Get metadata about the CLIP model."""
|
|
38
|
+
if self._metadata is not None:
|
|
39
|
+
return self._metadata
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from transformers import CLIPConfig
|
|
43
|
+
|
|
44
|
+
# Try to get config without loading full model
|
|
45
|
+
try:
|
|
46
|
+
config = CLIPConfig.from_pretrained(self.model_name)
|
|
47
|
+
dimension = config.projection_dim
|
|
48
|
+
max_length = config.text_config.max_position_embeddings
|
|
49
|
+
except Exception:
|
|
50
|
+
# Fallback dimensions for common CLIP models
|
|
51
|
+
dimension_map = {
|
|
52
|
+
"openai/clip-vit-base-patch32": 512,
|
|
53
|
+
"openai/clip-vit-base-patch16": 512,
|
|
54
|
+
"openai/clip-vit-large-patch14": 768,
|
|
55
|
+
"openai/clip-vit-large-patch14-336": 768,
|
|
56
|
+
}
|
|
57
|
+
dimension = dimension_map.get(self.model_name, 512)
|
|
58
|
+
max_length = 77 # Standard CLIP text length
|
|
59
|
+
|
|
60
|
+
self._metadata = EmbeddingMetadata(
|
|
61
|
+
name=self.model_name,
|
|
62
|
+
dimension=dimension,
|
|
63
|
+
modality=Modality.MULTIMODAL,
|
|
64
|
+
normalization=Normalization.L2, # CLIP normalizes embeddings
|
|
65
|
+
model_type="clip",
|
|
66
|
+
source="huggingface",
|
|
67
|
+
max_sequence_length=max_length,
|
|
68
|
+
description=f"CLIP multimodal model: {self.model_name}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
except ImportError:
|
|
72
|
+
raise ImportError(
|
|
73
|
+
"transformers library not installed. "
|
|
74
|
+
"Install with: pip install transformers"
|
|
75
|
+
)
|
|
76
|
+
except Exception as e:
|
|
77
|
+
# Fallback metadata
|
|
78
|
+
self._metadata = EmbeddingMetadata(
|
|
79
|
+
name=self.model_name,
|
|
80
|
+
dimension=512, # Common CLIP dimension
|
|
81
|
+
modality=Modality.MULTIMODAL,
|
|
82
|
+
normalization=Normalization.L2,
|
|
83
|
+
model_type="clip",
|
|
84
|
+
source="huggingface",
|
|
85
|
+
description=f"CLIP model: {self.model_name} (dimension not verified)"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return self._metadata
|
|
89
|
+
|
|
90
|
+
def _load_model(self):
|
|
91
|
+
"""Load the CLIP model and processor."""
|
|
92
|
+
try:
|
|
93
|
+
from transformers import CLIPModel, CLIPProcessor
|
|
94
|
+
except ImportError:
|
|
95
|
+
raise ImportError(
|
|
96
|
+
"transformers library not installed. "
|
|
97
|
+
"Install with: pip install transformers"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
self._model = CLIPModel.from_pretrained(self.model_name)
|
|
101
|
+
self._processor = CLIPProcessor.from_pretrained(self.model_name)
|
|
102
|
+
|
|
103
|
+
# Move to GPU if available
|
|
104
|
+
try:
|
|
105
|
+
import torch
|
|
106
|
+
if torch.cuda.is_available():
|
|
107
|
+
self._model = self._model.to('cuda')
|
|
108
|
+
except ImportError:
|
|
109
|
+
pass # PyTorch not available, stay on CPU
|
|
110
|
+
|
|
111
|
+
def encode(
|
|
112
|
+
self,
|
|
113
|
+
inputs: Union[str, List[str], Any],
|
|
114
|
+
normalize: bool = True,
|
|
115
|
+
show_progress: bool = False,
|
|
116
|
+
input_type: str = "text"
|
|
117
|
+
) -> np.ndarray:
|
|
118
|
+
"""Encode text or images into embeddings.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
inputs: Text strings, image paths, or PIL images
|
|
122
|
+
normalize: Whether to L2-normalize embeddings
|
|
123
|
+
show_progress: Whether to show progress (not implemented for CLIP)
|
|
124
|
+
input_type: "text" or "image"
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
numpy array of embeddings
|
|
128
|
+
"""
|
|
129
|
+
if not self._is_loaded:
|
|
130
|
+
self.warmup()
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
import torch
|
|
134
|
+
except ImportError:
|
|
135
|
+
raise ImportError("PyTorch required for CLIP. Install with: pip install torch")
|
|
136
|
+
|
|
137
|
+
# Convert single input to list
|
|
138
|
+
if isinstance(inputs, str) or not isinstance(inputs, list):
|
|
139
|
+
inputs = [inputs]
|
|
140
|
+
|
|
141
|
+
if self._processor is None:
|
|
142
|
+
raise RuntimeError("Model not loaded. Call warmup() first.")
|
|
143
|
+
|
|
144
|
+
with torch.no_grad():
|
|
145
|
+
if input_type == "text":
|
|
146
|
+
# Process text
|
|
147
|
+
processed = self._processor(
|
|
148
|
+
text=inputs,
|
|
149
|
+
return_tensors="pt",
|
|
150
|
+
padding=True,
|
|
151
|
+
truncation=True
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Move to same device as model
|
|
155
|
+
if next(self._model.parameters()).is_cuda:
|
|
156
|
+
processed = {k: v.cuda() for k, v in processed.items()}
|
|
157
|
+
|
|
158
|
+
# Get text embeddings
|
|
159
|
+
embeddings = self._model.get_text_features(**processed)
|
|
160
|
+
|
|
161
|
+
elif input_type == "image":
|
|
162
|
+
# Load images if they're paths
|
|
163
|
+
images = []
|
|
164
|
+
for inp in inputs:
|
|
165
|
+
if isinstance(inp, (str, Path)):
|
|
166
|
+
from PIL import Image
|
|
167
|
+
images.append(Image.open(inp))
|
|
168
|
+
else:
|
|
169
|
+
images.append(inp) # Assume already PIL Image
|
|
170
|
+
|
|
171
|
+
# Process images
|
|
172
|
+
processed = self._processor(
|
|
173
|
+
images=images,
|
|
174
|
+
return_tensors="pt"
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Move to same device as model
|
|
178
|
+
if next(self._model.parameters()).is_cuda:
|
|
179
|
+
processed = {k: v.cuda() for k, v in processed.items()}
|
|
180
|
+
|
|
181
|
+
# Get image embeddings
|
|
182
|
+
embeddings = self._model.get_image_features(**processed)
|
|
183
|
+
else:
|
|
184
|
+
raise ValueError(f"Unknown input_type: {input_type}. Use 'text' or 'image'")
|
|
185
|
+
|
|
186
|
+
# Convert to numpy
|
|
187
|
+
embeddings = embeddings.cpu().numpy()
|
|
188
|
+
|
|
189
|
+
# Normalize if requested (CLIP typically normalizes, but we can ensure it)
|
|
190
|
+
if normalize:
|
|
191
|
+
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
|
192
|
+
embeddings = embeddings / norms
|
|
193
|
+
|
|
194
|
+
return embeddings
|
|
195
|
+
|
|
196
|
+
def encode_text(
|
|
197
|
+
self,
|
|
198
|
+
texts: Union[str, List[str]],
|
|
199
|
+
normalize: bool = True
|
|
200
|
+
) -> np.ndarray:
|
|
201
|
+
"""Encode text inputs into embeddings.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
texts: Single text or list of texts
|
|
205
|
+
normalize: Whether to L2-normalize
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
numpy array of text embeddings
|
|
209
|
+
"""
|
|
210
|
+
return self.encode(texts, normalize=normalize, input_type="text")
|
|
211
|
+
|
|
212
|
+
def encode_image(
|
|
213
|
+
self,
|
|
214
|
+
images: Union[str, Path, Any, List],
|
|
215
|
+
normalize: bool = True
|
|
216
|
+
) -> np.ndarray:
|
|
217
|
+
"""Encode images into embeddings.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
images: Image path(s) or PIL Image(s)
|
|
221
|
+
normalize: Whether to L2-normalize
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
numpy array of image embeddings
|
|
225
|
+
"""
|
|
226
|
+
return self.encode(images, normalize=normalize, input_type="image")
|
|
227
|
+
|
|
228
|
+
def similarity(
|
|
229
|
+
self,
|
|
230
|
+
query: Union[str, np.ndarray],
|
|
231
|
+
corpus: List[str],
|
|
232
|
+
query_type: str = "text",
|
|
233
|
+
corpus_type: str = "text"
|
|
234
|
+
) -> np.ndarray:
|
|
235
|
+
"""Compute similarity between query and corpus (text or image).
|
|
236
|
+
|
|
237
|
+
Args:
|
|
238
|
+
query: Query string/image or embedding
|
|
239
|
+
corpus: List of corpus items (text or images)
|
|
240
|
+
query_type: "text" or "image"
|
|
241
|
+
corpus_type: "text" or "image"
|
|
242
|
+
|
|
243
|
+
Returns:
|
|
244
|
+
Similarity scores (cosine similarity)
|
|
245
|
+
"""
|
|
246
|
+
if not self._is_loaded:
|
|
247
|
+
self.warmup()
|
|
248
|
+
|
|
249
|
+
# Get embeddings
|
|
250
|
+
if isinstance(query, np.ndarray):
|
|
251
|
+
query_emb = query
|
|
252
|
+
else:
|
|
253
|
+
query_emb = self.encode(query, normalize=True, input_type=query_type)
|
|
254
|
+
|
|
255
|
+
corpus_emb = self.encode(corpus, normalize=True, input_type=corpus_type)
|
|
256
|
+
|
|
257
|
+
# Compute cosine similarity (dot product if normalized)
|
|
258
|
+
similarities = np.dot(corpus_emb, query_emb.T).squeeze()
|
|
259
|
+
|
|
260
|
+
return similarities
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
"""Factory for creating embedding providers."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, Dict, Type
|
|
4
|
+
from .base_provider import EmbeddingProvider
|
|
5
|
+
from .sentence_transformer_provider import SentenceTransformerProvider
|
|
6
|
+
from .clip_provider import CLIPProvider
|
|
7
|
+
from ..model_registry import get_model_registry
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ProviderFactory:
|
|
11
|
+
"""Factory for creating appropriate embedding providers based on model type."""
|
|
12
|
+
|
|
13
|
+
# Registry of provider classes by type
|
|
14
|
+
_PROVIDER_REGISTRY: Dict[str, Type[EmbeddingProvider]] = {
|
|
15
|
+
"sentence-transformer": SentenceTransformerProvider,
|
|
16
|
+
"clip": CLIPProvider,
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
# Model name patterns to auto-detect provider type
|
|
20
|
+
_MODEL_PATTERNS = {
|
|
21
|
+
"clip": ["clip", "CLIP"],
|
|
22
|
+
"sentence-transformer": [
|
|
23
|
+
"sentence-transformers/",
|
|
24
|
+
"all-MiniLM",
|
|
25
|
+
"all-mpnet",
|
|
26
|
+
"all-roberta",
|
|
27
|
+
"paraphrase-",
|
|
28
|
+
"multi-qa-",
|
|
29
|
+
"msmarco-",
|
|
30
|
+
"gtr-",
|
|
31
|
+
"bge-",
|
|
32
|
+
"gte-",
|
|
33
|
+
"e5-",
|
|
34
|
+
"jina-",
|
|
35
|
+
"nomic-",
|
|
36
|
+
]
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def create(
|
|
41
|
+
cls,
|
|
42
|
+
model_name: str,
|
|
43
|
+
model_type: Optional[str] = None,
|
|
44
|
+
**kwargs
|
|
45
|
+
) -> EmbeddingProvider:
|
|
46
|
+
"""Create an embedding provider for the given model.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
model_name: Model identifier (HF ID, path, or API name)
|
|
50
|
+
model_type: Explicit provider type (sentence-transformer, clip, openai, etc.)
|
|
51
|
+
If None, will attempt auto-detection based on model name
|
|
52
|
+
**kwargs: Additional arguments passed to provider constructor
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Appropriate EmbeddingProvider instance
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValueError: If model type is unknown or cannot be auto-detected
|
|
59
|
+
"""
|
|
60
|
+
# Auto-detect provider type if not specified
|
|
61
|
+
if model_type is None:
|
|
62
|
+
model_type = cls._detect_provider_type(model_name)
|
|
63
|
+
|
|
64
|
+
# Normalize model type
|
|
65
|
+
model_type = model_type.lower()
|
|
66
|
+
|
|
67
|
+
# Get provider class from registry
|
|
68
|
+
provider_class = cls._PROVIDER_REGISTRY.get(model_type)
|
|
69
|
+
|
|
70
|
+
if provider_class is None:
|
|
71
|
+
# Check if it's a cloud provider (not yet implemented)
|
|
72
|
+
if model_type in ["openai", "cohere", "vertex-ai", "voyage"]:
|
|
73
|
+
raise NotImplementedError(
|
|
74
|
+
f"Cloud provider '{model_type}' not yet implemented. "
|
|
75
|
+
f"Currently supported: {', '.join(cls._PROVIDER_REGISTRY.keys())}"
|
|
76
|
+
)
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Unknown provider type: {model_type}. "
|
|
80
|
+
f"Supported types: {', '.join(cls._PROVIDER_REGISTRY.keys())}"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Create and return provider instance
|
|
84
|
+
return provider_class(model_name, **kwargs)
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def _detect_provider_type(cls, model_name: str) -> str:
|
|
88
|
+
"""Auto-detect provider type based on model name patterns.
|
|
89
|
+
|
|
90
|
+
First checks the known model registry, then falls back to pattern matching.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
model_name: Model identifier
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Detected provider type
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If provider type cannot be detected
|
|
100
|
+
"""
|
|
101
|
+
# First, check if model is in registry
|
|
102
|
+
registry = get_model_registry()
|
|
103
|
+
model_info = registry.get_model_by_name(model_name)
|
|
104
|
+
if model_info:
|
|
105
|
+
return model_info.type
|
|
106
|
+
|
|
107
|
+
# Fall back to pattern matching
|
|
108
|
+
model_name_lower = model_name.lower()
|
|
109
|
+
|
|
110
|
+
# Check each pattern category
|
|
111
|
+
for provider_type, patterns in cls._MODEL_PATTERNS.items():
|
|
112
|
+
for pattern in patterns:
|
|
113
|
+
if pattern.lower() in model_name_lower:
|
|
114
|
+
return provider_type
|
|
115
|
+
|
|
116
|
+
# Default to sentence-transformer for HuggingFace models
|
|
117
|
+
if "/" in model_name and not model_name.startswith("http"):
|
|
118
|
+
return "sentence-transformer"
|
|
119
|
+
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Cannot auto-detect provider type for model: {model_name}. "
|
|
122
|
+
"Please specify model_type explicitly."
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def register_provider(cls, model_type: str, provider_class: Type[EmbeddingProvider]):
|
|
127
|
+
"""Register a new provider type.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
model_type: Provider type identifier
|
|
131
|
+
provider_class: Provider class (must inherit from EmbeddingProvider)
|
|
132
|
+
"""
|
|
133
|
+
if not issubclass(provider_class, EmbeddingProvider):
|
|
134
|
+
raise TypeError(f"{provider_class} must inherit from EmbeddingProvider")
|
|
135
|
+
|
|
136
|
+
cls._PROVIDER_REGISTRY[model_type.lower()] = provider_class
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def list_supported_types(cls) -> list:
|
|
140
|
+
"""Get list of supported provider types.
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
List of registered provider type names
|
|
144
|
+
"""
|
|
145
|
+
return list(cls._PROVIDER_REGISTRY.keys())
|
|
146
|
+
|
|
147
|
+
@classmethod
|
|
148
|
+
def supports_type(cls, model_type: str) -> bool:
|
|
149
|
+
"""Check if a provider type is supported.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
model_type: Provider type to check
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
True if supported, False otherwise
|
|
156
|
+
"""
|
|
157
|
+
return model_type.lower() in cls._PROVIDER_REGISTRY
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# Convenience function for creating providers
|
|
161
|
+
def create_provider(
|
|
162
|
+
model_name: str,
|
|
163
|
+
model_type: Optional[str] = None,
|
|
164
|
+
**kwargs
|
|
165
|
+
) -> EmbeddingProvider:
|
|
166
|
+
"""Create an embedding provider (convenience wrapper around ProviderFactory).
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
model_name: Model identifier
|
|
170
|
+
model_type: Optional explicit provider type
|
|
171
|
+
**kwargs: Additional arguments for provider
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
EmbeddingProvider instance
|
|
175
|
+
"""
|
|
176
|
+
return ProviderFactory.create(model_name, model_type, **kwargs)
|