vector-inspector 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1100 @@
1
+ """PgVector/PostgreSQL connection manager."""
2
+
3
+ from typing import Optional, List, Dict, Any
4
+ import json
5
+ import psycopg2
6
+ from psycopg2 import sql
7
+
8
+ ## No need to import register_vector; pgvector extension is enabled at table creation
9
+ from vector_inspector.core.connections.base_connection import VectorDBConnection
10
+ from vector_inspector.core.logging import log_error, log_info
11
+
12
+
13
+ class PgVectorConnection(VectorDBConnection):
14
+ """Manages connection to pgvector/PostgreSQL and provides query interface."""
15
+
16
+ def __init__(
17
+ self,
18
+ host: str = "localhost",
19
+ port: int = 5432,
20
+ database: str = "subtitles",
21
+ user: str = "postgres",
22
+ password: str = "postgres",
23
+ ):
24
+ """
25
+ Initialize PgVector/PostgreSQL connection.
26
+
27
+ Args:
28
+ host: Database host
29
+ port: Database port
30
+ database: Database name
31
+ user: Username
32
+ password: Password
33
+ """
34
+ self.host = host
35
+ self.port = port
36
+ self.database = database
37
+ self.user = user
38
+ self.password = password
39
+ self._client: Optional[psycopg2.extensions.connection] = None
40
+ # Track how many embeddings were regenerated by the last update operation
41
+ self._last_regenerated_count: int = 0
42
+
43
+ def connect(self) -> bool:
44
+ """
45
+ Establish connection to PostgreSQL.
46
+
47
+ Returns:
48
+ True if connection successful, False otherwise
49
+ """
50
+ try:
51
+ self._client = psycopg2.connect(
52
+ host=self.host,
53
+ port=self.port,
54
+ database=self.database,
55
+ user=self.user,
56
+ password=self.password,
57
+ )
58
+ # Use autocommit to avoid leaving the connection in an aborted
59
+ # transaction state after non-fatal errors. This prevents
60
+ # subsequent SELECTs from failing with "current transaction is aborted".
61
+ try:
62
+ self._client.autocommit = True
63
+ except Exception:
64
+ # Some connection wrappers may not support autocommit; ignore
65
+ pass
66
+ # Register pgvector adapter so Python lists can be passed as vector params
67
+ try:
68
+ from pgvector.psycopg2 import register_vector
69
+
70
+ try:
71
+ register_vector(self._client)
72
+ except Exception:
73
+ # Some versions accept connection or cursor; try both
74
+ try:
75
+ register_vector(self._client.cursor())
76
+ except Exception:
77
+ pass
78
+ except Exception:
79
+ pass
80
+ return True
81
+ except Exception as e:
82
+ log_error("Connection failed: %s", e)
83
+ self._client = None
84
+ return False
85
+
86
+ def disconnect(self):
87
+ """Close connection to PostgreSQL."""
88
+ if self._client:
89
+ self._client.close()
90
+ self._client = None
91
+
92
+ @property
93
+ def is_connected(self) -> bool:
94
+ """Check if connected to PostgreSQL."""
95
+ return self._client is not None
96
+
97
+ def list_collections(self) -> List[str]:
98
+ """
99
+ Get list of all vector tables (collections).
100
+
101
+ Returns:
102
+ List of table names containing vector columns
103
+ """
104
+ if not self._client:
105
+ return []
106
+ try:
107
+ with self._client.cursor() as cur:
108
+ cur.execute("""
109
+ SELECT DISTINCT table_name FROM information_schema.columns
110
+ WHERE data_type = 'USER-DEFINED'
111
+ AND udt_name = 'vector'
112
+ AND table_schema = 'public'
113
+ """)
114
+ tables = [row[0] for row in cur.fetchall()]
115
+ return tables
116
+ except Exception as e:
117
+ log_error("Failed to list collections: %s", e)
118
+ return []
119
+
120
+ def list_databases(self) -> List[str]:
121
+ """
122
+ List available databases on the server (non-template databases).
123
+
124
+ Returns:
125
+ List of database names, or empty list on error
126
+ """
127
+ # Prefer using the existing client if available, otherwise open a short-lived connection
128
+ conn = self._client
129
+ tmp_conn = None
130
+ try:
131
+ if not conn:
132
+ # Try connecting to the standard 'postgres' database as a safe default
133
+ tmp_conn = psycopg2.connect(
134
+ host=self.host,
135
+ port=self.port,
136
+ database="postgres",
137
+ user=self.user,
138
+ password=self.password,
139
+ )
140
+ conn = tmp_conn
141
+
142
+ with conn.cursor() as cur:
143
+ cur.execute(
144
+ "SELECT datname FROM pg_database WHERE datistemplate = false ORDER BY datname"
145
+ )
146
+ rows = cur.fetchall()
147
+ return [r[0] for r in rows]
148
+ except Exception as e:
149
+ log_error("Failed to list databases: %s", e)
150
+ return []
151
+ finally:
152
+ if tmp_conn:
153
+ try:
154
+ tmp_conn.close()
155
+ except Exception:
156
+ pass
157
+
158
+ def get_collection_info(self, name: str) -> Optional[Dict[str, Any]]:
159
+ """
160
+ Get collection metadata and statistics.
161
+
162
+ Args:
163
+ name: Table name
164
+
165
+ Returns:
166
+ Dictionary with collection info
167
+ """
168
+ if not self._client:
169
+ return None
170
+ try:
171
+ with self._client.cursor() as cur:
172
+ # Use sql.Identifier to safely quote table name
173
+ cur.execute(sql.SQL("SELECT COUNT(*) FROM {}").format(sql.Identifier(name)))
174
+ result = cur.fetchone()
175
+ count = result[0] if result else 0
176
+
177
+ # Get schema to identify metadata columns (exclude id, document, embedding)
178
+ schema = self._get_table_schema(name)
179
+ metadata_fields = [
180
+ col for col in schema.keys() if col not in ["id", "document", "embedding"]
181
+ ]
182
+
183
+ # Try to determine vector dimension and detect stored embedding model from a sample row
184
+ vector_dimension = "Unknown"
185
+ detected_model = None
186
+ detected_model_type = None
187
+
188
+ try:
189
+ cur.execute(
190
+ sql.SQL("SELECT embedding, metadata FROM {} LIMIT 1").format(
191
+ sql.Identifier(name)
192
+ )
193
+ )
194
+ sample = cur.fetchone()
195
+ if sample:
196
+ emb_val, meta_val = sample[0], sample[1]
197
+ # Determine vector dimension
198
+ try:
199
+ parsed = self._parse_vector(emb_val)
200
+ if parsed:
201
+ vector_dimension = len(parsed)
202
+ except Exception:
203
+ vector_dimension = "Unknown"
204
+
205
+ # Try to detect embedding model from metadata
206
+ meta_obj = None
207
+ if isinstance(meta_val, (str, bytes)):
208
+ try:
209
+ meta_obj = json.loads(meta_val)
210
+ except Exception:
211
+ meta_obj = None
212
+ elif isinstance(meta_val, dict):
213
+ meta_obj = meta_val
214
+
215
+ if meta_obj:
216
+ if "embedding_model" in meta_obj:
217
+ detected_model = meta_obj.get("embedding_model")
218
+ detected_model_type = meta_obj.get("embedding_model_type", "stored")
219
+ elif "_embedding_model" in meta_obj:
220
+ detected_model = meta_obj.get("_embedding_model")
221
+ detected_model_type = "stored"
222
+ except Exception:
223
+ # Best-effort; non-fatal
224
+ pass
225
+
226
+ result = {"name": name, "count": count, "metadata_fields": metadata_fields}
227
+ if vector_dimension != "Unknown":
228
+ result["vector_dimension"] = vector_dimension
229
+ if detected_model:
230
+ result["embedding_model"] = detected_model
231
+ result["embedding_model_type"] = detected_model_type or "stored"
232
+
233
+ return result
234
+ except Exception as e:
235
+ log_error("Failed to get collection info: %s", e)
236
+ return None
237
+
238
+ def create_collection(self, name: str, vector_size: int, distance: str = "cosine") -> bool:
239
+ """
240
+ Create a new table for storing vectors.
241
+
242
+ Args:
243
+ name: Table name
244
+ vector_size: Dimension of vectors
245
+ distance: Distance metric (cosine, euclidean, dotproduct, euclidean)
246
+
247
+ Returns:
248
+ True if successful, False otherwise
249
+ """
250
+ if not self._client:
251
+ return False
252
+ try:
253
+ with self._client.cursor() as cur:
254
+ # Ensure pgvector extension is enabled
255
+ cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
256
+
257
+ # Create table with TEXT id to support custom IDs from migrations/backups
258
+ cur.execute(
259
+ sql.SQL(
260
+ "CREATE TABLE {} (id TEXT PRIMARY KEY, document TEXT, metadata JSONB, embedding vector({}))"
261
+ ).format(sql.Identifier(name), sql.Literal(vector_size))
262
+ )
263
+
264
+ # Map distance metric to pgvector index operator
265
+ distance_lower = distance.lower()
266
+ if distance_lower in ["cosine", "cos"]:
267
+ ops_class = "vector_cosine_ops"
268
+ elif distance_lower in ["euclidean", "l2"]:
269
+ ops_class = "vector_l2_ops"
270
+ elif distance_lower in ["dotproduct", "dot", "ip"]:
271
+ ops_class = "vector_ip_ops"
272
+ else:
273
+ # Default to cosine
274
+ ops_class = "vector_cosine_ops"
275
+
276
+ # Create index for vector similarity search
277
+ index_name = f"{name}_embedding_idx"
278
+ cur.execute(
279
+ sql.SQL("CREATE INDEX {} ON {} USING ivfflat (embedding {})").format(
280
+ sql.Identifier(index_name), sql.Identifier(name), sql.SQL(ops_class)
281
+ )
282
+ )
283
+ self._client.commit()
284
+ return True
285
+ except Exception as e:
286
+ log_error("Failed to create collection: %s", e)
287
+ if self._client:
288
+ self._client.rollback()
289
+ return False
290
+
291
+ def add_items(
292
+ self,
293
+ collection_name: str,
294
+ documents: List[str],
295
+ metadatas: Optional[List[Dict[str, Any]]] = None,
296
+ ids: Optional[List[str]] = None,
297
+ embeddings: Optional[List[List[float]]] = None,
298
+ ) -> bool:
299
+ """
300
+ Add items to a collection.
301
+
302
+ Args:
303
+ collection_name: Table name
304
+ documents: Document texts
305
+ metadatas: Metadata for each document (optional)
306
+ ids: IDs for each document (required for proper migration support)
307
+ embeddings: Pre-computed embeddings
308
+
309
+ Returns:
310
+ True if successful, False otherwise
311
+ """
312
+ if not self._client:
313
+ return False
314
+
315
+ # If embeddings weren't provided, try to compute them using configured/default model
316
+ if not embeddings:
317
+ try:
318
+ from vector_inspector.services.settings_service import SettingsService
319
+ from vector_inspector.core.embedding_utils import (
320
+ load_embedding_model,
321
+ get_embedding_model_for_dimension,
322
+ DEFAULT_MODEL,
323
+ encode_text,
324
+ )
325
+
326
+ model_name = None
327
+ model_type = None
328
+
329
+ # 1) settings
330
+ settings = SettingsService()
331
+ model_info = settings.get_embedding_model(self.database, collection_name)
332
+ if model_info:
333
+ model_name = model_info.get("model")
334
+ model_type = model_info.get("type", "sentence-transformer")
335
+
336
+ # 2) collection metadata
337
+ coll_info = None
338
+ if not model_name:
339
+ coll_info = self.get_collection_info(collection_name)
340
+ if coll_info and coll_info.get("embedding_model"):
341
+ model_name = coll_info.get("embedding_model")
342
+ model_type = coll_info.get("embedding_model_type", "stored")
343
+
344
+ # 3) dimension-based fallback
345
+ loaded_model = None
346
+ if not model_name:
347
+ # Try to get vector dimension
348
+ dim = None
349
+ if not coll_info:
350
+ coll_info = self.get_collection_info(collection_name)
351
+ if coll_info and coll_info.get("vector_dimension"):
352
+ try:
353
+ dim = int(coll_info.get("vector_dimension"))
354
+ except Exception:
355
+ dim = None
356
+ if dim:
357
+ loaded_model, model_name, model_type = get_embedding_model_for_dimension(
358
+ dim
359
+ )
360
+ else:
361
+ model_name, model_type = DEFAULT_MODEL
362
+
363
+ # Load model
364
+ if not loaded_model:
365
+ loaded_model = load_embedding_model(model_name, model_type)
366
+
367
+ # Compute embeddings for all documents
368
+ if model_type != "clip":
369
+ embeddings = loaded_model.encode(documents, show_progress_bar=False).tolist()
370
+ else:
371
+ embeddings = [encode_text(d, loaded_model, model_type) for d in documents]
372
+ except Exception as e:
373
+ log_error("Failed to compute embeddings on add: %s", e)
374
+ return False
375
+ try:
376
+ import uuid
377
+
378
+ # Get table schema to determine column structure
379
+ schema = self._get_table_schema(collection_name)
380
+ has_metadata_col = "metadata" in schema
381
+
382
+ with self._client.cursor() as cur:
383
+ for i, emb in enumerate(embeddings):
384
+ # Use provided ID or generate a UUID
385
+ item_id = ids[i] if ids and i < len(ids) else str(uuid.uuid4())
386
+ doc = documents[i] if i < len(documents) else None
387
+ metadata = metadatas[i] if metadatas and i < len(metadatas) else {}
388
+ # Build insert statement based on schema
389
+ if has_metadata_col:
390
+ # Use JSONB metadata column
391
+ metadata_json = json.dumps(metadata) if metadata else None
392
+ cur.execute(
393
+ sql.SQL(
394
+ "INSERT INTO {} (id, document, metadata, embedding) VALUES (%s, %s, %s, %s)"
395
+ ).format(sql.Identifier(collection_name)),
396
+ (item_id, doc, metadata_json, emb),
397
+ )
398
+ else:
399
+ # Map metadata to specific columns
400
+ columns = ["id", "embedding"]
401
+ values = [item_id, emb]
402
+
403
+ if "document" in schema and doc is not None:
404
+ columns.append("document")
405
+ values.append(doc)
406
+
407
+ # Add metadata fields that exist as columns
408
+ if metadata:
409
+ for key, value in metadata.items():
410
+ if key in schema:
411
+ columns.append(key)
412
+ values.append(value)
413
+
414
+ placeholders = ", ".join(["%s"] * len(values))
415
+ cur.execute(
416
+ sql.SQL("INSERT INTO {} ({}) VALUES ({})").format(
417
+ sql.Identifier(collection_name),
418
+ sql.SQL(", ").join(sql.Identifier(c) for c in columns),
419
+ sql.SQL(placeholders),
420
+ ),
421
+ values,
422
+ )
423
+ self._client.commit()
424
+ return True
425
+ except Exception as e:
426
+ log_error("Failed to add items: %s", e)
427
+ if self._client:
428
+ self._client.rollback()
429
+ return False
430
+
431
+ def get_items(self, name: str, ids: List[str]) -> Dict[str, Any]:
432
+ """
433
+ Retrieve items by IDs.
434
+
435
+ Args:
436
+ name: Table name
437
+ ids: List of IDs
438
+
439
+ Returns:
440
+ Dict with 'documents', 'metadatas', 'embeddings'
441
+ """
442
+ if not self._client:
443
+ return {}
444
+ try:
445
+ schema = self._get_table_schema(name)
446
+ has_metadata_col = "metadata" in schema
447
+
448
+ with self._client.cursor() as cur:
449
+ # Select all columns
450
+ cur.execute(
451
+ sql.SQL("SELECT * FROM {} WHERE id = ANY(%s)").format(sql.Identifier(name)),
452
+ (ids,),
453
+ )
454
+ rows = cur.fetchall()
455
+ colnames = [desc[0] for desc in cur.description]
456
+
457
+ # Build results
458
+ result_ids = []
459
+ result_docs = []
460
+ result_metas = []
461
+ result_embeds = []
462
+
463
+ for row in rows:
464
+ row_dict = dict(zip(colnames, row))
465
+ result_ids.append(str(row_dict.get("id", "")))
466
+ result_docs.append(row_dict.get("document", ""))
467
+
468
+ # Handle metadata
469
+ if has_metadata_col:
470
+ meta = row_dict.get("metadata")
471
+ if isinstance(meta, (str, bytes)):
472
+ try:
473
+ parsed_meta = json.loads(meta)
474
+ except Exception:
475
+ parsed_meta = {}
476
+ elif isinstance(meta, dict):
477
+ parsed_meta = meta
478
+ else:
479
+ parsed_meta = {}
480
+ result_metas.append(parsed_meta)
481
+ else:
482
+ # Reconstruct metadata from columns
483
+ metadata = {
484
+ k: v
485
+ for k, v in row_dict.items()
486
+ if k not in ["id", "document", "embedding"]
487
+ }
488
+ result_metas.append(metadata)
489
+
490
+ # Handle embedding
491
+ result_embeds.append(self._parse_vector(row_dict.get("embedding", "")))
492
+
493
+ return {
494
+ "ids": result_ids,
495
+ "documents": result_docs,
496
+ "metadatas": result_metas,
497
+ "embeddings": result_embeds,
498
+ }
499
+ except Exception as e:
500
+ log_error("Failed to get items: %s", e)
501
+ return {}
502
+
503
+ def delete_collection(self, name: str) -> bool:
504
+ """
505
+ Delete a table (collection).
506
+
507
+ Args:
508
+ name: Table name
509
+
510
+ Returns:
511
+ True if successful, False otherwise
512
+ """
513
+ if not self._client:
514
+ return False
515
+ try:
516
+ with self._client.cursor() as cur:
517
+ cur.execute(sql.SQL("DROP TABLE IF EXISTS {} CASCADE").format(sql.Identifier(name)))
518
+ self._client.commit()
519
+ return True
520
+ except Exception as e:
521
+ log_error("Failed to delete collection: %s", e)
522
+ if self._client:
523
+ self._client.rollback()
524
+ return False
525
+
526
+ def count_collection(self, name: str) -> int:
527
+ """
528
+ Return the number of items in the collection.
529
+
530
+ Args:
531
+ name: Table name
532
+
533
+ Returns:
534
+ Number of items
535
+ """
536
+ if not self._client:
537
+ return 0
538
+ try:
539
+ with self._client.cursor() as cur:
540
+ cur.execute(sql.SQL("SELECT COUNT(*) FROM {}").format(sql.Identifier(name)))
541
+ result = cur.fetchone()
542
+ count = result[0] if result else 0
543
+ return count
544
+ except Exception as e:
545
+ log_error("Failed to count collection: %s", e)
546
+ return 0
547
+
548
+ def query_collection(
549
+ self,
550
+ collection_name: str,
551
+ query_texts: Optional[List[str]] = None,
552
+ query_embeddings: Optional[List[List[float]]] = None,
553
+ n_results: int = 10,
554
+ where: Optional[Dict[str, Any]] = None,
555
+ where_document: Optional[Dict[str, Any]] = None,
556
+ ) -> Optional[Dict[str, Any]]:
557
+ """
558
+ Query a collection for similar vectors.
559
+
560
+ Args:
561
+ collection_name: Table name
562
+ query_embeddings: Embedding vectors to search
563
+ n_results: Number of results to return
564
+ where: Metadata filter (dict of column:value pairs)
565
+ where_document: Document filter (not implemented)
566
+
567
+ Returns:
568
+ Query results dictionary
569
+ """
570
+ if not self._client:
571
+ return None
572
+
573
+ # If caller provided query texts (not embeddings), compute embeddings using configured model
574
+ if (not query_embeddings) and query_texts:
575
+ try:
576
+ from vector_inspector.services.settings_service import SettingsService
577
+ from vector_inspector.core.embedding_utils import (
578
+ load_embedding_model,
579
+ get_embedding_model_for_dimension,
580
+ DEFAULT_MODEL,
581
+ encode_text,
582
+ )
583
+
584
+ model_name = None
585
+ model_type = None
586
+
587
+ # 1) settings
588
+ settings = SettingsService()
589
+ model_info = settings.get_embedding_model(self.database, collection_name)
590
+ if model_info:
591
+ model_name = model_info.get("model")
592
+ model_type = model_info.get("type", "sentence-transformer")
593
+
594
+ # 2) collection metadata
595
+ if not model_name:
596
+ coll_info = self.get_collection_info(collection_name)
597
+ if coll_info and coll_info.get("embedding_model"):
598
+ model_name = coll_info.get("embedding_model")
599
+ model_type = coll_info.get("embedding_model_type", "stored")
600
+
601
+ # 3) dimension-based fallback
602
+ loaded_model = None
603
+ if not model_name:
604
+ dim = None
605
+ coll_info = self.get_collection_info(collection_name)
606
+ if coll_info and coll_info.get("vector_dimension"):
607
+ try:
608
+ dim = int(coll_info.get("vector_dimension"))
609
+ except Exception:
610
+ dim = None
611
+ if dim:
612
+ loaded_model, model_name, model_type = get_embedding_model_for_dimension(
613
+ dim
614
+ )
615
+ else:
616
+ model_name, model_type = DEFAULT_MODEL
617
+
618
+ if not loaded_model:
619
+ loaded_model = load_embedding_model(model_name, model_type)
620
+
621
+ # Compute embeddings for the provided query_texts (use helper for CLIP)
622
+ if model_type != "clip":
623
+ computed = loaded_model.encode(query_texts, show_progress_bar=False).tolist()
624
+ else:
625
+ computed = [encode_text(t, loaded_model, model_type) for t in query_texts]
626
+
627
+ query_embeddings = computed
628
+ except Exception as e:
629
+ log_error("Failed to compute query embeddings: %s", e)
630
+ return None
631
+ try:
632
+ schema = self._get_table_schema(collection_name)
633
+ has_metadata_col = "metadata" in schema
634
+
635
+ # For each query embedding, run a separate SELECT ordered by distance
636
+ # so callers receive the top-N results per query (matching SearchView expectations).
637
+ with self._client.cursor() as cur:
638
+ # Prepare containers for per-query results
639
+ per_ids: List[List[str]] = []
640
+ per_docs: List[List[str]] = []
641
+ per_metas: List[List[Dict[str, Any]]] = []
642
+ per_embeds: List[List[List[float]]] = []
643
+ per_dists: List[List[float]] = []
644
+
645
+ for emb in query_embeddings:
646
+ # Build base query for this single embedding
647
+ query_parts = [
648
+ sql.SQL("SELECT *, embedding <=> %s::vector AS distance FROM {}").format(
649
+ sql.Identifier(collection_name)
650
+ )
651
+ ]
652
+ params = [emb]
653
+
654
+ # Add WHERE clause for filtering
655
+ if where:
656
+ conditions = []
657
+ for key, value in where.items():
658
+ if has_metadata_col and key != "metadata":
659
+ conditions.append(sql.SQL("metadata->>%s = %s"))
660
+ params.extend([key, str(value)])
661
+ elif key in schema:
662
+ conditions.append(sql.SQL("{} = %s").format(sql.Identifier(key)))
663
+ params.append(value)
664
+
665
+ if conditions:
666
+ query_parts.append(sql.SQL(" WHERE "))
667
+ query_parts.append(sql.SQL(" AND ").join(conditions))
668
+
669
+ query_parts.append(sql.SQL(" ORDER BY distance ASC LIMIT %s"))
670
+ params.append(n_results)
671
+
672
+ query = sql.SQL("").join(query_parts)
673
+ cur.execute(query, params)
674
+ rows = cur.fetchall()
675
+ colnames = [desc[0] for desc in cur.description]
676
+
677
+ # Build per-query result lists
678
+ ids_q: List[str] = []
679
+ docs_q: List[str] = []
680
+ metas_q: List[Dict[str, Any]] = []
681
+ embeds_q: List[List[float]] = []
682
+ dists_q: List[float] = []
683
+
684
+ for row in rows:
685
+ row_dict = dict(zip(colnames, row))
686
+ ids_q.append(str(row_dict.get("id", "")))
687
+ docs_q.append(row_dict.get("document", ""))
688
+
689
+ # Handle metadata
690
+ if has_metadata_col:
691
+ meta = row_dict.get("metadata")
692
+ if isinstance(meta, (str, bytes)):
693
+ try:
694
+ parsed_meta = json.loads(meta)
695
+ except Exception:
696
+ parsed_meta = {}
697
+ elif isinstance(meta, dict):
698
+ parsed_meta = meta
699
+ else:
700
+ parsed_meta = {}
701
+ metas_q.append(parsed_meta)
702
+ else:
703
+ metadata = {
704
+ k: v
705
+ for k, v in row_dict.items()
706
+ if k not in ["id", "document", "embedding", "distance"]
707
+ }
708
+ metas_q.append(metadata)
709
+
710
+ embeds_q.append(self._parse_vector(row_dict.get("embedding", "")))
711
+ dists_q.append(float(row_dict.get("distance", 0)))
712
+
713
+ per_ids.append(ids_q)
714
+ per_docs.append(docs_q)
715
+ per_metas.append(metas_q)
716
+ per_embeds.append(embeds_q)
717
+ per_dists.append(dists_q)
718
+
719
+ # Return results in the same per-query list-of-lists format as other providers
720
+ return {
721
+ "ids": per_ids,
722
+ "documents": per_docs,
723
+ "metadatas": per_metas,
724
+ "embeddings": per_embeds,
725
+ "distances": per_dists,
726
+ }
727
+ except Exception as e:
728
+ log_error("Query failed: %s", e)
729
+ return None
730
+
731
+ def get_all_items(
732
+ self,
733
+ collection_name: str,
734
+ limit: Optional[int] = None,
735
+ offset: Optional[int] = None,
736
+ where: Optional[Dict[str, Any]] = None,
737
+ ) -> Optional[Dict[str, Any]]:
738
+ """
739
+ Get all items from a collection.
740
+
741
+ Args:
742
+ collection_name: Table name
743
+ limit: Max items
744
+ offset: Offset
745
+ where: Metadata filter (dict of column:value pairs)
746
+
747
+ Returns:
748
+ Dict with items
749
+ """
750
+ if not self._client:
751
+ return None
752
+ try:
753
+ schema = self._get_table_schema(collection_name)
754
+ has_metadata_col = "metadata" in schema
755
+
756
+ with self._client.cursor() as cur:
757
+ query_parts = [sql.SQL("SELECT * FROM {}").format(sql.Identifier(collection_name))]
758
+ params = []
759
+
760
+ # Add WHERE clause for filtering
761
+ if where:
762
+ conditions = []
763
+ for key, value in where.items():
764
+ if has_metadata_col and key != "metadata":
765
+ # Filter on JSONB metadata column
766
+ conditions.append(sql.SQL("metadata->>%s = %s"))
767
+ params.extend([key, str(value)])
768
+ elif key in schema:
769
+ # Filter on actual column
770
+ conditions.append(sql.SQL("{} = %s").format(sql.Identifier(key)))
771
+ params.append(value)
772
+
773
+ if conditions:
774
+ query_parts.append(sql.SQL(" WHERE "))
775
+ query_parts.append(sql.SQL(" AND ").join(conditions))
776
+
777
+ if limit:
778
+ query_parts.append(sql.SQL(" LIMIT %s"))
779
+ params.append(limit)
780
+ if offset:
781
+ query_parts.append(sql.SQL(" OFFSET %s"))
782
+ params.append(offset)
783
+
784
+ query = sql.SQL("").join(query_parts)
785
+ cur.execute(query, params if params else None)
786
+ rows = cur.fetchall()
787
+ colnames = [desc[0] for desc in cur.description]
788
+
789
+ # Build results
790
+ result_ids = []
791
+ result_docs = []
792
+ result_metas = []
793
+ result_embeds = []
794
+
795
+ for row in rows:
796
+ row_dict = dict(zip(colnames, row))
797
+ result_ids.append(str(row_dict.get("id", "")))
798
+ result_docs.append(row_dict.get("document", ""))
799
+
800
+ # Handle metadata
801
+ if has_metadata_col:
802
+ meta = row_dict.get("metadata")
803
+ if isinstance(meta, (str, bytes)):
804
+ try:
805
+ parsed_meta = json.loads(meta)
806
+ except Exception:
807
+ parsed_meta = {}
808
+ elif isinstance(meta, dict):
809
+ parsed_meta = meta
810
+ else:
811
+ parsed_meta = {}
812
+ result_metas.append(parsed_meta)
813
+ else:
814
+ # Reconstruct metadata from columns
815
+ metadata = {
816
+ k: v
817
+ for k, v in row_dict.items()
818
+ if k not in ["id", "document", "embedding"]
819
+ }
820
+ result_metas.append(metadata)
821
+
822
+ # Handle embedding
823
+ result_embeds.append(self._parse_vector(row_dict.get("embedding", "")))
824
+
825
+ return {
826
+ "ids": result_ids,
827
+ "documents": result_docs,
828
+ "metadatas": result_metas,
829
+ "embeddings": result_embeds,
830
+ }
831
+ except Exception as e:
832
+ log_error("Failed to get items: %s", e)
833
+ return None
834
+
835
+ def update_items(
836
+ self,
837
+ collection_name: str,
838
+ ids: List[str],
839
+ documents: Optional[List[str]] = None,
840
+ metadatas: Optional[List[Dict[str, Any]]] = None,
841
+ embeddings: Optional[List[List[float]]] = None,
842
+ ) -> bool:
843
+ """
844
+ Update items in a collection.
845
+
846
+ Args:
847
+ collection_name: Table name
848
+ ids: IDs to update
849
+ documents: New docs
850
+ metadatas: New metadata
851
+ embeddings: New embeddings
852
+
853
+ Returns:
854
+ True if successful, False otherwise
855
+ """
856
+ if not self._client or not ids:
857
+ return False
858
+ try:
859
+ # Get table schema to decide how to update metadata (jsonb column vs flattened cols)
860
+ schema = self._get_table_schema(collection_name)
861
+ has_metadata_col = "metadata" in schema
862
+
863
+ # If embeddings are not provided but documents were, compute embeddings
864
+ embeddings_local = embeddings
865
+ # Reset regen counter for this update operation
866
+ self._last_regenerated_count = 0
867
+ if (not embeddings) and documents:
868
+ try:
869
+ # Resolve model for this collection: prefer settings -> collection metadata -> dimension-based
870
+ from vector_inspector.services.settings_service import SettingsService
871
+ from vector_inspector.core.embedding_utils import (
872
+ load_embedding_model,
873
+ get_embedding_model_for_dimension,
874
+ DEFAULT_MODEL,
875
+ )
876
+
877
+ model_name = None
878
+ model_type = None
879
+
880
+ # 1) settings
881
+ settings = SettingsService()
882
+ model_info = settings.get_embedding_model(self.database, collection_name)
883
+ if model_info:
884
+ model_name = model_info.get("model")
885
+ model_type = model_info.get("type", "sentence-transformer")
886
+
887
+ # 2) collection metadata
888
+ if not model_name:
889
+ coll_info = self.get_collection_info(collection_name)
890
+ if coll_info and coll_info.get("embedding_model"):
891
+ model_name = coll_info.get("embedding_model")
892
+ model_type = coll_info.get("embedding_model_type", "stored")
893
+
894
+ # 3) dimension-based fallback
895
+ loaded_model = None
896
+ if not model_name:
897
+ # Try to get vector dimension
898
+ dim = None
899
+ coll_info = self.get_collection_info(collection_name)
900
+ if coll_info and coll_info.get("vector_dimension"):
901
+ try:
902
+ dim = int(coll_info.get("vector_dimension"))
903
+ except Exception:
904
+ dim = None
905
+ if dim:
906
+ loaded_model, model_name, model_type = (
907
+ get_embedding_model_for_dimension(dim)
908
+ )
909
+ else:
910
+ # Use default model
911
+ model_name, model_type = DEFAULT_MODEL
912
+
913
+ # Load model if not already loaded
914
+ if not loaded_model:
915
+ loaded_model = load_embedding_model(model_name, model_type)
916
+
917
+ # Compute embeddings only for documents that are present
918
+ compute_idxs = [i for i, d in enumerate(documents) if d]
919
+ if compute_idxs:
920
+ docs_to_compute = [documents[i] for i in compute_idxs]
921
+ # Use SentenceTransformer batch encode when possible
922
+ if model_type != "clip":
923
+ computed = loaded_model.encode(
924
+ docs_to_compute, show_progress_bar=False
925
+ ).tolist()
926
+ else:
927
+ # CLIP type - encode per document using helper
928
+ from vector_inspector.core.embedding_utils import encode_text
929
+
930
+ computed = [
931
+ encode_text(d, loaded_model, model_type) for d in docs_to_compute
932
+ ]
933
+ embeddings_local = [None] * len(ids)
934
+ for idx, emb in zip(compute_idxs, computed):
935
+ embeddings_local[idx] = emb
936
+ # Record how many embeddings we generated
937
+ try:
938
+ self._last_regenerated_count = len(compute_idxs)
939
+ log_info(
940
+ "[PgVectorConnection] Computed %d embeddings for update in %s",
941
+ self._last_regenerated_count,
942
+ collection_name,
943
+ )
944
+ except Exception:
945
+ pass
946
+ except Exception as e:
947
+ log_error("Failed to compute embeddings on update: %s", e)
948
+ embeddings_local = [None] * len(ids)
949
+ self._last_regenerated_count = 0
950
+
951
+ with self._client.cursor() as cur:
952
+ for i, item_id in enumerate(ids):
953
+ updates = []
954
+ params = []
955
+
956
+ if documents and i < len(documents):
957
+ updates.append(sql.SQL("document = %s"))
958
+ params.append(documents[i])
959
+
960
+ # Handle metadata update depending on schema
961
+ if metadatas and i < len(metadatas):
962
+ meta = metadatas[i]
963
+ if has_metadata_col:
964
+ updates.append(sql.SQL("metadata = %s"))
965
+ params.append(json.dumps(meta))
966
+ else:
967
+ # Map metadata keys to existing columns only
968
+ for key, value in meta.items():
969
+ if key in schema:
970
+ updates.append(sql.SQL("{} = %s").format(sql.Identifier(key)))
971
+ params.append(value)
972
+
973
+ # Use provided embeddings if present, otherwise use locally computed embedding
974
+ emb_to_use = None
975
+ if embeddings and i < len(embeddings):
976
+ emb_to_use = embeddings[i]
977
+ # caller provided embeddings -> no regeneration
978
+ self._last_regenerated_count = 0
979
+ elif embeddings_local and i < len(embeddings_local):
980
+ emb_to_use = embeddings_local[i]
981
+
982
+ if emb_to_use is not None:
983
+ # Cast parameter to pgvector to ensure correct operator typing
984
+ updates.append(sql.SQL("embedding = %s::vector"))
985
+ params.append(emb_to_use)
986
+
987
+ if updates:
988
+ params.append(item_id)
989
+ query = sql.SQL("UPDATE {} SET {} WHERE id = %s").format(
990
+ sql.Identifier(collection_name), sql.SQL(", ").join(updates)
991
+ )
992
+ cur.execute(query, params)
993
+
994
+ self._client.commit()
995
+ return True
996
+ except Exception as e:
997
+ log_error("Failed to update items: %s", e)
998
+ if self._client:
999
+ self._client.rollback()
1000
+ return False
1001
+
1002
+ def delete_items(
1003
+ self,
1004
+ collection_name: str,
1005
+ ids: Optional[List[str]] = None,
1006
+ where: Optional[Dict[str, Any]] = None,
1007
+ ) -> bool:
1008
+ """
1009
+ Delete items from a collection.
1010
+
1011
+ Args:
1012
+ collection_name: Table name
1013
+ ids: IDs to delete
1014
+ where: Metadata filter (not implemented)
1015
+
1016
+ Returns:
1017
+ True if successful, False otherwise
1018
+ """
1019
+ if not self._client or not ids:
1020
+ return False
1021
+ try:
1022
+ with self._client.cursor() as cur:
1023
+ cur.execute(
1024
+ sql.SQL("DELETE FROM {} WHERE id = ANY(%s)").format(
1025
+ sql.Identifier(collection_name)
1026
+ ),
1027
+ (ids,),
1028
+ )
1029
+ self._client.commit()
1030
+ return True
1031
+ except Exception as e:
1032
+ log_error("Failed to delete items: %s", e)
1033
+ if self._client:
1034
+ self._client.rollback()
1035
+ return False
1036
+
1037
+ def get_connection_info(self) -> Dict[str, Any]:
1038
+ """
1039
+ Get information about the current connection.
1040
+
1041
+ Returns:
1042
+ Dictionary with connection details
1043
+ """
1044
+ return {
1045
+ "provider": "PgVector/PostgreSQL",
1046
+ "host": self.host,
1047
+ "port": self.port,
1048
+ "database": self.database,
1049
+ "user": self.user,
1050
+ "connected": self.is_connected,
1051
+ }
1052
+
1053
+ def _get_table_schema(self, table_name: str) -> Dict[str, str]:
1054
+ """
1055
+ Get the schema (column names and types) for a table.
1056
+
1057
+ Args:
1058
+ table_name: Name of the table
1059
+
1060
+ Returns:
1061
+ Dict mapping column names to their SQL types
1062
+ """
1063
+ if not self._client:
1064
+ return {}
1065
+ try:
1066
+ with self._client.cursor() as cur:
1067
+ cur.execute(
1068
+ """SELECT column_name, data_type, udt_name
1069
+ FROM information_schema.columns
1070
+ WHERE table_name = %s AND table_schema = 'public'
1071
+ ORDER BY ordinal_position""",
1072
+ (table_name,),
1073
+ )
1074
+ schema = {}
1075
+ for row in cur.fetchall():
1076
+ col_name, data_type, udt_name = row
1077
+ # Use udt_name for custom types like vector
1078
+ schema[col_name] = udt_name if data_type == "USER-DEFINED" else data_type
1079
+ return schema
1080
+ except Exception as e:
1081
+ log_error("Failed to get table schema: %s", e)
1082
+ return {}
1083
+
1084
+ def _parse_vector(self, vector_str: Any) -> List[float]:
1085
+ """
1086
+ Parse pgvector string format to Python list.
1087
+
1088
+ Args:
1089
+ vector_str: Vector in string format from database
1090
+
1091
+ Returns:
1092
+ List of floats
1093
+ """
1094
+ if isinstance(vector_str, list):
1095
+ return vector_str
1096
+ if isinstance(vector_str, str):
1097
+ # Remove brackets and split by comma
1098
+ vector_str = vector_str.strip("[]")
1099
+ return [float(x) for x in vector_str.split(",")]
1100
+ return []