vector-inspector 0.3.4__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vector_inspector/core/connections/base_connection.py +86 -1
- vector_inspector/core/connections/chroma_connection.py +23 -3
- vector_inspector/core/connections/pgvector_connection.py +1100 -0
- vector_inspector/core/connections/pinecone_connection.py +24 -4
- vector_inspector/core/connections/qdrant_connection.py +224 -189
- vector_inspector/core/embedding_providers/provider_factory.py +33 -38
- vector_inspector/core/embedding_utils.py +2 -2
- vector_inspector/services/backup_restore_service.py +41 -33
- vector_inspector/ui/components/connection_manager_panel.py +96 -77
- vector_inspector/ui/components/profile_manager_panel.py +315 -121
- vector_inspector/ui/dialogs/embedding_config_dialog.py +79 -58
- vector_inspector/ui/main_window.py +22 -0
- vector_inspector/ui/views/connection_view.py +215 -116
- vector_inspector/ui/views/info_panel.py +6 -6
- vector_inspector/ui/views/metadata_view.py +466 -187
- {vector_inspector-0.3.4.dist-info → vector_inspector-0.3.5.dist-info}/METADATA +4 -3
- {vector_inspector-0.3.4.dist-info → vector_inspector-0.3.5.dist-info}/RECORD +19 -18
- {vector_inspector-0.3.4.dist-info → vector_inspector-0.3.5.dist-info}/WHEEL +0 -0
- {vector_inspector-0.3.4.dist-info → vector_inspector-0.3.5.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,1100 @@
|
|
|
1
|
+
"""PgVector/PostgreSQL connection manager."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional, List, Dict, Any
|
|
4
|
+
import json
|
|
5
|
+
import psycopg2
|
|
6
|
+
from psycopg2 import sql
|
|
7
|
+
|
|
8
|
+
## No need to import register_vector; pgvector extension is enabled at table creation
|
|
9
|
+
from vector_inspector.core.connections.base_connection import VectorDBConnection
|
|
10
|
+
from vector_inspector.core.logging import log_error, log_info
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PgVectorConnection(VectorDBConnection):
|
|
14
|
+
"""Manages connection to pgvector/PostgreSQL and provides query interface."""
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
host: str = "localhost",
|
|
19
|
+
port: int = 5432,
|
|
20
|
+
database: str = "subtitles",
|
|
21
|
+
user: str = "postgres",
|
|
22
|
+
password: str = "postgres",
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
Initialize PgVector/PostgreSQL connection.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
host: Database host
|
|
29
|
+
port: Database port
|
|
30
|
+
database: Database name
|
|
31
|
+
user: Username
|
|
32
|
+
password: Password
|
|
33
|
+
"""
|
|
34
|
+
self.host = host
|
|
35
|
+
self.port = port
|
|
36
|
+
self.database = database
|
|
37
|
+
self.user = user
|
|
38
|
+
self.password = password
|
|
39
|
+
self._client: Optional[psycopg2.extensions.connection] = None
|
|
40
|
+
# Track how many embeddings were regenerated by the last update operation
|
|
41
|
+
self._last_regenerated_count: int = 0
|
|
42
|
+
|
|
43
|
+
def connect(self) -> bool:
|
|
44
|
+
"""
|
|
45
|
+
Establish connection to PostgreSQL.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
True if connection successful, False otherwise
|
|
49
|
+
"""
|
|
50
|
+
try:
|
|
51
|
+
self._client = psycopg2.connect(
|
|
52
|
+
host=self.host,
|
|
53
|
+
port=self.port,
|
|
54
|
+
database=self.database,
|
|
55
|
+
user=self.user,
|
|
56
|
+
password=self.password,
|
|
57
|
+
)
|
|
58
|
+
# Use autocommit to avoid leaving the connection in an aborted
|
|
59
|
+
# transaction state after non-fatal errors. This prevents
|
|
60
|
+
# subsequent SELECTs from failing with "current transaction is aborted".
|
|
61
|
+
try:
|
|
62
|
+
self._client.autocommit = True
|
|
63
|
+
except Exception:
|
|
64
|
+
# Some connection wrappers may not support autocommit; ignore
|
|
65
|
+
pass
|
|
66
|
+
# Register pgvector adapter so Python lists can be passed as vector params
|
|
67
|
+
try:
|
|
68
|
+
from pgvector.psycopg2 import register_vector
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
register_vector(self._client)
|
|
72
|
+
except Exception:
|
|
73
|
+
# Some versions accept connection or cursor; try both
|
|
74
|
+
try:
|
|
75
|
+
register_vector(self._client.cursor())
|
|
76
|
+
except Exception:
|
|
77
|
+
pass
|
|
78
|
+
except Exception:
|
|
79
|
+
pass
|
|
80
|
+
return True
|
|
81
|
+
except Exception as e:
|
|
82
|
+
log_error("Connection failed: %s", e)
|
|
83
|
+
self._client = None
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
def disconnect(self):
|
|
87
|
+
"""Close connection to PostgreSQL."""
|
|
88
|
+
if self._client:
|
|
89
|
+
self._client.close()
|
|
90
|
+
self._client = None
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def is_connected(self) -> bool:
|
|
94
|
+
"""Check if connected to PostgreSQL."""
|
|
95
|
+
return self._client is not None
|
|
96
|
+
|
|
97
|
+
def list_collections(self) -> List[str]:
|
|
98
|
+
"""
|
|
99
|
+
Get list of all vector tables (collections).
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
List of table names containing vector columns
|
|
103
|
+
"""
|
|
104
|
+
if not self._client:
|
|
105
|
+
return []
|
|
106
|
+
try:
|
|
107
|
+
with self._client.cursor() as cur:
|
|
108
|
+
cur.execute("""
|
|
109
|
+
SELECT DISTINCT table_name FROM information_schema.columns
|
|
110
|
+
WHERE data_type = 'USER-DEFINED'
|
|
111
|
+
AND udt_name = 'vector'
|
|
112
|
+
AND table_schema = 'public'
|
|
113
|
+
""")
|
|
114
|
+
tables = [row[0] for row in cur.fetchall()]
|
|
115
|
+
return tables
|
|
116
|
+
except Exception as e:
|
|
117
|
+
log_error("Failed to list collections: %s", e)
|
|
118
|
+
return []
|
|
119
|
+
|
|
120
|
+
def list_databases(self) -> List[str]:
|
|
121
|
+
"""
|
|
122
|
+
List available databases on the server (non-template databases).
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
List of database names, or empty list on error
|
|
126
|
+
"""
|
|
127
|
+
# Prefer using the existing client if available, otherwise open a short-lived connection
|
|
128
|
+
conn = self._client
|
|
129
|
+
tmp_conn = None
|
|
130
|
+
try:
|
|
131
|
+
if not conn:
|
|
132
|
+
# Try connecting to the standard 'postgres' database as a safe default
|
|
133
|
+
tmp_conn = psycopg2.connect(
|
|
134
|
+
host=self.host,
|
|
135
|
+
port=self.port,
|
|
136
|
+
database="postgres",
|
|
137
|
+
user=self.user,
|
|
138
|
+
password=self.password,
|
|
139
|
+
)
|
|
140
|
+
conn = tmp_conn
|
|
141
|
+
|
|
142
|
+
with conn.cursor() as cur:
|
|
143
|
+
cur.execute(
|
|
144
|
+
"SELECT datname FROM pg_database WHERE datistemplate = false ORDER BY datname"
|
|
145
|
+
)
|
|
146
|
+
rows = cur.fetchall()
|
|
147
|
+
return [r[0] for r in rows]
|
|
148
|
+
except Exception as e:
|
|
149
|
+
log_error("Failed to list databases: %s", e)
|
|
150
|
+
return []
|
|
151
|
+
finally:
|
|
152
|
+
if tmp_conn:
|
|
153
|
+
try:
|
|
154
|
+
tmp_conn.close()
|
|
155
|
+
except Exception:
|
|
156
|
+
pass
|
|
157
|
+
|
|
158
|
+
def get_collection_info(self, name: str) -> Optional[Dict[str, Any]]:
|
|
159
|
+
"""
|
|
160
|
+
Get collection metadata and statistics.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
name: Table name
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Dictionary with collection info
|
|
167
|
+
"""
|
|
168
|
+
if not self._client:
|
|
169
|
+
return None
|
|
170
|
+
try:
|
|
171
|
+
with self._client.cursor() as cur:
|
|
172
|
+
# Use sql.Identifier to safely quote table name
|
|
173
|
+
cur.execute(sql.SQL("SELECT COUNT(*) FROM {}").format(sql.Identifier(name)))
|
|
174
|
+
result = cur.fetchone()
|
|
175
|
+
count = result[0] if result else 0
|
|
176
|
+
|
|
177
|
+
# Get schema to identify metadata columns (exclude id, document, embedding)
|
|
178
|
+
schema = self._get_table_schema(name)
|
|
179
|
+
metadata_fields = [
|
|
180
|
+
col for col in schema.keys() if col not in ["id", "document", "embedding"]
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
# Try to determine vector dimension and detect stored embedding model from a sample row
|
|
184
|
+
vector_dimension = "Unknown"
|
|
185
|
+
detected_model = None
|
|
186
|
+
detected_model_type = None
|
|
187
|
+
|
|
188
|
+
try:
|
|
189
|
+
cur.execute(
|
|
190
|
+
sql.SQL("SELECT embedding, metadata FROM {} LIMIT 1").format(
|
|
191
|
+
sql.Identifier(name)
|
|
192
|
+
)
|
|
193
|
+
)
|
|
194
|
+
sample = cur.fetchone()
|
|
195
|
+
if sample:
|
|
196
|
+
emb_val, meta_val = sample[0], sample[1]
|
|
197
|
+
# Determine vector dimension
|
|
198
|
+
try:
|
|
199
|
+
parsed = self._parse_vector(emb_val)
|
|
200
|
+
if parsed:
|
|
201
|
+
vector_dimension = len(parsed)
|
|
202
|
+
except Exception:
|
|
203
|
+
vector_dimension = "Unknown"
|
|
204
|
+
|
|
205
|
+
# Try to detect embedding model from metadata
|
|
206
|
+
meta_obj = None
|
|
207
|
+
if isinstance(meta_val, (str, bytes)):
|
|
208
|
+
try:
|
|
209
|
+
meta_obj = json.loads(meta_val)
|
|
210
|
+
except Exception:
|
|
211
|
+
meta_obj = None
|
|
212
|
+
elif isinstance(meta_val, dict):
|
|
213
|
+
meta_obj = meta_val
|
|
214
|
+
|
|
215
|
+
if meta_obj:
|
|
216
|
+
if "embedding_model" in meta_obj:
|
|
217
|
+
detected_model = meta_obj.get("embedding_model")
|
|
218
|
+
detected_model_type = meta_obj.get("embedding_model_type", "stored")
|
|
219
|
+
elif "_embedding_model" in meta_obj:
|
|
220
|
+
detected_model = meta_obj.get("_embedding_model")
|
|
221
|
+
detected_model_type = "stored"
|
|
222
|
+
except Exception:
|
|
223
|
+
# Best-effort; non-fatal
|
|
224
|
+
pass
|
|
225
|
+
|
|
226
|
+
result = {"name": name, "count": count, "metadata_fields": metadata_fields}
|
|
227
|
+
if vector_dimension != "Unknown":
|
|
228
|
+
result["vector_dimension"] = vector_dimension
|
|
229
|
+
if detected_model:
|
|
230
|
+
result["embedding_model"] = detected_model
|
|
231
|
+
result["embedding_model_type"] = detected_model_type or "stored"
|
|
232
|
+
|
|
233
|
+
return result
|
|
234
|
+
except Exception as e:
|
|
235
|
+
log_error("Failed to get collection info: %s", e)
|
|
236
|
+
return None
|
|
237
|
+
|
|
238
|
+
def create_collection(self, name: str, vector_size: int, distance: str = "cosine") -> bool:
|
|
239
|
+
"""
|
|
240
|
+
Create a new table for storing vectors.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
name: Table name
|
|
244
|
+
vector_size: Dimension of vectors
|
|
245
|
+
distance: Distance metric (cosine, euclidean, dotproduct, euclidean)
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
True if successful, False otherwise
|
|
249
|
+
"""
|
|
250
|
+
if not self._client:
|
|
251
|
+
return False
|
|
252
|
+
try:
|
|
253
|
+
with self._client.cursor() as cur:
|
|
254
|
+
# Ensure pgvector extension is enabled
|
|
255
|
+
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
|
256
|
+
|
|
257
|
+
# Create table with TEXT id to support custom IDs from migrations/backups
|
|
258
|
+
cur.execute(
|
|
259
|
+
sql.SQL(
|
|
260
|
+
"CREATE TABLE {} (id TEXT PRIMARY KEY, document TEXT, metadata JSONB, embedding vector({}))"
|
|
261
|
+
).format(sql.Identifier(name), sql.Literal(vector_size))
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
# Map distance metric to pgvector index operator
|
|
265
|
+
distance_lower = distance.lower()
|
|
266
|
+
if distance_lower in ["cosine", "cos"]:
|
|
267
|
+
ops_class = "vector_cosine_ops"
|
|
268
|
+
elif distance_lower in ["euclidean", "l2"]:
|
|
269
|
+
ops_class = "vector_l2_ops"
|
|
270
|
+
elif distance_lower in ["dotproduct", "dot", "ip"]:
|
|
271
|
+
ops_class = "vector_ip_ops"
|
|
272
|
+
else:
|
|
273
|
+
# Default to cosine
|
|
274
|
+
ops_class = "vector_cosine_ops"
|
|
275
|
+
|
|
276
|
+
# Create index for vector similarity search
|
|
277
|
+
index_name = f"{name}_embedding_idx"
|
|
278
|
+
cur.execute(
|
|
279
|
+
sql.SQL("CREATE INDEX {} ON {} USING ivfflat (embedding {})").format(
|
|
280
|
+
sql.Identifier(index_name), sql.Identifier(name), sql.SQL(ops_class)
|
|
281
|
+
)
|
|
282
|
+
)
|
|
283
|
+
self._client.commit()
|
|
284
|
+
return True
|
|
285
|
+
except Exception as e:
|
|
286
|
+
log_error("Failed to create collection: %s", e)
|
|
287
|
+
if self._client:
|
|
288
|
+
self._client.rollback()
|
|
289
|
+
return False
|
|
290
|
+
|
|
291
|
+
def add_items(
|
|
292
|
+
self,
|
|
293
|
+
collection_name: str,
|
|
294
|
+
documents: List[str],
|
|
295
|
+
metadatas: Optional[List[Dict[str, Any]]] = None,
|
|
296
|
+
ids: Optional[List[str]] = None,
|
|
297
|
+
embeddings: Optional[List[List[float]]] = None,
|
|
298
|
+
) -> bool:
|
|
299
|
+
"""
|
|
300
|
+
Add items to a collection.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
collection_name: Table name
|
|
304
|
+
documents: Document texts
|
|
305
|
+
metadatas: Metadata for each document (optional)
|
|
306
|
+
ids: IDs for each document (required for proper migration support)
|
|
307
|
+
embeddings: Pre-computed embeddings
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
True if successful, False otherwise
|
|
311
|
+
"""
|
|
312
|
+
if not self._client:
|
|
313
|
+
return False
|
|
314
|
+
|
|
315
|
+
# If embeddings weren't provided, try to compute them using configured/default model
|
|
316
|
+
if not embeddings:
|
|
317
|
+
try:
|
|
318
|
+
from vector_inspector.services.settings_service import SettingsService
|
|
319
|
+
from vector_inspector.core.embedding_utils import (
|
|
320
|
+
load_embedding_model,
|
|
321
|
+
get_embedding_model_for_dimension,
|
|
322
|
+
DEFAULT_MODEL,
|
|
323
|
+
encode_text,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
model_name = None
|
|
327
|
+
model_type = None
|
|
328
|
+
|
|
329
|
+
# 1) settings
|
|
330
|
+
settings = SettingsService()
|
|
331
|
+
model_info = settings.get_embedding_model(self.database, collection_name)
|
|
332
|
+
if model_info:
|
|
333
|
+
model_name = model_info.get("model")
|
|
334
|
+
model_type = model_info.get("type", "sentence-transformer")
|
|
335
|
+
|
|
336
|
+
# 2) collection metadata
|
|
337
|
+
coll_info = None
|
|
338
|
+
if not model_name:
|
|
339
|
+
coll_info = self.get_collection_info(collection_name)
|
|
340
|
+
if coll_info and coll_info.get("embedding_model"):
|
|
341
|
+
model_name = coll_info.get("embedding_model")
|
|
342
|
+
model_type = coll_info.get("embedding_model_type", "stored")
|
|
343
|
+
|
|
344
|
+
# 3) dimension-based fallback
|
|
345
|
+
loaded_model = None
|
|
346
|
+
if not model_name:
|
|
347
|
+
# Try to get vector dimension
|
|
348
|
+
dim = None
|
|
349
|
+
if not coll_info:
|
|
350
|
+
coll_info = self.get_collection_info(collection_name)
|
|
351
|
+
if coll_info and coll_info.get("vector_dimension"):
|
|
352
|
+
try:
|
|
353
|
+
dim = int(coll_info.get("vector_dimension"))
|
|
354
|
+
except Exception:
|
|
355
|
+
dim = None
|
|
356
|
+
if dim:
|
|
357
|
+
loaded_model, model_name, model_type = get_embedding_model_for_dimension(
|
|
358
|
+
dim
|
|
359
|
+
)
|
|
360
|
+
else:
|
|
361
|
+
model_name, model_type = DEFAULT_MODEL
|
|
362
|
+
|
|
363
|
+
# Load model
|
|
364
|
+
if not loaded_model:
|
|
365
|
+
loaded_model = load_embedding_model(model_name, model_type)
|
|
366
|
+
|
|
367
|
+
# Compute embeddings for all documents
|
|
368
|
+
if model_type != "clip":
|
|
369
|
+
embeddings = loaded_model.encode(documents, show_progress_bar=False).tolist()
|
|
370
|
+
else:
|
|
371
|
+
embeddings = [encode_text(d, loaded_model, model_type) for d in documents]
|
|
372
|
+
except Exception as e:
|
|
373
|
+
log_error("Failed to compute embeddings on add: %s", e)
|
|
374
|
+
return False
|
|
375
|
+
try:
|
|
376
|
+
import uuid
|
|
377
|
+
|
|
378
|
+
# Get table schema to determine column structure
|
|
379
|
+
schema = self._get_table_schema(collection_name)
|
|
380
|
+
has_metadata_col = "metadata" in schema
|
|
381
|
+
|
|
382
|
+
with self._client.cursor() as cur:
|
|
383
|
+
for i, emb in enumerate(embeddings):
|
|
384
|
+
# Use provided ID or generate a UUID
|
|
385
|
+
item_id = ids[i] if ids and i < len(ids) else str(uuid.uuid4())
|
|
386
|
+
doc = documents[i] if i < len(documents) else None
|
|
387
|
+
metadata = metadatas[i] if metadatas and i < len(metadatas) else {}
|
|
388
|
+
# Build insert statement based on schema
|
|
389
|
+
if has_metadata_col:
|
|
390
|
+
# Use JSONB metadata column
|
|
391
|
+
metadata_json = json.dumps(metadata) if metadata else None
|
|
392
|
+
cur.execute(
|
|
393
|
+
sql.SQL(
|
|
394
|
+
"INSERT INTO {} (id, document, metadata, embedding) VALUES (%s, %s, %s, %s)"
|
|
395
|
+
).format(sql.Identifier(collection_name)),
|
|
396
|
+
(item_id, doc, metadata_json, emb),
|
|
397
|
+
)
|
|
398
|
+
else:
|
|
399
|
+
# Map metadata to specific columns
|
|
400
|
+
columns = ["id", "embedding"]
|
|
401
|
+
values = [item_id, emb]
|
|
402
|
+
|
|
403
|
+
if "document" in schema and doc is not None:
|
|
404
|
+
columns.append("document")
|
|
405
|
+
values.append(doc)
|
|
406
|
+
|
|
407
|
+
# Add metadata fields that exist as columns
|
|
408
|
+
if metadata:
|
|
409
|
+
for key, value in metadata.items():
|
|
410
|
+
if key in schema:
|
|
411
|
+
columns.append(key)
|
|
412
|
+
values.append(value)
|
|
413
|
+
|
|
414
|
+
placeholders = ", ".join(["%s"] * len(values))
|
|
415
|
+
cur.execute(
|
|
416
|
+
sql.SQL("INSERT INTO {} ({}) VALUES ({})").format(
|
|
417
|
+
sql.Identifier(collection_name),
|
|
418
|
+
sql.SQL(", ").join(sql.Identifier(c) for c in columns),
|
|
419
|
+
sql.SQL(placeholders),
|
|
420
|
+
),
|
|
421
|
+
values,
|
|
422
|
+
)
|
|
423
|
+
self._client.commit()
|
|
424
|
+
return True
|
|
425
|
+
except Exception as e:
|
|
426
|
+
log_error("Failed to add items: %s", e)
|
|
427
|
+
if self._client:
|
|
428
|
+
self._client.rollback()
|
|
429
|
+
return False
|
|
430
|
+
|
|
431
|
+
def get_items(self, name: str, ids: List[str]) -> Dict[str, Any]:
|
|
432
|
+
"""
|
|
433
|
+
Retrieve items by IDs.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
name: Table name
|
|
437
|
+
ids: List of IDs
|
|
438
|
+
|
|
439
|
+
Returns:
|
|
440
|
+
Dict with 'documents', 'metadatas', 'embeddings'
|
|
441
|
+
"""
|
|
442
|
+
if not self._client:
|
|
443
|
+
return {}
|
|
444
|
+
try:
|
|
445
|
+
schema = self._get_table_schema(name)
|
|
446
|
+
has_metadata_col = "metadata" in schema
|
|
447
|
+
|
|
448
|
+
with self._client.cursor() as cur:
|
|
449
|
+
# Select all columns
|
|
450
|
+
cur.execute(
|
|
451
|
+
sql.SQL("SELECT * FROM {} WHERE id = ANY(%s)").format(sql.Identifier(name)),
|
|
452
|
+
(ids,),
|
|
453
|
+
)
|
|
454
|
+
rows = cur.fetchall()
|
|
455
|
+
colnames = [desc[0] for desc in cur.description]
|
|
456
|
+
|
|
457
|
+
# Build results
|
|
458
|
+
result_ids = []
|
|
459
|
+
result_docs = []
|
|
460
|
+
result_metas = []
|
|
461
|
+
result_embeds = []
|
|
462
|
+
|
|
463
|
+
for row in rows:
|
|
464
|
+
row_dict = dict(zip(colnames, row))
|
|
465
|
+
result_ids.append(str(row_dict.get("id", "")))
|
|
466
|
+
result_docs.append(row_dict.get("document", ""))
|
|
467
|
+
|
|
468
|
+
# Handle metadata
|
|
469
|
+
if has_metadata_col:
|
|
470
|
+
meta = row_dict.get("metadata")
|
|
471
|
+
if isinstance(meta, (str, bytes)):
|
|
472
|
+
try:
|
|
473
|
+
parsed_meta = json.loads(meta)
|
|
474
|
+
except Exception:
|
|
475
|
+
parsed_meta = {}
|
|
476
|
+
elif isinstance(meta, dict):
|
|
477
|
+
parsed_meta = meta
|
|
478
|
+
else:
|
|
479
|
+
parsed_meta = {}
|
|
480
|
+
result_metas.append(parsed_meta)
|
|
481
|
+
else:
|
|
482
|
+
# Reconstruct metadata from columns
|
|
483
|
+
metadata = {
|
|
484
|
+
k: v
|
|
485
|
+
for k, v in row_dict.items()
|
|
486
|
+
if k not in ["id", "document", "embedding"]
|
|
487
|
+
}
|
|
488
|
+
result_metas.append(metadata)
|
|
489
|
+
|
|
490
|
+
# Handle embedding
|
|
491
|
+
result_embeds.append(self._parse_vector(row_dict.get("embedding", "")))
|
|
492
|
+
|
|
493
|
+
return {
|
|
494
|
+
"ids": result_ids,
|
|
495
|
+
"documents": result_docs,
|
|
496
|
+
"metadatas": result_metas,
|
|
497
|
+
"embeddings": result_embeds,
|
|
498
|
+
}
|
|
499
|
+
except Exception as e:
|
|
500
|
+
log_error("Failed to get items: %s", e)
|
|
501
|
+
return {}
|
|
502
|
+
|
|
503
|
+
def delete_collection(self, name: str) -> bool:
|
|
504
|
+
"""
|
|
505
|
+
Delete a table (collection).
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
name: Table name
|
|
509
|
+
|
|
510
|
+
Returns:
|
|
511
|
+
True if successful, False otherwise
|
|
512
|
+
"""
|
|
513
|
+
if not self._client:
|
|
514
|
+
return False
|
|
515
|
+
try:
|
|
516
|
+
with self._client.cursor() as cur:
|
|
517
|
+
cur.execute(sql.SQL("DROP TABLE IF EXISTS {} CASCADE").format(sql.Identifier(name)))
|
|
518
|
+
self._client.commit()
|
|
519
|
+
return True
|
|
520
|
+
except Exception as e:
|
|
521
|
+
log_error("Failed to delete collection: %s", e)
|
|
522
|
+
if self._client:
|
|
523
|
+
self._client.rollback()
|
|
524
|
+
return False
|
|
525
|
+
|
|
526
|
+
def count_collection(self, name: str) -> int:
|
|
527
|
+
"""
|
|
528
|
+
Return the number of items in the collection.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
name: Table name
|
|
532
|
+
|
|
533
|
+
Returns:
|
|
534
|
+
Number of items
|
|
535
|
+
"""
|
|
536
|
+
if not self._client:
|
|
537
|
+
return 0
|
|
538
|
+
try:
|
|
539
|
+
with self._client.cursor() as cur:
|
|
540
|
+
cur.execute(sql.SQL("SELECT COUNT(*) FROM {}").format(sql.Identifier(name)))
|
|
541
|
+
result = cur.fetchone()
|
|
542
|
+
count = result[0] if result else 0
|
|
543
|
+
return count
|
|
544
|
+
except Exception as e:
|
|
545
|
+
log_error("Failed to count collection: %s", e)
|
|
546
|
+
return 0
|
|
547
|
+
|
|
548
|
+
def query_collection(
|
|
549
|
+
self,
|
|
550
|
+
collection_name: str,
|
|
551
|
+
query_texts: Optional[List[str]] = None,
|
|
552
|
+
query_embeddings: Optional[List[List[float]]] = None,
|
|
553
|
+
n_results: int = 10,
|
|
554
|
+
where: Optional[Dict[str, Any]] = None,
|
|
555
|
+
where_document: Optional[Dict[str, Any]] = None,
|
|
556
|
+
) -> Optional[Dict[str, Any]]:
|
|
557
|
+
"""
|
|
558
|
+
Query a collection for similar vectors.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
collection_name: Table name
|
|
562
|
+
query_embeddings: Embedding vectors to search
|
|
563
|
+
n_results: Number of results to return
|
|
564
|
+
where: Metadata filter (dict of column:value pairs)
|
|
565
|
+
where_document: Document filter (not implemented)
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
Query results dictionary
|
|
569
|
+
"""
|
|
570
|
+
if not self._client:
|
|
571
|
+
return None
|
|
572
|
+
|
|
573
|
+
# If caller provided query texts (not embeddings), compute embeddings using configured model
|
|
574
|
+
if (not query_embeddings) and query_texts:
|
|
575
|
+
try:
|
|
576
|
+
from vector_inspector.services.settings_service import SettingsService
|
|
577
|
+
from vector_inspector.core.embedding_utils import (
|
|
578
|
+
load_embedding_model,
|
|
579
|
+
get_embedding_model_for_dimension,
|
|
580
|
+
DEFAULT_MODEL,
|
|
581
|
+
encode_text,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
model_name = None
|
|
585
|
+
model_type = None
|
|
586
|
+
|
|
587
|
+
# 1) settings
|
|
588
|
+
settings = SettingsService()
|
|
589
|
+
model_info = settings.get_embedding_model(self.database, collection_name)
|
|
590
|
+
if model_info:
|
|
591
|
+
model_name = model_info.get("model")
|
|
592
|
+
model_type = model_info.get("type", "sentence-transformer")
|
|
593
|
+
|
|
594
|
+
# 2) collection metadata
|
|
595
|
+
if not model_name:
|
|
596
|
+
coll_info = self.get_collection_info(collection_name)
|
|
597
|
+
if coll_info and coll_info.get("embedding_model"):
|
|
598
|
+
model_name = coll_info.get("embedding_model")
|
|
599
|
+
model_type = coll_info.get("embedding_model_type", "stored")
|
|
600
|
+
|
|
601
|
+
# 3) dimension-based fallback
|
|
602
|
+
loaded_model = None
|
|
603
|
+
if not model_name:
|
|
604
|
+
dim = None
|
|
605
|
+
coll_info = self.get_collection_info(collection_name)
|
|
606
|
+
if coll_info and coll_info.get("vector_dimension"):
|
|
607
|
+
try:
|
|
608
|
+
dim = int(coll_info.get("vector_dimension"))
|
|
609
|
+
except Exception:
|
|
610
|
+
dim = None
|
|
611
|
+
if dim:
|
|
612
|
+
loaded_model, model_name, model_type = get_embedding_model_for_dimension(
|
|
613
|
+
dim
|
|
614
|
+
)
|
|
615
|
+
else:
|
|
616
|
+
model_name, model_type = DEFAULT_MODEL
|
|
617
|
+
|
|
618
|
+
if not loaded_model:
|
|
619
|
+
loaded_model = load_embedding_model(model_name, model_type)
|
|
620
|
+
|
|
621
|
+
# Compute embeddings for the provided query_texts (use helper for CLIP)
|
|
622
|
+
if model_type != "clip":
|
|
623
|
+
computed = loaded_model.encode(query_texts, show_progress_bar=False).tolist()
|
|
624
|
+
else:
|
|
625
|
+
computed = [encode_text(t, loaded_model, model_type) for t in query_texts]
|
|
626
|
+
|
|
627
|
+
query_embeddings = computed
|
|
628
|
+
except Exception as e:
|
|
629
|
+
log_error("Failed to compute query embeddings: %s", e)
|
|
630
|
+
return None
|
|
631
|
+
try:
|
|
632
|
+
schema = self._get_table_schema(collection_name)
|
|
633
|
+
has_metadata_col = "metadata" in schema
|
|
634
|
+
|
|
635
|
+
# For each query embedding, run a separate SELECT ordered by distance
|
|
636
|
+
# so callers receive the top-N results per query (matching SearchView expectations).
|
|
637
|
+
with self._client.cursor() as cur:
|
|
638
|
+
# Prepare containers for per-query results
|
|
639
|
+
per_ids: List[List[str]] = []
|
|
640
|
+
per_docs: List[List[str]] = []
|
|
641
|
+
per_metas: List[List[Dict[str, Any]]] = []
|
|
642
|
+
per_embeds: List[List[List[float]]] = []
|
|
643
|
+
per_dists: List[List[float]] = []
|
|
644
|
+
|
|
645
|
+
for emb in query_embeddings:
|
|
646
|
+
# Build base query for this single embedding
|
|
647
|
+
query_parts = [
|
|
648
|
+
sql.SQL("SELECT *, embedding <=> %s::vector AS distance FROM {}").format(
|
|
649
|
+
sql.Identifier(collection_name)
|
|
650
|
+
)
|
|
651
|
+
]
|
|
652
|
+
params = [emb]
|
|
653
|
+
|
|
654
|
+
# Add WHERE clause for filtering
|
|
655
|
+
if where:
|
|
656
|
+
conditions = []
|
|
657
|
+
for key, value in where.items():
|
|
658
|
+
if has_metadata_col and key != "metadata":
|
|
659
|
+
conditions.append(sql.SQL("metadata->>%s = %s"))
|
|
660
|
+
params.extend([key, str(value)])
|
|
661
|
+
elif key in schema:
|
|
662
|
+
conditions.append(sql.SQL("{} = %s").format(sql.Identifier(key)))
|
|
663
|
+
params.append(value)
|
|
664
|
+
|
|
665
|
+
if conditions:
|
|
666
|
+
query_parts.append(sql.SQL(" WHERE "))
|
|
667
|
+
query_parts.append(sql.SQL(" AND ").join(conditions))
|
|
668
|
+
|
|
669
|
+
query_parts.append(sql.SQL(" ORDER BY distance ASC LIMIT %s"))
|
|
670
|
+
params.append(n_results)
|
|
671
|
+
|
|
672
|
+
query = sql.SQL("").join(query_parts)
|
|
673
|
+
cur.execute(query, params)
|
|
674
|
+
rows = cur.fetchall()
|
|
675
|
+
colnames = [desc[0] for desc in cur.description]
|
|
676
|
+
|
|
677
|
+
# Build per-query result lists
|
|
678
|
+
ids_q: List[str] = []
|
|
679
|
+
docs_q: List[str] = []
|
|
680
|
+
metas_q: List[Dict[str, Any]] = []
|
|
681
|
+
embeds_q: List[List[float]] = []
|
|
682
|
+
dists_q: List[float] = []
|
|
683
|
+
|
|
684
|
+
for row in rows:
|
|
685
|
+
row_dict = dict(zip(colnames, row))
|
|
686
|
+
ids_q.append(str(row_dict.get("id", "")))
|
|
687
|
+
docs_q.append(row_dict.get("document", ""))
|
|
688
|
+
|
|
689
|
+
# Handle metadata
|
|
690
|
+
if has_metadata_col:
|
|
691
|
+
meta = row_dict.get("metadata")
|
|
692
|
+
if isinstance(meta, (str, bytes)):
|
|
693
|
+
try:
|
|
694
|
+
parsed_meta = json.loads(meta)
|
|
695
|
+
except Exception:
|
|
696
|
+
parsed_meta = {}
|
|
697
|
+
elif isinstance(meta, dict):
|
|
698
|
+
parsed_meta = meta
|
|
699
|
+
else:
|
|
700
|
+
parsed_meta = {}
|
|
701
|
+
metas_q.append(parsed_meta)
|
|
702
|
+
else:
|
|
703
|
+
metadata = {
|
|
704
|
+
k: v
|
|
705
|
+
for k, v in row_dict.items()
|
|
706
|
+
if k not in ["id", "document", "embedding", "distance"]
|
|
707
|
+
}
|
|
708
|
+
metas_q.append(metadata)
|
|
709
|
+
|
|
710
|
+
embeds_q.append(self._parse_vector(row_dict.get("embedding", "")))
|
|
711
|
+
dists_q.append(float(row_dict.get("distance", 0)))
|
|
712
|
+
|
|
713
|
+
per_ids.append(ids_q)
|
|
714
|
+
per_docs.append(docs_q)
|
|
715
|
+
per_metas.append(metas_q)
|
|
716
|
+
per_embeds.append(embeds_q)
|
|
717
|
+
per_dists.append(dists_q)
|
|
718
|
+
|
|
719
|
+
# Return results in the same per-query list-of-lists format as other providers
|
|
720
|
+
return {
|
|
721
|
+
"ids": per_ids,
|
|
722
|
+
"documents": per_docs,
|
|
723
|
+
"metadatas": per_metas,
|
|
724
|
+
"embeddings": per_embeds,
|
|
725
|
+
"distances": per_dists,
|
|
726
|
+
}
|
|
727
|
+
except Exception as e:
|
|
728
|
+
log_error("Query failed: %s", e)
|
|
729
|
+
return None
|
|
730
|
+
|
|
731
|
+
def get_all_items(
|
|
732
|
+
self,
|
|
733
|
+
collection_name: str,
|
|
734
|
+
limit: Optional[int] = None,
|
|
735
|
+
offset: Optional[int] = None,
|
|
736
|
+
where: Optional[Dict[str, Any]] = None,
|
|
737
|
+
) -> Optional[Dict[str, Any]]:
|
|
738
|
+
"""
|
|
739
|
+
Get all items from a collection.
|
|
740
|
+
|
|
741
|
+
Args:
|
|
742
|
+
collection_name: Table name
|
|
743
|
+
limit: Max items
|
|
744
|
+
offset: Offset
|
|
745
|
+
where: Metadata filter (dict of column:value pairs)
|
|
746
|
+
|
|
747
|
+
Returns:
|
|
748
|
+
Dict with items
|
|
749
|
+
"""
|
|
750
|
+
if not self._client:
|
|
751
|
+
return None
|
|
752
|
+
try:
|
|
753
|
+
schema = self._get_table_schema(collection_name)
|
|
754
|
+
has_metadata_col = "metadata" in schema
|
|
755
|
+
|
|
756
|
+
with self._client.cursor() as cur:
|
|
757
|
+
query_parts = [sql.SQL("SELECT * FROM {}").format(sql.Identifier(collection_name))]
|
|
758
|
+
params = []
|
|
759
|
+
|
|
760
|
+
# Add WHERE clause for filtering
|
|
761
|
+
if where:
|
|
762
|
+
conditions = []
|
|
763
|
+
for key, value in where.items():
|
|
764
|
+
if has_metadata_col and key != "metadata":
|
|
765
|
+
# Filter on JSONB metadata column
|
|
766
|
+
conditions.append(sql.SQL("metadata->>%s = %s"))
|
|
767
|
+
params.extend([key, str(value)])
|
|
768
|
+
elif key in schema:
|
|
769
|
+
# Filter on actual column
|
|
770
|
+
conditions.append(sql.SQL("{} = %s").format(sql.Identifier(key)))
|
|
771
|
+
params.append(value)
|
|
772
|
+
|
|
773
|
+
if conditions:
|
|
774
|
+
query_parts.append(sql.SQL(" WHERE "))
|
|
775
|
+
query_parts.append(sql.SQL(" AND ").join(conditions))
|
|
776
|
+
|
|
777
|
+
if limit:
|
|
778
|
+
query_parts.append(sql.SQL(" LIMIT %s"))
|
|
779
|
+
params.append(limit)
|
|
780
|
+
if offset:
|
|
781
|
+
query_parts.append(sql.SQL(" OFFSET %s"))
|
|
782
|
+
params.append(offset)
|
|
783
|
+
|
|
784
|
+
query = sql.SQL("").join(query_parts)
|
|
785
|
+
cur.execute(query, params if params else None)
|
|
786
|
+
rows = cur.fetchall()
|
|
787
|
+
colnames = [desc[0] for desc in cur.description]
|
|
788
|
+
|
|
789
|
+
# Build results
|
|
790
|
+
result_ids = []
|
|
791
|
+
result_docs = []
|
|
792
|
+
result_metas = []
|
|
793
|
+
result_embeds = []
|
|
794
|
+
|
|
795
|
+
for row in rows:
|
|
796
|
+
row_dict = dict(zip(colnames, row))
|
|
797
|
+
result_ids.append(str(row_dict.get("id", "")))
|
|
798
|
+
result_docs.append(row_dict.get("document", ""))
|
|
799
|
+
|
|
800
|
+
# Handle metadata
|
|
801
|
+
if has_metadata_col:
|
|
802
|
+
meta = row_dict.get("metadata")
|
|
803
|
+
if isinstance(meta, (str, bytes)):
|
|
804
|
+
try:
|
|
805
|
+
parsed_meta = json.loads(meta)
|
|
806
|
+
except Exception:
|
|
807
|
+
parsed_meta = {}
|
|
808
|
+
elif isinstance(meta, dict):
|
|
809
|
+
parsed_meta = meta
|
|
810
|
+
else:
|
|
811
|
+
parsed_meta = {}
|
|
812
|
+
result_metas.append(parsed_meta)
|
|
813
|
+
else:
|
|
814
|
+
# Reconstruct metadata from columns
|
|
815
|
+
metadata = {
|
|
816
|
+
k: v
|
|
817
|
+
for k, v in row_dict.items()
|
|
818
|
+
if k not in ["id", "document", "embedding"]
|
|
819
|
+
}
|
|
820
|
+
result_metas.append(metadata)
|
|
821
|
+
|
|
822
|
+
# Handle embedding
|
|
823
|
+
result_embeds.append(self._parse_vector(row_dict.get("embedding", "")))
|
|
824
|
+
|
|
825
|
+
return {
|
|
826
|
+
"ids": result_ids,
|
|
827
|
+
"documents": result_docs,
|
|
828
|
+
"metadatas": result_metas,
|
|
829
|
+
"embeddings": result_embeds,
|
|
830
|
+
}
|
|
831
|
+
except Exception as e:
|
|
832
|
+
log_error("Failed to get items: %s", e)
|
|
833
|
+
return None
|
|
834
|
+
|
|
835
|
+
def update_items(
|
|
836
|
+
self,
|
|
837
|
+
collection_name: str,
|
|
838
|
+
ids: List[str],
|
|
839
|
+
documents: Optional[List[str]] = None,
|
|
840
|
+
metadatas: Optional[List[Dict[str, Any]]] = None,
|
|
841
|
+
embeddings: Optional[List[List[float]]] = None,
|
|
842
|
+
) -> bool:
|
|
843
|
+
"""
|
|
844
|
+
Update items in a collection.
|
|
845
|
+
|
|
846
|
+
Args:
|
|
847
|
+
collection_name: Table name
|
|
848
|
+
ids: IDs to update
|
|
849
|
+
documents: New docs
|
|
850
|
+
metadatas: New metadata
|
|
851
|
+
embeddings: New embeddings
|
|
852
|
+
|
|
853
|
+
Returns:
|
|
854
|
+
True if successful, False otherwise
|
|
855
|
+
"""
|
|
856
|
+
if not self._client or not ids:
|
|
857
|
+
return False
|
|
858
|
+
try:
|
|
859
|
+
# Get table schema to decide how to update metadata (jsonb column vs flattened cols)
|
|
860
|
+
schema = self._get_table_schema(collection_name)
|
|
861
|
+
has_metadata_col = "metadata" in schema
|
|
862
|
+
|
|
863
|
+
# If embeddings are not provided but documents were, compute embeddings
|
|
864
|
+
embeddings_local = embeddings
|
|
865
|
+
# Reset regen counter for this update operation
|
|
866
|
+
self._last_regenerated_count = 0
|
|
867
|
+
if (not embeddings) and documents:
|
|
868
|
+
try:
|
|
869
|
+
# Resolve model for this collection: prefer settings -> collection metadata -> dimension-based
|
|
870
|
+
from vector_inspector.services.settings_service import SettingsService
|
|
871
|
+
from vector_inspector.core.embedding_utils import (
|
|
872
|
+
load_embedding_model,
|
|
873
|
+
get_embedding_model_for_dimension,
|
|
874
|
+
DEFAULT_MODEL,
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
model_name = None
|
|
878
|
+
model_type = None
|
|
879
|
+
|
|
880
|
+
# 1) settings
|
|
881
|
+
settings = SettingsService()
|
|
882
|
+
model_info = settings.get_embedding_model(self.database, collection_name)
|
|
883
|
+
if model_info:
|
|
884
|
+
model_name = model_info.get("model")
|
|
885
|
+
model_type = model_info.get("type", "sentence-transformer")
|
|
886
|
+
|
|
887
|
+
# 2) collection metadata
|
|
888
|
+
if not model_name:
|
|
889
|
+
coll_info = self.get_collection_info(collection_name)
|
|
890
|
+
if coll_info and coll_info.get("embedding_model"):
|
|
891
|
+
model_name = coll_info.get("embedding_model")
|
|
892
|
+
model_type = coll_info.get("embedding_model_type", "stored")
|
|
893
|
+
|
|
894
|
+
# 3) dimension-based fallback
|
|
895
|
+
loaded_model = None
|
|
896
|
+
if not model_name:
|
|
897
|
+
# Try to get vector dimension
|
|
898
|
+
dim = None
|
|
899
|
+
coll_info = self.get_collection_info(collection_name)
|
|
900
|
+
if coll_info and coll_info.get("vector_dimension"):
|
|
901
|
+
try:
|
|
902
|
+
dim = int(coll_info.get("vector_dimension"))
|
|
903
|
+
except Exception:
|
|
904
|
+
dim = None
|
|
905
|
+
if dim:
|
|
906
|
+
loaded_model, model_name, model_type = (
|
|
907
|
+
get_embedding_model_for_dimension(dim)
|
|
908
|
+
)
|
|
909
|
+
else:
|
|
910
|
+
# Use default model
|
|
911
|
+
model_name, model_type = DEFAULT_MODEL
|
|
912
|
+
|
|
913
|
+
# Load model if not already loaded
|
|
914
|
+
if not loaded_model:
|
|
915
|
+
loaded_model = load_embedding_model(model_name, model_type)
|
|
916
|
+
|
|
917
|
+
# Compute embeddings only for documents that are present
|
|
918
|
+
compute_idxs = [i for i, d in enumerate(documents) if d]
|
|
919
|
+
if compute_idxs:
|
|
920
|
+
docs_to_compute = [documents[i] for i in compute_idxs]
|
|
921
|
+
# Use SentenceTransformer batch encode when possible
|
|
922
|
+
if model_type != "clip":
|
|
923
|
+
computed = loaded_model.encode(
|
|
924
|
+
docs_to_compute, show_progress_bar=False
|
|
925
|
+
).tolist()
|
|
926
|
+
else:
|
|
927
|
+
# CLIP type - encode per document using helper
|
|
928
|
+
from vector_inspector.core.embedding_utils import encode_text
|
|
929
|
+
|
|
930
|
+
computed = [
|
|
931
|
+
encode_text(d, loaded_model, model_type) for d in docs_to_compute
|
|
932
|
+
]
|
|
933
|
+
embeddings_local = [None] * len(ids)
|
|
934
|
+
for idx, emb in zip(compute_idxs, computed):
|
|
935
|
+
embeddings_local[idx] = emb
|
|
936
|
+
# Record how many embeddings we generated
|
|
937
|
+
try:
|
|
938
|
+
self._last_regenerated_count = len(compute_idxs)
|
|
939
|
+
log_info(
|
|
940
|
+
"[PgVectorConnection] Computed %d embeddings for update in %s",
|
|
941
|
+
self._last_regenerated_count,
|
|
942
|
+
collection_name,
|
|
943
|
+
)
|
|
944
|
+
except Exception:
|
|
945
|
+
pass
|
|
946
|
+
except Exception as e:
|
|
947
|
+
log_error("Failed to compute embeddings on update: %s", e)
|
|
948
|
+
embeddings_local = [None] * len(ids)
|
|
949
|
+
self._last_regenerated_count = 0
|
|
950
|
+
|
|
951
|
+
with self._client.cursor() as cur:
|
|
952
|
+
for i, item_id in enumerate(ids):
|
|
953
|
+
updates = []
|
|
954
|
+
params = []
|
|
955
|
+
|
|
956
|
+
if documents and i < len(documents):
|
|
957
|
+
updates.append(sql.SQL("document = %s"))
|
|
958
|
+
params.append(documents[i])
|
|
959
|
+
|
|
960
|
+
# Handle metadata update depending on schema
|
|
961
|
+
if metadatas and i < len(metadatas):
|
|
962
|
+
meta = metadatas[i]
|
|
963
|
+
if has_metadata_col:
|
|
964
|
+
updates.append(sql.SQL("metadata = %s"))
|
|
965
|
+
params.append(json.dumps(meta))
|
|
966
|
+
else:
|
|
967
|
+
# Map metadata keys to existing columns only
|
|
968
|
+
for key, value in meta.items():
|
|
969
|
+
if key in schema:
|
|
970
|
+
updates.append(sql.SQL("{} = %s").format(sql.Identifier(key)))
|
|
971
|
+
params.append(value)
|
|
972
|
+
|
|
973
|
+
# Use provided embeddings if present, otherwise use locally computed embedding
|
|
974
|
+
emb_to_use = None
|
|
975
|
+
if embeddings and i < len(embeddings):
|
|
976
|
+
emb_to_use = embeddings[i]
|
|
977
|
+
# caller provided embeddings -> no regeneration
|
|
978
|
+
self._last_regenerated_count = 0
|
|
979
|
+
elif embeddings_local and i < len(embeddings_local):
|
|
980
|
+
emb_to_use = embeddings_local[i]
|
|
981
|
+
|
|
982
|
+
if emb_to_use is not None:
|
|
983
|
+
# Cast parameter to pgvector to ensure correct operator typing
|
|
984
|
+
updates.append(sql.SQL("embedding = %s::vector"))
|
|
985
|
+
params.append(emb_to_use)
|
|
986
|
+
|
|
987
|
+
if updates:
|
|
988
|
+
params.append(item_id)
|
|
989
|
+
query = sql.SQL("UPDATE {} SET {} WHERE id = %s").format(
|
|
990
|
+
sql.Identifier(collection_name), sql.SQL(", ").join(updates)
|
|
991
|
+
)
|
|
992
|
+
cur.execute(query, params)
|
|
993
|
+
|
|
994
|
+
self._client.commit()
|
|
995
|
+
return True
|
|
996
|
+
except Exception as e:
|
|
997
|
+
log_error("Failed to update items: %s", e)
|
|
998
|
+
if self._client:
|
|
999
|
+
self._client.rollback()
|
|
1000
|
+
return False
|
|
1001
|
+
|
|
1002
|
+
def delete_items(
|
|
1003
|
+
self,
|
|
1004
|
+
collection_name: str,
|
|
1005
|
+
ids: Optional[List[str]] = None,
|
|
1006
|
+
where: Optional[Dict[str, Any]] = None,
|
|
1007
|
+
) -> bool:
|
|
1008
|
+
"""
|
|
1009
|
+
Delete items from a collection.
|
|
1010
|
+
|
|
1011
|
+
Args:
|
|
1012
|
+
collection_name: Table name
|
|
1013
|
+
ids: IDs to delete
|
|
1014
|
+
where: Metadata filter (not implemented)
|
|
1015
|
+
|
|
1016
|
+
Returns:
|
|
1017
|
+
True if successful, False otherwise
|
|
1018
|
+
"""
|
|
1019
|
+
if not self._client or not ids:
|
|
1020
|
+
return False
|
|
1021
|
+
try:
|
|
1022
|
+
with self._client.cursor() as cur:
|
|
1023
|
+
cur.execute(
|
|
1024
|
+
sql.SQL("DELETE FROM {} WHERE id = ANY(%s)").format(
|
|
1025
|
+
sql.Identifier(collection_name)
|
|
1026
|
+
),
|
|
1027
|
+
(ids,),
|
|
1028
|
+
)
|
|
1029
|
+
self._client.commit()
|
|
1030
|
+
return True
|
|
1031
|
+
except Exception as e:
|
|
1032
|
+
log_error("Failed to delete items: %s", e)
|
|
1033
|
+
if self._client:
|
|
1034
|
+
self._client.rollback()
|
|
1035
|
+
return False
|
|
1036
|
+
|
|
1037
|
+
def get_connection_info(self) -> Dict[str, Any]:
|
|
1038
|
+
"""
|
|
1039
|
+
Get information about the current connection.
|
|
1040
|
+
|
|
1041
|
+
Returns:
|
|
1042
|
+
Dictionary with connection details
|
|
1043
|
+
"""
|
|
1044
|
+
return {
|
|
1045
|
+
"provider": "PgVector/PostgreSQL",
|
|
1046
|
+
"host": self.host,
|
|
1047
|
+
"port": self.port,
|
|
1048
|
+
"database": self.database,
|
|
1049
|
+
"user": self.user,
|
|
1050
|
+
"connected": self.is_connected,
|
|
1051
|
+
}
|
|
1052
|
+
|
|
1053
|
+
def _get_table_schema(self, table_name: str) -> Dict[str, str]:
|
|
1054
|
+
"""
|
|
1055
|
+
Get the schema (column names and types) for a table.
|
|
1056
|
+
|
|
1057
|
+
Args:
|
|
1058
|
+
table_name: Name of the table
|
|
1059
|
+
|
|
1060
|
+
Returns:
|
|
1061
|
+
Dict mapping column names to their SQL types
|
|
1062
|
+
"""
|
|
1063
|
+
if not self._client:
|
|
1064
|
+
return {}
|
|
1065
|
+
try:
|
|
1066
|
+
with self._client.cursor() as cur:
|
|
1067
|
+
cur.execute(
|
|
1068
|
+
"""SELECT column_name, data_type, udt_name
|
|
1069
|
+
FROM information_schema.columns
|
|
1070
|
+
WHERE table_name = %s AND table_schema = 'public'
|
|
1071
|
+
ORDER BY ordinal_position""",
|
|
1072
|
+
(table_name,),
|
|
1073
|
+
)
|
|
1074
|
+
schema = {}
|
|
1075
|
+
for row in cur.fetchall():
|
|
1076
|
+
col_name, data_type, udt_name = row
|
|
1077
|
+
# Use udt_name for custom types like vector
|
|
1078
|
+
schema[col_name] = udt_name if data_type == "USER-DEFINED" else data_type
|
|
1079
|
+
return schema
|
|
1080
|
+
except Exception as e:
|
|
1081
|
+
log_error("Failed to get table schema: %s", e)
|
|
1082
|
+
return {}
|
|
1083
|
+
|
|
1084
|
+
def _parse_vector(self, vector_str: Any) -> List[float]:
|
|
1085
|
+
"""
|
|
1086
|
+
Parse pgvector string format to Python list.
|
|
1087
|
+
|
|
1088
|
+
Args:
|
|
1089
|
+
vector_str: Vector in string format from database
|
|
1090
|
+
|
|
1091
|
+
Returns:
|
|
1092
|
+
List of floats
|
|
1093
|
+
"""
|
|
1094
|
+
if isinstance(vector_str, list):
|
|
1095
|
+
return vector_str
|
|
1096
|
+
if isinstance(vector_str, str):
|
|
1097
|
+
# Remove brackets and split by comma
|
|
1098
|
+
vector_str = vector_str.strip("[]")
|
|
1099
|
+
return [float(x) for x in vector_str.split(",")]
|
|
1100
|
+
return []
|