usearch 2.23.0__cp314-cp314t-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
usearch/index.py ADDED
@@ -0,0 +1,1721 @@
1
+ from __future__ import annotations
2
+ from inspect import signature
3
+ from collections.abc import Sequence
4
+
5
+ # The purpose of this file is to provide Pythonic wrapper on top
6
+ # the native precompiled CPython module. It improves compatibility
7
+ # Python tooling, linters, and static analyzers. It also embeds JIT
8
+ # into the primary `Index` class, connecting USearch with Numba.
9
+ import os
10
+ import sys
11
+ import math
12
+ from dataclasses import dataclass
13
+ from typing import (
14
+ Any,
15
+ Optional,
16
+ Union,
17
+ NamedTuple,
18
+ List,
19
+ Iterable,
20
+ Tuple,
21
+ Dict,
22
+ Callable,
23
+ )
24
+
25
+ import numpy as np
26
+ from tqdm import tqdm
27
+
28
+ # Precompiled symbols that won't be exposed directly:
29
+ from usearch.compiled import (
30
+ Index as _CompiledIndex,
31
+ Indexes as _CompiledIndexes,
32
+ IndexStats as _CompiledIndexStats,
33
+ index_dense_metadata_from_path as _index_dense_metadata_from_path,
34
+ index_dense_metadata_from_buffer as _index_dense_metadata_from_buffer,
35
+ exact_search as _exact_search,
36
+ hardware_acceleration as _hardware_acceleration,
37
+ kmeans as _kmeans,
38
+ )
39
+
40
+ # Precompiled symbols that will be exposed
41
+ from usearch.compiled import (
42
+ MetricKind,
43
+ ScalarKind,
44
+ MetricSignature,
45
+ DEFAULT_CONNECTIVITY,
46
+ DEFAULT_EXPANSION_ADD,
47
+ DEFAULT_EXPANSION_SEARCH,
48
+ USES_OPENMP,
49
+ USES_SIMSIMD,
50
+ USES_FP16LIB,
51
+ )
52
+
53
+ MetricKindBitwise = (
54
+ MetricKind.Hamming,
55
+ MetricKind.Tanimoto,
56
+ MetricKind.Sorensen,
57
+ )
58
+
59
+
60
+ class CompiledMetric(NamedTuple):
61
+ pointer: int
62
+ kind: MetricKind
63
+ signature: MetricSignature
64
+
65
+
66
+ # Define TypeAlias for older Python versions
67
+ if sys.version_info >= (3, 10):
68
+ from typing import TypeAlias
69
+ else:
70
+ TypeAlias = object # Fallback for older Python versions
71
+
72
+ Key: TypeAlias = np.uint64
73
+
74
+ NoneType: TypeAlias = type(None)
75
+
76
+ KeyOrKeysLike = Union[Key, Iterable[Key], int, Iterable[int], np.ndarray, memoryview]
77
+
78
+ VectorOrVectorsLike = Union[np.ndarray, Iterable[np.ndarray], memoryview]
79
+
80
+ DTypeLike = Union[str, ScalarKind]
81
+
82
+ MetricLike = Union[str, MetricKind, CompiledMetric]
83
+
84
+ BytesLike = Union[bytes, bytearray, memoryview]
85
+
86
+ PathOrBuffer = Union[str, os.PathLike, BytesLike]
87
+
88
+ ProgressCallback = Callable[[int, int], bool]
89
+
90
+
91
+ def _match_signature(func: Callable[[Any], Any], arg_types: List[type], ret_type: type) -> bool:
92
+ assert callable(func), "Not callable"
93
+ sig = signature(func)
94
+ param_types = [param.annotation for param in sig.parameters.values()]
95
+ return param_types == arg_types and sig.return_annotation == ret_type
96
+
97
+
98
+ def _normalize_dtype(
99
+ dtype,
100
+ ndim: int = 0,
101
+ metric: MetricKind = MetricKind.Cos,
102
+ ) -> ScalarKind:
103
+ if dtype is None or dtype == "":
104
+ if metric in MetricKindBitwise:
105
+ return ScalarKind.B1
106
+ if _hardware_acceleration(dtype=ScalarKind.BF16, ndim=ndim, metric_kind=metric):
107
+ return ScalarKind.BF16
108
+ if _hardware_acceleration(dtype=ScalarKind.F16, ndim=ndim, metric_kind=metric):
109
+ return ScalarKind.F16
110
+ return ScalarKind.F32
111
+
112
+ if isinstance(dtype, ScalarKind):
113
+ return dtype
114
+
115
+ if isinstance(dtype, str):
116
+ dtype = dtype.lower()
117
+
118
+ _normalize = {
119
+ "f64": ScalarKind.F64,
120
+ "f32": ScalarKind.F32,
121
+ "bf16": ScalarKind.BF16,
122
+ "f16": ScalarKind.F16,
123
+ "i8": ScalarKind.I8,
124
+ "b1": ScalarKind.B1,
125
+ "b1x8": ScalarKind.B1,
126
+ "bits": ScalarKind.B1,
127
+ "float64": ScalarKind.F64,
128
+ "float32": ScalarKind.F32,
129
+ "bfloat16": ScalarKind.BF16,
130
+ "float16": ScalarKind.F16,
131
+ "int8": ScalarKind.I8,
132
+ np.float64: ScalarKind.F64,
133
+ np.float32: ScalarKind.F32,
134
+ np.float16: ScalarKind.F16,
135
+ np.int8: ScalarKind.I8,
136
+ np.uint8: ScalarKind.B1,
137
+ }
138
+ return _normalize[dtype]
139
+
140
+
141
+ def _to_numpy_dtype(dtype: ScalarKind):
142
+ if dtype == ScalarKind.BF16:
143
+ return None
144
+ _normalize = {
145
+ ScalarKind.F64: np.float64,
146
+ ScalarKind.F32: np.float32,
147
+ ScalarKind.F16: np.float16,
148
+ ScalarKind.I8: np.int8,
149
+ ScalarKind.B1: np.uint8,
150
+ }
151
+ if dtype in _normalize.values():
152
+ return dtype
153
+ return _normalize[dtype]
154
+
155
+
156
+ def _normalize_metric(metric) -> MetricKind:
157
+ if metric is None:
158
+ return MetricKind.Cos
159
+
160
+ if isinstance(metric, str):
161
+ _normalize = {
162
+ "cos": MetricKind.Cos,
163
+ "cosine": MetricKind.Cos,
164
+ "ip": MetricKind.IP,
165
+ "dot": MetricKind.IP,
166
+ "inner_product": MetricKind.IP,
167
+ "l2sq": MetricKind.L2sq,
168
+ "l2_sq": MetricKind.L2sq,
169
+ "haversine": MetricKind.Haversine,
170
+ "divergence": MetricKind.Divergence,
171
+ "pearson": MetricKind.Pearson,
172
+ "hamming": MetricKind.Hamming,
173
+ "tanimoto": MetricKind.Tanimoto,
174
+ "sorensen": MetricKind.Sorensen,
175
+ }
176
+ return _normalize[metric.lower()]
177
+
178
+ return metric
179
+
180
+
181
+ def _is_buffer(obj: Any) -> bool:
182
+ """Check if the object is a buffer-like object.
183
+ More portable than `hasattr(obj, "__buffer__")`, which requires Python 3.11+."""
184
+ try:
185
+ memoryview(obj)
186
+ return True
187
+ except TypeError:
188
+ return False
189
+
190
+
191
+ def _search_in_compiled(
192
+ compiled_callable: Callable,
193
+ vectors: np.ndarray,
194
+ *,
195
+ log: Union[str, bool],
196
+ progress: Optional[ProgressCallback],
197
+ **kwargs,
198
+ ) -> Union[Matches, BatchMatches]:
199
+ #
200
+ assert isinstance(vectors, np.ndarray), "Expects a NumPy array"
201
+ assert vectors.ndim == 1 or vectors.ndim == 2, "Expects a matrix or vector"
202
+ assert not progress or _match_signature(progress, [int, int], bool), "Invalid callback"
203
+
204
+ if vectors.ndim == 1:
205
+ vectors = vectors.reshape(1, len(vectors))
206
+ count_vectors = vectors.shape[0]
207
+
208
+ def distill_batch(
209
+ batch_matches: BatchMatches,
210
+ ) -> Union[BatchMatches, Matches]:
211
+ return batch_matches[0] if count_vectors == 1 else batch_matches
212
+
213
+ progress_callback = progress
214
+
215
+ # Create progress bar if needed
216
+ if log:
217
+ name = log if isinstance(log, str) else "Search"
218
+ progress_bar = tqdm(
219
+ desc=name,
220
+ total=count_vectors,
221
+ unit="vector",
222
+ )
223
+
224
+ user_progress = progress
225
+
226
+ def update_progress_bar(processed: int, total: int) -> bool:
227
+ progress_bar.update(processed - progress_bar.n)
228
+ if user_progress:
229
+ return user_progress(processed, total)
230
+ return True
231
+
232
+ progress_callback = update_progress_bar
233
+
234
+ if progress_callback:
235
+ tuple_ = compiled_callable(vectors, progress=progress_callback, **kwargs)
236
+ else:
237
+ tuple_ = compiled_callable(vectors, **kwargs)
238
+
239
+ if log:
240
+ progress_bar.close()
241
+
242
+ return distill_batch(BatchMatches(*tuple_))
243
+
244
+
245
+ def _add_to_compiled(
246
+ compiled,
247
+ *,
248
+ keys,
249
+ vectors,
250
+ copy: bool,
251
+ threads: int,
252
+ log: Union[str, bool],
253
+ progress: Optional[ProgressCallback],
254
+ ) -> Union[int, np.ndarray]:
255
+ #
256
+ assert isinstance(vectors, np.ndarray), "Expects a NumPy array"
257
+ assert not progress or _match_signature(progress, [int, int], bool), "Invalid callback"
258
+ assert vectors.ndim == 1 or vectors.ndim == 2, "Expects a matrix or vector"
259
+ if vectors.ndim == 1:
260
+ vectors = vectors.reshape(1, len(vectors))
261
+
262
+ # Validate or generate the keys
263
+ count_vectors = vectors.shape[0]
264
+ generate_labels = keys is None
265
+ if generate_labels:
266
+ start_id = len(compiled)
267
+ keys = np.arange(start_id, start_id + count_vectors, dtype=Key)
268
+ else:
269
+ if not isinstance(keys, Iterable):
270
+ assert count_vectors == 1, "Each vector must have a key"
271
+ keys = [keys]
272
+ keys = np.array(keys).astype(Key)
273
+
274
+ assert len(keys) == count_vectors
275
+
276
+ # Create progress bar if needed
277
+ if log:
278
+ name = log if isinstance(log, str) else "Add"
279
+ progress_bar = tqdm(
280
+ desc=name,
281
+ total=count_vectors,
282
+ unit="vector",
283
+ )
284
+
285
+ def update_progress_bar(processed: int, total: int) -> bool:
286
+ progress_bar.update(processed - progress_bar.n)
287
+ return progress(processed, total) if progress else True
288
+
289
+ compiled.add_many(
290
+ keys,
291
+ vectors,
292
+ copy=copy,
293
+ threads=threads,
294
+ progress=update_progress_bar,
295
+ )
296
+ progress_bar.close()
297
+ else:
298
+ compiled.add_many(keys, vectors, copy=copy, threads=threads, progress=progress)
299
+
300
+ return keys
301
+
302
+
303
+ @dataclass
304
+ class Match:
305
+ """Single search result with key and distance."""
306
+
307
+ key: int
308
+ distance: float
309
+
310
+ def to_tuple(self) -> tuple:
311
+ return self.key, self.distance
312
+
313
+
314
+ @dataclass
315
+ class Matches:
316
+ """Search results for a single query."""
317
+
318
+ keys: np.ndarray
319
+ distances: np.ndarray
320
+
321
+ visited_members: int = 0
322
+ computed_distances: int = 0
323
+
324
+ def __len__(self) -> int:
325
+ return len(self.keys)
326
+
327
+ def __getitem__(self, index: int) -> Match:
328
+ if isinstance(index, int) and index < len(self):
329
+ return Match(
330
+ key=self.keys[index],
331
+ distance=self.distances[index],
332
+ )
333
+ else:
334
+ raise IndexError(f"`index` must be an integer under {len(self)}")
335
+
336
+ def to_list(self) -> List[tuple]:
337
+ """Convert to list of (key, distance) tuples."""
338
+ return [(int(key), float(distance)) for key, distance in zip(self.keys, self.distances)]
339
+
340
+ def __repr__(self) -> str:
341
+ return f"usearch.Matches({len(self)})"
342
+
343
+
344
+ @dataclass
345
+ class BatchMatches(Sequence):
346
+ """Search results for multiple queries in batch operations.
347
+
348
+ Unused positions in arrays contain sentinel values (default keys, max distances).
349
+ Access individual results via indexing: batch_matches[i] returns valid matches only.
350
+
351
+ Attributes:
352
+ keys: 2D array of shape (n_queries, k) containing match keys
353
+ distances: 2D array of shape (n_queries, k) containing distances
354
+ counts: 1D array of shape (n_queries,) with actual number of matches per query
355
+ visited_members: Total graph nodes visited during search
356
+ computed_distances: Total distance computations performed
357
+ """
358
+
359
+ keys: np.ndarray
360
+ distances: np.ndarray
361
+ counts: np.ndarray
362
+
363
+ visited_members: int = 0
364
+ computed_distances: int = 0
365
+
366
+ def __len__(self) -> int:
367
+ return len(self.counts)
368
+
369
+ def __getitem__(self, index: int) -> Matches:
370
+ if isinstance(index, int) and index < len(self):
371
+ return Matches(
372
+ keys=self.keys[index, : self.counts[index]],
373
+ distances=self.distances[index, : self.counts[index]],
374
+ visited_members=self.visited_members // len(self),
375
+ computed_distances=self.computed_distances // len(self),
376
+ )
377
+ else:
378
+ raise IndexError(f"`index` must be an integer under {len(self)}")
379
+
380
+ def to_list(self) -> List[List[tuple]]:
381
+ """Flatten matches for all queries into a list of `(key, distance)` tuples."""
382
+ list_of_matches = [self.__getitem__(row) for row in range(self.__len__())]
383
+ return [match.to_tuple() for matches in list_of_matches for match in matches]
384
+
385
+ def mean_recall(self, expected: np.ndarray, count: Optional[int] = None) -> float:
386
+ """Measures recall [0, 1] as of `Matches` that contain the corresponding
387
+ `expected` entry anywhere among results."""
388
+ return self.count_matches(expected, count=count) / len(expected)
389
+
390
+ def count_matches(self, expected: np.ndarray, count: Optional[int] = None) -> int:
391
+ """Measures recall [0, len(expected)] as of `Matches` that contain the corresponding
392
+ `expected` entry anywhere among results.
393
+ """
394
+ assert len(expected) == len(self)
395
+ recall = 0
396
+ if count is None:
397
+ count = self.keys.shape[1]
398
+
399
+ if count == 1:
400
+ recall = np.sum(self.keys[:, 0] == expected)
401
+ else:
402
+ for i in range(len(self)):
403
+ recall += expected[i] in self.keys[i, :count]
404
+ return recall
405
+
406
+ def __repr__(self) -> str:
407
+ return f"usearch.BatchMatches({np.sum(self.counts)} across {len(self)} queries)"
408
+
409
+
410
+ @dataclass
411
+ class Clustering:
412
+ def __init__(
413
+ self,
414
+ index: Index,
415
+ matches: BatchMatches,
416
+ queries: Optional[np.ndarray] = None,
417
+ ) -> None:
418
+ if queries is None:
419
+ queries = index._compiled.get_keys_in_slice()
420
+ self.index = index
421
+ self.queries = queries
422
+ self.matches = matches
423
+
424
+ def __repr__(self) -> str:
425
+ return f"usearch.Clustering(for {len(self.queries)} queries)"
426
+
427
+ @property
428
+ def centroids_popularity(self) -> Tuple[np.ndarray, np.ndarray]:
429
+ return np.unique(self.matches.keys, return_counts=True)
430
+
431
+ def members_of(self, centroid: Key) -> np.ndarray:
432
+ return self.queries[self.matches.keys.flatten() == centroid]
433
+
434
+ def subcluster(self, centroid: Key, **clustering_kwargs) -> Clustering:
435
+ sub_keys = self.members_of(centroid)
436
+ return self.index.cluster(keys=sub_keys, **clustering_kwargs)
437
+
438
+ def plot_centroids_popularity(self):
439
+ from matplotlib import pyplot as plt
440
+
441
+ _, sizes = self.centroids_popularity
442
+ plt.yscale("log")
443
+ plt.plot(sorted(sizes), np.arange(len(sizes)))
444
+ plt.show()
445
+
446
+ @property
447
+ def network(self):
448
+ import networkx as nx
449
+
450
+ keys, sizes = self.centroids_popularity
451
+
452
+ g = nx.Graph()
453
+ for key, size in zip(keys, sizes):
454
+ g.add_node(key, size=size)
455
+
456
+ for i, i_key in enumerate(keys):
457
+ for j_key in keys[:i]:
458
+ d = self.index.pairwise_distance(i_key, j_key)
459
+ g.add_edge(i_key, j_key, distance=d)
460
+
461
+ return g
462
+
463
+
464
+ class IndexedKeys(Sequence):
465
+ """View of all keys in the index."""
466
+
467
+ def __init__(self, index: Index) -> None:
468
+ self.index = index
469
+
470
+ def __len__(self) -> int:
471
+ return len(self.index)
472
+
473
+ def __getitem__(
474
+ self,
475
+ offset_offsets_or_slice: Union[int, np.ndarray, slice],
476
+ ) -> Union[Key, np.ndarray]:
477
+ if isinstance(offset_offsets_or_slice, slice):
478
+ start, stop, step = offset_offsets_or_slice.indices(len(self))
479
+ if step != 1:
480
+ raise ValueError("Slicing with a step is not supported")
481
+ return self.index._compiled.get_keys_in_slice(start, stop - start)
482
+
483
+ elif isinstance(offset_offsets_or_slice, Iterable):
484
+ offsets = np.array(offset_offsets_or_slice)
485
+ return self.index._compiled.get_keys_at_offsets(offsets)
486
+
487
+ else:
488
+ offset = int(offset_offsets_or_slice)
489
+ if offset < 0:
490
+ offset += len(self)
491
+ if offset < 0 or offset >= len(self):
492
+ raise IndexError("Index out of range")
493
+ return self.index._compiled.get_key_at_offset(offset)
494
+
495
+ def __array__(self, dtype=None) -> np.ndarray:
496
+ if dtype is None:
497
+ dtype = Key
498
+ return self.index._compiled.get_keys_in_slice().astype(dtype)
499
+
500
+
501
+ class Index:
502
+ """Fast approximate nearest neighbor search for dense vectors.
503
+
504
+ Supports various distance metrics (cosine, euclidean, inner product, etc.)
505
+ and automatic precision optimization. Vector keys must be integers.
506
+ All vectors must have the same dimensionality.
507
+
508
+ Example:
509
+ >>> index = Index(ndim=128, metric='cos')
510
+ >>> index.add(key=42, vector=np.random.rand(128))
511
+ >>> matches = index.search(query_vector, count=10)
512
+ """
513
+
514
+ def __init__(
515
+ self,
516
+ *, # All arguments must be named
517
+ ndim: int = 0,
518
+ metric: MetricLike = MetricKind.Cos,
519
+ dtype: Optional[DTypeLike] = None,
520
+ connectivity: Optional[int] = None,
521
+ expansion_add: Optional[int] = None,
522
+ expansion_search: Optional[int] = None,
523
+ multi: bool = False,
524
+ path: Optional[os.PathLike] = None,
525
+ view: bool = False,
526
+ enable_key_lookups: bool = True,
527
+ ) -> None:
528
+ """Construct the index and compiles the functions, if requested (expensive).
529
+
530
+ :param ndim: Number of vector dimensions
531
+ :type ndim: int
532
+ Required for some metrics, pre-set for others.
533
+ Haversine, for example, only applies to 2-dimensional latitude/longitude
534
+ coordinates. Angular (Cos) and Euclidean (L2sq), obviously, apply to
535
+ vectors with arbitrary number of dimensions.
536
+
537
+ :param metric: Distance function
538
+ :type metric: MetricLike, defaults to MetricKind.Cos
539
+ Kind of the distance function, or the Numba `cfunc` JIT-compiled object.
540
+ Possible `MetricKind` values: IP, Cos, L2sq, Haversine, Pearson,
541
+ Hamming, Tanimoto, Sorensen.
542
+
543
+ :param dtype: Scalar type for internal vector storage
544
+ :type dtype: Optional[DTypeLike], defaults to None
545
+ For continuous metrics can be: f16, f32, f64, or i8.
546
+ For bitwise metrics it's implementation-defined, and can't change.
547
+ If nothing is provided, the optimal data type is selected based on the metric
548
+ kind and hardware support.
549
+ Example: you can use the `f16` index with `f32` vectors in Euclidean space,
550
+ which will be automatically downcasted. Moreover, if `dtype=None` is passed,
551
+ and hardware supports `f16` SIMD-instructions, this choice will be done for you.
552
+ You can later double-check the used representation with `index.dtype`.
553
+
554
+ :param connectivity: Connections per node in HNSW
555
+ :type connectivity: Optional[int], defaults to None
556
+ Hyper-parameter for the number of Graph connections
557
+ per layer of HNSW. The original paper calls it "M".
558
+ Optional, but can't be changed after construction.
559
+
560
+ :param expansion_add: Traversal depth on insertions
561
+ :type expansion_add: Optional[int], defaults to None
562
+ Hyper-parameter for the search depth when inserting new
563
+ vectors. The original paper calls it "efConstruction".
564
+ Can be changed afterwards, as the `.expansion_add`.
565
+
566
+ :param expansion_search: Traversal depth on queries
567
+ :type expansion_search: Optional[int], defaults to None
568
+ Hyper-parameter for the search depth when querying
569
+ nearest neighbors. The original paper calls it "ef".
570
+ Can be changed afterwards, as the `.expansion_search`.
571
+
572
+ :param multi: Allow multiple vectors with the same key
573
+ :type multi: bool, defaults to True
574
+ :param path: Where to store the index
575
+ :type path: Optional[os.PathLike], defaults to None
576
+ :param view: Are we simply viewing an immutable index
577
+ :type view: bool, defaults to False
578
+ """
579
+
580
+ if connectivity is None:
581
+ connectivity = DEFAULT_CONNECTIVITY
582
+ if expansion_add is None:
583
+ expansion_add = DEFAULT_EXPANSION_ADD
584
+ if expansion_search is None:
585
+ expansion_search = DEFAULT_EXPANSION_SEARCH
586
+
587
+ assert isinstance(connectivity, int), "Expects integer `connectivity`"
588
+ assert isinstance(expansion_add, int), "Expects integer `expansion_add`"
589
+ assert isinstance(expansion_search, int), "Expects integer `expansion_search`"
590
+
591
+ metric = _normalize_metric(metric)
592
+ if isinstance(metric, MetricKind):
593
+ self._metric_kind = metric
594
+ self._metric_jit = None
595
+ self._metric_pointer = 0
596
+ self._metric_signature = MetricSignature.ArrayArraySize
597
+ elif isinstance(metric, CompiledMetric):
598
+ self._metric_jit = metric
599
+ self._metric_kind = metric.kind
600
+ self._metric_pointer = metric.pointer
601
+ self._metric_signature = metric.signature
602
+ else:
603
+ raise ValueError("The `metric` must be a `CompiledMetric` or a `MetricKind`")
604
+
605
+ # Validate, that the right scalar type is defined
606
+ dtype = _normalize_dtype(dtype, ndim, self._metric_kind)
607
+ self._compiled = _CompiledIndex(
608
+ ndim=ndim,
609
+ dtype=dtype,
610
+ connectivity=connectivity,
611
+ expansion_add=expansion_add,
612
+ expansion_search=expansion_search,
613
+ multi=multi,
614
+ enable_key_lookups=enable_key_lookups,
615
+ metric_kind=self._metric_kind,
616
+ metric_pointer=self._metric_pointer,
617
+ metric_signature=self._metric_signature,
618
+ )
619
+
620
+ self.path = path
621
+ if path is not None and os.path.exists(path):
622
+ if view:
623
+ self.view(path)
624
+ else:
625
+ self.load(path)
626
+
627
+ @staticmethod
628
+ def metadata(path_or_buffer: PathOrBuffer) -> Optional[dict]:
629
+ try:
630
+ if _is_buffer(path_or_buffer):
631
+ return _index_dense_metadata_from_buffer(path_or_buffer)
632
+ else:
633
+ path_or_buffer = os.fspath(path_or_buffer)
634
+ if not os.path.exists(path_or_buffer):
635
+ return None
636
+ return _index_dense_metadata_from_path(path_or_buffer)
637
+ except Exception as e:
638
+ raise e
639
+
640
+ @staticmethod
641
+ def restore(path_or_buffer: PathOrBuffer, view: bool = False, **kwargs) -> Optional[Index]:
642
+ meta = Index.metadata(path_or_buffer)
643
+ if not meta:
644
+ return None
645
+
646
+ index = Index(
647
+ ndim=meta["dimensions"],
648
+ dtype=meta["kind_scalar"],
649
+ metric=meta["kind_metric"],
650
+ **kwargs,
651
+ )
652
+
653
+ if view:
654
+ index.view(path_or_buffer)
655
+ else:
656
+ index.load(path_or_buffer)
657
+ return index
658
+
659
+ def __len__(self) -> int:
660
+ return self._compiled.__len__()
661
+
662
+ def add(
663
+ self,
664
+ keys: KeyOrKeysLike,
665
+ vectors: VectorOrVectorsLike,
666
+ *,
667
+ copy: bool = True,
668
+ threads: int = 0,
669
+ log: Union[str, bool] = False,
670
+ progress: Optional[ProgressCallback] = None,
671
+ ) -> Union[int, np.ndarray]:
672
+ """Inserts one or move vectors into the index.
673
+
674
+ For maximal performance the `keys` and `vectors`
675
+ should conform to the Python's "buffer protocol" spec.
676
+
677
+ To index a single entry:
678
+ keys: int, vectors: np.ndarray.
679
+ To index many entries:
680
+ keys: np.ndarray, vectors: np.ndarray.
681
+
682
+ When working with extremely large indexes, you may want to
683
+ pass `copy=False`, if you can guarantee the lifetime of the
684
+ primary vectors store during the process of construction.
685
+
686
+ :param keys: Unique identifier(s) for passed vectors
687
+ :type keys: Optional[KeyOrKeysLike], can be `None`
688
+ :param vectors: Vector or a row-major matrix
689
+ :type vectors: VectorOrVectorsLike
690
+ :param copy: Should the index store a copy of vectors
691
+ :type copy: bool, defaults to True
692
+ :param threads: Optimal number of cores to use
693
+ :type threads: int, defaults to 0
694
+ :param log: Whether to print the progress bar
695
+ :type log: Union[str, bool], defaults to False
696
+ :param progress: Callback to report stats of the progress and control it
697
+ :type progress: Optional[ProgressCallback], defaults to None
698
+ :return: Inserted key or keys
699
+ :type: Union[int, np.ndarray]
700
+ """
701
+ return _add_to_compiled(
702
+ self._compiled,
703
+ keys=keys,
704
+ vectors=vectors,
705
+ copy=copy,
706
+ threads=threads,
707
+ log=log,
708
+ progress=progress,
709
+ )
710
+
711
+ def search(
712
+ self,
713
+ vectors: VectorOrVectorsLike,
714
+ count: int = 10,
715
+ radius: float = math.inf,
716
+ *,
717
+ threads: int = 0,
718
+ exact: bool = False,
719
+ log: Union[str, bool] = False,
720
+ progress: Optional[ProgressCallback] = None,
721
+ ) -> Union[Matches, BatchMatches]:
722
+ """Performs approximate nearest neighbors search for one or more queries.
723
+
724
+ When searching with batch queries, returns BatchMatches that pre-allocates arrays
725
+ for the requested `count` size. If fewer matches exist than requested (e.g., when
726
+ count > index size), use individual query access via batch_matches[i] to get only
727
+ valid results, or check batch_matches.counts to see actual result counts per query.
728
+
729
+ :param vectors: Query vector or vectors.
730
+ :type vectors: VectorOrVectorsLike
731
+ :param count: Upper count on the number of matches to find
732
+ :type count: int, defaults to 10
733
+ When count > index size, only available vectors will be returned.
734
+ For BatchMatches, unused positions contain sentinel values.
735
+ :param threads: Optimal number of cores to use
736
+ :type threads: int, defaults to 0
737
+ :param exact: Perform exhaustive linear-time exact search
738
+ :type exact: bool, defaults to False
739
+ :param log: Whether to print the progress bar, default to False
740
+ :type log: Union[str, bool], optional
741
+ :param progress: Callback to report stats of the progress and control it
742
+ :type progress: Optional[ProgressCallback], defaults to None
743
+ :return: Matches for one or more queries
744
+ :rtype: Union[Matches, BatchMatches]
745
+ For single queries: Matches with only valid results
746
+ For batch queries: BatchMatches - use indexing for individual results
747
+ """
748
+
749
+ return _search_in_compiled(
750
+ self._compiled.search_many,
751
+ vectors,
752
+ # Batch scheduling:
753
+ log=log,
754
+ # Search constraints:
755
+ count=count,
756
+ exact=exact,
757
+ threads=threads,
758
+ progress=progress,
759
+ )
760
+
761
+ def contains(self, keys: KeyOrKeysLike) -> Union[bool, np.ndarray]:
762
+ if isinstance(keys, Iterable):
763
+ return self._compiled.contains_many(np.array(keys, dtype=Key))
764
+ else:
765
+ return self._compiled.contains_one(int(keys))
766
+
767
+ def __contains__(self, keys: KeyOrKeysLike) -> Union[bool, np.ndarray]:
768
+ return self.contains(keys)
769
+
770
+ def count(self, keys: KeyOrKeysLike) -> Union[int, np.ndarray]:
771
+ if isinstance(keys, Iterable):
772
+ return self._compiled.count_many(np.array(keys, dtype=Key))
773
+ else:
774
+ return self._compiled.count_one(int(keys))
775
+
776
+ def get(
777
+ self,
778
+ keys: KeyOrKeysLike,
779
+ dtype: Optional[DTypeLike] = None,
780
+ ) -> Union[Optional[np.ndarray], Tuple[Optional[np.ndarray]]]:
781
+ """Looks up one or more keys from the `Index`, retrieving corresponding vectors.
782
+
783
+ Returns `None`, if one key is requested, and its not present.
784
+ Returns a (row) vector, if the key maps into a single vector.
785
+ Returns a (row-major) matrix, if the key maps into a multiple vectors.
786
+ If multiple keys are requested, composes many such responses into a `tuple`.
787
+
788
+ :param keys: One or more keys to lookup
789
+ :type keys: KeyOrKeysLike
790
+ :return: One or more keys lookup results
791
+ :rtype: Union[Optional[np.ndarray], Tuple[Optional[np.ndarray]]]
792
+ """
793
+ if not dtype:
794
+ dtype = self.dtype
795
+ view_dtype = _to_numpy_dtype(dtype)
796
+ if view_dtype is None:
797
+ dtype = ScalarKind.F32
798
+ view_dtype = np.float32
799
+ else:
800
+ dtype = _normalize_dtype(dtype)
801
+ view_dtype = _to_numpy_dtype(dtype)
802
+ if view_dtype is None:
803
+ raise NotImplementedError("The requested representation type is not supported by NumPy")
804
+
805
+ def cast(result):
806
+ if result is not None:
807
+ return result.view(view_dtype)
808
+ return result
809
+
810
+ is_one = not isinstance(keys, Iterable)
811
+ if is_one:
812
+ keys = [keys]
813
+ if not isinstance(keys, np.ndarray):
814
+ keys = np.array(keys, dtype=Key)
815
+ else:
816
+ keys = keys.astype(Key)
817
+
818
+ results = self._compiled.get_many(keys, dtype)
819
+ results = cast(results) if isinstance(results, np.ndarray) else [cast(result) for result in results]
820
+ return results[0] if is_one else results
821
+
822
+ def __getitem__(self, keys: KeyOrKeysLike) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
823
+ """Looks up one or more keys from the `Index`, retrieving corresponding vectors.
824
+
825
+ Returns `None`, if one key is requested, and its not present.
826
+ Returns a (row) vector, if the key maps into a single vector.
827
+ Returns a (row-major) matrix, if the key maps into a multiple vectors.
828
+ If multiple keys are requested, composes many such responses into a `tuple`.
829
+
830
+ :param keys: One or more keys to lookup
831
+ :type keys: KeyOrKeysLike
832
+ :return: One or more keys lookup results
833
+ :rtype: Union[Optional[np.ndarray], Tuple[Optional[np.ndarray]]]
834
+ """
835
+ return self.get(keys)
836
+
837
+ def remove(
838
+ self,
839
+ keys: KeyOrKeysLike,
840
+ *,
841
+ compact: bool = False,
842
+ threads: int = 0,
843
+ ) -> Union[int, np.ndarray]:
844
+ """Removes one or move vectors from the index.
845
+
846
+ When working with extremely large indexes, you may want to
847
+ mark some entries deleted, instead of rebuilding a filtered index.
848
+ In other cases, rebuilding - is the recommended approach.
849
+
850
+ :param keys: Unique identifier for passed vectors, optional
851
+ :type keys: KeyOrKeysLike
852
+ :param compact: Removes links to removed nodes (expensive), defaults to False
853
+ :type compact: bool, optional
854
+ :param threads: Optimal number of cores to use, defaults to 0
855
+ :type threads: int, optional
856
+ :return: Array of integers for the number of removed vectors per key
857
+ :type: Union[int, np.ndarray]
858
+ """
859
+ if not isinstance(keys, Iterable):
860
+ return self._compiled.remove_one(keys, compact=compact, threads=threads)
861
+ else:
862
+ keys = np.array(keys, dtype=Key)
863
+ return self._compiled.remove_many(keys, compact=compact, threads=threads)
864
+
865
+ def __delitem__(self, keys: KeyOrKeysLike) -> Union[int, np.ndarray]:
866
+ return self.remove(keys)
867
+
868
+ def rename(
869
+ self,
870
+ from_: KeyOrKeysLike,
871
+ to: KeyOrKeysLike,
872
+ ) -> Union[int, np.ndarray]:
873
+ """Rename existing member vector or vectors.
874
+
875
+ May be used in iterative clustering procedures, where one would iteratively
876
+ relabel every vector with the name of the cluster an entry belongs to, until
877
+ the system converges.
878
+
879
+ :param from_: One or more keys to be renamed
880
+ :type from_: KeyOrKeysLike
881
+ :param to: New name or names (of identical length as `from_`)
882
+ :type to: KeyOrKeysLike
883
+ :return: Number of vectors that were found and renamed
884
+ :rtype: int
885
+ """
886
+ if isinstance(from_, Iterable):
887
+ from_ = np.array(from_, dtype=Key)
888
+ if isinstance(to, Iterable):
889
+ to = np.array(to, dtype=Key)
890
+ return self._compiled.rename_many_to_many(from_, to)
891
+
892
+ else:
893
+ return self._compiled.rename_many_to_one(from_, int(to))
894
+
895
+ else:
896
+ return self._compiled.rename_one_to_one(int(from_), int(to))
897
+
898
+ @property
899
+ def jit(self) -> bool:
900
+ """
901
+ :return: True, if the provided `metric` was JIT-ed
902
+ :rtype: bool
903
+ """
904
+ return self._metric_jit is not None
905
+
906
+ @property
907
+ def hardware_acceleration(self) -> str:
908
+ """Describes the kind of hardware-acceleration support used in this instance.
909
+
910
+ This indicates the type of hardware acceleration that is available and
911
+ being utilized for the current index configuration, including the metric
912
+ kind and number of dimensions.
913
+
914
+ :return: "auto" if no hardware acceleration is available, otherwise an ISA subset name.
915
+ :rtype: str
916
+ """
917
+ return self._compiled.hardware_acceleration
918
+
919
+ @property
920
+ def size(self) -> int:
921
+ """Returns the number of vectors currently indexed.
922
+
923
+ :return: The number of vectors in the index.
924
+ :rtype: int
925
+ """
926
+ return self._compiled.size
927
+
928
+ @property
929
+ def ndim(self) -> int:
930
+ """Returns the number of dimensions for vectors in the index.
931
+
932
+ :return: The dimensionality of vectors in the index.
933
+ :rtype: int
934
+ """
935
+ return self._compiled.ndim
936
+
937
+ @property
938
+ def serialized_length(self) -> int:
939
+ """Returns the length in bytes required to serialize the index.
940
+
941
+ :return: The serialized length of the index in bytes.
942
+ :rtype: int
943
+ """
944
+ return self._compiled.serialized_length
945
+
946
+ @property
947
+ def metric_kind(self) -> Union[MetricKind, CompiledMetric]:
948
+ """Returns the type of metric used for distance calculations.
949
+
950
+ :return: The metric kind used in the index.
951
+ :rtype: Union[MetricKind, CompiledMetric]
952
+ """
953
+ return self._metric_jit.kind if self._metric_jit else self._metric_kind
954
+
955
+ @property
956
+ def metric(self) -> Union[MetricKind, CompiledMetric]:
957
+ """Returns the metric object used for distance calculations.
958
+
959
+ :return: The metric used in the index.
960
+ :rtype: Union[MetricKind, CompiledMetric]
961
+ """
962
+ return self._metric_jit if self._metric_jit else self._metric_kind
963
+
964
+ @metric.setter
965
+ def metric(self, metric: MetricLike):
966
+ """Sets a new metric for the index.
967
+
968
+ :param metric: The new metric to be used.
969
+ :type metric: MetricLike
970
+ :raises ValueError: If the metric is not of type `CompiledMetric` or `MetricKind`.
971
+ """
972
+ metric = _normalize_metric(metric)
973
+ if isinstance(metric, MetricKind):
974
+ metric_kind = metric
975
+ metric_pointer = 0
976
+ metric_signature = MetricSignature.ArrayArraySize
977
+ elif isinstance(metric, CompiledMetric):
978
+ metric_kind = metric.kind
979
+ metric_pointer = metric.pointer
980
+ metric_signature = metric.signature
981
+ else:
982
+ raise ValueError("The `metric` must be a `CompiledMetric` or a `MetricKind`")
983
+
984
+ return self._compiled.change_metric(
985
+ metric_kind=metric_kind,
986
+ metric_pointer=metric_pointer,
987
+ metric_signature=metric_signature,
988
+ )
989
+
990
+ @property
991
+ def dtype(self) -> ScalarKind:
992
+ """Returns the data type of the vectors in the index.
993
+
994
+ :return: The data type of the vectors.
995
+ :rtype: ScalarKind
996
+ """
997
+ return self._compiled.dtype
998
+
999
+ @property
1000
+ def connectivity(self) -> int:
1001
+ """Returns the connectivity parameter of the index.
1002
+
1003
+ This parameter controls how many neighbors each node in the graph is connected to.
1004
+
1005
+ :return: The connectivity of the index.
1006
+ :rtype: int
1007
+ """
1008
+ return self._compiled.connectivity
1009
+
1010
+ @property
1011
+ def capacity(self) -> int:
1012
+ """Returns the current capacity of the index.
1013
+
1014
+ This indicates the maximum number of vectors that can be indexed without reallocation.
1015
+
1016
+ :return: The capacity of the index.
1017
+ :rtype: int
1018
+ """
1019
+ return self._compiled.capacity
1020
+
1021
+ @property
1022
+ def memory_usage(self) -> int:
1023
+ """Returns the memory usage of the index in bytes.
1024
+
1025
+ :return: The memory usage of the index.
1026
+ :rtype: int
1027
+ """
1028
+ return self._compiled.memory_usage
1029
+
1030
+ @property
1031
+ def expansion_add(self) -> int:
1032
+ """Returns the expansion parameter used during addition.
1033
+
1034
+ This parameter controls how many candidates are considered when adding new vectors to the index.
1035
+
1036
+ :return: The expansion parameter for additions.
1037
+ :rtype: int
1038
+ """
1039
+ return self._compiled.expansion_add
1040
+
1041
+ @property
1042
+ def expansion_search(self) -> int:
1043
+ """Returns the expansion parameter used during searches.
1044
+
1045
+ This parameter controls how many candidates are considered when searching in the index.
1046
+
1047
+ :return: The expansion parameter for searches.
1048
+ :rtype: int
1049
+ """
1050
+ return self._compiled.expansion_search
1051
+
1052
+ @expansion_add.setter
1053
+ def expansion_add(self, v: int):
1054
+ """Sets the expansion parameter used during addition.
1055
+
1056
+ :param v: The new expansion parameter for additions.
1057
+ :type v: int
1058
+ """
1059
+ self._compiled.expansion_add = v
1060
+
1061
+ @expansion_search.setter
1062
+ def expansion_search(self, v: int):
1063
+ """Sets the expansion parameter used during searches.
1064
+
1065
+ :param v: The new expansion parameter for searches.
1066
+ :type v: int
1067
+ """
1068
+ self._compiled.expansion_search = v
1069
+
1070
+ def save(
1071
+ self,
1072
+ path_or_buffer: Union[str, os.PathLike, NoneType] = None,
1073
+ progress: Optional[ProgressCallback] = None,
1074
+ ) -> Optional[bytes]:
1075
+ """Saves the index to a file or buffer.
1076
+
1077
+ If `path_or_buffer` is not provided, it defaults to the path stored in `self.path`.
1078
+
1079
+ :param path_or_buffer: The path or buffer where the index will be saved.
1080
+ :type path_or_buffer: Union[str, os.PathLike, NoneType], optional
1081
+ :param progress: A callback function for progress tracking.
1082
+ :type progress: Optional[ProgressCallback], optional
1083
+ :return: The index data as bytes if saving to a buffer, otherwise None.
1084
+ :rtype: Optional[bytes]
1085
+ """
1086
+ assert not progress or _match_signature(progress, [int, int], bool), "Invalid callback signature"
1087
+
1088
+ path_or_buffer = path_or_buffer if path_or_buffer is not None else self.path
1089
+ if path_or_buffer is None:
1090
+ return self._compiled.save_index_to_buffer(progress)
1091
+ else:
1092
+ self._compiled.save_index_to_path(os.fspath(path_or_buffer), progress)
1093
+
1094
+ def load(
1095
+ self,
1096
+ path_or_buffer: Union[PathOrBuffer, NoneType] = None,
1097
+ progress: Optional[ProgressCallback] = None,
1098
+ ):
1099
+ """Loads the index from a file or buffer.
1100
+
1101
+ If `path_or_buffer` is not provided, it defaults to the path stored in `self.path`.
1102
+
1103
+ :param path_or_buffer: The path or buffer from which the index will be loaded.
1104
+ :type path_or_buffer: Union[str, os.PathLike, BytesLike, NoneType], optional
1105
+ :param progress: A callback function for progress tracking.
1106
+ :type progress: Optional[ProgressCallback], optional
1107
+ :raises Exception: If no source is defined.
1108
+ :raises RuntimeError: If the file does not exist.
1109
+ """
1110
+ assert not progress or _match_signature(progress, [int, int], bool), "Invalid callback signature"
1111
+
1112
+ path_or_buffer = path_or_buffer if path_or_buffer is not None else self.path
1113
+ if path_or_buffer is None:
1114
+ raise ValueError("path_or_buffer is required")
1115
+ if _is_buffer(path_or_buffer):
1116
+ self._compiled.load_index_from_buffer(path_or_buffer, progress)
1117
+ else:
1118
+ path_or_buffer = os.fspath(path_or_buffer)
1119
+ if os.path.exists(path_or_buffer):
1120
+ self._compiled.load_index_from_path(path_or_buffer, progress)
1121
+ else:
1122
+ raise FileNotFoundError(f"File not found: {path_or_buffer}")
1123
+
1124
+ def view(
1125
+ self,
1126
+ path_or_buffer: Union[PathOrBuffer, NoneType] = None,
1127
+ progress: Optional[ProgressCallback] = None,
1128
+ ):
1129
+ """Maps the index from a file or buffer without loading it into memory.
1130
+
1131
+ If `path_or_buffer` is not provided, it defaults to the path stored in `self.path`.
1132
+
1133
+ :param path_or_buffer: The path or buffer to map the index from.
1134
+ :type path_or_buffer: Union[str, os.PathLike, bytes, bytearray, NoneType], optional
1135
+ :param progress: A callback function for progress tracking.
1136
+ :type progress: Optional[ProgressCallback], optional
1137
+ :raises Exception: If no source is defined.
1138
+ """
1139
+ assert not progress or _match_signature(progress, [int, int], bool), "Invalid callback signature"
1140
+
1141
+ path_or_buffer = path_or_buffer if path_or_buffer is not None else self.path
1142
+ if path_or_buffer is None:
1143
+ raise ValueError("path_or_buffer is required")
1144
+ if _is_buffer(path_or_buffer):
1145
+ self._compiled.view_index_from_buffer(path_or_buffer, progress)
1146
+ else:
1147
+ self._compiled.view_index_from_path(os.fspath(path_or_buffer), progress)
1148
+
1149
+ def clear(self):
1150
+ """Erases all vectors from the index, preserving the allocated space for future insertions."""
1151
+ self._compiled.clear()
1152
+
1153
+ def reset(self):
1154
+ """Erases all data from the index, closes any open files, and returns allocated memory to the OS."""
1155
+ if not hasattr(self, "_compiled"):
1156
+ return
1157
+ self._compiled.reset()
1158
+
1159
+ def __del__(self):
1160
+ """Destructor method to reset the index when the object is deleted."""
1161
+ self.reset()
1162
+
1163
+ def copy(self) -> Index:
1164
+ """Creates a copy of the current index.
1165
+
1166
+ :return: A new instance of the Index class with the same configuration and data.
1167
+ :rtype: Index
1168
+ """
1169
+ result = Index(
1170
+ ndim=self.ndim,
1171
+ metric=self.metric,
1172
+ dtype=self.dtype,
1173
+ connectivity=self.connectivity,
1174
+ expansion_add=self.expansion_add,
1175
+ expansion_search=self.expansion_search,
1176
+ path=self.path,
1177
+ )
1178
+ result._compiled = self._compiled.copy()
1179
+ return result
1180
+
1181
+ def join(
1182
+ self,
1183
+ other: Index,
1184
+ max_proposals: int = 0,
1185
+ exact: bool = False,
1186
+ progress: Optional[ProgressCallback] = None,
1187
+ ) -> Dict[Key, Key]:
1188
+ """Performs "Semantic Join" or pairwise matching between `self` & `other` index.
1189
+ Is different from `search`, as no collisions are allowed in resulting pairs.
1190
+ Uses the concept of "Stable Marriages" from Combinatorics, famous for the 2012
1191
+ Nobel Prize in Economics.
1192
+
1193
+ :param other: Another index.
1194
+ :type other: Index
1195
+ :param max_proposals: Limit on candidates evaluated per vector, defaults to 0
1196
+ :type max_proposals: int, optional
1197
+ :param exact: Controls if underlying `search` should be exact, defaults to False
1198
+ :type exact: bool, optional
1199
+ :param progress: Callback to report stats of the progress and control it
1200
+ :type progress: Optional[ProgressCallback], defaults to None
1201
+ :return: Mapping from keys of `self` to keys of `other`
1202
+ :rtype: Dict[Key, Key]
1203
+ """
1204
+ assert not progress or _match_signature(progress, [int, int], bool), "Invalid callback signature"
1205
+
1206
+ return self._compiled.join(
1207
+ other=other._compiled,
1208
+ max_proposals=max_proposals,
1209
+ exact=exact,
1210
+ progress=progress,
1211
+ )
1212
+
1213
+ def cluster(
1214
+ self,
1215
+ *,
1216
+ vectors: Optional[np.ndarray] = None,
1217
+ keys: Optional[np.ndarray] = None,
1218
+ min_count: Optional[int] = None,
1219
+ max_count: Optional[int] = None,
1220
+ threads: int = 0,
1221
+ log: Union[str, bool] = False,
1222
+ progress: Optional[ProgressCallback] = None,
1223
+ ) -> Clustering:
1224
+ """
1225
+ Clusters already indexed or provided `vectors`, mapping them to various centroids.
1226
+
1227
+ :param vectors: .
1228
+ :type vectors: Optional[VectorOrVectorsLike]
1229
+ :param count: Upper bound on the number of clusters to produce
1230
+ :type count: Optional[int], defaults to None
1231
+
1232
+ :param threads: Optimal number of cores to use,
1233
+ :type threads: int, defaults to 0
1234
+ :param log: Whether to print the progress bar
1235
+ :type log: Union[str, bool], defaults to False
1236
+ :param progress: Callback to report stats of the progress and control it
1237
+ :type progress: Optional[ProgressCallback], defaults to None
1238
+ :return: Matches for one or more queries
1239
+ :rtype: Union[Matches, BatchMatches]
1240
+ """
1241
+ assert not progress or _match_signature(progress, [int, int], bool), "Invalid callback signature"
1242
+
1243
+ if min_count is None:
1244
+ min_count = 0
1245
+ if max_count is None:
1246
+ max_count = 0
1247
+
1248
+ if vectors is not None:
1249
+ assert keys is None, "You can either cluster vectors or member keys"
1250
+ results = self._compiled.cluster_vectors(
1251
+ vectors,
1252
+ min_count=min_count,
1253
+ max_count=max_count,
1254
+ threads=threads,
1255
+ progress=progress,
1256
+ )
1257
+ else:
1258
+ if keys is None:
1259
+ keys = self._compiled.get_keys_in_slice()
1260
+ if not isinstance(keys, np.ndarray):
1261
+ keys = np.array(keys)
1262
+ keys = keys.astype(Key)
1263
+ results = self._compiled.cluster_keys(
1264
+ keys,
1265
+ min_count=min_count,
1266
+ max_count=max_count,
1267
+ threads=threads,
1268
+ progress=progress,
1269
+ )
1270
+
1271
+ batch_matches = BatchMatches(*results)
1272
+ return Clustering(self, batch_matches, keys)
1273
+
1274
+ def pairwise_distance(self, left: KeyOrKeysLike, right: KeyOrKeysLike) -> Union[np.ndarray, float]:
1275
+ """Computes the pairwise distance between keys or key arrays.
1276
+
1277
+ If `left` and `right` are single keys, returns the distance between them.
1278
+ If `left` and `right` are arrays of keys, returns a matrix of pairwise distances.
1279
+
1280
+ :param left: A single key or an iterable of keys.
1281
+ :type left: KeyOrKeysLike
1282
+ :param right: A single key or an iterable of keys.
1283
+ :type right: KeyOrKeysLike
1284
+ :return: Pairwise distance(s) between the provided keys.
1285
+ :rtype: Union[np.ndarray, float]
1286
+ """
1287
+ assert isinstance(left, Iterable) == isinstance(right, Iterable)
1288
+
1289
+ if not isinstance(left, Iterable):
1290
+ return self._compiled.pairwise_distance(int(left), int(right))
1291
+ else:
1292
+ left = np.array(left).astype(Key)
1293
+ right = np.array(right).astype(Key)
1294
+ return self._compiled.pairwise_distances(left, right)
1295
+
1296
+ @property
1297
+ def keys(self) -> IndexedKeys:
1298
+ """Returns all keys currently indexed.
1299
+
1300
+ :return: All indexed keys.
1301
+ :rtype: IndexedKeys
1302
+ """
1303
+ return IndexedKeys(self)
1304
+
1305
+ @property
1306
+ def vectors(self) -> np.ndarray:
1307
+ """Retrieves all vectors associated with the indexed keys.
1308
+
1309
+ :return: Array of vectors.
1310
+ :rtype: np.ndarray
1311
+ """
1312
+ return self.get(self.keys)
1313
+
1314
+ @property
1315
+ def max_level(self) -> int:
1316
+ """Returns the maximum level in the multi-level graph.
1317
+
1318
+ :return: The maximum level in the graph.
1319
+ :rtype: int
1320
+ """
1321
+ return self._compiled.max_level
1322
+
1323
+ @property
1324
+ def nlevels(self) -> int:
1325
+ """Returns the number of levels in the multi-level graph.
1326
+
1327
+ :return: Number of levels in the graph.
1328
+ :rtype: int
1329
+ """
1330
+ return self._compiled.max_level + 1
1331
+
1332
+ @property
1333
+ def multi(self) -> bool:
1334
+ """Indicates whether the index supports multi-value entries.
1335
+
1336
+ :return: True if the index supports multi-value entries, False otherwise.
1337
+ :rtype: bool
1338
+ """
1339
+ return self._compiled.multi
1340
+
1341
+ @property
1342
+ def stats(self) -> _CompiledIndexStats:
1343
+ """Get the accumulated statistics for the entire multi-level graph.
1344
+
1345
+ :return: Statistics for the entire multi-level graph.
1346
+ :rtype: _CompiledIndexStats
1347
+
1348
+ Statistics:
1349
+ - `nodes` (int): Number of nodes in the graph.
1350
+ - `edges` (int): Number of edges in the graph.
1351
+ - `max_edges` (int): Maximum possible number of edges in the graph.
1352
+ - `allocated_bytes` (int): Memory allocated for the graph.
1353
+ """
1354
+ return self._compiled.stats
1355
+
1356
+ @property
1357
+ def levels_stats(self) -> List[_CompiledIndexStats]:
1358
+ """Get the accumulated statistics for each level of the graph.
1359
+
1360
+ :return: List of statistics for each level of the graph.
1361
+ :rtype: List[_CompiledIndexStats]
1362
+
1363
+ Statistics for each level:
1364
+ - `nodes` (int): Number of nodes in the level.
1365
+ - `edges` (int): Number of edges in the level.
1366
+ - `max_edges` (int): Maximum possible number of edges in the level.
1367
+ - `allocated_bytes` (int): Memory allocated for the level.
1368
+ """
1369
+ return self._compiled.levels_stats
1370
+
1371
+ def level_stats(self, level: int) -> _CompiledIndexStats:
1372
+ """Get statistics for a specific level of the graph.
1373
+
1374
+ :param level: The level for which to retrieve statistics.
1375
+ :type level: int
1376
+ :return: Statistics for the specified level.
1377
+ :rtype: _CompiledIndexStats
1378
+
1379
+ Statistics:
1380
+ - `nodes` (int): Number of nodes in the level.
1381
+ - `edges` (int): Number of edges in the level.
1382
+ - `max_edges` (int): Maximum possible number of edges in the level.
1383
+ - `allocated_bytes` (int): Memory allocated for the level.
1384
+ """
1385
+ return self._compiled.level_stats(level)
1386
+
1387
+ @property
1388
+ def specs(self) -> Dict[str, Union[str, int, bool]]:
1389
+ """Returns the specifications of the index.
1390
+
1391
+ :return: Dictionary of index specifications.
1392
+ :rtype: Dict[str, Union[str, int, bool]]
1393
+ """
1394
+ if not hasattr(self, "_compiled"):
1395
+ return "usearch.Index(failed)"
1396
+ return {
1397
+ "type": "usearch.Index",
1398
+ "ndim": self.ndim,
1399
+ "multi": self.multi,
1400
+ "connectivity": self.connectivity,
1401
+ "expansion_add": self.expansion_add,
1402
+ "expansion_search": self.expansion_search,
1403
+ "size": self.size,
1404
+ "jit": self.jit,
1405
+ "hardware_acceleration": self.hardware_acceleration,
1406
+ "metric_kind": self.metric_kind,
1407
+ "dtype": self.dtype,
1408
+ "path": self.path,
1409
+ "compiled_with_openmp": USES_OPENMP,
1410
+ "compiled_with_simsimd": USES_SIMSIMD,
1411
+ "compiled_with_native_f16": USES_FP16LIB,
1412
+ }
1413
+
1414
+ def __repr__(self) -> str:
1415
+ """Returns a string representation of the index object.
1416
+
1417
+ :return: String representation of the index.
1418
+ :rtype: str
1419
+ """
1420
+ if not hasattr(self, "_compiled"):
1421
+ return "usearch.Index(failed)"
1422
+ f = (
1423
+ "usearch.Index({} x {}, {}, multi: {}, connectivity: {}, "
1424
+ "expansion: {} & {}, {:,} vectors in {} levels, {} hardware acceleration)"
1425
+ )
1426
+ return f.format(
1427
+ self.dtype,
1428
+ self.ndim,
1429
+ self.metric_kind,
1430
+ self.multi,
1431
+ self.connectivity,
1432
+ self.expansion_add,
1433
+ self.expansion_search,
1434
+ len(self),
1435
+ self.nlevels,
1436
+ self.hardware_acceleration,
1437
+ )
1438
+
1439
+ def __repr_pretty__(self) -> str:
1440
+ """Returns a pretty-printed string representation of the index object.
1441
+
1442
+ :return: Pretty-printed string representation of the index.
1443
+ :rtype: str
1444
+ """
1445
+ if not hasattr(self, "_compiled"):
1446
+ return "usearch.Index(failed)"
1447
+ level_stats = [f"--- {i}. {self.level_stats(i).nodes:,} nodes" for i in range(self.nlevels)]
1448
+ lines = "\n".join(
1449
+ [
1450
+ "usearch.Index",
1451
+ "- config",
1452
+ f"-- data type: {self.dtype}",
1453
+ f"-- dimensions: {self.ndim}",
1454
+ f"-- metric: {self.metric_kind}",
1455
+ f"-- multi: {self.multi}",
1456
+ f"-- connectivity: {self.connectivity}",
1457
+ f"-- expansion on addition :{self.expansion_add} candidates",
1458
+ f"-- expansion on search: {self.expansion_search} candidates",
1459
+ "- binary",
1460
+ f"-- uses OpenMP: {USES_OPENMP}",
1461
+ f"-- uses SimSIMD: {USES_SIMSIMD}",
1462
+ f"-- supports half-precision: {USES_FP16LIB}",
1463
+ f"-- uses hardware acceleration: {self.hardware_acceleration}",
1464
+ "- state",
1465
+ f"-- size: {self.size:,} vectors",
1466
+ f"-- memory usage: {self.memory_usage:,} bytes",
1467
+ f"-- max level: {self.max_level}",
1468
+ *level_stats,
1469
+ ]
1470
+ )
1471
+ return lines
1472
+
1473
+ def _repr_pretty_(self, printer, cycle):
1474
+ """Handles pretty-printing of the object within interactive environments.
1475
+
1476
+ :param printer: The pretty printer instance.
1477
+ :type printer: Any
1478
+ :param cycle: Cycle flag indicating recursion.
1479
+ :type cycle: bool
1480
+ """
1481
+ printer.text(self.__repr_pretty__())
1482
+
1483
+
1484
+ class Indexes:
1485
+ def __init__(
1486
+ self,
1487
+ indexes: Iterable[Index] = [],
1488
+ paths: Iterable[os.PathLike] = [],
1489
+ view: bool = False,
1490
+ threads: int = 0,
1491
+ ) -> None:
1492
+ self._compiled = _CompiledIndexes()
1493
+ for index in indexes:
1494
+ self._compiled.merge(index._compiled)
1495
+ self._compiled.merge_paths(paths, view=view, threads=threads)
1496
+
1497
+ def merge(self, index: Index):
1498
+ self._compiled.merge(index._compiled)
1499
+
1500
+ def merge_path(self, path: os.PathLike):
1501
+ self._compiled.merge_path(os.fspath(path))
1502
+
1503
+ def __len__(self) -> int:
1504
+ return self._compiled.__len__()
1505
+
1506
+ def search(
1507
+ self,
1508
+ vectors,
1509
+ count: int = 10,
1510
+ *,
1511
+ threads: int = 0,
1512
+ exact: bool = False,
1513
+ progress: Optional[ProgressCallback] = None,
1514
+ ):
1515
+ return _search_in_compiled(
1516
+ self._compiled.search_many,
1517
+ vectors,
1518
+ # Batch scheduling:
1519
+ log=False,
1520
+ # Search constraints:
1521
+ count=count,
1522
+ exact=exact,
1523
+ threads=threads,
1524
+ progress=progress,
1525
+ )
1526
+
1527
+
1528
+ def search(
1529
+ dataset: np.ndarray,
1530
+ query: np.ndarray,
1531
+ count: int = 10,
1532
+ metric: MetricLike = MetricKind.Cos,
1533
+ *,
1534
+ exact: bool = False,
1535
+ threads: int = 0,
1536
+ log: Union[str, bool] = False,
1537
+ progress: Optional[ProgressCallback] = None,
1538
+ ) -> Union[Matches, BatchMatches]:
1539
+ """Shortcut for search, that can avoid index construction. Particularly useful for
1540
+ tiny datasets, where brute-force exact search works fast enough.
1541
+
1542
+ :param dataset: Row-major matrix.
1543
+ :type dataset: np.ndarray
1544
+ :param query: Query vector or vectors (also row-major), to find in `dataset`.
1545
+ :type query: np.ndarray
1546
+
1547
+ :param count: Upper count on the number of matches to find, defaults to 10
1548
+ :type count: int, optional
1549
+
1550
+ :param metric: Distance function
1551
+ :type metric: MetricLike, defaults to MetricKind.Cos
1552
+ Kind of the distance function, or the Numba `cfunc` JIT-compiled object.
1553
+ Possible `MetricKind` values: IP, Cos, L2sq, Haversine, Pearson,
1554
+ Hamming, Tanimoto, Sorensen.
1555
+
1556
+ :param threads: Optimal number of cores to use, defaults to 0
1557
+ :type threads: int, optional
1558
+ :param exact: Perform exhaustive linear-time exact search, defaults to False
1559
+ :type exact: bool, optional
1560
+ :param log: Whether to print the progress bar, default to False
1561
+ :type log: Union[str, bool], optional
1562
+ :param progress: Callback to report stats of the progress and control it
1563
+ :type progress: Optional[ProgressCallback], defaults to None
1564
+ :return: Matches for one or more queries
1565
+ :rtype: Union[Matches, BatchMatches]
1566
+ """
1567
+ assert not progress or _match_signature(progress, [int, int], bool), "Invalid callback signature"
1568
+ assert dataset.ndim == 2, "Dataset must be a matrix, with a vector in each row"
1569
+
1570
+ if not exact:
1571
+ index = Index(
1572
+ ndim=dataset.shape[1],
1573
+ metric=metric,
1574
+ dtype=dataset.dtype,
1575
+ )
1576
+ index.add(
1577
+ None,
1578
+ dataset,
1579
+ threads=threads,
1580
+ log=log,
1581
+ progress=progress,
1582
+ )
1583
+ return index.search(
1584
+ query,
1585
+ count,
1586
+ threads=threads,
1587
+ log=log,
1588
+ progress=progress,
1589
+ )
1590
+
1591
+ metric = _normalize_metric(metric)
1592
+ if isinstance(metric, MetricKind):
1593
+ metric_kind = metric
1594
+ metric_pointer = 0
1595
+ metric_signature = MetricSignature.ArrayArraySize
1596
+ elif isinstance(metric, CompiledMetric):
1597
+ metric_kind = metric.kind
1598
+ metric_pointer = metric.pointer
1599
+ metric_signature = metric.signature
1600
+ else:
1601
+ raise ValueError("The `metric` must be a `CompiledMetric` or a `MetricKind`")
1602
+
1603
+ def search_batch(query, **kwargs):
1604
+ assert dataset.shape[1] == query.shape[1], "Number of dimensions differs"
1605
+ if dataset.dtype != query.dtype:
1606
+ query = query.astype(dataset.dtype)
1607
+
1608
+ return _exact_search(
1609
+ dataset,
1610
+ query,
1611
+ metric_kind=metric_kind,
1612
+ metric_signature=metric_signature,
1613
+ metric_pointer=metric_pointer,
1614
+ **kwargs,
1615
+ )
1616
+
1617
+ return _search_in_compiled(
1618
+ search_batch,
1619
+ query,
1620
+ # Batch scheduling:
1621
+ log=log,
1622
+ # Search constraints:
1623
+ count=count,
1624
+ threads=threads,
1625
+ progress=progress,
1626
+ )
1627
+
1628
+
1629
+ def kmeans(
1630
+ X,
1631
+ k,
1632
+ metric: str = "l2sq",
1633
+ dtype: str = "bf16",
1634
+ max_iterations: int = 300,
1635
+ inertia_threshold: float = 1e-4,
1636
+ max_seconds: float = 60.0,
1637
+ min_shifts: float = 0.01,
1638
+ seed: Optional[int] = None,
1639
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
1640
+ """
1641
+ Performs KMeans clustering on a dataset using the USearch library with mixed-precision support.
1642
+
1643
+ This function clusters the given dataset `X` into `k` clusters by iteratively assigning points
1644
+ to the nearest centroids and updating the centroids based on the mean of the points assigned to them.
1645
+ The algorithm supports mixed-precision types and early termination based on convergence criteria
1646
+ like the number of iterations, inertia threshold, maximum runtime, and minimum point shifts.
1647
+
1648
+ Parameters
1649
+ ----------
1650
+ X : numpy.ndarray
1651
+ The input data, where each row represents a data point and each column represents a feature.
1652
+ k : int
1653
+ The number of clusters to form.
1654
+ metric : str, optional
1655
+ The distance metric used to calculate the distance between points and centroids.
1656
+ Default is "l2sq" (squared Euclidean distance). Cosine "cos" distance is also supported.
1657
+ dtype : str, optional
1658
+ The data type used for clustering calculations. Default is "bf16" (Brain Float 16).
1659
+ Other supported types include "f32" (float32) and "f64" (float64), "f16" (float16),
1660
+ "i8" (int8), and b1 (boolean) bit-packed vectors.
1661
+ max_iterations : int, optional
1662
+ The maximum number of iterations the algorithm should run. Default is 300.
1663
+ inertia_threshold : float, optional
1664
+ The threshold for inertia (sum of squared distances to centroids) to terminate early.
1665
+ When the change in inertia between iterations falls below this value, the algorithm stops.
1666
+ Default is 1e-4.
1667
+ max_seconds : float, optional
1668
+ The maximum allowable runtime for the algorithm in seconds. If exceeded, the algorithm
1669
+ terminates early. Default is 60.0 seconds.
1670
+ min_shifts : float, optional
1671
+ The minimum fraction of points that must change their assigned cluster between iterations
1672
+ to continue. If fewer than this fraction of points change clusters, the algorithm terminates.
1673
+ Default is 0.01 (1% of the total points).
1674
+ seed : int, optional
1675
+ The random seed used to initialize the centroids. Default is None.
1676
+
1677
+ Returns
1678
+ -------
1679
+ assignments : numpy.ndarray
1680
+ An array containing the index of the assigned cluster for each point in the dataset.
1681
+ distances : numpy.ndarray
1682
+ An array containing the distance of each point to its assigned cluster centroid.
1683
+ centroids : numpy.ndarray
1684
+ The final centroids of the clusters.
1685
+
1686
+ Raises
1687
+ ------
1688
+ ValueError
1689
+ If any of the input parameters are invalid, such as the number of clusters being greater
1690
+ than the number of data points.
1691
+
1692
+ Notes
1693
+ -----
1694
+ This implementation utilizes mixed-precision computation to speed up the clustering process
1695
+ while maintaining accuracy. It also incorporates early exit conditions to avoid unnecessary
1696
+ computation when the clustering has stabilized, either by reaching a minimal inertia threshold,
1697
+ exceeding the maximum runtime, or when very few points are changing clusters between iterations.
1698
+
1699
+ Example
1700
+ -------
1701
+ >>> X = np.random.rand(100, 10)
1702
+ >>> k = 5
1703
+ >>> assignments, distances, centroids = usearch.index.kmeans(X, k)
1704
+ """
1705
+ metric = _normalize_metric(metric)
1706
+ dtype = _normalize_dtype(dtype, ndim=X.shape[1], metric=metric)
1707
+
1708
+ # Generating a 64-bit unsigned integer in NumPy may be somewhat tricky.
1709
+ seed = np.random.default_rng().integers(0, 2**64, dtype=np.uint64) if seed is None else seed
1710
+ assignments, distances, centroids = _kmeans(
1711
+ X,
1712
+ k,
1713
+ metric_kind=metric,
1714
+ max_iterations=max_iterations,
1715
+ max_seconds=max_seconds,
1716
+ min_shifts=min_shifts,
1717
+ inertia_threshold=inertia_threshold,
1718
+ dtype=dtype,
1719
+ seed=seed,
1720
+ )
1721
+ return assignments, distances, centroids