faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexHNSW.h>
11
9
 
12
10
  #include <omp.h>
@@ -17,43 +15,26 @@
17
15
  #include <cstdlib>
18
16
  #include <cstring>
19
17
 
18
+ #include <limits>
19
+ #include <memory>
20
20
  #include <queue>
21
+ #include <random>
21
22
  #include <unordered_set>
22
23
 
23
- #include <stdint.h>
24
24
  #include <sys/stat.h>
25
25
  #include <sys/types.h>
26
+ #include <cstdint>
26
27
 
27
28
  #include <faiss/Index2Layer.h>
28
29
  #include <faiss/IndexFlat.h>
29
30
  #include <faiss/IndexIVFPQ.h>
30
31
  #include <faiss/impl/AuxIndexStructures.h>
31
32
  #include <faiss/impl/FaissAssert.h>
32
- #include <faiss/utils/Heap.h>
33
+ #include <faiss/impl/ResultHandler.h>
33
34
  #include <faiss/utils/distances.h>
34
35
  #include <faiss/utils/random.h>
35
36
  #include <faiss/utils/sorting.h>
36
37
 
37
- extern "C" {
38
-
39
- /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
40
-
41
- int sgemm_(
42
- const char* transa,
43
- const char* transb,
44
- FINTEGER* m,
45
- FINTEGER* n,
46
- FINTEGER* k,
47
- const float* alpha,
48
- const float* a,
49
- FINTEGER* lda,
50
- const float* b,
51
- FINTEGER* ldb,
52
- float* beta,
53
- float* c,
54
- FINTEGER* ldc);
55
- }
56
-
57
38
  namespace faiss {
58
39
 
59
40
  using MinimaxHeap = HNSW::MinimaxHeap;
@@ -68,35 +49,6 @@ HNSWStats hnsw_stats;
68
49
 
69
50
  namespace {
70
51
 
71
- /* Wrap the distance computer into one that negates the
72
- distances. This makes supporting INNER_PRODUCE search easier */
73
-
74
- struct NegativeDistanceComputer : DistanceComputer {
75
- /// owned by this
76
- DistanceComputer* basedis;
77
-
78
- explicit NegativeDistanceComputer(DistanceComputer* basedis)
79
- : basedis(basedis) {}
80
-
81
- void set_query(const float* x) override {
82
- basedis->set_query(x);
83
- }
84
-
85
- /// compute distance of vector i to current query
86
- float operator()(idx_t i) override {
87
- return -(*basedis)(i);
88
- }
89
-
90
- /// compute distance between two stored vectors
91
- float symmetric_dis(idx_t i, idx_t j) override {
92
- return -basedis->symmetric_dis(i, j);
93
- }
94
-
95
- virtual ~NegativeDistanceComputer() {
96
- delete basedis;
97
- }
98
- };
99
-
100
52
  DistanceComputer* storage_distance_computer(const Index* storage) {
101
53
  if (is_similarity_metric(storage->metric_type)) {
102
54
  return new NegativeDistanceComputer(storage->get_distance_computer());
@@ -175,7 +127,9 @@ void hnsw_add_vertices(
175
127
 
176
128
  int i1 = n;
177
129
 
178
- for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) {
130
+ for (int pt_level = hist.size() - 1;
131
+ pt_level >= !index_hnsw.init_level0;
132
+ pt_level--) {
179
133
  int i0 = i1 - hist[pt_level];
180
134
 
181
135
  if (verbose) {
@@ -192,9 +146,8 @@ void hnsw_add_vertices(
192
146
  {
193
147
  VisitedTable vt(ntotal);
194
148
 
195
- DistanceComputer* dis =
196
- storage_distance_computer(index_hnsw.storage);
197
- ScopeDeleter1<DistanceComputer> del(dis);
149
+ std::unique_ptr<DistanceComputer> dis(
150
+ storage_distance_computer(index_hnsw.storage));
198
151
  int prev_display =
199
152
  verbose && omp_get_thread_num() == 0 ? 0 : -1;
200
153
  size_t counter = 0;
@@ -212,7 +165,13 @@ void hnsw_add_vertices(
212
165
  continue;
213
166
  }
214
167
 
215
- hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt);
168
+ hnsw.add_with_locks(
169
+ *dis,
170
+ pt_level,
171
+ pt_id,
172
+ locks,
173
+ vt,
174
+ index_hnsw.keep_max_size_level0 && (pt_level == 0));
216
175
 
217
176
  if (prev_display >= 0 && i - i0 > prev_display + 10000) {
218
177
  prev_display = i - i0;
@@ -232,7 +191,11 @@ void hnsw_add_vertices(
232
191
  }
233
192
  i1 = i0;
234
193
  }
235
- FAISS_ASSERT(i1 == 0);
194
+ if (index_hnsw.init_level0) {
195
+ FAISS_ASSERT(i1 == 0);
196
+ } else {
197
+ FAISS_ASSERT((i1 - hist[0]) == 0);
198
+ }
236
199
  }
237
200
  if (verbose) {
238
201
  printf("Done in %.3f ms\n", getmillisecs() - t0);
@@ -250,18 +213,10 @@ void hnsw_add_vertices(
250
213
  **************************************************************/
251
214
 
252
215
  IndexHNSW::IndexHNSW(int d, int M, MetricType metric)
253
- : Index(d, metric),
254
- hnsw(M),
255
- own_fields(false),
256
- storage(nullptr),
257
- reconstruct_from_neighbors(nullptr) {}
216
+ : Index(d, metric), hnsw(M) {}
258
217
 
259
218
  IndexHNSW::IndexHNSW(Index* storage, int M)
260
- : Index(storage->d, storage->metric_type),
261
- hnsw(M),
262
- own_fields(false),
263
- storage(storage),
264
- reconstruct_from_neighbors(nullptr) {}
219
+ : Index(storage->d, storage->metric_type), hnsw(M), storage(storage) {}
265
220
 
266
221
  IndexHNSW::~IndexHNSW() {
267
222
  if (own_fields) {
@@ -278,18 +233,21 @@ void IndexHNSW::train(idx_t n, const float* x) {
278
233
  is_trained = true;
279
234
  }
280
235
 
281
- void IndexHNSW::search(
236
+ namespace {
237
+
238
+ template <class BlockResultHandler>
239
+ void hnsw_search(
240
+ const IndexHNSW* index,
282
241
  idx_t n,
283
242
  const float* x,
284
- idx_t k,
285
- float* distances,
286
- idx_t* labels,
287
- const SearchParameters* params_in) const {
288
- FAISS_THROW_IF_NOT(k > 0);
243
+ BlockResultHandler& bres,
244
+ const SearchParameters* params_in) {
289
245
  FAISS_THROW_IF_NOT_MSG(
290
- storage,
291
- "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
246
+ index->storage,
247
+ "No storage index, please use IndexHNSWFlat (or variants) "
248
+ "instead of IndexHNSW directly");
292
249
  const SearchParametersHNSW* params = nullptr;
250
+ const HNSW& hnsw = index->hnsw;
293
251
 
294
252
  int efSearch = hnsw.efSearch;
295
253
  if (params_in) {
@@ -297,63 +255,82 @@ void IndexHNSW::search(
297
255
  FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
298
256
  efSearch = params->efSearch;
299
257
  }
300
- size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
258
+ size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
301
259
 
302
- idx_t check_period =
303
- InterruptCallback::get_period_hint(hnsw.max_level * d * efSearch);
260
+ idx_t check_period = InterruptCallback::get_period_hint(
261
+ hnsw.max_level * index->d * efSearch);
304
262
 
305
263
  for (idx_t i0 = 0; i0 < n; i0 += check_period) {
306
264
  idx_t i1 = std::min(i0 + check_period, n);
307
265
 
308
- #pragma omp parallel
266
+ #pragma omp parallel if (i1 - i0 > 1)
309
267
  {
310
- VisitedTable vt(ntotal);
268
+ VisitedTable vt(index->ntotal);
269
+ typename BlockResultHandler::SingleResultHandler res(bres);
311
270
 
312
- DistanceComputer* dis = storage_distance_computer(storage);
313
- ScopeDeleter1<DistanceComputer> del(dis);
271
+ std::unique_ptr<DistanceComputer> dis(
272
+ storage_distance_computer(index->storage));
314
273
 
315
- #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
274
+ #pragma omp for reduction(+ : n1, n2, ndis, nhops) schedule(guided)
316
275
  for (idx_t i = i0; i < i1; i++) {
317
- idx_t* idxi = labels + i * k;
318
- float* simi = distances + i * k;
319
- dis->set_query(x + i * d);
276
+ res.begin(i);
277
+ dis->set_query(x + i * index->d);
320
278
 
321
- maxheap_heapify(k, simi, idxi);
322
- HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt, params);
279
+ HNSWStats stats = hnsw.search(*dis, res, vt, params);
323
280
  n1 += stats.n1;
324
281
  n2 += stats.n2;
325
- n3 += stats.n3;
326
282
  ndis += stats.ndis;
327
- nreorder += stats.nreorder;
328
- maxheap_reorder(k, simi, idxi);
329
-
330
- if (reconstruct_from_neighbors &&
331
- reconstruct_from_neighbors->k_reorder != 0) {
332
- int k_reorder = reconstruct_from_neighbors->k_reorder;
333
- if (k_reorder == -1 || k_reorder > k)
334
- k_reorder = k;
335
-
336
- nreorder += reconstruct_from_neighbors->compute_distances(
337
- k_reorder, idxi, x + i * d, simi);
338
-
339
- // sort top k_reorder
340
- maxheap_heapify(
341
- k_reorder, simi, idxi, simi, idxi, k_reorder);
342
- maxheap_reorder(k_reorder, simi, idxi);
343
- }
283
+ nhops += stats.nhops;
284
+ res.end();
344
285
  }
345
286
  }
346
287
  InterruptCallback::check();
347
288
  }
348
289
 
349
- if (is_similarity_metric(metric_type)) {
290
+ hnsw_stats.combine({n1, n2, ndis, nhops});
291
+ }
292
+
293
+ } // anonymous namespace
294
+
295
+ void IndexHNSW::search(
296
+ idx_t n,
297
+ const float* x,
298
+ idx_t k,
299
+ float* distances,
300
+ idx_t* labels,
301
+ const SearchParameters* params_in) const {
302
+ FAISS_THROW_IF_NOT(k > 0);
303
+
304
+ using RH = HeapBlockResultHandler<HNSW::C>;
305
+ RH bres(n, distances, labels, k);
306
+
307
+ hnsw_search(this, n, x, bres, params_in);
308
+
309
+ if (is_similarity_metric(this->metric_type)) {
350
310
  // we need to revert the negated distances
351
311
  for (size_t i = 0; i < k * n; i++) {
352
312
  distances[i] = -distances[i];
353
313
  }
354
314
  }
315
+ }
316
+
317
+ void IndexHNSW::range_search(
318
+ idx_t n,
319
+ const float* x,
320
+ float radius,
321
+ RangeSearchResult* result,
322
+ const SearchParameters* params) const {
323
+ using RH = RangeSearchBlockResultHandler<HNSW::C>;
324
+ RH bres(result, is_similarity_metric(metric_type) ? -radius : radius);
325
+
326
+ hnsw_search(this, n, x, bres, params);
355
327
 
356
- hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
328
+ if (is_similarity_metric(this->metric_type)) {
329
+ // we need to revert the negated distances
330
+ for (size_t i = 0; i < result->lims[result->nq]; i++) {
331
+ result->distances[i] = -result->distances[i];
332
+ }
333
+ }
357
334
  }
358
335
 
359
336
  void IndexHNSW::add(idx_t n, const float* x) {
@@ -381,8 +358,8 @@ void IndexHNSW::reconstruct(idx_t key, float* recons) const {
381
358
  void IndexHNSW::shrink_level_0_neighbors(int new_size) {
382
359
  #pragma omp parallel
383
360
  {
384
- DistanceComputer* dis = storage_distance_computer(storage);
385
- ScopeDeleter1<DistanceComputer> del(dis);
361
+ std::unique_ptr<DistanceComputer> dis(
362
+ storage_distance_computer(storage));
386
363
 
387
364
  #pragma omp for
388
365
  for (idx_t i = 0; i < ntotal; i++) {
@@ -423,45 +400,59 @@ void IndexHNSW::search_level_0(
423
400
  float* distances,
424
401
  idx_t* labels,
425
402
  int nprobe,
426
- int search_type) const {
403
+ int search_type,
404
+ const SearchParameters* params_in) const {
427
405
  FAISS_THROW_IF_NOT(k > 0);
428
406
  FAISS_THROW_IF_NOT(nprobe > 0);
429
407
 
408
+ const SearchParametersHNSW* params = nullptr;
409
+
410
+ if (params_in) {
411
+ params = dynamic_cast<const SearchParametersHNSW*>(params_in);
412
+ FAISS_THROW_IF_NOT_MSG(params, "params type invalid");
413
+ }
414
+
430
415
  storage_idx_t ntotal = hnsw.levels.size();
431
416
 
417
+ using RH = HeapBlockResultHandler<HNSW::C>;
418
+ RH bres(n, distances, labels, k);
419
+
432
420
  #pragma omp parallel
433
421
  {
434
422
  std::unique_ptr<DistanceComputer> qdis(
435
423
  storage_distance_computer(storage));
436
424
  HNSWStats search_stats;
437
425
  VisitedTable vt(ntotal);
426
+ RH::SingleResultHandler res(bres);
438
427
 
439
428
  #pragma omp for
440
429
  for (idx_t i = 0; i < n; i++) {
441
- idx_t* idxi = labels + i * k;
442
- float* simi = distances + i * k;
443
-
430
+ res.begin(i);
444
431
  qdis->set_query(x + i * d);
445
- maxheap_heapify(k, simi, idxi);
446
432
 
447
433
  hnsw.search_level_0(
448
434
  *qdis.get(),
449
- k,
450
- idxi,
451
- simi,
435
+ res,
452
436
  nprobe,
453
437
  nearest + i * nprobe,
454
438
  nearest_d + i * nprobe,
455
439
  search_type,
456
440
  search_stats,
457
- vt);
458
-
441
+ vt,
442
+ params);
443
+ res.end();
459
444
  vt.advance();
460
- maxheap_reorder(k, simi, idxi);
461
445
  }
462
446
  #pragma omp critical
463
447
  { hnsw_stats.combine(search_stats); }
464
448
  }
449
+ if (is_similarity_metric(this->metric_type)) {
450
+ // we need to revert the negated distances
451
+ #pragma omp parallel for
452
+ for (int64_t i = 0; i < k * n; i++) {
453
+ distances[i] = -distances[i];
454
+ }
455
+ }
465
456
  }
466
457
 
467
458
  void IndexHNSW::init_level_0_from_knngraph(
@@ -515,8 +506,8 @@ void IndexHNSW::init_level_0_from_entry_points(
515
506
  {
516
507
  VisitedTable vt(ntotal);
517
508
 
518
- DistanceComputer* dis = storage_distance_computer(storage);
519
- ScopeDeleter1<DistanceComputer> del(dis);
509
+ std::unique_ptr<DistanceComputer> dis(
510
+ storage_distance_computer(storage));
520
511
  std::vector<float> vec(storage->d);
521
512
 
522
513
  #pragma omp for schedule(dynamic)
@@ -551,8 +542,8 @@ void IndexHNSW::reorder_links() {
551
542
  std::vector<float> distances(M);
552
543
  std::vector<size_t> order(M);
553
544
  std::vector<storage_idx_t> tmp(M);
554
- DistanceComputer* dis = storage_distance_computer(storage);
555
- ScopeDeleter1<DistanceComputer> del(dis);
545
+ std::unique_ptr<DistanceComputer> dis(
546
+ storage_distance_computer(storage));
556
547
 
557
548
  #pragma omp for
558
549
  for (storage_idx_t i = 0; i < ntotal; i++) {
@@ -614,245 +605,16 @@ void IndexHNSW::link_singletons() {
614
605
  }
615
606
  }
616
607
 
617
- /**************************************************************
618
- * ReconstructFromNeighbors implementation
619
- **************************************************************/
620
-
621
- ReconstructFromNeighbors::ReconstructFromNeighbors(
622
- const IndexHNSW& index,
623
- size_t k,
624
- size_t nsq)
625
- : index(index), k(k), nsq(nsq) {
626
- M = index.hnsw.nb_neighbors(0);
627
- FAISS_ASSERT(k <= 256);
628
- code_size = k == 1 ? 0 : nsq;
629
- ntotal = 0;
630
- d = index.d;
631
- FAISS_ASSERT(d % nsq == 0);
632
- dsub = d / nsq;
633
- k_reorder = -1;
634
- }
635
-
636
- void ReconstructFromNeighbors::reconstruct(
637
- storage_idx_t i,
638
- float* x,
639
- float* tmp) const {
640
- const HNSW& hnsw = index.hnsw;
641
- size_t begin, end;
642
- hnsw.neighbor_range(i, 0, &begin, &end);
643
-
644
- if (k == 1 || nsq == 1) {
645
- const float* beta;
646
- if (k == 1) {
647
- beta = codebook.data();
648
- } else {
649
- int idx = codes[i];
650
- beta = codebook.data() + idx * (M + 1);
651
- }
652
-
653
- float w0 = beta[0]; // weight of image itself
654
- index.storage->reconstruct(i, tmp);
655
-
656
- for (int l = 0; l < d; l++)
657
- x[l] = w0 * tmp[l];
658
-
659
- for (size_t j = begin; j < end; j++) {
660
- storage_idx_t ji = hnsw.neighbors[j];
661
- if (ji < 0)
662
- ji = i;
663
- float w = beta[j - begin + 1];
664
- index.storage->reconstruct(ji, tmp);
665
- for (int l = 0; l < d; l++)
666
- x[l] += w * tmp[l];
667
- }
668
- } else if (nsq == 2) {
669
- int idx0 = codes[2 * i];
670
- int idx1 = codes[2 * i + 1];
671
-
672
- const float* beta0 = codebook.data() + idx0 * (M + 1);
673
- const float* beta1 = codebook.data() + (idx1 + k) * (M + 1);
674
-
675
- index.storage->reconstruct(i, tmp);
676
-
677
- float w0;
678
-
679
- w0 = beta0[0];
680
- for (int l = 0; l < dsub; l++)
681
- x[l] = w0 * tmp[l];
682
-
683
- w0 = beta1[0];
684
- for (int l = dsub; l < d; l++)
685
- x[l] = w0 * tmp[l];
686
-
687
- for (size_t j = begin; j < end; j++) {
688
- storage_idx_t ji = hnsw.neighbors[j];
689
- if (ji < 0)
690
- ji = i;
691
- index.storage->reconstruct(ji, tmp);
692
- float w;
693
- w = beta0[j - begin + 1];
694
- for (int l = 0; l < dsub; l++)
695
- x[l] += w * tmp[l];
696
-
697
- w = beta1[j - begin + 1];
698
- for (int l = dsub; l < d; l++)
699
- x[l] += w * tmp[l];
700
- }
701
- } else {
702
- std::vector<const float*> betas(nsq);
703
- {
704
- const float* b = codebook.data();
705
- const uint8_t* c = &codes[i * code_size];
706
- for (int sq = 0; sq < nsq; sq++) {
707
- betas[sq] = b + (*c++) * (M + 1);
708
- b += (M + 1) * k;
709
- }
710
- }
711
-
712
- index.storage->reconstruct(i, tmp);
713
- {
714
- int d0 = 0;
715
- for (int sq = 0; sq < nsq; sq++) {
716
- float w = *(betas[sq]++);
717
- int d1 = d0 + dsub;
718
- for (int l = d0; l < d1; l++) {
719
- x[l] = w * tmp[l];
720
- }
721
- d0 = d1;
722
- }
723
- }
724
-
725
- for (size_t j = begin; j < end; j++) {
726
- storage_idx_t ji = hnsw.neighbors[j];
727
- if (ji < 0)
728
- ji = i;
729
-
730
- index.storage->reconstruct(ji, tmp);
731
- int d0 = 0;
732
- for (int sq = 0; sq < nsq; sq++) {
733
- float w = *(betas[sq]++);
734
- int d1 = d0 + dsub;
735
- for (int l = d0; l < d1; l++) {
736
- x[l] += w * tmp[l];
737
- }
738
- d0 = d1;
739
- }
740
- }
741
- }
742
- }
743
-
744
- void ReconstructFromNeighbors::reconstruct_n(
745
- storage_idx_t n0,
746
- storage_idx_t ni,
747
- float* x) const {
748
- #pragma omp parallel
749
- {
750
- std::vector<float> tmp(index.d);
751
- #pragma omp for
752
- for (storage_idx_t i = 0; i < ni; i++) {
753
- reconstruct(n0 + i, x + i * index.d, tmp.data());
754
- }
755
- }
756
- }
757
-
758
- size_t ReconstructFromNeighbors::compute_distances(
759
- size_t n,
760
- const idx_t* shortlist,
761
- const float* query,
762
- float* distances) const {
763
- std::vector<float> tmp(2 * index.d);
764
- size_t ncomp = 0;
765
- for (int i = 0; i < n; i++) {
766
- if (shortlist[i] < 0)
767
- break;
768
- reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d);
769
- distances[i] = fvec_L2sqr(query, tmp.data(), index.d);
770
- ncomp++;
771
- }
772
- return ncomp;
773
- }
774
-
775
- void ReconstructFromNeighbors::get_neighbor_table(storage_idx_t i, float* tmp1)
776
- const {
777
- const HNSW& hnsw = index.hnsw;
778
- size_t begin, end;
779
- hnsw.neighbor_range(i, 0, &begin, &end);
780
- size_t d = index.d;
781
-
782
- index.storage->reconstruct(i, tmp1);
783
-
784
- for (size_t j = begin; j < end; j++) {
785
- storage_idx_t ji = hnsw.neighbors[j];
786
- if (ji < 0)
787
- ji = i;
788
- index.storage->reconstruct(ji, tmp1 + (j - begin + 1) * d);
789
- }
790
- }
791
-
792
- /// called by add_codes
793
- void ReconstructFromNeighbors::estimate_code(
794
- const float* x,
795
- storage_idx_t i,
796
- uint8_t* code) const {
797
- // fill in tmp table with the neighbor values
798
- float* tmp1 = new float[d * (M + 1) + (d * k)];
799
- float* tmp2 = tmp1 + d * (M + 1);
800
- ScopeDeleter<float> del(tmp1);
801
-
802
- // collect coordinates of base
803
- get_neighbor_table(i, tmp1);
804
-
805
- for (size_t sq = 0; sq < nsq; sq++) {
806
- int d0 = sq * dsub;
807
-
808
- {
809
- FINTEGER ki = k, di = d, m1 = M + 1;
810
- FINTEGER dsubi = dsub;
811
- float zero = 0, one = 1;
812
-
813
- sgemm_("N",
814
- "N",
815
- &dsubi,
816
- &ki,
817
- &m1,
818
- &one,
819
- tmp1 + d0,
820
- &di,
821
- codebook.data() + sq * (m1 * k),
822
- &m1,
823
- &zero,
824
- tmp2,
825
- &dsubi);
826
- }
827
-
828
- float min = HUGE_VAL;
829
- int argmin = -1;
830
- for (size_t j = 0; j < k; j++) {
831
- float dis = fvec_L2sqr(x + d0, tmp2 + j * dsub, dsub);
832
- if (dis < min) {
833
- min = dis;
834
- argmin = j;
835
- }
836
- }
837
- code[sq] = argmin;
838
- }
608
+ void IndexHNSW::permute_entries(const idx_t* perm) {
609
+ auto flat_storage = dynamic_cast<IndexFlatCodes*>(storage);
610
+ FAISS_THROW_IF_NOT_MSG(
611
+ flat_storage, "don't know how to permute this index");
612
+ flat_storage->permute_entries(perm);
613
+ hnsw.permute_entries(perm);
839
614
  }
840
615
 
841
- void ReconstructFromNeighbors::add_codes(size_t n, const float* x) {
842
- if (k == 1) { // nothing to encode
843
- ntotal += n;
844
- return;
845
- }
846
- codes.resize(codes.size() + code_size * n);
847
- #pragma omp parallel for
848
- for (int i = 0; i < n; i++) {
849
- estimate_code(
850
- x + i * index.d,
851
- ntotal + i,
852
- codes.data() + (ntotal + i) * code_size);
853
- }
854
- ntotal += n;
855
- FAISS_ASSERT(codes.size() == ntotal * code_size);
616
+ DistanceComputer* IndexHNSW::get_distance_computer() const {
617
+ return storage->get_distance_computer();
856
618
  }
857
619
 
858
620
  /**************************************************************
@@ -864,7 +626,10 @@ IndexHNSWFlat::IndexHNSWFlat() {
864
626
  }
865
627
 
866
628
  IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
867
- : IndexHNSW(new IndexFlat(d, metric), M) {
629
+ : IndexHNSW(
630
+ (metric == METRIC_L2) ? new IndexFlatL2(d)
631
+ : new IndexFlat(d, metric),
632
+ M) {
868
633
  own_fields = true;
869
634
  is_trained = true;
870
635
  }
@@ -873,10 +638,15 @@ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
873
638
  * IndexHNSWPQ implementation
874
639
  **************************************************************/
875
640
 
876
- IndexHNSWPQ::IndexHNSWPQ() {}
641
+ IndexHNSWPQ::IndexHNSWPQ() = default;
877
642
 
878
- IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M)
879
- : IndexHNSW(new IndexPQ(d, pq_m, 8), M) {
643
+ IndexHNSWPQ::IndexHNSWPQ(
644
+ int d,
645
+ int pq_m,
646
+ int M,
647
+ int pq_nbits,
648
+ MetricType metric)
649
+ : IndexHNSW(new IndexPQ(d, pq_m, pq_nbits, metric), M) {
880
650
  own_fields = true;
881
651
  is_trained = false;
882
652
  }
@@ -896,11 +666,11 @@ IndexHNSWSQ::IndexHNSWSQ(
896
666
  int M,
897
667
  MetricType metric)
898
668
  : IndexHNSW(new IndexScalarQuantizer(d, qtype, metric), M) {
899
- is_trained = false;
669
+ is_trained = this->storage->is_trained;
900
670
  own_fields = true;
901
671
  }
902
672
 
903
- IndexHNSWSQ::IndexHNSWSQ() {}
673
+ IndexHNSWSQ::IndexHNSWSQ() = default;
904
674
 
905
675
  /**************************************************************
906
676
  * IndexHNSW2Level implementation
@@ -916,7 +686,7 @@ IndexHNSW2Level::IndexHNSW2Level(
916
686
  is_trained = false;
917
687
  }
918
688
 
919
- IndexHNSW2Level::IndexHNSW2Level() {}
689
+ IndexHNSW2Level::IndexHNSW2Level() = default;
920
690
 
921
691
  namespace {
922
692
 
@@ -935,7 +705,6 @@ int search_from_candidates_2(
935
705
  int level,
936
706
  int nres_in = 0) {
937
707
  int nres = nres_in;
938
- int ndis = 0;
939
708
  for (int i = 0; i < candidates.size(); i++) {
940
709
  idx_t v1 = candidates.ids[i];
941
710
  FAISS_ASSERT(v1 >= 0);
@@ -958,7 +727,6 @@ int search_from_candidates_2(
958
727
  if (vt.visited[v1] == vt.visno + 1) {
959
728
  // nothing to do
960
729
  } else {
961
- ndis++;
962
730
  float d = qdis(v1);
963
731
  candidates.push(v1, d);
964
732
 
@@ -1004,7 +772,7 @@ void IndexHNSW2Level::search(
1004
772
  IndexHNSW::search(n, x, k, distances, labels);
1005
773
 
1006
774
  } else { // "mixed" search
1007
- size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
775
+ size_t n1 = 0, n2 = 0, ndis = 0, nhops = 0;
1008
776
 
1009
777
  const IndexIVFPQ* index_ivfpq =
1010
778
  dynamic_cast<const IndexIVFPQ*>(storage);
@@ -1030,13 +798,13 @@ void IndexHNSW2Level::search(
1030
798
  #pragma omp parallel
1031
799
  {
1032
800
  VisitedTable vt(ntotal);
1033
- DistanceComputer* dis = storage_distance_computer(storage);
1034
- ScopeDeleter1<DistanceComputer> del(dis);
801
+ std::unique_ptr<DistanceComputer> dis(
802
+ storage_distance_computer(storage));
1035
803
 
1036
- int candidates_size = hnsw.upper_beam;
804
+ constexpr int candidates_size = 1;
1037
805
  MinimaxHeap candidates(candidates_size);
1038
806
 
1039
- #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
807
+ #pragma omp for reduction(+ : n1, n2, ndis, nhops)
1040
808
  for (idx_t i = 0; i < n; i++) {
1041
809
  idx_t* idxi = labels + i * k;
1042
810
  float* simi = distances + i * k;
@@ -1058,7 +826,7 @@ void IndexHNSW2Level::search(
1058
826
 
1059
827
  candidates.clear();
1060
828
 
1061
- for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
829
+ for (int j = 0; j < k; j++) {
1062
830
  if (idxi[j] < 0)
1063
831
  break;
1064
832
  candidates.push(idxi[j], simi[j]);
@@ -1081,9 +849,8 @@ void IndexHNSW2Level::search(
1081
849
  k);
1082
850
  n1 += search_stats.n1;
1083
851
  n2 += search_stats.n2;
1084
- n3 += search_stats.n3;
1085
852
  ndis += search_stats.ndis;
1086
- nreorder += search_stats.nreorder;
853
+ nhops += search_stats.nhops;
1087
854
 
1088
855
  vt.advance();
1089
856
  vt.advance();
@@ -1092,7 +859,7 @@ void IndexHNSW2Level::search(
1092
859
  }
1093
860
  }
1094
861
 
1095
- hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
862
+ hnsw_stats.combine({n1, n2, ndis, nhops});
1096
863
  }
1097
864
  }
1098
865
 
@@ -1118,4 +885,86 @@ void IndexHNSW2Level::flip_to_ivf() {
1118
885
  delete storage2l;
1119
886
  }
1120
887
 
888
+ /**************************************************************
889
+ * IndexHNSWCagra implementation
890
+ **************************************************************/
891
+
892
+ IndexHNSWCagra::IndexHNSWCagra() {
893
+ is_trained = true;
894
+ }
895
+
896
+ IndexHNSWCagra::IndexHNSWCagra(int d, int M, MetricType metric)
897
+ : IndexHNSW(
898
+ (metric == METRIC_L2)
899
+ ? static_cast<IndexFlat*>(new IndexFlatL2(d))
900
+ : static_cast<IndexFlat*>(new IndexFlatIP(d)),
901
+ M) {
902
+ FAISS_THROW_IF_NOT_MSG(
903
+ ((metric == METRIC_L2) || (metric == METRIC_INNER_PRODUCT)),
904
+ "unsupported metric type for IndexHNSWCagra");
905
+ own_fields = true;
906
+ is_trained = true;
907
+ init_level0 = true;
908
+ keep_max_size_level0 = true;
909
+ }
910
+
911
+ void IndexHNSWCagra::add(idx_t n, const float* x) {
912
+ FAISS_THROW_IF_NOT_MSG(
913
+ !base_level_only,
914
+ "Cannot add vectors when base_level_only is set to True");
915
+
916
+ IndexHNSW::add(n, x);
917
+ }
918
+
919
+ void IndexHNSWCagra::search(
920
+ idx_t n,
921
+ const float* x,
922
+ idx_t k,
923
+ float* distances,
924
+ idx_t* labels,
925
+ const SearchParameters* params) const {
926
+ if (!base_level_only) {
927
+ IndexHNSW::search(n, x, k, distances, labels, params);
928
+ } else {
929
+ std::vector<storage_idx_t> nearest(n);
930
+ std::vector<float> nearest_d(n);
931
+
932
+ #pragma omp for
933
+ for (idx_t i = 0; i < n; i++) {
934
+ std::unique_ptr<DistanceComputer> dis(
935
+ storage_distance_computer(this->storage));
936
+ dis->set_query(x + i * d);
937
+ nearest[i] = -1;
938
+ nearest_d[i] = std::numeric_limits<float>::max();
939
+
940
+ std::random_device rd;
941
+ std::mt19937 gen(rd());
942
+ std::uniform_int_distribution<idx_t> distrib(0, this->ntotal - 1);
943
+
944
+ for (idx_t j = 0; j < num_base_level_search_entrypoints; j++) {
945
+ auto idx = distrib(gen);
946
+ auto distance = (*dis)(idx);
947
+ if (distance < nearest_d[i]) {
948
+ nearest[i] = idx;
949
+ nearest_d[i] = distance;
950
+ }
951
+ }
952
+ FAISS_THROW_IF_NOT_MSG(
953
+ nearest[i] >= 0, "Could not find a valid entrypoint.");
954
+ }
955
+
956
+ search_level_0(
957
+ n,
958
+ x,
959
+ k,
960
+ nearest.data(),
961
+ nearest_d.data(),
962
+ distances,
963
+ labels,
964
+ 1, // n_probes
965
+ 1, // search_type
966
+ params);
967
+ }
968
+ }
969
+
1121
970
  } // namespace faiss