vector-inspector 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vector_inspector/config/__init__.py +4 -0
- vector_inspector/config/known_embedding_models.json +432 -0
- vector_inspector/core/cache_manager.py +159 -0
- vector_inspector/core/connection_manager.py +277 -0
- vector_inspector/core/connections/__init__.py +2 -1
- vector_inspector/core/connections/base_connection.py +42 -1
- vector_inspector/core/connections/chroma_connection.py +137 -16
- vector_inspector/core/connections/pinecone_connection.py +768 -0
- vector_inspector/core/connections/qdrant_connection.py +62 -8
- vector_inspector/core/embedding_providers/__init__.py +14 -0
- vector_inspector/core/embedding_providers/base_provider.py +128 -0
- vector_inspector/core/embedding_providers/clip_provider.py +260 -0
- vector_inspector/core/embedding_providers/provider_factory.py +176 -0
- vector_inspector/core/embedding_providers/sentence_transformer_provider.py +203 -0
- vector_inspector/core/embedding_utils.py +167 -0
- vector_inspector/core/model_registry.py +205 -0
- vector_inspector/services/backup_restore_service.py +19 -29
- vector_inspector/services/credential_service.py +130 -0
- vector_inspector/services/filter_service.py +1 -1
- vector_inspector/services/profile_service.py +409 -0
- vector_inspector/services/settings_service.py +136 -1
- vector_inspector/ui/components/connection_manager_panel.py +327 -0
- vector_inspector/ui/components/profile_manager_panel.py +565 -0
- vector_inspector/ui/dialogs/__init__.py +6 -0
- vector_inspector/ui/dialogs/cross_db_migration.py +383 -0
- vector_inspector/ui/dialogs/embedding_config_dialog.py +315 -0
- vector_inspector/ui/dialogs/provider_type_dialog.py +189 -0
- vector_inspector/ui/main_window.py +456 -190
- vector_inspector/ui/views/connection_view.py +55 -10
- vector_inspector/ui/views/info_panel.py +272 -55
- vector_inspector/ui/views/metadata_view.py +71 -3
- vector_inspector/ui/views/search_view.py +44 -4
- vector_inspector/ui/views/visualization_view.py +19 -5
- {vector_inspector-0.2.6.dist-info → vector_inspector-0.3.1.dist-info}/METADATA +3 -1
- vector_inspector-0.3.1.dist-info/RECORD +55 -0
- vector_inspector-0.2.6.dist-info/RECORD +0 -35
- {vector_inspector-0.2.6.dist-info → vector_inspector-0.3.1.dist-info}/WHEEL +0 -0
- {vector_inspector-0.2.6.dist-info → vector_inspector-0.3.1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
"""Connection manager for handling multiple vector database connections."""
|
|
2
|
+
|
|
3
|
+
import uuid
|
|
4
|
+
from typing import Dict, Optional, List, Any
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from PySide6.QtCore import QObject, Signal
|
|
7
|
+
|
|
8
|
+
from .connections.base_connection import VectorDBConnection
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ConnectionState(Enum):
|
|
12
|
+
"""Possible connection states."""
|
|
13
|
+
DISCONNECTED = "disconnected"
|
|
14
|
+
CONNECTING = "connecting"
|
|
15
|
+
CONNECTED = "connected"
|
|
16
|
+
ERROR = "error"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ConnectionInstance:
|
|
20
|
+
"""Represents a single active connection with its state and context."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
connection_id: str,
|
|
25
|
+
name: str,
|
|
26
|
+
provider: str,
|
|
27
|
+
connection: VectorDBConnection,
|
|
28
|
+
config: Dict[str, Any]
|
|
29
|
+
):
|
|
30
|
+
"""
|
|
31
|
+
Initialize a connection instance.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
connection_id: Unique connection identifier
|
|
35
|
+
name: User-friendly connection name
|
|
36
|
+
provider: Provider type (chromadb, qdrant, etc.)
|
|
37
|
+
connection: The actual connection object
|
|
38
|
+
config: Connection configuration dict
|
|
39
|
+
"""
|
|
40
|
+
self.id = connection_id
|
|
41
|
+
self.name = name
|
|
42
|
+
self.provider = provider
|
|
43
|
+
self.connection = connection
|
|
44
|
+
self.config = config
|
|
45
|
+
self.state = ConnectionState.DISCONNECTED
|
|
46
|
+
self.active_collection: Optional[str] = None
|
|
47
|
+
self.collections: List[str] = []
|
|
48
|
+
self.error_message: Optional[str] = None
|
|
49
|
+
|
|
50
|
+
def get_display_name(self) -> str:
|
|
51
|
+
"""Get a display-friendly connection name."""
|
|
52
|
+
return f"{self.name} ({self.provider})"
|
|
53
|
+
|
|
54
|
+
def get_breadcrumb(self) -> str:
|
|
55
|
+
"""Get breadcrumb showing connection > collection."""
|
|
56
|
+
if self.active_collection:
|
|
57
|
+
return f"{self.name} > {self.active_collection}"
|
|
58
|
+
return self.name
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ConnectionManager(QObject):
|
|
62
|
+
"""Manages multiple vector database connections and saved profiles.
|
|
63
|
+
|
|
64
|
+
Signals:
|
|
65
|
+
connection_opened: Emitted when a new connection is opened (connection_id)
|
|
66
|
+
connection_closed: Emitted when a connection is closed (connection_id)
|
|
67
|
+
connection_state_changed: Emitted when connection state changes (connection_id, state)
|
|
68
|
+
active_connection_changed: Emitted when active connection changes (connection_id or None)
|
|
69
|
+
active_collection_changed: Emitted when active collection changes (connection_id, collection_name or None)
|
|
70
|
+
collections_updated: Emitted when collections list is updated (connection_id, collections)
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
# Signals
|
|
74
|
+
connection_opened = Signal(str) # connection_id
|
|
75
|
+
connection_closed = Signal(str) # connection_id
|
|
76
|
+
connection_state_changed = Signal(str, ConnectionState) # connection_id, state
|
|
77
|
+
active_connection_changed = Signal(object) # connection_id or None
|
|
78
|
+
active_collection_changed = Signal(str, object) # connection_id, collection_name or None
|
|
79
|
+
collections_updated = Signal(str, list) # connection_id, collections
|
|
80
|
+
|
|
81
|
+
MAX_CONNECTIONS = 10 # Limit to prevent resource exhaustion
|
|
82
|
+
|
|
83
|
+
def __init__(self):
|
|
84
|
+
"""Initialize the connection manager."""
|
|
85
|
+
super().__init__()
|
|
86
|
+
self._connections: Dict[str, ConnectionInstance] = {}
|
|
87
|
+
self._active_connection_id: Optional[str] = None
|
|
88
|
+
|
|
89
|
+
def create_connection(
|
|
90
|
+
self,
|
|
91
|
+
name: str,
|
|
92
|
+
provider: str,
|
|
93
|
+
connection: VectorDBConnection,
|
|
94
|
+
config: Dict[str, Any]
|
|
95
|
+
) -> str:
|
|
96
|
+
"""
|
|
97
|
+
Create a new connection instance (not yet connected).
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
name: User-friendly connection name
|
|
101
|
+
provider: Provider type
|
|
102
|
+
connection: The connection object
|
|
103
|
+
config: Connection configuration
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
The connection ID
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
RuntimeError: If maximum connections limit reached
|
|
110
|
+
"""
|
|
111
|
+
if len(self._connections) >= self.MAX_CONNECTIONS:
|
|
112
|
+
raise RuntimeError(f"Maximum number of connections ({self.MAX_CONNECTIONS}) reached")
|
|
113
|
+
|
|
114
|
+
connection_id = str(uuid.uuid4())
|
|
115
|
+
instance = ConnectionInstance(connection_id, name, provider, connection, config)
|
|
116
|
+
self._connections[connection_id] = instance
|
|
117
|
+
|
|
118
|
+
# Set as active if it's the first connection
|
|
119
|
+
if len(self._connections) == 1:
|
|
120
|
+
self._active_connection_id = connection_id
|
|
121
|
+
self.active_connection_changed.emit(connection_id)
|
|
122
|
+
|
|
123
|
+
# Don't emit connection_opened yet - wait until actually connected
|
|
124
|
+
return connection_id
|
|
125
|
+
|
|
126
|
+
def mark_connection_opened(self, connection_id: str):
|
|
127
|
+
"""
|
|
128
|
+
Mark a connection as opened (after successful connection).
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
connection_id: ID of connection that opened
|
|
132
|
+
"""
|
|
133
|
+
if connection_id in self._connections:
|
|
134
|
+
self.connection_opened.emit(connection_id)
|
|
135
|
+
|
|
136
|
+
def get_connection(self, connection_id: str) -> Optional[ConnectionInstance]:
|
|
137
|
+
"""Get a connection instance by ID."""
|
|
138
|
+
return self._connections.get(connection_id)
|
|
139
|
+
|
|
140
|
+
def get_active_connection(self) -> Optional[ConnectionInstance]:
|
|
141
|
+
"""Get the currently active connection instance."""
|
|
142
|
+
if self._active_connection_id:
|
|
143
|
+
return self._connections.get(self._active_connection_id)
|
|
144
|
+
return None
|
|
145
|
+
|
|
146
|
+
def get_active_connection_id(self) -> Optional[str]:
|
|
147
|
+
"""Get the currently active connection ID."""
|
|
148
|
+
return self._active_connection_id
|
|
149
|
+
|
|
150
|
+
def set_active_connection(self, connection_id: str) -> bool:
|
|
151
|
+
"""
|
|
152
|
+
Set the active connection.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
connection_id: ID of connection to make active
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
True if successful, False if connection not found
|
|
159
|
+
"""
|
|
160
|
+
if connection_id not in self._connections:
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
self._active_connection_id = connection_id
|
|
164
|
+
self.active_connection_changed.emit(connection_id)
|
|
165
|
+
return True
|
|
166
|
+
|
|
167
|
+
def close_connection(self, connection_id: str) -> bool:
|
|
168
|
+
"""
|
|
169
|
+
Close and remove a connection.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
connection_id: ID of connection to close
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
True if successful, False if connection not found
|
|
176
|
+
"""
|
|
177
|
+
instance = self._connections.get(connection_id)
|
|
178
|
+
if not instance:
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
# Disconnect the connection
|
|
182
|
+
try:
|
|
183
|
+
instance.connection.disconnect()
|
|
184
|
+
except Exception as e:
|
|
185
|
+
print(f"Error disconnecting: {e}")
|
|
186
|
+
|
|
187
|
+
# Remove from connections dict
|
|
188
|
+
del self._connections[connection_id]
|
|
189
|
+
|
|
190
|
+
# If this was the active connection, set a new one or None
|
|
191
|
+
if self._active_connection_id == connection_id:
|
|
192
|
+
if self._connections:
|
|
193
|
+
# Set first available connection as active
|
|
194
|
+
self._active_connection_id = next(iter(self._connections.keys()))
|
|
195
|
+
self.active_connection_changed.emit(self._active_connection_id)
|
|
196
|
+
else:
|
|
197
|
+
self._active_connection_id = None
|
|
198
|
+
self.active_connection_changed.emit(None)
|
|
199
|
+
|
|
200
|
+
self.connection_closed.emit(connection_id)
|
|
201
|
+
return True
|
|
202
|
+
|
|
203
|
+
def update_connection_state(self, connection_id: str, state: ConnectionState, error: Optional[str] = None):
|
|
204
|
+
"""
|
|
205
|
+
Update the state of a connection.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
connection_id: ID of connection
|
|
209
|
+
state: New connection state
|
|
210
|
+
error: Optional error message if state is ERROR
|
|
211
|
+
"""
|
|
212
|
+
instance = self._connections.get(connection_id)
|
|
213
|
+
if instance:
|
|
214
|
+
instance.state = state
|
|
215
|
+
if error:
|
|
216
|
+
instance.error_message = error
|
|
217
|
+
else:
|
|
218
|
+
instance.error_message = None
|
|
219
|
+
self.connection_state_changed.emit(connection_id, state)
|
|
220
|
+
|
|
221
|
+
def update_collections(self, connection_id: str, collections: List[str]):
|
|
222
|
+
"""
|
|
223
|
+
Update the collections list for a connection.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
connection_id: ID of connection
|
|
227
|
+
collections: List of collection names
|
|
228
|
+
"""
|
|
229
|
+
instance = self._connections.get(connection_id)
|
|
230
|
+
if instance:
|
|
231
|
+
instance.collections = collections
|
|
232
|
+
self.collections_updated.emit(connection_id, collections)
|
|
233
|
+
|
|
234
|
+
def set_active_collection(self, connection_id: str, collection_name: Optional[str]):
|
|
235
|
+
"""
|
|
236
|
+
Set the active collection for a connection.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
connection_id: ID of connection
|
|
240
|
+
collection_name: Name of collection to make active, or None
|
|
241
|
+
"""
|
|
242
|
+
instance = self._connections.get(connection_id)
|
|
243
|
+
if instance:
|
|
244
|
+
instance.active_collection = collection_name
|
|
245
|
+
self.active_collection_changed.emit(connection_id, collection_name)
|
|
246
|
+
|
|
247
|
+
def get_all_connections(self) -> List[ConnectionInstance]:
|
|
248
|
+
"""Get list of all connection instances."""
|
|
249
|
+
return list(self._connections.values())
|
|
250
|
+
|
|
251
|
+
def get_connection_count(self) -> int:
|
|
252
|
+
"""Get the number of active connections."""
|
|
253
|
+
return len(self._connections)
|
|
254
|
+
|
|
255
|
+
def close_all_connections(self):
|
|
256
|
+
"""Close all connections. Typically called on application exit."""
|
|
257
|
+
connection_ids = list(self._connections.keys())
|
|
258
|
+
for conn_id in connection_ids:
|
|
259
|
+
self.close_connection(conn_id)
|
|
260
|
+
|
|
261
|
+
def rename_connection(self, connection_id: str, new_name: str) -> bool:
|
|
262
|
+
"""
|
|
263
|
+
Rename a connection.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
connection_id: ID of connection
|
|
267
|
+
new_name: New name for the connection
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
True if successful, False if connection not found
|
|
271
|
+
"""
|
|
272
|
+
instance = self._connections.get(connection_id)
|
|
273
|
+
if instance:
|
|
274
|
+
instance.name = new_name
|
|
275
|
+
return True
|
|
276
|
+
return False
|
|
277
|
+
|
|
@@ -3,5 +3,6 @@
|
|
|
3
3
|
from .base_connection import VectorDBConnection
|
|
4
4
|
from .chroma_connection import ChromaDBConnection
|
|
5
5
|
from .qdrant_connection import QdrantConnection
|
|
6
|
+
from .pinecone_connection import PineconeConnection
|
|
6
7
|
|
|
7
|
-
__all__ = ["VectorDBConnection", "ChromaDBConnection", "QdrantConnection"]
|
|
8
|
+
__all__ = ["VectorDBConnection", "ChromaDBConnection", "QdrantConnection", "PineconeConnection"]
|
|
@@ -229,5 +229,46 @@ class VectorDBConnection(ABC):
|
|
|
229
229
|
{"name": "in", "server_side": True},
|
|
230
230
|
{"name": "not in", "server_side": True},
|
|
231
231
|
{"name": "contains", "server_side": False},
|
|
232
|
-
{"name": "not contains", "server_side": False},
|
|
233
232
|
]
|
|
233
|
+
|
|
234
|
+
def get_embedding_model(self, collection_name: str, connection_id: Optional[str] = None) -> Optional[str]:
|
|
235
|
+
"""
|
|
236
|
+
Get the embedding model used for a collection.
|
|
237
|
+
|
|
238
|
+
Retrieves the model name from:
|
|
239
|
+
1. Collection-level metadata (if supported)
|
|
240
|
+
2. Vector metadata (_embedding_model field)
|
|
241
|
+
3. User settings (for collections we can't modify)
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
collection_name: Name of collection
|
|
245
|
+
connection_id: Optional connection ID for settings lookup
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Model name string (e.g., "sentence-transformers/all-MiniLM-L6-v2") or None
|
|
249
|
+
"""
|
|
250
|
+
try:
|
|
251
|
+
# First try to get from collection-level metadata
|
|
252
|
+
info = self.get_collection_info(collection_name)
|
|
253
|
+
if info and info.get("embedding_model"):
|
|
254
|
+
return info["embedding_model"]
|
|
255
|
+
|
|
256
|
+
# Fall back to checking a sample vector's metadata
|
|
257
|
+
data = self.get_all_items(collection_name, limit=1, offset=0)
|
|
258
|
+
if data and data.get("metadatas") and len(data["metadatas"]) > 0:
|
|
259
|
+
metadata = data["metadatas"][0]
|
|
260
|
+
if "_embedding_model" in metadata:
|
|
261
|
+
return metadata["_embedding_model"]
|
|
262
|
+
|
|
263
|
+
# Finally, check user settings (for collections we can't modify)
|
|
264
|
+
if connection_id:
|
|
265
|
+
from ...services.settings_service import SettingsService
|
|
266
|
+
settings = SettingsService()
|
|
267
|
+
model_info = settings.get_embedding_model(connection_id, collection_name)
|
|
268
|
+
if model_info:
|
|
269
|
+
return model_info["model"]
|
|
270
|
+
|
|
271
|
+
return None
|
|
272
|
+
except Exception as e:
|
|
273
|
+
print(f"Failed to get embedding model: {e}")
|
|
274
|
+
return None
|
|
@@ -6,10 +6,44 @@ from pathlib import Path
|
|
|
6
6
|
import chromadb
|
|
7
7
|
from chromadb.api import ClientAPI
|
|
8
8
|
from chromadb.api.models.Collection import Collection
|
|
9
|
+
from chromadb import Documents, EmbeddingFunction, Embeddings
|
|
9
10
|
|
|
10
11
|
from .base_connection import VectorDBConnection
|
|
11
12
|
|
|
12
13
|
|
|
14
|
+
class DimensionAwareEmbeddingFunction(EmbeddingFunction):
|
|
15
|
+
"""Embedding function that selects model based on collection's expected dimension."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, expected_dimension: int):
|
|
18
|
+
"""Initialize with expected dimension (model loaded lazily on first use)."""
|
|
19
|
+
self.expected_dimension = expected_dimension
|
|
20
|
+
self.model = None
|
|
21
|
+
self.model_name = None
|
|
22
|
+
self.model_type = None
|
|
23
|
+
self._initialized = False
|
|
24
|
+
|
|
25
|
+
def _ensure_model_loaded(self):
|
|
26
|
+
"""Lazy load the embedding model on first use."""
|
|
27
|
+
if self._initialized:
|
|
28
|
+
return
|
|
29
|
+
|
|
30
|
+
from ..embedding_utils import get_embedding_model_for_dimension
|
|
31
|
+
print(f"[ChromaDB] Loading embedding model for {self.expected_dimension}d vectors...")
|
|
32
|
+
self.model, self.model_name, self.model_type = get_embedding_model_for_dimension(self.expected_dimension)
|
|
33
|
+
print(f"[ChromaDB] Using {self.model_type} model '{self.model_name}' for {self.expected_dimension}d embeddings")
|
|
34
|
+
self._initialized = True
|
|
35
|
+
|
|
36
|
+
def __call__(self, input: Documents) -> Embeddings:
|
|
37
|
+
"""Embed documents using the dimension-appropriate model."""
|
|
38
|
+
self._ensure_model_loaded()
|
|
39
|
+
from ..embedding_utils import encode_text
|
|
40
|
+
embeddings = []
|
|
41
|
+
for text in input:
|
|
42
|
+
embedding = encode_text(text, self.model, self.model_type)
|
|
43
|
+
embeddings.append(embedding)
|
|
44
|
+
return embeddings
|
|
45
|
+
|
|
46
|
+
|
|
13
47
|
class ChromaDBConnection(VectorDBConnection):
|
|
14
48
|
"""Manages connection to ChromaDB and provides query interface."""
|
|
15
49
|
|
|
@@ -90,12 +124,47 @@ class ChromaDBConnection(VectorDBConnection):
|
|
|
90
124
|
print(f"Failed to list collections: {e}")
|
|
91
125
|
return []
|
|
92
126
|
|
|
93
|
-
def
|
|
94
|
-
"""
|
|
95
|
-
|
|
127
|
+
def _get_collection_basic(self, name: str) -> Optional[Collection]:
|
|
128
|
+
"""Get collection without custom embedding function (for info lookup)."""
|
|
129
|
+
if not self._client:
|
|
130
|
+
return None
|
|
131
|
+
try:
|
|
132
|
+
return self._client.get_collection(name=name)
|
|
133
|
+
except Exception as e:
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
def _get_embedding_function_for_collection(self, name: str) -> Optional[EmbeddingFunction]:
|
|
137
|
+
"""Get the appropriate embedding function for a collection based on its dimension."""
|
|
138
|
+
# Get basic collection to check dimension
|
|
139
|
+
basic_col = self._get_collection_basic(name)
|
|
140
|
+
if not basic_col:
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
# Get a sample to determine vector dimension
|
|
145
|
+
sample = basic_col.get(limit=1, include=["embeddings"])
|
|
146
|
+
embeddings = sample.get("embeddings") if sample else None
|
|
147
|
+
# Avoid numpy array truthiness issues - check is not None explicitly
|
|
148
|
+
if embeddings is not None and len(embeddings) > 0:
|
|
149
|
+
first_embedding = embeddings[0]
|
|
150
|
+
# Check if embedding exists and has content
|
|
151
|
+
if first_embedding is not None and len(first_embedding) > 0:
|
|
152
|
+
vector_dim = len(first_embedding)
|
|
153
|
+
print(f"[ChromaDB] Collection '{name}' has {vector_dim}d vectors")
|
|
154
|
+
return DimensionAwareEmbeddingFunction(vector_dim)
|
|
155
|
+
except Exception as e:
|
|
156
|
+
print(f"[ChromaDB] Failed to determine embedding function: {e}")
|
|
157
|
+
import traceback
|
|
158
|
+
traceback.print_exc()
|
|
159
|
+
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
def get_collection(self, name: str, embedding_function: Optional[EmbeddingFunction] = None) -> Optional[Collection]:
|
|
163
|
+
"""Get a collection (without overriding existing embedding function).
|
|
96
164
|
|
|
97
165
|
Args:
|
|
98
166
|
name: Collection name
|
|
167
|
+
embedding_function: Optional custom embedding function (ignored if collection exists)
|
|
99
168
|
|
|
100
169
|
Returns:
|
|
101
170
|
Collection object or None if failed
|
|
@@ -103,7 +172,9 @@ class ChromaDBConnection(VectorDBConnection):
|
|
|
103
172
|
if not self._client:
|
|
104
173
|
return None
|
|
105
174
|
try:
|
|
106
|
-
|
|
175
|
+
# Just get the collection without trying to override embedding function
|
|
176
|
+
# This avoids conflicts with existing collections
|
|
177
|
+
self._current_collection = self._client.get_collection(name=name)
|
|
107
178
|
return self._current_collection
|
|
108
179
|
except Exception as e:
|
|
109
180
|
print(f"Failed to get collection: {e}")
|
|
@@ -119,7 +190,7 @@ class ChromaDBConnection(VectorDBConnection):
|
|
|
119
190
|
Returns:
|
|
120
191
|
Dictionary with collection info
|
|
121
192
|
"""
|
|
122
|
-
collection = self.
|
|
193
|
+
collection = self._get_collection_basic(name)
|
|
123
194
|
if not collection:
|
|
124
195
|
return None
|
|
125
196
|
|
|
@@ -141,27 +212,37 @@ class ChromaDBConnection(VectorDBConnection):
|
|
|
141
212
|
# ChromaDB uses cosine distance by default (or can be configured)
|
|
142
213
|
# Try to get metadata from collection if available
|
|
143
214
|
distance_metric = "Cosine (default)"
|
|
215
|
+
embedding_model = None
|
|
144
216
|
try:
|
|
145
217
|
# ChromaDB collections may have metadata about distance function
|
|
146
218
|
col_metadata = collection.metadata
|
|
147
|
-
if col_metadata
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
219
|
+
if col_metadata:
|
|
220
|
+
if "hnsw:space" in col_metadata:
|
|
221
|
+
space = col_metadata["hnsw:space"]
|
|
222
|
+
if space == "l2":
|
|
223
|
+
distance_metric = "Euclidean (L2)"
|
|
224
|
+
elif space == "ip":
|
|
225
|
+
distance_metric = "Inner Product"
|
|
226
|
+
elif space == "cosine":
|
|
227
|
+
distance_metric = "Cosine"
|
|
228
|
+
# Get embedding model if stored
|
|
229
|
+
if "embedding_model" in col_metadata:
|
|
230
|
+
embedding_model = col_metadata["embedding_model"]
|
|
155
231
|
except:
|
|
156
232
|
pass # Use default if unable to determine
|
|
157
233
|
|
|
158
|
-
|
|
234
|
+
result = {
|
|
159
235
|
"name": name,
|
|
160
236
|
"count": count,
|
|
161
237
|
"metadata_fields": metadata_fields,
|
|
162
238
|
"vector_dimension": vector_dimension,
|
|
163
239
|
"distance_metric": distance_metric,
|
|
164
240
|
}
|
|
241
|
+
|
|
242
|
+
if embedding_model:
|
|
243
|
+
result["embedding_model"] = embedding_model
|
|
244
|
+
|
|
245
|
+
return result
|
|
165
246
|
except Exception as e:
|
|
166
247
|
print(f"Failed to get collection info: {e}")
|
|
167
248
|
return None
|
|
@@ -189,10 +270,22 @@ class ChromaDBConnection(VectorDBConnection):
|
|
|
189
270
|
Returns:
|
|
190
271
|
Query results or None if failed
|
|
191
272
|
"""
|
|
273
|
+
print(f"[ChromaDB] query_collection called for '{collection_name}'")
|
|
192
274
|
collection = self.get_collection(collection_name)
|
|
193
275
|
if not collection:
|
|
276
|
+
print(f"[ChromaDB] Failed to get collection '{collection_name}'")
|
|
194
277
|
return None
|
|
195
278
|
|
|
279
|
+
# If query_texts provided, we need to manually embed them with dimension-aware model
|
|
280
|
+
if query_texts and not query_embeddings:
|
|
281
|
+
embedding_function = self._get_embedding_function_for_collection(collection_name)
|
|
282
|
+
if embedding_function:
|
|
283
|
+
print(f"[ChromaDB] Manually embedding query texts with dimension-aware model")
|
|
284
|
+
query_embeddings = embedding_function(query_texts)
|
|
285
|
+
query_texts = None # Use embeddings instead of texts
|
|
286
|
+
else:
|
|
287
|
+
print(f"[ChromaDB] Warning: Could not determine embedding function, using collection's default")
|
|
288
|
+
|
|
196
289
|
try:
|
|
197
290
|
results = collection.query(
|
|
198
291
|
query_texts=query_texts,
|
|
@@ -205,6 +298,8 @@ class ChromaDBConnection(VectorDBConnection):
|
|
|
205
298
|
return cast(Dict[str, Any], results)
|
|
206
299
|
except Exception as e:
|
|
207
300
|
print(f"Query failed: {e}")
|
|
301
|
+
import traceback
|
|
302
|
+
traceback.print_exc()
|
|
208
303
|
return None
|
|
209
304
|
|
|
210
305
|
def get_all_items(
|
|
@@ -368,8 +463,34 @@ class ChromaDBConnection(VectorDBConnection):
|
|
|
368
463
|
|
|
369
464
|
# Implement base connection uniform APIs
|
|
370
465
|
def create_collection(self, name: str, vector_size: int, distance: str = "Cosine") -> bool:
|
|
371
|
-
"""Create a collection.
|
|
372
|
-
|
|
466
|
+
"""Create a collection. If it doesn't exist, attempt to create it using Chroma client APIs."""
|
|
467
|
+
if not self._client:
|
|
468
|
+
return False
|
|
469
|
+
|
|
470
|
+
try:
|
|
471
|
+
# Prefer get_or_create_collection if available
|
|
472
|
+
if hasattr(self._client, "get_or_create_collection"):
|
|
473
|
+
col = self._client.get_or_create_collection(name=name)
|
|
474
|
+
self._current_collection = col
|
|
475
|
+
return True
|
|
476
|
+
|
|
477
|
+
# Fallback to create_collection/create and then fetch
|
|
478
|
+
if hasattr(self._client, "create_collection"):
|
|
479
|
+
try:
|
|
480
|
+
self._client.create_collection(name=name)
|
|
481
|
+
except Exception:
|
|
482
|
+
# Some clients may raise if already exists; ignore
|
|
483
|
+
pass
|
|
484
|
+
col = self._client.get_collection(name=name)
|
|
485
|
+
self._current_collection = col
|
|
486
|
+
return col is not None
|
|
487
|
+
|
|
488
|
+
# As a last resort, check if collection exists
|
|
489
|
+
col = self.get_collection(name)
|
|
490
|
+
return col is not None
|
|
491
|
+
except Exception as e:
|
|
492
|
+
print(f"Failed to create collection: {e}")
|
|
493
|
+
return False
|
|
373
494
|
|
|
374
495
|
def get_items(self, name: str, ids: List[str]) -> Dict[str, Any]:
|
|
375
496
|
"""Retrieve items by IDs."""
|