faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -14,33 +14,31 @@
14
14
 
15
15
  #include <faiss/Index.h>
16
16
 
17
-
18
17
  namespace faiss {
19
18
 
20
19
  /** Index that stores the full vectors and performs exhaustive search */
21
- struct IndexFlat: Index {
22
-
20
+ struct IndexFlat : Index {
23
21
  /// database vectors, size ntotal * d
24
22
  std::vector<float> xb;
25
23
 
26
- explicit IndexFlat (idx_t d, MetricType metric = METRIC_L2);
24
+ explicit IndexFlat(idx_t d, MetricType metric = METRIC_L2);
27
25
 
28
26
  void add(idx_t n, const float* x) override;
29
27
 
30
28
  void reset() override;
31
29
 
32
30
  void search(
33
- idx_t n,
34
- const float* x,
35
- idx_t k,
36
- float* distances,
37
- idx_t* labels) const override;
31
+ idx_t n,
32
+ const float* x,
33
+ idx_t k,
34
+ float* distances,
35
+ idx_t* labels) const override;
38
36
 
39
37
  void range_search(
40
- idx_t n,
41
- const float* x,
42
- float radius,
43
- RangeSearchResult* result) const override;
38
+ idx_t n,
39
+ const float* x,
40
+ float radius,
41
+ RangeSearchResult* result) const override;
44
42
 
45
43
  void reconstruct(idx_t key, float* recons) const override;
46
44
 
@@ -52,59 +50,51 @@ struct IndexFlat: Index {
52
50
  * @param distances
53
51
  * corresponding output distances, size n * k
54
52
  */
55
- void compute_distance_subset (
53
+ void compute_distance_subset(
56
54
  idx_t n,
57
- const float *x,
55
+ const float* x,
58
56
  idx_t k,
59
- float *distances,
60
- const idx_t *labels) const;
57
+ float* distances,
58
+ const idx_t* labels) const;
61
59
 
62
60
  /** remove some ids. NB that Because of the structure of the
63
61
  * indexing structure, the semantics of this operation are
64
62
  * different from the usual ones: the new ids are shifted */
65
63
  size_t remove_ids(const IDSelector& sel) override;
66
64
 
67
- IndexFlat () {}
65
+ IndexFlat() {}
68
66
 
69
- DistanceComputer * get_distance_computer() const override;
67
+ DistanceComputer* get_distance_computer() const override;
70
68
 
71
69
  /* The stanadlone codec interface (just memcopies in this case) */
72
- size_t sa_code_size () const override;
73
-
74
- void sa_encode (idx_t n, const float *x,
75
- uint8_t *bytes) const override;
70
+ size_t sa_code_size() const override;
76
71
 
77
- void sa_decode (idx_t n, const uint8_t *bytes,
78
- float *x) const override;
72
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
79
73
 
74
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
80
75
  };
81
76
 
82
-
83
-
84
- struct IndexFlatIP:IndexFlat {
85
- explicit IndexFlatIP (idx_t d): IndexFlat (d, METRIC_INNER_PRODUCT) {}
86
- IndexFlatIP () {}
77
+ struct IndexFlatIP : IndexFlat {
78
+ explicit IndexFlatIP(idx_t d) : IndexFlat(d, METRIC_INNER_PRODUCT) {}
79
+ IndexFlatIP() {}
87
80
  };
88
81
 
89
-
90
- struct IndexFlatL2:IndexFlat {
91
- explicit IndexFlatL2 (idx_t d): IndexFlat (d, METRIC_L2) {}
92
- IndexFlatL2 () {}
82
+ struct IndexFlatL2 : IndexFlat {
83
+ explicit IndexFlatL2(idx_t d) : IndexFlat(d, METRIC_L2) {}
84
+ IndexFlatL2() {}
93
85
  };
94
86
 
95
-
96
-
97
87
  /// optimized version for 1D "vectors".
98
- struct IndexFlat1D:IndexFlatL2 {
88
+ struct IndexFlat1D : IndexFlatL2 {
99
89
  bool continuous_update; ///< is the permutation updated continuously?
100
90
 
101
91
  std::vector<idx_t> perm; ///< sorted database indices
102
92
 
103
- explicit IndexFlat1D (bool continuous_update=true);
93
+ explicit IndexFlat1D(bool continuous_update = true);
104
94
 
105
95
  /// if not continuous_update, call this between the last add and
106
96
  /// the first search
107
- void update_permutation ();
97
+ void update_permutation();
108
98
 
109
99
  void add(idx_t n, const float* x) override;
110
100
 
@@ -112,14 +102,13 @@ struct IndexFlat1D:IndexFlatL2 {
112
102
 
113
103
  /// Warn: the distances returned are L1 not L2
114
104
  void search(
115
- idx_t n,
116
- const float* x,
117
- idx_t k,
118
- float* distances,
119
- idx_t* labels) const override;
105
+ idx_t n,
106
+ const float* x,
107
+ idx_t k,
108
+ float* distances,
109
+ idx_t* labels) const override;
120
110
  };
121
111
 
122
-
123
- }
112
+ } // namespace faiss
124
113
 
125
114
  #endif
@@ -9,44 +9,51 @@
9
9
 
10
10
  #include <faiss/IndexHNSW.h>
11
11
 
12
-
13
- #include <cstdlib>
12
+ #include <omp.h>
14
13
  #include <cassert>
15
- #include <cstring>
16
- #include <cstdio>
17
14
  #include <cinttypes>
18
15
  #include <cmath>
19
- #include <omp.h>
16
+ #include <cstdio>
17
+ #include <cstdlib>
18
+ #include <cstring>
20
19
 
21
- #include <unordered_set>
22
20
  #include <queue>
21
+ #include <unordered_set>
23
22
 
24
- #include <sys/types.h>
25
- #include <sys/stat.h>
26
23
  #include <stdint.h>
24
+ #include <sys/stat.h>
25
+ #include <sys/types.h>
27
26
 
28
27
  #ifdef __SSE__
29
28
  #endif
30
29
 
31
- #include <faiss/utils/distances.h>
32
- #include <faiss/utils/random.h>
33
- #include <faiss/utils/Heap.h>
34
- #include <faiss/impl/FaissAssert.h>
30
+ #include <faiss/Index2Layer.h>
35
31
  #include <faiss/IndexFlat.h>
36
32
  #include <faiss/IndexIVFPQ.h>
37
- #include <faiss/Index2Layer.h>
38
33
  #include <faiss/impl/AuxIndexStructures.h>
39
-
34
+ #include <faiss/impl/FaissAssert.h>
35
+ #include <faiss/utils/Heap.h>
36
+ #include <faiss/utils/distances.h>
37
+ #include <faiss/utils/random.h>
40
38
 
41
39
  extern "C" {
42
40
 
43
41
  /* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
44
42
 
45
- int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
46
- n, FINTEGER *k, const float *alpha, const float *a,
47
- FINTEGER *lda, const float *b, FINTEGER *
48
- ldb, float *beta, float *c, FINTEGER *ldc);
49
-
43
+ int sgemm_(
44
+ const char* transa,
45
+ const char* transb,
46
+ FINTEGER* m,
47
+ FINTEGER* n,
48
+ FINTEGER* k,
49
+ const float* alpha,
50
+ const float* a,
51
+ FINTEGER* lda,
52
+ const float* b,
53
+ FINTEGER* ldb,
54
+ float* beta,
55
+ float* c,
56
+ FINTEGER* ldc);
50
57
  }
51
58
 
52
59
  namespace faiss {
@@ -64,42 +71,36 @@ HNSWStats hnsw_stats;
64
71
 
65
72
  namespace {
66
73
 
67
-
68
74
  /* Wrap the distance computer into one that negates the
69
75
  distances. This makes supporting INNER_PRODUCE search easier */
70
76
 
71
- struct NegativeDistanceComputer: DistanceComputer {
72
-
77
+ struct NegativeDistanceComputer : DistanceComputer {
73
78
  /// owned by this
74
- DistanceComputer *basedis;
79
+ DistanceComputer* basedis;
75
80
 
76
- explicit NegativeDistanceComputer(DistanceComputer *basedis):
77
- basedis(basedis)
78
- {}
81
+ explicit NegativeDistanceComputer(DistanceComputer* basedis)
82
+ : basedis(basedis) {}
79
83
 
80
- void set_query(const float *x) override {
84
+ void set_query(const float* x) override {
81
85
  basedis->set_query(x);
82
86
  }
83
87
 
84
- /// compute distance of vector i to current query
85
- float operator () (idx_t i) override {
88
+ /// compute distance of vector i to current query
89
+ float operator()(idx_t i) override {
86
90
  return -(*basedis)(i);
87
91
  }
88
92
 
89
- /// compute distance between two stored vectors
90
- float symmetric_dis (idx_t i, idx_t j) override {
93
+ /// compute distance between two stored vectors
94
+ float symmetric_dis(idx_t i, idx_t j) override {
91
95
  return -basedis->symmetric_dis(i, j);
92
96
  }
93
97
 
94
- virtual ~NegativeDistanceComputer ()
95
- {
98
+ virtual ~NegativeDistanceComputer() {
96
99
  delete basedis;
97
100
  }
98
-
99
101
  };
100
102
 
101
- DistanceComputer *storage_distance_computer(const Index *storage)
102
- {
103
+ DistanceComputer* storage_distance_computer(const Index* storage) {
103
104
  if (storage->metric_type == METRIC_INNER_PRODUCT) {
104
105
  return new NegativeDistanceComputer(storage->get_distance_computer());
105
106
  } else {
@@ -107,21 +108,23 @@ DistanceComputer *storage_distance_computer(const Index *storage)
107
108
  }
108
109
  }
109
110
 
110
-
111
-
112
- void hnsw_add_vertices(IndexHNSW &index_hnsw,
113
- size_t n0,
114
- size_t n, const float *x,
115
- bool verbose,
116
- bool preset_levels = false) {
111
+ void hnsw_add_vertices(
112
+ IndexHNSW& index_hnsw,
113
+ size_t n0,
114
+ size_t n,
115
+ const float* x,
116
+ bool verbose,
117
+ bool preset_levels = false) {
117
118
  size_t d = index_hnsw.d;
118
- HNSW & hnsw = index_hnsw.hnsw;
119
+ HNSW& hnsw = index_hnsw.hnsw;
119
120
  size_t ntotal = n0 + n;
120
121
  double t0 = getmillisecs();
121
122
  if (verbose) {
122
123
  printf("hnsw_add_vertices: adding %zd elements on top of %zd "
123
124
  "(preset_levels=%d)\n",
124
- n, n0, int(preset_levels));
125
+ n,
126
+ n0,
127
+ int(preset_levels));
125
128
  }
126
129
 
127
130
  if (n == 0) {
@@ -135,7 +138,7 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
135
138
  }
136
139
 
137
140
  std::vector<omp_lock_t> locks(ntotal);
138
- for(int i = 0; i < ntotal; i++)
141
+ for (int i = 0; i < ntotal; i++)
139
142
  omp_init_lock(&locks[i]);
140
143
 
141
144
  // add vectors from highest to lowest level
@@ -150,7 +153,7 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
150
153
  int pt_level = hnsw.levels[pt_id] - 1;
151
154
  while (pt_level >= hist.size())
152
155
  hist.push_back(0);
153
- hist[pt_level] ++;
156
+ hist[pt_level]++;
154
157
  }
155
158
 
156
159
  // accumulate
@@ -167,8 +170,8 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
167
170
  }
168
171
  }
169
172
 
170
- idx_t check_period = InterruptCallback::get_period_hint
171
- (max_level * index_hnsw.d * hnsw.efConstruction);
173
+ idx_t check_period = InterruptCallback::get_period_hint(
174
+ max_level * index_hnsw.d * hnsw.efConstruction);
172
175
 
173
176
  { // perform add
174
177
  RandomGenerator rng2(789);
@@ -179,8 +182,7 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
179
182
  int i0 = i1 - hist[pt_level];
180
183
 
181
184
  if (verbose) {
182
- printf("Adding %d elements at level %d\n",
183
- i1 - i0, pt_level);
185
+ printf("Adding %d elements at level %d\n", i1 - i0, pt_level);
184
186
  }
185
187
 
186
188
  // random permutation to get rid of dataset order bias
@@ -189,20 +191,21 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
189
191
 
190
192
  bool interrupt = false;
191
193
 
192
- #pragma omp parallel if(i1 > i0 + 100)
194
+ #pragma omp parallel if (i1 > i0 + 100)
193
195
  {
194
- VisitedTable vt (ntotal);
196
+ VisitedTable vt(ntotal);
195
197
 
196
- DistanceComputer *dis =
197
- storage_distance_computer (index_hnsw.storage);
198
+ DistanceComputer* dis =
199
+ storage_distance_computer(index_hnsw.storage);
198
200
  ScopeDeleter1<DistanceComputer> del(dis);
199
- int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1;
201
+ int prev_display =
202
+ verbose && omp_get_thread_num() == 0 ? 0 : -1;
200
203
  size_t counter = 0;
201
204
 
202
- #pragma omp for schedule(dynamic)
205
+ #pragma omp for schedule(dynamic)
203
206
  for (int i = i0; i < i1; i++) {
204
207
  storage_idx_t pt_id = order[i];
205
- dis->set_query (x + (pt_id - n0) * d);
208
+ dis->set_query(x + (pt_id - n0) * d);
206
209
 
207
210
  // cannot break
208
211
  if (interrupt) {
@@ -218,16 +221,15 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
218
221
  }
219
222
 
220
223
  if (counter % check_period == 0) {
221
- if (InterruptCallback::is_interrupted ()) {
224
+ if (InterruptCallback::is_interrupted()) {
222
225
  interrupt = true;
223
226
  }
224
227
  }
225
228
  counter++;
226
229
  }
227
-
228
230
  }
229
231
  if (interrupt) {
230
- FAISS_THROW_MSG ("computation interrupted");
232
+ FAISS_THROW_MSG("computation interrupted");
231
233
  }
232
234
  i1 = i0;
233
235
  }
@@ -237,36 +239,30 @@ void hnsw_add_vertices(IndexHNSW &index_hnsw,
237
239
  printf("Done in %.3f ms\n", getmillisecs() - t0);
238
240
  }
239
241
 
240
- for(int i = 0; i < ntotal; i++) {
242
+ for (int i = 0; i < ntotal; i++) {
241
243
  omp_destroy_lock(&locks[i]);
242
244
  }
243
245
  }
244
246
 
245
-
246
- } // namespace
247
-
248
-
249
-
247
+ } // namespace
250
248
 
251
249
  /**************************************************************
252
250
  * IndexHNSW implementation
253
251
  **************************************************************/
254
252
 
255
- IndexHNSW::IndexHNSW(int d, int M, MetricType metric):
256
- Index(d, metric),
257
- hnsw(M),
258
- own_fields(false),
259
- storage(nullptr),
260
- reconstruct_from_neighbors(nullptr)
261
- {}
262
-
263
- IndexHNSW::IndexHNSW(Index *storage, int M):
264
- Index(storage->d, storage->metric_type),
265
- hnsw(M),
266
- own_fields(false),
267
- storage(storage),
268
- reconstruct_from_neighbors(nullptr)
269
- {}
253
+ IndexHNSW::IndexHNSW(int d, int M, MetricType metric)
254
+ : Index(d, metric),
255
+ hnsw(M),
256
+ own_fields(false),
257
+ storage(nullptr),
258
+ reconstruct_from_neighbors(nullptr) {}
259
+
260
+ IndexHNSW::IndexHNSW(Index* storage, int M)
261
+ : Index(storage->d, storage->metric_type),
262
+ hnsw(M),
263
+ own_fields(false),
264
+ storage(storage),
265
+ reconstruct_from_neighbors(nullptr) {}
270
266
 
271
267
  IndexHNSW::~IndexHNSW() {
272
268
  if (own_fields) {
@@ -274,68 +270,75 @@ IndexHNSW::~IndexHNSW() {
274
270
  }
275
271
  }
276
272
 
277
- void IndexHNSW::train(idx_t n, const float* x)
278
- {
279
- FAISS_THROW_IF_NOT_MSG(storage,
280
- "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
273
+ void IndexHNSW::train(idx_t n, const float* x) {
274
+ FAISS_THROW_IF_NOT_MSG(
275
+ storage,
276
+ "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
281
277
  // hnsw structure does not require training
282
- storage->train (n, x);
278
+ storage->train(n, x);
283
279
  is_trained = true;
284
280
  }
285
281
 
286
- void IndexHNSW::search (idx_t n, const float *x, idx_t k,
287
- float *distances, idx_t *labels) const
282
+ void IndexHNSW::search(
283
+ idx_t n,
284
+ const float* x,
285
+ idx_t k,
286
+ float* distances,
287
+ idx_t* labels) const
288
288
 
289
289
  {
290
- FAISS_THROW_IF_NOT_MSG(storage,
291
- "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
290
+ FAISS_THROW_IF_NOT(k > 0);
291
+
292
+ FAISS_THROW_IF_NOT_MSG(
293
+ storage,
294
+ "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
292
295
  size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
293
296
 
294
- idx_t check_period = InterruptCallback::get_period_hint (
295
- hnsw.max_level * d * hnsw.efSearch);
297
+ idx_t check_period = InterruptCallback::get_period_hint(
298
+ hnsw.max_level * d * hnsw.efSearch);
296
299
 
297
300
  for (idx_t i0 = 0; i0 < n; i0 += check_period) {
298
301
  idx_t i1 = std::min(i0 + check_period, n);
299
302
 
300
303
  #pragma omp parallel
301
304
  {
302
- VisitedTable vt (ntotal);
305
+ VisitedTable vt(ntotal);
303
306
 
304
- DistanceComputer *dis = storage_distance_computer(storage);
307
+ DistanceComputer* dis = storage_distance_computer(storage);
305
308
  ScopeDeleter1<DistanceComputer> del(dis);
306
309
 
307
- #pragma omp for reduction (+ : n1, n2, n3, ndis, nreorder)
308
- for(idx_t i = i0; i < i1; i++) {
309
- idx_t * idxi = labels + i * k;
310
- float * simi = distances + i * k;
310
+ #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
311
+ for (idx_t i = i0; i < i1; i++) {
312
+ idx_t* idxi = labels + i * k;
313
+ float* simi = distances + i * k;
311
314
  dis->set_query(x + i * d);
312
315
 
313
- maxheap_heapify (k, simi, idxi);
316
+ maxheap_heapify(k, simi, idxi);
314
317
  HNSWStats stats = hnsw.search(*dis, k, idxi, simi, vt);
315
318
  n1 += stats.n1;
316
319
  n2 += stats.n2;
317
320
  n3 += stats.n3;
318
321
  ndis += stats.ndis;
319
322
  nreorder += stats.nreorder;
320
- maxheap_reorder (k, simi, idxi);
323
+ maxheap_reorder(k, simi, idxi);
321
324
 
322
325
  if (reconstruct_from_neighbors &&
323
326
  reconstruct_from_neighbors->k_reorder != 0) {
324
327
  int k_reorder = reconstruct_from_neighbors->k_reorder;
325
- if (k_reorder == -1 || k_reorder > k) k_reorder = k;
328
+ if (k_reorder == -1 || k_reorder > k)
329
+ k_reorder = k;
326
330
 
327
331
  nreorder += reconstruct_from_neighbors->compute_distances(
328
- k_reorder, idxi, x + i * d, simi);
332
+ k_reorder, idxi, x + i * d, simi);
329
333
 
330
334
  // sort top k_reorder
331
- maxheap_heapify (k_reorder, simi, idxi, simi, idxi, k_reorder);
332
- maxheap_reorder (k_reorder, simi, idxi);
335
+ maxheap_heapify(
336
+ k_reorder, simi, idxi, simi, idxi, k_reorder);
337
+ maxheap_reorder(k_reorder, simi, idxi);
333
338
  }
334
-
335
339
  }
336
-
337
340
  }
338
- InterruptCallback::check ();
341
+ InterruptCallback::check();
339
342
  }
340
343
 
341
344
  if (metric_type == METRIC_INNER_PRODUCT) {
@@ -348,42 +351,36 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
348
351
  hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
349
352
  }
350
353
 
351
-
352
- void IndexHNSW::add(idx_t n, const float *x)
353
- {
354
- FAISS_THROW_IF_NOT_MSG(storage,
355
- "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
354
+ void IndexHNSW::add(idx_t n, const float* x) {
355
+ FAISS_THROW_IF_NOT_MSG(
356
+ storage,
357
+ "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
356
358
  FAISS_THROW_IF_NOT(is_trained);
357
359
  int n0 = ntotal;
358
360
  storage->add(n, x);
359
361
  ntotal = storage->ntotal;
360
362
 
361
- hnsw_add_vertices (*this, n0, n, x, verbose,
362
- hnsw.levels.size() == ntotal);
363
+ hnsw_add_vertices(*this, n0, n, x, verbose, hnsw.levels.size() == ntotal);
363
364
  }
364
365
 
365
- void IndexHNSW::reset()
366
- {
366
+ void IndexHNSW::reset() {
367
367
  hnsw.reset();
368
368
  storage->reset();
369
369
  ntotal = 0;
370
370
  }
371
371
 
372
- void IndexHNSW::reconstruct (idx_t key, float* recons) const
373
- {
372
+ void IndexHNSW::reconstruct(idx_t key, float* recons) const {
374
373
  storage->reconstruct(key, recons);
375
374
  }
376
375
 
377
- void IndexHNSW::shrink_level_0_neighbors(int new_size)
378
- {
376
+ void IndexHNSW::shrink_level_0_neighbors(int new_size) {
379
377
  #pragma omp parallel
380
378
  {
381
- DistanceComputer *dis = storage_distance_computer(storage);
379
+ DistanceComputer* dis = storage_distance_computer(storage);
382
380
  ScopeDeleter1<DistanceComputer> del(dis);
383
381
 
384
382
  #pragma omp for
385
383
  for (idx_t i = 0; i < ntotal; i++) {
386
-
387
384
  size_t begin, end;
388
385
  hnsw.neighbor_range(i, 0, &begin, &end);
389
386
 
@@ -391,15 +388,16 @@ void IndexHNSW::shrink_level_0_neighbors(int new_size)
391
388
 
392
389
  for (size_t j = begin; j < end; j++) {
393
390
  int v1 = hnsw.neighbors[j];
394
- if (v1 < 0) break;
391
+ if (v1 < 0)
392
+ break;
395
393
  initial_list.emplace(dis->symmetric_dis(i, v1), v1);
396
394
 
397
395
  // initial_list.emplace(qdis(v1), v1);
398
396
  }
399
397
 
400
398
  std::vector<NodeDistFarther> shrunk_list;
401
- HNSW::shrink_neighbor_list(*dis, initial_list,
402
- shrunk_list, new_size);
399
+ HNSW::shrink_neighbor_list(
400
+ *dis, initial_list, shrunk_list, new_size);
403
401
 
404
402
  for (size_t j = begin; j < end; j++) {
405
403
  if (j - begin < shrunk_list.size())
@@ -409,44 +407,50 @@ void IndexHNSW::shrink_level_0_neighbors(int new_size)
409
407
  }
410
408
  }
411
409
  }
412
-
413
410
  }
414
411
 
415
412
  void IndexHNSW::search_level_0(
416
- idx_t n, const float *x, idx_t k,
417
- const storage_idx_t *nearest, const float *nearest_d,
418
- float *distances, idx_t *labels, int nprobe,
419
- int search_type) const
420
- {
413
+ idx_t n,
414
+ const float* x,
415
+ idx_t k,
416
+ const storage_idx_t* nearest,
417
+ const float* nearest_d,
418
+ float* distances,
419
+ idx_t* labels,
420
+ int nprobe,
421
+ int search_type) const {
422
+ FAISS_THROW_IF_NOT(k > 0);
423
+ FAISS_THROW_IF_NOT(nprobe > 0);
421
424
 
422
425
  storage_idx_t ntotal = hnsw.levels.size();
423
426
  size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
424
427
 
425
428
  #pragma omp parallel
426
429
  {
427
- DistanceComputer *qdis = storage_distance_computer(storage);
430
+ DistanceComputer* qdis = storage_distance_computer(storage);
428
431
  ScopeDeleter1<DistanceComputer> del(qdis);
429
432
 
430
- VisitedTable vt (ntotal);
433
+ VisitedTable vt(ntotal);
431
434
 
432
- #pragma omp for reduction (+ : n1, n2, n3, ndis, nreorder)
433
- for(idx_t i = 0; i < n; i++) {
434
- idx_t * idxi = labels + i * k;
435
- float * simi = distances + i * k;
435
+ #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
436
+ for (idx_t i = 0; i < n; i++) {
437
+ idx_t* idxi = labels + i * k;
438
+ float* simi = distances + i * k;
436
439
 
437
440
  qdis->set_query(x + i * d);
438
- maxheap_heapify (k, simi, idxi);
441
+ maxheap_heapify(k, simi, idxi);
439
442
 
440
443
  if (search_type == 1) {
441
-
442
444
  int nres = 0;
443
445
 
444
- for(int j = 0; j < nprobe; j++) {
446
+ for (int j = 0; j < nprobe; j++) {
445
447
  storage_idx_t cj = nearest[i * nprobe + j];
446
448
 
447
- if (cj < 0) break;
449
+ if (cj < 0)
450
+ break;
448
451
 
449
- if (vt.get(cj)) continue;
452
+ if (vt.get(cj))
453
+ continue;
450
454
 
451
455
  int candidates_size = std::max(hnsw.efSearch, int(k));
452
456
  MinimaxHeap candidates(candidates_size);
@@ -455,45 +459,46 @@ void IndexHNSW::search_level_0(
455
459
 
456
460
  HNSWStats search_stats;
457
461
  nres = hnsw.search_from_candidates(
458
- *qdis, k, idxi, simi,
459
- candidates, vt, search_stats, 0, nres
460
- );
462
+ *qdis,
463
+ k,
464
+ idxi,
465
+ simi,
466
+ candidates,
467
+ vt,
468
+ search_stats,
469
+ 0,
470
+ nres);
461
471
  n1 += search_stats.n1;
462
472
  n2 += search_stats.n2;
463
473
  n3 += search_stats.n3;
464
474
  ndis += search_stats.ndis;
465
475
  nreorder += search_stats.nreorder;
466
-
467
476
  }
468
477
  } else if (search_type == 2) {
469
-
470
478
  int candidates_size = std::max(hnsw.efSearch, int(k));
471
479
  candidates_size = std::max(candidates_size, nprobe);
472
480
 
473
481
  MinimaxHeap candidates(candidates_size);
474
- for(int j = 0; j < nprobe; j++) {
482
+ for (int j = 0; j < nprobe; j++) {
475
483
  storage_idx_t cj = nearest[i * nprobe + j];
476
484
 
477
- if (cj < 0) break;
485
+ if (cj < 0)
486
+ break;
478
487
  candidates.push(cj, nearest_d[i * nprobe + j]);
479
488
  }
480
489
 
481
490
  HNSWStats search_stats;
482
491
  hnsw.search_from_candidates(
483
- *qdis, k, idxi, simi,
484
- candidates, vt, search_stats, 0
485
- );
492
+ *qdis, k, idxi, simi, candidates, vt, search_stats, 0);
486
493
  n1 += search_stats.n1;
487
494
  n2 += search_stats.n2;
488
495
  n3 += search_stats.n3;
489
496
  ndis += search_stats.ndis;
490
497
  nreorder += search_stats.nreorder;
491
-
492
498
  }
493
499
  vt.advance();
494
500
 
495
- maxheap_reorder (k, simi, idxi);
496
-
501
+ maxheap_reorder(k, simi, idxi);
497
502
  }
498
503
  }
499
504
 
@@ -501,13 +506,14 @@ void IndexHNSW::search_level_0(
501
506
  }
502
507
 
503
508
  void IndexHNSW::init_level_0_from_knngraph(
504
- int k, const float *D, const idx_t *I)
505
- {
506
- int dest_size = hnsw.nb_neighbors (0);
509
+ int k,
510
+ const float* D,
511
+ const idx_t* I) {
512
+ int dest_size = hnsw.nb_neighbors(0);
507
513
 
508
514
  #pragma omp parallel for
509
515
  for (idx_t i = 0; i < ntotal; i++) {
510
- DistanceComputer *qdis = storage_distance_computer(storage);
516
+ DistanceComputer* qdis = storage_distance_computer(storage);
511
517
  std::vector<float> vec(d);
512
518
  storage->reconstruct(i, vec.data());
513
519
  qdis->set_query(vec.data());
@@ -516,8 +522,10 @@ void IndexHNSW::init_level_0_from_knngraph(
516
522
 
517
523
  for (size_t j = 0; j < k; j++) {
518
524
  int v1 = I[i * k + j];
519
- if (v1 == i) continue;
520
- if (v1 < 0) break;
525
+ if (v1 == i)
526
+ continue;
527
+ if (v1 < 0)
528
+ break;
521
529
  initial_list.emplace(D[i * k + j], v1);
522
530
  }
523
531
 
@@ -536,35 +544,31 @@ void IndexHNSW::init_level_0_from_knngraph(
536
544
  }
537
545
  }
538
546
 
539
-
540
-
541
547
  void IndexHNSW::init_level_0_from_entry_points(
542
- int n, const storage_idx_t *points,
543
- const storage_idx_t *nearests)
544
- {
545
-
548
+ int n,
549
+ const storage_idx_t* points,
550
+ const storage_idx_t* nearests) {
546
551
  std::vector<omp_lock_t> locks(ntotal);
547
- for(int i = 0; i < ntotal; i++)
552
+ for (int i = 0; i < ntotal; i++)
548
553
  omp_init_lock(&locks[i]);
549
554
 
550
555
  #pragma omp parallel
551
556
  {
552
- VisitedTable vt (ntotal);
557
+ VisitedTable vt(ntotal);
553
558
 
554
- DistanceComputer *dis = storage_distance_computer(storage);
559
+ DistanceComputer* dis = storage_distance_computer(storage);
555
560
  ScopeDeleter1<DistanceComputer> del(dis);
556
561
  std::vector<float> vec(storage->d);
557
562
 
558
- #pragma omp for schedule(dynamic)
563
+ #pragma omp for schedule(dynamic)
559
564
  for (int i = 0; i < n; i++) {
560
565
  storage_idx_t pt_id = points[i];
561
566
  storage_idx_t nearest = nearests[i];
562
- storage->reconstruct (pt_id, vec.data());
563
- dis->set_query (vec.data());
567
+ storage->reconstruct(pt_id, vec.data());
568
+ dis->set_query(vec.data());
564
569
 
565
- hnsw.add_links_starting_from(*dis, pt_id,
566
- nearest, (*dis)(nearest),
567
- 0, locks.data(), vt);
570
+ hnsw.add_links_starting_from(
571
+ *dis, pt_id, nearest, (*dis)(nearest), 0, locks.data(), vt);
568
572
 
569
573
  if (verbose && i % 10000 == 0) {
570
574
  printf(" %d / %d\r", i, n);
@@ -576,25 +580,23 @@ void IndexHNSW::init_level_0_from_entry_points(
576
580
  printf("\n");
577
581
  }
578
582
 
579
- for(int i = 0; i < ntotal; i++)
583
+ for (int i = 0; i < ntotal; i++)
580
584
  omp_destroy_lock(&locks[i]);
581
585
  }
582
586
 
583
- void IndexHNSW::reorder_links()
584
- {
587
+ void IndexHNSW::reorder_links() {
585
588
  int M = hnsw.nb_neighbors(0);
586
589
 
587
590
  #pragma omp parallel
588
591
  {
589
- std::vector<float> distances (M);
590
- std::vector<size_t> order (M);
591
- std::vector<storage_idx_t> tmp (M);
592
- DistanceComputer *dis = storage_distance_computer(storage);
592
+ std::vector<float> distances(M);
593
+ std::vector<size_t> order(M);
594
+ std::vector<storage_idx_t> tmp(M);
595
+ DistanceComputer* dis = storage_distance_computer(storage);
593
596
  ScopeDeleter1<DistanceComputer> del(dis);
594
597
 
595
598
  #pragma omp for
596
- for(storage_idx_t i = 0; i < ntotal; i++) {
597
-
599
+ for (storage_idx_t i = 0; i < ntotal; i++) {
598
600
  size_t begin, end;
599
601
  hnsw.neighbor_range(i, 0, &begin, &end);
600
602
 
@@ -605,21 +607,18 @@ void IndexHNSW::reorder_links()
605
607
  break;
606
608
  }
607
609
  distances[j - begin] = dis->symmetric_dis(i, nj);
608
- tmp [j - begin] = nj;
610
+ tmp[j - begin] = nj;
609
611
  }
610
612
 
611
- fvec_argsort (end - begin, distances.data(), order.data());
613
+ fvec_argsort(end - begin, distances.data(), order.data());
612
614
  for (size_t j = begin; j < end; j++) {
613
615
  hnsw.neighbors[j] = tmp[order[j - begin]];
614
616
  }
615
617
  }
616
-
617
618
  }
618
619
  }
619
620
 
620
-
621
- void IndexHNSW::link_singletons()
622
- {
621
+ void IndexHNSW::link_singletons() {
623
622
  printf("search for singletons\n");
624
623
 
625
624
  std::vector<bool> seen(ntotal);
@@ -629,7 +628,8 @@ void IndexHNSW::link_singletons()
629
628
  hnsw.neighbor_range(i, 0, &begin, &end);
630
629
  for (size_t j = begin; j < end; j++) {
631
630
  storage_idx_t ni = hnsw.neighbors[j];
632
- if (ni >= 0) seen[ni] = true;
631
+ if (ni >= 0)
632
+ seen[ni] = true;
633
633
  }
634
634
  }
635
635
 
@@ -645,27 +645,25 @@ void IndexHNSW::link_singletons()
645
645
  }
646
646
 
647
647
  printf(" Found %d / %" PRId64 " singletons (%d appear in a level above)\n",
648
- n_sing, ntotal, n_sing_l1);
648
+ n_sing,
649
+ ntotal,
650
+ n_sing_l1);
649
651
 
650
- std::vector<float>recons(singletons.size() * d);
652
+ std::vector<float> recons(singletons.size() * d);
651
653
  for (int i = 0; i < singletons.size(); i++) {
652
-
653
654
  FAISS_ASSERT(!"not implemented");
654
-
655
655
  }
656
-
657
-
658
656
  }
659
657
 
660
-
661
658
  /**************************************************************
662
659
  * ReconstructFromNeighbors implementation
663
660
  **************************************************************/
664
661
 
665
-
666
662
  ReconstructFromNeighbors::ReconstructFromNeighbors(
667
- const IndexHNSW & index, size_t k, size_t nsq):
668
- index(index), k(k), nsq(nsq) {
663
+ const IndexHNSW& index,
664
+ size_t k,
665
+ size_t nsq)
666
+ : index(index), k(k), nsq(nsq) {
669
667
  M = index.hnsw.nb_neighbors(0);
670
668
  FAISS_ASSERT(k <= 256);
671
669
  code_size = k == 1 ? 0 : nsq;
@@ -676,16 +674,16 @@ ReconstructFromNeighbors::ReconstructFromNeighbors(
676
674
  k_reorder = -1;
677
675
  }
678
676
 
679
- void ReconstructFromNeighbors::reconstruct(storage_idx_t i, float *x, float *tmp) const
680
- {
681
-
682
-
683
- const HNSW & hnsw = index.hnsw;
677
+ void ReconstructFromNeighbors::reconstruct(
678
+ storage_idx_t i,
679
+ float* x,
680
+ float* tmp) const {
681
+ const HNSW& hnsw = index.hnsw;
684
682
  size_t begin, end;
685
683
  hnsw.neighbor_range(i, 0, &begin, &end);
686
684
 
687
685
  if (k == 1 || nsq == 1) {
688
- const float * beta;
686
+ const float* beta;
689
687
  if (k == 1) {
690
688
  beta = codebook.data();
691
689
  } else {
@@ -700,9 +698,9 @@ void ReconstructFromNeighbors::reconstruct(storage_idx_t i, float *x, float *tmp
700
698
  x[l] = w0 * tmp[l];
701
699
 
702
700
  for (size_t j = begin; j < end; j++) {
703
-
704
701
  storage_idx_t ji = hnsw.neighbors[j];
705
- if (ji < 0) ji = i;
702
+ if (ji < 0)
703
+ ji = i;
706
704
  float w = beta[j - begin + 1];
707
705
  index.storage->reconstruct(ji, tmp);
708
706
  for (int l = 0; l < d; l++)
@@ -712,8 +710,8 @@ void ReconstructFromNeighbors::reconstruct(storage_idx_t i, float *x, float *tmp
712
710
  int idx0 = codes[2 * i];
713
711
  int idx1 = codes[2 * i + 1];
714
712
 
715
- const float *beta0 = codebook.data() + idx0 * (M + 1);
716
- const float *beta1 = codebook.data() + (idx1 + k) * (M + 1);
713
+ const float* beta0 = codebook.data() + idx0 * (M + 1);
714
+ const float* beta1 = codebook.data() + (idx1 + k) * (M + 1);
717
715
 
718
716
  index.storage->reconstruct(i, tmp);
719
717
 
@@ -729,7 +727,8 @@ void ReconstructFromNeighbors::reconstruct(storage_idx_t i, float *x, float *tmp
729
727
 
730
728
  for (size_t j = begin; j < end; j++) {
731
729
  storage_idx_t ji = hnsw.neighbors[j];
732
- if (ji < 0) ji = i;
730
+ if (ji < 0)
731
+ ji = i;
733
732
  index.storage->reconstruct(ji, tmp);
734
733
  float w;
735
734
  w = beta0[j - begin + 1];
@@ -741,10 +740,10 @@ void ReconstructFromNeighbors::reconstruct(storage_idx_t i, float *x, float *tmp
741
740
  x[l] += w * tmp[l];
742
741
  }
743
742
  } else {
744
- std::vector<const float *> betas(nsq);
743
+ std::vector<const float*> betas(nsq);
745
744
  {
746
- const float *b = codebook.data();
747
- const uint8_t *c = &codes[i * code_size];
745
+ const float* b = codebook.data();
746
+ const uint8_t* c = &codes[i * code_size];
748
747
  for (int sq = 0; sq < nsq; sq++) {
749
748
  betas[sq] = b + (*c++) * (M + 1);
750
749
  b += (M + 1) * k;
@@ -766,7 +765,8 @@ void ReconstructFromNeighbors::reconstruct(storage_idx_t i, float *x, float *tmp
766
765
 
767
766
  for (size_t j = begin; j < end; j++) {
768
767
  storage_idx_t ji = hnsw.neighbors[j];
769
- if (ji < 0) ji = i;
768
+ if (ji < 0)
769
+ ji = i;
770
770
 
771
771
  index.storage->reconstruct(ji, tmp);
772
772
  int d0 = 0;
@@ -782,10 +782,10 @@ void ReconstructFromNeighbors::reconstruct(storage_idx_t i, float *x, float *tmp
782
782
  }
783
783
  }
784
784
 
785
- void ReconstructFromNeighbors::reconstruct_n(storage_idx_t n0,
786
- storage_idx_t ni,
787
- float *x) const
788
- {
785
+ void ReconstructFromNeighbors::reconstruct_n(
786
+ storage_idx_t n0,
787
+ storage_idx_t ni,
788
+ float* x) const {
789
789
  #pragma omp parallel
790
790
  {
791
791
  std::vector<float> tmp(index.d);
@@ -797,13 +797,15 @@ void ReconstructFromNeighbors::reconstruct_n(storage_idx_t n0,
797
797
  }
798
798
 
799
799
  size_t ReconstructFromNeighbors::compute_distances(
800
- size_t n, const idx_t *shortlist,
801
- const float *query, float *distances) const
802
- {
800
+ size_t n,
801
+ const idx_t* shortlist,
802
+ const float* query,
803
+ float* distances) const {
803
804
  std::vector<float> tmp(2 * index.d);
804
805
  size_t ncomp = 0;
805
806
  for (int i = 0; i < n; i++) {
806
- if (shortlist[i] < 0) break;
807
+ if (shortlist[i] < 0)
808
+ break;
807
809
  reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d);
808
810
  distances[i] = fvec_L2sqr(query, tmp.data(), index.d);
809
811
  ncomp++;
@@ -811,9 +813,9 @@ size_t ReconstructFromNeighbors::compute_distances(
811
813
  return ncomp;
812
814
  }
813
815
 
814
- void ReconstructFromNeighbors::get_neighbor_table(storage_idx_t i, float *tmp1) const
815
- {
816
- const HNSW & hnsw = index.hnsw;
816
+ void ReconstructFromNeighbors::get_neighbor_table(storage_idx_t i, float* tmp1)
817
+ const {
818
+ const HNSW& hnsw = index.hnsw;
817
819
  size_t begin, end;
818
820
  hnsw.neighbor_range(i, 0, &begin, &end);
819
821
  size_t d = index.d;
@@ -822,25 +824,24 @@ void ReconstructFromNeighbors::get_neighbor_table(storage_idx_t i, float *tmp1)
822
824
 
823
825
  for (size_t j = begin; j < end; j++) {
824
826
  storage_idx_t ji = hnsw.neighbors[j];
825
- if (ji < 0) ji = i;
827
+ if (ji < 0)
828
+ ji = i;
826
829
  index.storage->reconstruct(ji, tmp1 + (j - begin + 1) * d);
827
830
  }
828
-
829
831
  }
830
832
 
831
-
832
833
  /// called by add_codes
833
834
  void ReconstructFromNeighbors::estimate_code(
834
- const float *x, storage_idx_t i, uint8_t *code) const
835
- {
836
-
835
+ const float* x,
836
+ storage_idx_t i,
837
+ uint8_t* code) const {
837
838
  // fill in tmp table with the neighbor values
838
- float *tmp1 = new float[d * (M + 1) + (d * k)];
839
- float *tmp2 = tmp1 + d * (M + 1);
839
+ float* tmp1 = new float[d * (M + 1) + (d * k)];
840
+ float* tmp2 = tmp1 + d * (M + 1);
840
841
  ScopeDeleter<float> del(tmp1);
841
842
 
842
843
  // collect coordinates of base
843
- get_neighbor_table (i, tmp1);
844
+ get_neighbor_table(i, tmp1);
844
845
 
845
846
  for (size_t sq = 0; sq < nsq; sq++) {
846
847
  int d0 = sq * dsub;
@@ -850,10 +851,19 @@ void ReconstructFromNeighbors::estimate_code(
850
851
  FINTEGER dsubi = dsub;
851
852
  float zero = 0, one = 1;
852
853
 
853
- sgemm_ ("N", "N", &dsubi, &ki, &m1, &one,
854
- tmp1 + d0, &di,
855
- codebook.data() + sq * (m1 * k), &m1,
856
- &zero, tmp2, &dsubi);
854
+ sgemm_("N",
855
+ "N",
856
+ &dsubi,
857
+ &ki,
858
+ &m1,
859
+ &one,
860
+ tmp1 + d0,
861
+ &di,
862
+ codebook.data() + sq * (m1 * k),
863
+ &m1,
864
+ &zero,
865
+ tmp2,
866
+ &dsubi);
857
867
  }
858
868
 
859
869
  float min = HUGE_VAL;
@@ -867,11 +877,9 @@ void ReconstructFromNeighbors::estimate_code(
867
877
  }
868
878
  code[sq] = argmin;
869
879
  }
870
-
871
880
  }
872
881
 
873
- void ReconstructFromNeighbors::add_codes(size_t n, const float *x)
874
- {
882
+ void ReconstructFromNeighbors::add_codes(size_t n, const float* x) {
875
883
  if (k == 1) { // nothing to encode
876
884
  ntotal += n;
877
885
  return;
@@ -879,98 +887,94 @@ void ReconstructFromNeighbors::add_codes(size_t n, const float *x)
879
887
  codes.resize(codes.size() + code_size * n);
880
888
  #pragma omp parallel for
881
889
  for (int i = 0; i < n; i++) {
882
- estimate_code(x + i * index.d, ntotal + i,
883
- codes.data() + (ntotal + i) * code_size);
890
+ estimate_code(
891
+ x + i * index.d,
892
+ ntotal + i,
893
+ codes.data() + (ntotal + i) * code_size);
884
894
  }
885
895
  ntotal += n;
886
- FAISS_ASSERT (codes.size() == ntotal * code_size);
896
+ FAISS_ASSERT(codes.size() == ntotal * code_size);
887
897
  }
888
898
 
889
-
890
899
  /**************************************************************
891
900
  * IndexHNSWFlat implementation
892
901
  **************************************************************/
893
902
 
894
-
895
- IndexHNSWFlat::IndexHNSWFlat()
896
- {
903
+ IndexHNSWFlat::IndexHNSWFlat() {
897
904
  is_trained = true;
898
905
  }
899
906
 
900
- IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric):
901
- IndexHNSW(new IndexFlat(d, metric), M)
902
- {
907
+ IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
908
+ : IndexHNSW(new IndexFlat(d, metric), M) {
903
909
  own_fields = true;
904
910
  is_trained = true;
905
911
  }
906
912
 
907
-
908
913
  /**************************************************************
909
914
  * IndexHNSWPQ implementation
910
915
  **************************************************************/
911
916
 
912
-
913
917
  IndexHNSWPQ::IndexHNSWPQ() {}
914
918
 
915
- IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M):
916
- IndexHNSW(new IndexPQ(d, pq_m, 8), M)
917
- {
919
+ IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M)
920
+ : IndexHNSW(new IndexPQ(d, pq_m, 8), M) {
918
921
  own_fields = true;
919
922
  is_trained = false;
920
923
  }
921
924
 
922
- void IndexHNSWPQ::train(idx_t n, const float* x)
923
- {
924
- IndexHNSW::train (n, x);
925
- (dynamic_cast<IndexPQ*> (storage))->pq.compute_sdc_table();
925
+ void IndexHNSWPQ::train(idx_t n, const float* x) {
926
+ IndexHNSW::train(n, x);
927
+ (dynamic_cast<IndexPQ*>(storage))->pq.compute_sdc_table();
926
928
  }
927
929
 
928
-
929
930
  /**************************************************************
930
931
  * IndexHNSWSQ implementation
931
932
  **************************************************************/
932
933
 
933
-
934
- IndexHNSWSQ::IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M,
935
- MetricType metric):
936
- IndexHNSW (new IndexScalarQuantizer (d, qtype, metric), M)
937
- {
934
+ IndexHNSWSQ::IndexHNSWSQ(
935
+ int d,
936
+ ScalarQuantizer::QuantizerType qtype,
937
+ int M,
938
+ MetricType metric)
939
+ : IndexHNSW(new IndexScalarQuantizer(d, qtype, metric), M) {
938
940
  is_trained = false;
939
941
  own_fields = true;
940
942
  }
941
943
 
942
944
  IndexHNSWSQ::IndexHNSWSQ() {}
943
945
 
944
-
945
946
  /**************************************************************
946
947
  * IndexHNSW2Level implementation
947
948
  **************************************************************/
948
949
 
949
-
950
- IndexHNSW2Level::IndexHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M):
951
- IndexHNSW (new Index2Layer (quantizer, nlist, m_pq), M)
952
- {
950
+ IndexHNSW2Level::IndexHNSW2Level(
951
+ Index* quantizer,
952
+ size_t nlist,
953
+ int m_pq,
954
+ int M)
955
+ : IndexHNSW(new Index2Layer(quantizer, nlist, m_pq), M) {
953
956
  own_fields = true;
954
957
  is_trained = false;
955
958
  }
956
959
 
957
960
  IndexHNSW2Level::IndexHNSW2Level() {}
958
961
 
959
-
960
962
  namespace {
961
963
 
962
-
963
964
  // same as search_from_candidates but uses v
964
965
  // visno -> is in result list
965
966
  // visno + 1 -> in result list + in candidates
966
- int search_from_candidates_2(const HNSW & hnsw,
967
- DistanceComputer & qdis, int k,
968
- idx_t *I, float * D,
969
- MinimaxHeap &candidates,
970
- VisitedTable &vt,
971
- HNSWStats &stats,
972
- int level, int nres_in = 0)
973
- {
967
+ int search_from_candidates_2(
968
+ const HNSW& hnsw,
969
+ DistanceComputer& qdis,
970
+ int k,
971
+ idx_t* I,
972
+ float* D,
973
+ MinimaxHeap& candidates,
974
+ VisitedTable& vt,
975
+ HNSWStats& stats,
976
+ int level,
977
+ int nres_in = 0) {
974
978
  int nres = nres_in;
975
979
  int ndis = 0;
976
980
  for (int i = 0; i < candidates.size(); i++) {
@@ -990,7 +994,8 @@ int search_from_candidates_2(const HNSW & hnsw,
990
994
 
991
995
  for (size_t j = begin; j < end; j++) {
992
996
  int v1 = hnsw.neighbors[j];
993
- if (v1 < 0) break;
997
+ if (v1 < 0)
998
+ break;
994
999
  if (vt.visited[v1] == vt.visno + 1) {
995
1000
  // nothing to do
996
1001
  } else {
@@ -1001,9 +1006,9 @@ int search_from_candidates_2(const HNSW & hnsw,
1001
1006
  // never seen before --> add to heap
1002
1007
  if (vt.visited[v1] < vt.visno) {
1003
1008
  if (nres < k) {
1004
- faiss::maxheap_push (++nres, D, I, d, v1);
1009
+ faiss::maxheap_push(++nres, D, I, d, v1);
1005
1010
  } else if (d < D[0]) {
1006
- faiss::maxheap_replace_top (nres, D, I, d, v1);
1011
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
1007
1012
  }
1008
1013
  }
1009
1014
  vt.visited[v1] = vt.visno + 1;
@@ -1016,65 +1021,76 @@ int search_from_candidates_2(const HNSW & hnsw,
1016
1021
  }
1017
1022
  }
1018
1023
 
1019
- stats.n1 ++;
1024
+ stats.n1++;
1020
1025
  if (candidates.size() == 0)
1021
- stats.n2 ++;
1026
+ stats.n2++;
1022
1027
 
1023
1028
  return nres;
1024
1029
  }
1025
1030
 
1031
+ } // namespace
1026
1032
 
1027
- } // namespace
1033
+ void IndexHNSW2Level::search(
1034
+ idx_t n,
1035
+ const float* x,
1036
+ idx_t k,
1037
+ float* distances,
1038
+ idx_t* labels) const {
1039
+ FAISS_THROW_IF_NOT(k > 0);
1028
1040
 
1029
- void IndexHNSW2Level::search (idx_t n, const float *x, idx_t k,
1030
- float *distances, idx_t *labels) const
1031
- {
1032
1041
  if (dynamic_cast<const Index2Layer*>(storage)) {
1033
- IndexHNSW::search (n, x, k, distances, labels);
1042
+ IndexHNSW::search(n, x, k, distances, labels);
1034
1043
 
1035
1044
  } else { // "mixed" search
1036
1045
  size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
1037
1046
 
1038
- const IndexIVFPQ *index_ivfpq =
1039
- dynamic_cast<const IndexIVFPQ*>(storage);
1047
+ const IndexIVFPQ* index_ivfpq =
1048
+ dynamic_cast<const IndexIVFPQ*>(storage);
1040
1049
 
1041
1050
  int nprobe = index_ivfpq->nprobe;
1042
1051
 
1043
1052
  std::unique_ptr<idx_t[]> coarse_assign(new idx_t[n * nprobe]);
1044
1053
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
1045
1054
 
1046
- index_ivfpq->quantizer->search (n, x, nprobe, coarse_dis.get(),
1047
- coarse_assign.get());
1055
+ index_ivfpq->quantizer->search(
1056
+ n, x, nprobe, coarse_dis.get(), coarse_assign.get());
1048
1057
 
1049
- index_ivfpq->search_preassigned (n, x, k, coarse_assign.get(),
1050
- coarse_dis.get(), distances, labels,
1051
- false);
1058
+ index_ivfpq->search_preassigned(
1059
+ n,
1060
+ x,
1061
+ k,
1062
+ coarse_assign.get(),
1063
+ coarse_dis.get(),
1064
+ distances,
1065
+ labels,
1066
+ false);
1052
1067
 
1053
1068
  #pragma omp parallel
1054
1069
  {
1055
- VisitedTable vt (ntotal);
1056
- DistanceComputer *dis = storage_distance_computer(storage);
1070
+ VisitedTable vt(ntotal);
1071
+ DistanceComputer* dis = storage_distance_computer(storage);
1057
1072
  ScopeDeleter1<DistanceComputer> del(dis);
1058
1073
 
1059
1074
  int candidates_size = hnsw.upper_beam;
1060
1075
  MinimaxHeap candidates(candidates_size);
1061
1076
 
1062
- #pragma omp for reduction (+ : n1, n2, n3, ndis, nreorder)
1063
- for(idx_t i = 0; i < n; i++) {
1064
- idx_t * idxi = labels + i * k;
1065
- float * simi = distances + i * k;
1077
+ #pragma omp for reduction(+ : n1, n2, n3, ndis, nreorder)
1078
+ for (idx_t i = 0; i < n; i++) {
1079
+ idx_t* idxi = labels + i * k;
1080
+ float* simi = distances + i * k;
1066
1081
  dis->set_query(x + i * d);
1067
1082
 
1068
1083
  // mark all inverted list elements as visited
1069
1084
 
1070
1085
  for (int j = 0; j < nprobe; j++) {
1071
1086
  idx_t key = coarse_assign[j + i * nprobe];
1072
- if (key < 0) break;
1073
- size_t list_length = index_ivfpq->get_list_size (key);
1074
- const idx_t * ids = index_ivfpq->invlists->get_ids (key);
1087
+ if (key < 0)
1088
+ break;
1089
+ size_t list_length = index_ivfpq->get_list_size(key);
1090
+ const idx_t* ids = index_ivfpq->invlists->get_ids(key);
1075
1091
 
1076
1092
  for (int jj = 0; jj < list_length; jj++) {
1077
- vt.set (ids[jj]);
1093
+ vt.set(ids[jj]);
1078
1094
  }
1079
1095
  }
1080
1096
 
@@ -1084,23 +1100,29 @@ void IndexHNSW2Level::search (idx_t n, const float *x, idx_t k,
1084
1100
  int search_policy = 2;
1085
1101
 
1086
1102
  if (search_policy == 1) {
1087
-
1088
- for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) {
1089
- if (idxi[j] < 0) break;
1090
- candidates.push (idxi[j], simi[j]);
1103
+ for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1104
+ if (idxi[j] < 0)
1105
+ break;
1106
+ candidates.push(idxi[j], simi[j]);
1091
1107
  // search_from_candidates adds them back
1092
1108
  idxi[j] = -1;
1093
1109
  simi[j] = HUGE_VAL;
1094
1110
  }
1095
1111
 
1096
1112
  // reorder from sorted to heap
1097
- maxheap_heapify (k, simi, idxi, simi, idxi, k);
1113
+ maxheap_heapify(k, simi, idxi, simi, idxi, k);
1098
1114
 
1099
1115
  HNSWStats search_stats;
1100
1116
  hnsw.search_from_candidates(
1101
- *dis, k, idxi, simi,
1102
- candidates, vt, search_stats, 0, k
1103
- );
1117
+ *dis,
1118
+ k,
1119
+ idxi,
1120
+ simi,
1121
+ candidates,
1122
+ vt,
1123
+ search_stats,
1124
+ 0,
1125
+ k);
1104
1126
  n1 += search_stats.n1;
1105
1127
  n2 += search_stats.n2;
1106
1128
  n3 += search_stats.n3;
@@ -1110,63 +1132,65 @@ void IndexHNSW2Level::search (idx_t n, const float *x, idx_t k,
1110
1132
  vt.advance();
1111
1133
 
1112
1134
  } else if (search_policy == 2) {
1113
-
1114
- for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) {
1115
- if (idxi[j] < 0) break;
1116
- candidates.push (idxi[j], simi[j]);
1135
+ for (int j = 0; j < hnsw.upper_beam && j < k; j++) {
1136
+ if (idxi[j] < 0)
1137
+ break;
1138
+ candidates.push(idxi[j], simi[j]);
1117
1139
  }
1118
1140
 
1119
1141
  // reorder from sorted to heap
1120
- maxheap_heapify (k, simi, idxi, simi, idxi, k);
1142
+ maxheap_heapify(k, simi, idxi, simi, idxi, k);
1121
1143
 
1122
1144
  HNSWStats search_stats;
1123
- search_from_candidates_2 (
1124
- hnsw, *dis, k, idxi, simi,
1125
- candidates, vt, search_stats, 0, k);
1145
+ search_from_candidates_2(
1146
+ hnsw,
1147
+ *dis,
1148
+ k,
1149
+ idxi,
1150
+ simi,
1151
+ candidates,
1152
+ vt,
1153
+ search_stats,
1154
+ 0,
1155
+ k);
1126
1156
  n1 += search_stats.n1;
1127
1157
  n2 += search_stats.n2;
1128
1158
  n3 += search_stats.n3;
1129
1159
  ndis += search_stats.ndis;
1130
1160
  nreorder += search_stats.nreorder;
1131
1161
 
1132
- vt.advance ();
1133
- vt.advance ();
1134
-
1162
+ vt.advance();
1163
+ vt.advance();
1135
1164
  }
1136
1165
 
1137
- maxheap_reorder (k, simi, idxi);
1166
+ maxheap_reorder(k, simi, idxi);
1138
1167
  }
1139
1168
  }
1140
1169
 
1141
1170
  hnsw_stats.combine({n1, n2, n3, ndis, nreorder});
1142
1171
  }
1143
-
1144
-
1145
1172
  }
1146
1173
 
1174
+ void IndexHNSW2Level::flip_to_ivf() {
1175
+ Index2Layer* storage2l = dynamic_cast<Index2Layer*>(storage);
1147
1176
 
1148
- void IndexHNSW2Level::flip_to_ivf ()
1149
- {
1150
- Index2Layer *storage2l =
1151
- dynamic_cast<Index2Layer*>(storage);
1177
+ FAISS_THROW_IF_NOT(storage2l);
1152
1178
 
1153
- FAISS_THROW_IF_NOT (storage2l);
1154
-
1155
- IndexIVFPQ * index_ivfpq =
1156
- new IndexIVFPQ (storage2l->q1.quantizer,
1157
- d, storage2l->q1.nlist,
1158
- storage2l->pq.M, 8);
1179
+ IndexIVFPQ* index_ivfpq = new IndexIVFPQ(
1180
+ storage2l->q1.quantizer,
1181
+ d,
1182
+ storage2l->q1.nlist,
1183
+ storage2l->pq.M,
1184
+ 8);
1159
1185
  index_ivfpq->pq = storage2l->pq;
1160
1186
  index_ivfpq->is_trained = storage2l->is_trained;
1161
1187
  index_ivfpq->precompute_table();
1162
1188
  index_ivfpq->own_fields = storage2l->q1.own_fields;
1163
1189
  storage2l->transfer_to_IVFPQ(*index_ivfpq);
1164
- index_ivfpq->make_direct_map (true);
1190
+ index_ivfpq->make_direct_map(true);
1165
1191
 
1166
1192
  storage = index_ivfpq;
1167
1193
  delete storage2l;
1168
-
1169
1194
  }
1170
1195
 
1171
-
1172
- } // namespace faiss
1196
+ } // namespace faiss