quantara 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1024 @@
1
+ import os
2
+ import json
3
+ import pickle
4
+ import dataclasses
5
+
6
+ from typing import (
7
+ Any,
8
+ Optional,
9
+ Dict,
10
+ List,
11
+ Tuple,
12
+ Literal
13
+ )
14
+ from dataclasses import dataclass
15
+
16
+ from quantara.utils.utils import (
17
+ get_uuid,
18
+ cosine_similarity,
19
+ dot_similarity,
20
+ euclidean_distance
21
+ )
22
+
23
+ from quantara.errors.errors import (
24
+ CollectionNotFoundError,
25
+ CollectionAlreadyExistsError,
26
+ IndexNotFoundError,
27
+ EmbeddingDimensionError
28
+ )
29
+
30
+ from quantara.indexes.bruteforce import BruteForceIndex
31
+ from quantara.indexes.registry import INDEX_REGISTRY
32
+
33
+ from quantara.database.models import Record, Config
34
+
35
+
36
+ class Database:
37
+
38
+ INDEXES = INDEX_REGISTRY.get_indexes_dict()
39
+ INDEXES["default"] = BruteForceIndex
40
+
41
+ def __init__(
42
+ self,
43
+ db_name: str,
44
+ dimensions: int | None = None,
45
+ auto_dim: bool = True,
46
+ auto_persist: bool = True,
47
+ **kwargs
48
+ ):
49
+ self.db_name = db_name
50
+ self.dimensions = dimensions
51
+ self.auto_dim = auto_dim
52
+ self.auto_persist = auto_persist
53
+
54
+ # index_paths: dict[collection_name, path_to_.index_file]
55
+ # Replaces the old single-collection index_collection / index_path kwargs.
56
+ # Backward-compat shim: if the caller still passes the old kwargs, fold
57
+ # them into the new dict so existing code doesn't break.
58
+ _legacy_col = kwargs.get("index_collection", None) # get the legacy collection name
59
+ _legacy_path = kwargs.get("index_path", None) # get the legacy index path
60
+
61
+ provided: dict[str, str] = kwargs.get("index_paths", {}) # get the new index paths
62
+
63
+ if _legacy_col and _legacy_path and _legacy_col not in provided:
64
+ provided[_legacy_col] = _legacy_path
65
+
66
+ # Validate: every value must end with ".index".
67
+ for col, p in provided.items():
68
+ if not p.endswith(".index"):
69
+ raise ValueError(
70
+ f"index_paths['{col}'] = '{p}' does not end with '.index'."
71
+ )
72
+
73
+ self.provided_index_paths: dict[str, str] = provided
74
+
75
+ if not self.db_name.endswith(".db"):
76
+ self.db_name += ".db"
77
+
78
+ self._path = os.path.join(os.getcwd(), self.db_name)
79
+
80
+ self._records: dict[str, dict[str, Record | Config]] = self._default_records()
81
+ self._indexes: dict[str, Any] = {}
82
+
83
+ self._load_doc()
84
+
85
+ # ==========================================================
86
+ # Internal Helpers
87
+ # ==========================================================
88
+
89
+ def _validate_collection(self, collection: str) -> None:
90
+ """Raise CollectionNotFoundError if the given collection does not exist in _records."""
91
+ if collection not in self._records:
92
+ raise CollectionNotFoundError(
93
+ f"Collection '{collection}' not found."
94
+ )
95
+
96
+ def _validate_doc(self, collection: str, id: str) -> None:
97
+ """Raise CollectionNotFoundError or IndexNotFoundError if the collection or record id is missing."""
98
+ self._validate_collection(collection)
99
+
100
+ if id not in self._records[collection]:
101
+ raise IndexNotFoundError(
102
+ f"Id '{id}' not found in collection '{collection}'."
103
+ )
104
+
105
+ def _validate_index_algo(self, index: str) -> None:
106
+ """Raise ValueError if the given index type string is not present in the index registry."""
107
+ if index not in self.INDEXES:
108
+ raise ValueError(
109
+ f"Index '{index}' not found in registry. "
110
+ f"Available indexes: {list(self.INDEXES.keys())}"
111
+ )
112
+
113
+ # ==========================================================
114
+ # Config Manager
115
+ # ==========================================================
116
+
117
+ def _get_config(self) -> dict[str, Any]:
118
+ """Return the entire _config dict from _records."""
119
+ return self._records["_config"]
120
+
121
+ def _set_config(self, config: dict[str, Any]) -> None:
122
+ """Replace the entire _config dict in _records with the given dict."""
123
+ self._records["_config"] = config
124
+
125
+ def _update_config(self, key: str, value: Any) -> None:
126
+ """Set a single key-value pair inside _config."""
127
+ self._records["_config"][key] = value
128
+
129
+ def _get_config_value(self, key: str) -> Any:
130
+ """Return the value for a given key from _config."""
131
+ return self._records["_config"][key]
132
+
133
+ def _has_config_value(self, key: str) -> bool:
134
+ """Return True if the given key exists in _config."""
135
+ return key in self._records["_config"]
136
+
137
+ def _delete_config_value(self, key: str) -> None:
138
+ """Delete a key from _config."""
139
+ del self._records["_config"][key]
140
+
141
+ def _clear_config(self) -> None:
142
+ """Reset _config to an empty dict."""
143
+ self._records["_config"] = {}
144
+
145
+ # ==========================================================
146
+ # Public API for Config
147
+ # ==========================================================
148
+
149
+ def get_config(self) -> dict[str, Any]:
150
+ """Return the full database configuration dictionary."""
151
+ return self._get_config()
152
+
153
+ def set_config(self, config: dict[str, Any]) -> None:
154
+ """Replace the full database configuration with the given dict."""
155
+ self._set_config(config)
156
+
157
+ def update_config(self, key: str, value: Any) -> None:
158
+ """Update a single key inside the database configuration."""
159
+ self._update_config(key, value)
160
+
161
+ def get_config_value(self, key: str) -> Any:
162
+ """Return the value of a single key from the database configuration."""
163
+ return self._get_config_value(key)
164
+
165
+ def has_config_value(self, key: str) -> bool:
166
+ """Return True if the given key exists in the database configuration."""
167
+ return self._has_config_value(key)
168
+
169
+ def delete_config_value(self, key: str) -> None:
170
+ """Remove a key from the database configuration."""
171
+ self._delete_config_value(key)
172
+
173
+ def clear_config(self) -> None:
174
+ """Wipe the entire database configuration dictionary."""
175
+ self._clear_config()
176
+
177
+ # ==========================================================
178
+ # Collections
179
+ # ==========================================================
180
+
181
+ def create_collection(
182
+ self,
183
+ collection: str,
184
+ index: Literal["default", "bruteforce"] | str = "default"
185
+ ) -> None:
186
+ """
187
+ Create a new named collection with the specified index type.
188
+ Does nothing if the collection already exists.
189
+ Persists the chosen index type in _config so it survives a reload.
190
+ """
191
+ self._validate_index_algo(index)
192
+
193
+ if not isinstance(collection, str):
194
+ raise TypeError("Collection name must be a string.")
195
+
196
+ if collection in self._records:
197
+ return
198
+
199
+ self._records[collection] = {}
200
+ self._indexes[collection] = self.INDEXES[index]()
201
+
202
+ self._records["_config"].setdefault("_collection_indexes", {})[collection] = index
203
+
204
+ if self.auto_persist:
205
+ self.persist_doc()
206
+
207
+ def delete_collection(self, collection: str) -> None:
208
+ """
209
+ Delete a named collection and its associated in-memory index.
210
+ Raises ValueError if attempting to delete the reserved _config collection,
211
+ and RuntimeError if attempting to delete the default collection.
212
+ """
213
+ if collection == "_config":
214
+ raise ValueError("Cannot delete config collection.")
215
+
216
+ if collection == "default":
217
+ raise RuntimeError("Default collection cannot be deleted.")
218
+
219
+ self._validate_collection(collection)
220
+
221
+ del self._records[collection]
222
+ del self._indexes[collection]
223
+
224
+ self._records["_config"].get("_collection_indexes", {}).pop(collection, None)
225
+
226
+ if self.auto_persist:
227
+ self.persist_doc()
228
+
229
+ def list_collections(self) -> list[str]:
230
+ """Return a list of all user-facing collection names, excluding _config."""
231
+ return [k for k in self._records if k != "_config"]
232
+
233
+ def rename_collection(self, old_name: str, new_name: str) -> None:
234
+ """
235
+ Rename an existing collection.
236
+ Moves both the records and the in-memory index to the new key,
237
+ and updates the stored index type in _config.
238
+ """
239
+ self._validate_collection(old_name)
240
+
241
+ if new_name in self._records:
242
+ raise CollectionAlreadyExistsError(
243
+ f"Collection '{new_name}' already exists."
244
+ )
245
+
246
+ self._records[new_name] = self._records[old_name]
247
+ del self._records[old_name]
248
+
249
+ self._indexes[new_name] = self._indexes[old_name]
250
+ del self._indexes[old_name]
251
+
252
+ col_indexes = self._records["_config"].get("_collection_indexes", {})
253
+ if old_name in col_indexes:
254
+ col_indexes[new_name] = col_indexes.pop(old_name)
255
+
256
+ if self.auto_persist:
257
+ self.persist_doc()
258
+
259
+ def clone_collection(self, source: str, target: str) -> None:
260
+ """
261
+ Shallow-copy a collection and its index into a new collection named target.
262
+ Carries the source index type over to the clone in _config.
263
+ """
264
+ self._validate_collection(source)
265
+
266
+ if target in self._records:
267
+ raise CollectionAlreadyExistsError(
268
+ f"Collection '{target}' already exists."
269
+ )
270
+
271
+ self._records[target] = self._records[source]
272
+ self._indexes[target] = self._indexes[source]
273
+
274
+ col_indexes = self._records["_config"].get("_collection_indexes", {})
275
+ if source in col_indexes:
276
+ col_indexes[target] = col_indexes[source]
277
+
278
+ if self.auto_persist:
279
+ self.persist_doc()
280
+
281
+ # ==========================================================
282
+ # CRUD
283
+ # ==========================================================
284
+
285
+ def insert_doc(
286
+ self,
287
+ name: Optional[str] = None,
288
+ vector: Optional[list[float]] = None,
289
+ metadata: Optional[dict[str, Any]] = None,
290
+ collection: str = "default",
291
+ **kwargs
292
+ ) -> str:
293
+ """
294
+ Insert a single record into the given collection.
295
+ Auto-infers dimensions from the first vector if auto_dim is True.
296
+ Returns the newly assigned record UUID.
297
+ """
298
+ self._validate_collection(collection)
299
+
300
+ _name = name if name is not None else kwargs.get("name")
301
+ _vector = vector if vector is not None else kwargs.get("vector")
302
+ _metadata = metadata if metadata is not None else kwargs.get("metadata")
303
+
304
+ missing_params = []
305
+
306
+ if _name is None:
307
+ missing_params.append("name")
308
+ if _vector is None:
309
+ missing_params.append("vector")
310
+
311
+ if missing_params:
312
+ raise RuntimeError(
313
+ f"Missing parameters: {', '.join(missing_params)}"
314
+ )
315
+
316
+ if self.dimensions is None:
317
+ if self.auto_dim:
318
+ self.dimensions = len(_vector)
319
+ self._records["_config"]["dimensions"] = self.dimensions
320
+ else:
321
+ raise EmbeddingDimensionError(
322
+ "Parameter `dimensions` cannot be None or undefined when "
323
+ "inserting into the vector database."
324
+ )
325
+
326
+ if len(_vector) != self.dimensions:
327
+ raise EmbeddingDimensionError(
328
+ f"Expected embedding dimension {self.dimensions}, "
329
+ f"got {len(_vector)}."
330
+ )
331
+
332
+ record_id = get_uuid()
333
+
334
+ self._records[collection][record_id] = Record(
335
+ id=record_id,
336
+ name=_name,
337
+ vector=_vector,
338
+ metadata=_metadata or {}
339
+ )
340
+
341
+ self._indexes[collection].add(record_id, _vector)
342
+
343
+ if self.auto_persist:
344
+ self.persist_doc()
345
+
346
+ return record_id
347
+
348
+ def delete_doc(self, id: str, collection: str = "default") -> None:
349
+ """
350
+ Remove a record by id from the given collection and its index.
351
+ Raises IndexNotFoundError if the id does not exist.
352
+ """
353
+ self._validate_doc(collection, id)
354
+
355
+ del self._records[collection][id]
356
+ self._indexes[collection].remove(id)
357
+
358
+ if self.auto_persist:
359
+ self.persist_doc()
360
+
361
+ def update_doc(
362
+ self,
363
+ id: str,
364
+ collection: str = "default",
365
+ name: Optional[str] = None,
366
+ vector: Optional[list[float]] = None,
367
+ metadata: Optional[dict[str, Any]] = None
368
+ ) -> None:
369
+ """
370
+ Update one or more fields of an existing record.
371
+ Only updates the in-memory index when a new vector is actually provided,
372
+ to avoid corrupting the index with a None value.
373
+ """
374
+ self._validate_doc(collection, id)
375
+
376
+ if vector is not None:
377
+ if self.dimensions is not None:
378
+ if len(vector) != self.dimensions:
379
+ raise EmbeddingDimensionError(
380
+ f"Expected embedding dimension {self.dimensions}, "
381
+ f"got {len(vector)}."
382
+ )
383
+ else:
384
+ if self.auto_dim:
385
+ self.dimensions = len(vector)
386
+ self._records["_config"]["dimensions"] = self.dimensions
387
+
388
+ record = self._records[collection][id]
389
+
390
+ if name is not None:
391
+ record.name = name
392
+ if vector is not None:
393
+ record.vector = vector
394
+ if metadata is not None:
395
+ record.metadata = metadata
396
+
397
+ if vector is not None:
398
+ self._indexes[collection].update(id, vector)
399
+
400
+ if self.auto_persist:
401
+ self.persist_doc()
402
+
403
+ def get_doc(self, id: str, collection: str = "default") -> Record:
404
+ """Return the Record object for the given id from the specified collection."""
405
+ self._validate_doc(collection, id)
406
+ return self._records[collection][id]
407
+
408
+ # ==========================================================
409
+ # Search
410
+ # ==========================================================
411
+
412
+ def search_doc(
413
+ self,
414
+ input_vector: list[float],
415
+ filters: Dict[str, Any] = None,
416
+ top_k: int = 3,
417
+ return_text_outputs: bool = False,
418
+ collection: str = "default",
419
+ metric: str = "cosine",
420
+ basic: bool = True
421
+ ):
422
+ """
423
+ Search for the top-k nearest records to input_vector in the given collection.
424
+
425
+ When basic=True, performs a linear scan over all records, applying optional
426
+ metadata filters and scoring via the chosen metric (cosine, dot, euclidean).
427
+
428
+ When basic=False, delegates to the collection's index for approximate or
429
+ exact nearest-neighbour search, then applies metadata filters post-hoc.
430
+
431
+ Returns a list of (record_id, score) tuples, or
432
+ (record_id, score, name, metadata) tuples when return_text_outputs=True.
433
+ """
434
+ if basic:
435
+ if self.dimensions is not None:
436
+ if len(input_vector) != self.dimensions:
437
+ raise EmbeddingDimensionError(
438
+ f"Expected embedding dimension {self.dimensions}, "
439
+ f"got {len(input_vector)}."
440
+ )
441
+
442
+ self._validate_collection(collection)
443
+
444
+ scores = []
445
+
446
+ for record_id, record in self._records[collection].items():
447
+ if filters is not None:
448
+ valid = all(
449
+ record.metadata.get(k) == v for k, v in filters.items()
450
+ )
451
+ if not valid:
452
+ continue
453
+
454
+ if metric == "cosine":
455
+ score = cosine_similarity(input_vector, record.vector)
456
+ elif metric == "dot":
457
+ score = dot_similarity(input_vector, record.vector)
458
+ elif metric == "euclidean":
459
+ score = euclidean_distance(input_vector, record.vector)
460
+ else:
461
+ raise ValueError(f"Unknown metric: {metric}")
462
+
463
+ if return_text_outputs:
464
+ scores.append((record_id, score, record.name, record.metadata))
465
+ else:
466
+ scores.append((record_id, score))
467
+
468
+ scores.sort(key=lambda x: x[1], reverse=(metric != "euclidean"))
469
+ return scores[:top_k]
470
+
471
+ else:
472
+ results = self._indexes[collection].search(
473
+ query_vector=input_vector,
474
+ top_k=top_k,
475
+ metric=metric
476
+ )
477
+
478
+ final_results = []
479
+
480
+ for record_id, score in results:
481
+ record = self._records[collection][record_id]
482
+
483
+ if filters is not None:
484
+ valid = all(
485
+ record.metadata.get(k) == v for k, v in filters.items()
486
+ )
487
+ if not valid:
488
+ continue
489
+
490
+ if return_text_outputs:
491
+ final_results.append(
492
+ (record_id, score, record.name, record.metadata)
493
+ )
494
+ else:
495
+ final_results.append((record_id, score))
496
+
497
+ return final_results
498
+
499
+ # ==========================================================
500
+ # Utility
501
+ # ==========================================================
502
+
503
+ def list_docs(self, collection: str = "default") -> list[str]:
504
+ """Return a list of all record ids in the given collection."""
505
+ self._validate_collection(collection)
506
+ return list(self._records[collection].keys())
507
+
508
+ def clear(self, collection: Optional[str] = None) -> None:
509
+ """
510
+ Clear all records and reset indexes.
511
+ If collection is None, resets the entire database to its default state.
512
+ If a collection name is given, clears only that collection's records and index.
513
+ """
514
+ if collection is None:
515
+ self._records = self._default_records()
516
+ self._indexes = {"default": self.INDEXES["default"]()}
517
+ else:
518
+ self._validate_collection(collection)
519
+ self._records[collection].clear()
520
+ self._indexes[collection].clear()
521
+
522
+ if self.auto_persist:
523
+ self.persist_doc()
524
+
525
+ # ==========================================================
526
+ # Persistence
527
+ # ==========================================================
528
+
529
+ def persist_doc(self) -> None:
530
+ """Serialise the entire _records dict to disk using pickle."""
531
+ with open(self._path, "wb") as f:
532
+ pickle.dump(self._records, f, protocol=pickle.HIGHEST_PROTOCOL)
533
+
534
+ def _default_records(self) -> dict:
535
+ """
536
+ Return the baseline _records structure used when creating a new database
537
+ or resetting an existing one. Contains a _config entry and an empty
538
+ default collection.
539
+ """
540
+ return {
541
+ "_config": {
542
+ "dimensions": self.dimensions,
543
+ "auto_dim": self.auto_dim,
544
+ "auto_persist": self.auto_persist,
545
+ "_collection_indexes": {
546
+ "default": "default"
547
+ }
548
+ },
549
+ "default": {}
550
+ }
551
+
552
+ def _rebuild_indexes(self) -> None:
553
+ """
554
+ Reconstruct all in-memory indexes from the current _records state.
555
+
556
+ Reads the index type per collection from _config['_collection_indexes'],
557
+ falling back to "default" for any collection not listed or whose stored
558
+ index type no longer exists in the registry. Wipes _indexes first to
559
+ prevent vectors from being added twice.
560
+ """
561
+ self._indexes = {}
562
+
563
+ col_indexes: dict[str, str] = (
564
+ self._records.get("_config", {}).get("_collection_indexes", {})
565
+ )
566
+
567
+ for collection_name, collection in self._records.items():
568
+ if collection_name == "_config":
569
+ continue
570
+
571
+ index_type = col_indexes.get(collection_name, "default")
572
+
573
+ if index_type not in self.INDEXES:
574
+ index_type = "default"
575
+
576
+ self._indexes[collection_name] = self.INDEXES[index_type]()
577
+
578
+ for record_id, record in collection.items():
579
+ self._indexes[collection_name].add(record_id, record.vector)
580
+
581
+ def _validate_index_schema(self, data: dict) -> None:
582
+ """
583
+ Validate that every record in the loaded data is a proper Record dataclass
584
+ instance with the correct field types and consistent embedding dimensions.
585
+
586
+ Checks performed per record:
587
+ - Must be an instance of Record (not a raw dict from a corrupt pickle).
588
+ - Must have non-empty string 'id' and 'name' fields.
589
+ - Must have a non-empty list 'vector' of floats or ints.
590
+ - Must have a dict 'metadata' field.
591
+ - Vector dimension must match self.dimensions if already set, or must be
592
+ consistent across all records in the same collection.
593
+
594
+ Raises:
595
+ TypeError: if a record is not a Record instance.
596
+ ValueError: if a required field is missing, has the wrong type,
597
+ or if vector dimensions are inconsistent.
598
+ """
599
+ config_dimensions: Optional[int] = (
600
+ data.get("_config", {}).get("dimensions", None)
601
+ )
602
+
603
+ for collection_name, collection in data.items():
604
+ if collection_name == "_config":
605
+ continue
606
+
607
+ if not isinstance(collection, dict):
608
+ raise TypeError(
609
+ f"Collection '{collection_name}' must be a dict, "
610
+ f"got {type(collection).__name__}."
611
+ )
612
+
613
+ observed_dim: Optional[int] = None
614
+
615
+ for record_id, record in collection.items():
616
+ # Each stored value must be a Record dataclass, not a raw dict.
617
+ if not isinstance(record, Record):
618
+ raise TypeError(
619
+ f"Record '{record_id}' in collection '{collection_name}' "
620
+ f"must be a Record instance, got {type(record).__name__}. "
621
+ f"The database file may be corrupt or was exported as JSON "
622
+ f"without being re-imported correctly."
623
+ )
624
+
625
+ # id
626
+ if not isinstance(record.id, str) or not record.id:
627
+ raise ValueError(
628
+ f"Record '{record_id}' in collection '{collection_name}' "
629
+ f"has an invalid 'id': {record.id!r}."
630
+ )
631
+
632
+ # name
633
+ if not isinstance(record.name, str) or not record.name:
634
+ raise ValueError(
635
+ f"Record '{record_id}' in collection '{collection_name}' "
636
+ f"has an invalid 'name': {record.name!r}."
637
+ )
638
+
639
+ # vector
640
+ if (
641
+ not isinstance(record.vector, (list, tuple))
642
+ or len(record.vector) == 0
643
+ ):
644
+ raise ValueError(
645
+ f"Record '{record_id}' in collection '{collection_name}' "
646
+ f"has an invalid or empty 'vector'."
647
+ )
648
+
649
+ if not all(isinstance(v, (int, float)) for v in record.vector):
650
+ raise ValueError(
651
+ f"Record '{record_id}' in collection '{collection_name}' "
652
+ f"contains non-numeric values in 'vector'."
653
+ )
654
+
655
+ # metadata
656
+ if not isinstance(record.metadata, dict):
657
+ raise ValueError(
658
+ f"Record '{record_id}' in collection '{collection_name}' "
659
+ f"has an invalid 'metadata': expected dict, "
660
+ f"got {type(record.metadata).__name__}."
661
+ )
662
+
663
+ # Dimension consistency — first check against the config value,
664
+ # then against the first vector seen in this collection.
665
+ record_dim = len(record.vector)
666
+
667
+ if config_dimensions is not None and record_dim != config_dimensions:
668
+ raise ValueError(
669
+ f"Record '{record_id}' in collection '{collection_name}' "
670
+ f"has vector dimension {record_dim}, but config declares "
671
+ f"dimensions={config_dimensions}."
672
+ )
673
+
674
+ if observed_dim is None:
675
+ observed_dim = record_dim
676
+ elif record_dim != observed_dim:
677
+ raise ValueError(
678
+ f"Inconsistent vector dimensions in collection "
679
+ f"'{collection_name}': expected {observed_dim}, "
680
+ f"got {record_dim} for record '{record_id}'."
681
+ )
682
+
683
+ def _load_doc(self) -> None:
684
+ """
685
+ Load the database from disk.
686
+
687
+ Behaviour:
688
+ - If no file exists at _path, initialises a fresh default database
689
+ and returns.
690
+ - Otherwise deserialises _records from the pickle file and validates
691
+ the schema via _validate_index_schema.
692
+ - Calls _rebuild_indexes to reconstruct all in-memory indexes from
693
+ _records, giving every collection a correct index object first.
694
+ - If provided_index_paths is non-empty, iterates over each
695
+ (collection, path) pair and calls _load_index to overwrite the
696
+ just-rebuilt in-memory index with the pre-saved file. Collections
697
+ not listed in provided_index_paths keep their rebuilt index.
698
+ Unknown collection names are skipped with a warning rather than
699
+ raising, because the collection may not yet exist in the loaded
700
+ records (e.g. the file was deleted but the index file remains).
701
+ - On EOFError (empty/truncated file), silently resets to defaults.
702
+ - On any other exception, resets to defaults and re-raises as
703
+ RuntimeError to surface the underlying cause to the caller.
704
+ """
705
+ if not os.path.exists(self._path):
706
+ self._records = self._default_records()
707
+ self._rebuild_indexes()
708
+ return
709
+
710
+ try:
711
+ with open(self._path, "rb") as f:
712
+ loaded: dict[str, dict[str, Record]] = pickle.load(f)
713
+
714
+ # Validate the deserialized data before accepting it.
715
+ self._validate_index_schema(loaded)
716
+
717
+ # Accept the validated data.
718
+ self._records = loaded
719
+
720
+ # Sync instance-level dimension from config if not already set.
721
+ stored_dim = self._records.get("_config", {}).get("dimensions", None)
722
+ if self.dimensions is None and stored_dim is not None:
723
+ self.dimensions = stored_dim
724
+
725
+ # Rebuild all indexes first so every collection has an index object.
726
+ self._rebuild_indexes()
727
+
728
+ # Overwrite indexes for any collections that have a pre-saved file.
729
+ for collection, path in self.provided_index_paths.items():
730
+ if collection not in self._records:
731
+ import warnings
732
+ warnings.warn(
733
+ f"index_paths contains collection '{collection}' which "
734
+ f"does not exist in the loaded database — skipping.",
735
+ UserWarning,
736
+ stacklevel=2,
737
+ )
738
+ continue
739
+ self._load_index(collection=collection, path=path)
740
+
741
+ except EOFError:
742
+ self._records = self._default_records()
743
+ self._rebuild_indexes()
744
+
745
+ except (TypeError, ValueError) as e:
746
+ # Schema validation failure — reset to defaults and surface the error.
747
+ self._records = self._default_records()
748
+ self._rebuild_indexes()
749
+ raise RuntimeError(
750
+ f"Database schema validation failed for '{self._path}': {e}"
751
+ ) from e
752
+
753
+ except Exception as e:
754
+ self._records = self._default_records()
755
+ self._rebuild_indexes()
756
+ raise RuntimeError(
757
+ f"Failed to load database from '{self._path}': {e}"
758
+ ) from e
759
+
760
+ # ==========================================================
761
+ # Statistics
762
+ # ==========================================================
763
+
764
+ def collection_stats(self, collection: str = "default") -> dict[str, Any]:
765
+ """
766
+ Return statistics for a single collection: document count,
767
+ average vector dimension, and the index type in use.
768
+ """
769
+ self._validate_collection(collection)
770
+
771
+ dimensions = [
772
+ len(record.vector)
773
+ for record in self._records[collection].values()
774
+ ]
775
+
776
+ avg_dim = sum(dimensions) / len(dimensions) if dimensions else 0
777
+
778
+ col_indexes = self._records["_config"].get("_collection_indexes", {})
779
+
780
+ return {
781
+ "collection": collection,
782
+ "documents": len(dimensions),
783
+ "average_dimension": avg_dim,
784
+ "index_type": col_indexes.get(collection, "default")
785
+ }
786
+
787
+ def stats(self) -> dict[str, Any]:
788
+ """
789
+ Return aggregate statistics for the entire database: total collection
790
+ count, total document count, average vector dimension across all
791
+ collections, and the file path of the database on disk.
792
+ """
793
+ all_dims = []
794
+
795
+ for collection_name, collection in self._records.items():
796
+ if collection_name == "_config":
797
+ continue
798
+ for record in collection.values():
799
+ all_dims.append(len(record.vector))
800
+
801
+ avg_dim = sum(all_dims) / len(all_dims) if all_dims else 0
802
+
803
+ return {
804
+ "collections": len(self._records) - 1,
805
+ "documents": len(all_dims),
806
+ "average_dimension": avg_dim,
807
+ "database_path": self._path
808
+ }
809
+
810
+ # ==========================================================
811
+ # Magic Methods
812
+ # ==========================================================
813
+
814
+ def __len__(self) -> int:
815
+ """Return the total number of records across all collections."""
816
+ return sum(
817
+ len(collection)
818
+ for name, collection in self._records.items()
819
+ if name != "_config"
820
+ )
821
+
822
+ def __contains__(self, id: str) -> bool:
823
+ """Return True if the given record id exists in any collection."""
824
+ return any(
825
+ id in collection
826
+ for name, collection in self._records.items()
827
+ if name != "_config"
828
+ )
829
+
830
+ def __repr__(self) -> str:
831
+ """Return a concise string representation of the Database instance."""
832
+ return (
833
+ f"Database("
834
+ f"name='{self.db_name}', "
835
+ f"collections={len(self._records) - 1}, "
836
+ f"documents={len(self)}"
837
+ f")"
838
+ )
839
+
840
+ # ==========================================================
841
+ # Import / Export
842
+ # ==========================================================
843
+
844
+ def _validate_consistent_format(self, data: dict) -> None:
845
+ """
846
+ Validate that all records in the given raw dict (as loaded from JSON)
847
+ are plain dicts with the required keys: id, vector, and metadata.
848
+ Used before importing from a JSON export.
849
+ """
850
+ for collection_name, collection in data.items():
851
+ if collection_name == "_config":
852
+ continue
853
+ for record in collection.values():
854
+ if not isinstance(record, dict):
855
+ raise ValueError(f"Invalid record format: {record}")
856
+ if "id" not in record:
857
+ raise ValueError(f"Missing 'id' in record: {record}")
858
+ if "vector" not in record:
859
+ raise ValueError(f"Missing 'vector' in record: {record}")
860
+ if "metadata" not in record:
861
+ raise ValueError(f"Missing 'metadata' in record: {record}")
862
+
863
+ def import_from_json(self, json_path: str) -> None:
864
+ """
865
+ Import records from a JSON file produced by export_to_json.
866
+ Merges collections from the file into the current database,
867
+ restoring the correct index type per collection from _config.
868
+ """
869
+ if not os.path.exists(json_path):
870
+ raise FileNotFoundError(f"JSON file not found: {json_path}")
871
+
872
+ with open(json_path, 'r') as f:
873
+ data = json.load(f)
874
+
875
+ self._validate_consistent_format(data)
876
+
877
+ if "_config" in data:
878
+ self.dimensions = data["_config"].get("dimensions", self.dimensions)
879
+ self.auto_dim = data["_config"].get("auto_dim", self.auto_dim)
880
+ self.auto_persist = data["_config"].get("auto_persist", self.auto_persist)
881
+ self._records["_config"] = data["_config"]
882
+ self._records["_config"].setdefault(
883
+ "_collection_indexes", {"default": "default"}
884
+ )
885
+
886
+ col_indexes = self._records["_config"].get("_collection_indexes", {})
887
+
888
+ for collection_name, collection in data.items():
889
+ if collection_name == "_config":
890
+ continue
891
+
892
+ self._records.setdefault(collection_name, {})
893
+
894
+ index_type = col_indexes.get(collection_name, "default")
895
+ if index_type not in self.INDEXES:
896
+ index_type = "default"
897
+
898
+ self._indexes.setdefault(
899
+ collection_name, self.INDEXES[index_type]()
900
+ )
901
+
902
+ for record_id, record_data in collection.items():
903
+ self._records[collection_name][record_id] = Record(
904
+ id=record_data["id"],
905
+ name=record_data["name"],
906
+ vector=record_data["vector"],
907
+ metadata=record_data["metadata"]
908
+ )
909
+ self._indexes[collection_name].add(record_id, record_data["vector"])
910
+
911
+ if self.auto_persist:
912
+ self.persist_doc()
913
+
914
+ def export_to_json(self, json_path: str) -> None:
915
+ """
916
+ Serialise the entire database to a human-readable JSON file.
917
+ Raises FileExistsError if the target path already exists.
918
+ Record dataclasses are converted to dicts via dataclasses.asdict.
919
+ """
920
+ if os.path.exists(json_path):
921
+ raise FileExistsError(f"JSON file already exists: {json_path}")
922
+
923
+ with open(json_path, 'w') as f:
924
+ json.dump(
925
+ self._records,
926
+ f,
927
+ indent=4,
928
+ default=lambda o: (
929
+ dataclasses.asdict(o) if dataclasses.is_dataclass(o) else o
930
+ )
931
+ )
932
+
933
+ # ==========================================================
934
+ # Batch Management
935
+ # ==========================================================
936
+
937
+ def batch_insert_docs(
938
+ self,
939
+ objects: List[Tuple[str, list[float], dict[str, Any]]],
940
+ collection: str = "default"
941
+ ) -> List[str]:
942
+ """
943
+ Insert multiple records into a collection in one call.
944
+ Accepts a list of (name, vector, metadata) tuples.
945
+ Validates dimensions for each record and persists only once at the end,
946
+ rather than once per record, to avoid repeated disk writes.
947
+ Returns a list of UUIDs in the same order as the input list.
948
+ """
949
+ if not objects:
950
+ raise ValueError("objects must not be empty.")
951
+
952
+ self._validate_collection(collection)
953
+
954
+ record_ids = []
955
+
956
+ for name, vector, metadata in objects:
957
+ record_id = get_uuid()
958
+
959
+ if self.dimensions is None:
960
+ if self.auto_dim:
961
+ self.dimensions = len(vector)
962
+ self._records["_config"]["dimensions"] = self.dimensions
963
+ else:
964
+ raise EmbeddingDimensionError(
965
+ "Parameter `dimensions` cannot be None or undefined when "
966
+ "inserting into the vector database."
967
+ )
968
+
969
+ if len(vector) != self.dimensions:
970
+ raise EmbeddingDimensionError(
971
+ f"Expected embedding dimension {self.dimensions}, "
972
+ f"got {len(vector)}."
973
+ )
974
+
975
+ self._records[collection][record_id] = Record(
976
+ id=record_id,
977
+ name=name,
978
+ vector=vector,
979
+ metadata=metadata or {}
980
+ )
981
+
982
+ self._indexes[collection].add(record_id, vector)
983
+ record_ids.append(record_id)
984
+
985
+ if self.auto_persist:
986
+ self.persist_doc()
987
+
988
+ return record_ids
989
+
990
+ # ==========================================================
991
+ # Index Import / Export
992
+ # ==========================================================
993
+
994
+ def save_index(self, collection: str = "default", path: str = None) -> str:
995
+ """
996
+ Persist a collection's in-memory index to a binary .index file.
997
+ Defaults to <db_name>_<collection>.index alongside the database file.
998
+ Returns the path the index was saved to.
999
+ """
1000
+ self._validate_collection(collection)
1001
+
1002
+ if path is None:
1003
+ base = self._path.replace(".db", "")
1004
+ path = f"{base}_{collection}.index"
1005
+
1006
+ self._indexes[collection].save(path)
1007
+ return path
1008
+
1009
+ def _load_index(self, collection: str = "default", path: str = None) -> None:
1010
+ """
1011
+ Replace a collection's current in-memory index with one loaded from disk.
1012
+ Defaults to <db_name>_<collection>.index alongside the database file.
1013
+ Raises FileNotFoundError if the index file does not exist.
1014
+ """
1015
+ self._validate_collection(collection)
1016
+
1017
+ if path is None:
1018
+ base = self._path.replace(".db", "")
1019
+ path = f"{base}_{collection}.index"
1020
+
1021
+ if not os.path.exists(path):
1022
+ raise FileNotFoundError(f"Index file not found: {path}")
1023
+
1024
+ self._indexes[collection].load(path)