faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -11,13 +11,12 @@
11
11
 
12
12
  #include <vector>
13
13
 
14
- #include <faiss/impl/HNSW.h>
15
14
  #include <faiss/IndexFlat.h>
16
15
  #include <faiss/IndexPQ.h>
17
16
  #include <faiss/IndexScalarQuantizer.h>
17
+ #include <faiss/impl/HNSW.h>
18
18
  #include <faiss/utils/utils.h>
19
19
 
20
-
21
20
  namespace faiss {
22
21
 
23
22
  struct IndexHNSW;
@@ -26,9 +25,9 @@ struct ReconstructFromNeighbors {
26
25
  typedef Index::idx_t idx_t;
27
26
  typedef HNSW::storage_idx_t storage_idx_t;
28
27
 
29
- const IndexHNSW & index;
30
- size_t M; // number of neighbors
31
- size_t k; // number of codebook entries
28
+ const IndexHNSW& index;
29
+ size_t M; // number of neighbors
30
+ size_t k; // number of codebook entries
32
31
  size_t nsq; // number of subvectors
33
32
  size_t code_size;
34
33
  int k_reorder; // nb to reorder. -1 = all
@@ -39,35 +38,37 @@ struct ReconstructFromNeighbors {
39
38
  size_t ntotal;
40
39
  size_t d, dsub; // derived values
41
40
 
42
- explicit ReconstructFromNeighbors(const IndexHNSW& index,
43
- size_t k=256, size_t nsq=1);
41
+ explicit ReconstructFromNeighbors(
42
+ const IndexHNSW& index,
43
+ size_t k = 256,
44
+ size_t nsq = 1);
44
45
 
45
46
  /// codes must be added in the correct order and the IndexHNSW
46
47
  /// must be populated and sorted
47
- void add_codes(size_t n, const float *x);
48
+ void add_codes(size_t n, const float* x);
48
49
 
49
- size_t compute_distances(size_t n, const idx_t *shortlist,
50
- const float *query, float *distances) const;
50
+ size_t compute_distances(
51
+ size_t n,
52
+ const idx_t* shortlist,
53
+ const float* query,
54
+ float* distances) const;
51
55
 
52
56
  /// called by add_codes
53
- void estimate_code(const float *x, storage_idx_t i, uint8_t *code) const;
57
+ void estimate_code(const float* x, storage_idx_t i, uint8_t* code) const;
54
58
 
55
59
  /// called by compute_distances
56
- void reconstruct(storage_idx_t i, float *x, float *tmp) const;
60
+ void reconstruct(storage_idx_t i, float* x, float* tmp) const;
57
61
 
58
- void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float *x) const;
62
+ void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float* x) const;
59
63
 
60
64
  /// get the M+1 -by-d table for neighbor coordinates for vector i
61
- void get_neighbor_table(storage_idx_t i, float *out) const;
62
-
65
+ void get_neighbor_table(storage_idx_t i, float* out) const;
63
66
  };
64
67
 
65
-
66
68
  /** The HNSW index is a normal random-access index with a HNSW
67
69
  * link structure built on top */
68
70
 
69
71
  struct IndexHNSW : Index {
70
-
71
72
  typedef HNSW::storage_idx_t storage_idx_t;
72
73
 
73
74
  // the link strcuture
@@ -75,27 +76,31 @@ struct IndexHNSW : Index {
75
76
 
76
77
  // the sequential storage
77
78
  bool own_fields;
78
- Index *storage;
79
+ Index* storage;
79
80
 
80
- ReconstructFromNeighbors *reconstruct_from_neighbors;
81
+ ReconstructFromNeighbors* reconstruct_from_neighbors;
81
82
 
82
- explicit IndexHNSW (int d = 0, int M = 32, MetricType metric = METRIC_L2);
83
- explicit IndexHNSW (Index *storage, int M = 32);
83
+ explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2);
84
+ explicit IndexHNSW(Index* storage, int M = 32);
84
85
 
85
86
  ~IndexHNSW() override;
86
87
 
87
- void add(idx_t n, const float *x) override;
88
+ void add(idx_t n, const float* x) override;
88
89
 
89
90
  /// Trains the storage if needed
90
91
  void train(idx_t n, const float* x) override;
91
92
 
92
93
  /// entry point for search
93
- void search (idx_t n, const float *x, idx_t k,
94
- float *distances, idx_t *labels) const override;
94
+ void search(
95
+ idx_t n,
96
+ const float* x,
97
+ idx_t k,
98
+ float* distances,
99
+ idx_t* labels) const override;
95
100
 
96
101
  void reconstruct(idx_t key, float* recons) const override;
97
102
 
98
- void reset () override;
103
+ void reset() override;
99
104
 
100
105
  void shrink_level_0_neighbors(int size);
101
106
 
@@ -105,19 +110,25 @@ struct IndexHNSW : Index {
105
110
  * @param search_type 1:perform one search per nprobe, 2: enqueue
106
111
  * all entry points
107
112
  */
108
- void search_level_0(idx_t n, const float *x, idx_t k,
109
- const storage_idx_t *nearest, const float *nearest_d,
110
- float *distances, idx_t *labels, int nprobe = 1,
111
- int search_type = 1) const;
113
+ void search_level_0(
114
+ idx_t n,
115
+ const float* x,
116
+ idx_t k,
117
+ const storage_idx_t* nearest,
118
+ const float* nearest_d,
119
+ float* distances,
120
+ idx_t* labels,
121
+ int nprobe = 1,
122
+ int search_type = 1) const;
112
123
 
113
124
  /// alternative graph building
114
- void init_level_0_from_knngraph(
115
- int k, const float *D, const idx_t *I);
125
+ void init_level_0_from_knngraph(int k, const float* D, const idx_t* I);
116
126
 
117
127
  /// alternative graph building
118
128
  void init_level_0_from_entry_points(
119
- int npt, const storage_idx_t *points,
120
- const storage_idx_t *nearests);
129
+ int npt,
130
+ const storage_idx_t* points,
131
+ const storage_idx_t* nearests);
121
132
 
122
133
  // reorder links from nearest to farthest
123
134
  void reorder_links();
@@ -125,7 +136,6 @@ struct IndexHNSW : Index {
125
136
  void link_singletons();
126
137
  };
127
138
 
128
-
129
139
  /** Flat index topped with with a HNSW structure to access elements
130
140
  * more efficiently.
131
141
  */
@@ -149,22 +159,28 @@ struct IndexHNSWPQ : IndexHNSW {
149
159
  */
150
160
  struct IndexHNSWSQ : IndexHNSW {
151
161
  IndexHNSWSQ();
152
- IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M, MetricType metric = METRIC_L2);
162
+ IndexHNSWSQ(
163
+ int d,
164
+ ScalarQuantizer::QuantizerType qtype,
165
+ int M,
166
+ MetricType metric = METRIC_L2);
153
167
  };
154
168
 
155
169
  /** 2-level code structure with fast random access
156
170
  */
157
171
  struct IndexHNSW2Level : IndexHNSW {
158
172
  IndexHNSW2Level();
159
- IndexHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M);
173
+ IndexHNSW2Level(Index* quantizer, size_t nlist, int m_pq, int M);
160
174
 
161
175
  void flip_to_ivf();
162
176
 
163
177
  /// entry point for search
164
- void search (idx_t n, const float *x, idx_t k,
165
- float *distances, idx_t *labels) const override;
166
-
178
+ void search(
179
+ idx_t n,
180
+ const float* x,
181
+ idx_t k,
182
+ float* distances,
183
+ idx_t* labels) const override;
167
184
  };
168
185
 
169
-
170
- } // namespace faiss
186
+ } // namespace faiss
@@ -9,7 +9,6 @@
9
9
 
10
10
  #include <faiss/IndexIVF.h>
11
11
 
12
-
13
12
  #include <omp.h>
14
13
  #include <mutex>
15
14
 
@@ -18,12 +17,12 @@
18
17
  #include <cstdio>
19
18
  #include <memory>
20
19
 
21
- #include <faiss/utils/utils.h>
22
20
  #include <faiss/utils/hamming.h>
21
+ #include <faiss/utils/utils.h>
23
22
 
24
- #include <faiss/impl/FaissAssert.h>
25
23
  #include <faiss/IndexFlat.h>
26
24
  #include <faiss/impl/AuxIndexStructures.h>
25
+ #include <faiss/impl/FaissAssert.h>
27
26
 
28
27
  namespace faiss {
29
28
 
@@ -34,99 +33,97 @@ using ScopedCodes = InvertedLists::ScopedCodes;
34
33
  * Level1Quantizer implementation
35
34
  ******************************************/
36
35
 
37
-
38
- Level1Quantizer::Level1Quantizer (Index * quantizer, size_t nlist):
39
- quantizer (quantizer),
40
- nlist (nlist),
41
- quantizer_trains_alone (0),
42
- own_fields (false),
43
- clustering_index (nullptr)
44
- {
36
+ Level1Quantizer::Level1Quantizer(Index* quantizer, size_t nlist)
37
+ : quantizer(quantizer),
38
+ nlist(nlist),
39
+ quantizer_trains_alone(0),
40
+ own_fields(false),
41
+ clustering_index(nullptr) {
45
42
  // here we set a low # iterations because this is typically used
46
43
  // for large clusterings (nb this is not used for the MultiIndex,
47
44
  // for which quantizer_trains_alone = true)
48
45
  cp.niter = 10;
49
46
  }
50
47
 
51
- Level1Quantizer::Level1Quantizer ():
52
- quantizer (nullptr),
53
- nlist (0),
54
- quantizer_trains_alone (0), own_fields (false),
55
- clustering_index (nullptr)
56
- {}
48
+ Level1Quantizer::Level1Quantizer()
49
+ : quantizer(nullptr),
50
+ nlist(0),
51
+ quantizer_trains_alone(0),
52
+ own_fields(false),
53
+ clustering_index(nullptr) {}
57
54
 
58
- Level1Quantizer::~Level1Quantizer ()
59
- {
60
- if (own_fields) delete quantizer;
55
+ Level1Quantizer::~Level1Quantizer() {
56
+ if (own_fields)
57
+ delete quantizer;
61
58
  }
62
59
 
63
- void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricType metric_type)
64
- {
60
+ void Level1Quantizer::train_q1(
61
+ size_t n,
62
+ const float* x,
63
+ bool verbose,
64
+ MetricType metric_type) {
65
65
  size_t d = quantizer->d;
66
66
  if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
67
67
  if (verbose)
68
- printf ("IVF quantizer does not need training.\n");
68
+ printf("IVF quantizer does not need training.\n");
69
69
  } else if (quantizer_trains_alone == 1) {
70
70
  if (verbose)
71
- printf ("IVF quantizer trains alone...\n");
72
- quantizer->train (n, x);
71
+ printf("IVF quantizer trains alone...\n");
72
+ quantizer->train(n, x);
73
73
  quantizer->verbose = verbose;
74
- FAISS_THROW_IF_NOT_MSG (quantizer->ntotal == nlist,
75
- "nlist not consistent with quantizer size");
74
+ FAISS_THROW_IF_NOT_MSG(
75
+ quantizer->ntotal == nlist,
76
+ "nlist not consistent with quantizer size");
76
77
  } else if (quantizer_trains_alone == 0) {
77
78
  if (verbose)
78
- printf ("Training level-1 quantizer on %zd vectors in %zdD\n",
79
- n, d);
79
+ printf("Training level-1 quantizer on %zd vectors in %zdD\n", n, d);
80
80
 
81
- Clustering clus (d, nlist, cp);
81
+ Clustering clus(d, nlist, cp);
82
82
  quantizer->reset();
83
83
  if (clustering_index) {
84
- clus.train (n, x, *clustering_index);
85
- quantizer->add (nlist, clus.centroids.data());
84
+ clus.train(n, x, *clustering_index);
85
+ quantizer->add(nlist, clus.centroids.data());
86
86
  } else {
87
- clus.train (n, x, *quantizer);
87
+ clus.train(n, x, *quantizer);
88
88
  }
89
89
  quantizer->is_trained = true;
90
90
  } else if (quantizer_trains_alone == 2) {
91
91
  if (verbose) {
92
- printf (
93
- "Training L2 quantizer on %zd vectors in %zdD%s\n",
94
- n, d,
95
- clustering_index ? "(user provided index)" : "");
92
+ printf("Training L2 quantizer on %zd vectors in %zdD%s\n",
93
+ n,
94
+ d,
95
+ clustering_index ? "(user provided index)" : "");
96
96
  }
97
97
  // also accept spherical centroids because in that case
98
98
  // L2 and IP are equivalent
99
- FAISS_THROW_IF_NOT (
100
- metric_type == METRIC_L2 ||
101
- (metric_type == METRIC_INNER_PRODUCT && cp.spherical)
102
- );
99
+ FAISS_THROW_IF_NOT(
100
+ metric_type == METRIC_L2 ||
101
+ (metric_type == METRIC_INNER_PRODUCT && cp.spherical));
103
102
 
104
- Clustering clus (d, nlist, cp);
103
+ Clustering clus(d, nlist, cp);
105
104
  if (!clustering_index) {
106
- IndexFlatL2 assigner (d);
105
+ IndexFlatL2 assigner(d);
107
106
  clus.train(n, x, assigner);
108
107
  } else {
109
108
  clus.train(n, x, *clustering_index);
110
109
  }
111
110
  if (verbose)
112
- printf ("Adding centroids to quantizer\n");
113
- quantizer->add (nlist, clus.centroids.data());
111
+ printf("Adding centroids to quantizer\n");
112
+ quantizer->add(nlist, clus.centroids.data());
114
113
  }
115
114
  }
116
115
 
117
- size_t Level1Quantizer::coarse_code_size () const
118
- {
116
+ size_t Level1Quantizer::coarse_code_size() const {
119
117
  size_t nl = nlist - 1;
120
118
  size_t nbyte = 0;
121
119
  while (nl > 0) {
122
- nbyte ++;
120
+ nbyte++;
123
121
  nl >>= 8;
124
122
  }
125
123
  return nbyte;
126
124
  }
127
125
 
128
- void Level1Quantizer::encode_listno (Index::idx_t list_no, uint8_t *code) const
129
- {
126
+ void Level1Quantizer::encode_listno(Index::idx_t list_no, uint8_t* code) const {
130
127
  // little endian
131
128
  size_t nl = nlist - 1;
132
129
  while (nl > 0) {
@@ -136,8 +133,7 @@ void Level1Quantizer::encode_listno (Index::idx_t list_no, uint8_t *code) const
136
133
  }
137
134
  }
138
135
 
139
- Index::idx_t Level1Quantizer::decode_listno (const uint8_t *code) const
140
- {
136
+ Index::idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
141
137
  size_t nl = nlist - 1;
142
138
  int64_t list_no = 0;
143
139
  int nbit = 0;
@@ -146,161 +142,184 @@ Index::idx_t Level1Quantizer::decode_listno (const uint8_t *code) const
146
142
  nbit += 8;
147
143
  nl >>= 8;
148
144
  }
149
- FAISS_THROW_IF_NOT (list_no >= 0 && list_no < nlist);
145
+ FAISS_THROW_IF_NOT(list_no >= 0 && list_no < nlist);
150
146
  return list_no;
151
147
  }
152
148
 
153
-
154
-
155
149
  /*****************************************
156
150
  * IndexIVF implementation
157
151
  ******************************************/
158
152
 
159
-
160
- IndexIVF::IndexIVF (Index * quantizer, size_t d,
161
- size_t nlist, size_t code_size,
162
- MetricType metric):
163
- Index (d, metric),
164
- Level1Quantizer (quantizer, nlist),
165
- invlists (new ArrayInvertedLists (nlist, code_size)),
166
- own_invlists (true),
167
- code_size (code_size),
168
- nprobe (1),
169
- max_codes (0),
170
- parallel_mode (0)
171
- {
172
- FAISS_THROW_IF_NOT (d == quantizer->d);
153
+ IndexIVF::IndexIVF(
154
+ Index* quantizer,
155
+ size_t d,
156
+ size_t nlist,
157
+ size_t code_size,
158
+ MetricType metric)
159
+ : Index(d, metric),
160
+ Level1Quantizer(quantizer, nlist),
161
+ invlists(new ArrayInvertedLists(nlist, code_size)),
162
+ own_invlists(true),
163
+ code_size(code_size),
164
+ nprobe(1),
165
+ max_codes(0),
166
+ parallel_mode(0) {
167
+ FAISS_THROW_IF_NOT(d == quantizer->d);
173
168
  is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
174
169
  // Spherical by default if the metric is inner_product
175
170
  if (metric_type == METRIC_INNER_PRODUCT) {
176
171
  cp.spherical = true;
177
172
  }
178
-
179
173
  }
180
174
 
181
- IndexIVF::IndexIVF ():
182
- invlists (nullptr), own_invlists (false),
183
- code_size (0),
184
- nprobe (1), max_codes (0), parallel_mode (0)
185
- {}
175
+ IndexIVF::IndexIVF()
176
+ : invlists(nullptr),
177
+ own_invlists(false),
178
+ code_size(0),
179
+ nprobe(1),
180
+ max_codes(0),
181
+ parallel_mode(0) {}
186
182
 
187
- void IndexIVF::add (idx_t n, const float * x)
188
- {
189
- add_with_ids (n, x, nullptr);
183
+ void IndexIVF::add(idx_t n, const float* x) {
184
+ add_with_ids(n, x, nullptr);
190
185
  }
191
186
 
187
+ void IndexIVF::add_with_ids(idx_t n, const float* x, const idx_t* xids) {
188
+ std::unique_ptr<idx_t[]> coarse_idx(new idx_t[n]);
189
+ quantizer->assign(n, x, coarse_idx.get());
190
+ add_core(n, x, xids, coarse_idx.get());
191
+ }
192
192
 
193
- void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
194
- {
193
+ void IndexIVF::add_core(
194
+ idx_t n,
195
+ const float* x,
196
+ const idx_t* xids,
197
+ const idx_t* coarse_idx) {
195
198
  // do some blocking to avoid excessive allocs
196
199
  idx_t bs = 65536;
197
200
  if (n > bs) {
198
201
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
199
- idx_t i1 = std::min (n, i0 + bs);
202
+ idx_t i1 = std::min(n, i0 + bs);
200
203
  if (verbose) {
201
- printf(" IndexIVF::add_with_ids %" PRId64 ":%" PRId64 "\n", i0, i1);
204
+ printf(" IndexIVF::add_with_ids %" PRId64 ":%" PRId64 "\n",
205
+ i0,
206
+ i1);
202
207
  }
203
- add_with_ids (i1 - i0, x + i0 * d,
204
- xids ? xids + i0 : nullptr);
208
+ add_core(
209
+ i1 - i0,
210
+ x + i0 * d,
211
+ xids ? xids + i0 : nullptr,
212
+ coarse_idx + i0);
205
213
  }
206
214
  return;
207
215
  }
216
+ FAISS_THROW_IF_NOT(coarse_idx);
217
+ FAISS_THROW_IF_NOT(is_trained);
218
+ direct_map.check_can_add(xids);
208
219
 
209
- FAISS_THROW_IF_NOT (is_trained);
210
- direct_map.check_can_add (xids);
211
-
212
- std::unique_ptr<idx_t []> idx(new idx_t[n]);
213
- quantizer->assign (n, x, idx.get());
214
220
  size_t nadd = 0, nminus1 = 0;
215
221
 
216
222
  for (size_t i = 0; i < n; i++) {
217
- if (idx[i] < 0) nminus1++;
223
+ if (coarse_idx[i] < 0)
224
+ nminus1++;
218
225
  }
219
226
 
220
- std::unique_ptr<uint8_t []> flat_codes(new uint8_t [n * code_size]);
221
- encode_vectors (n, x, idx.get(), flat_codes.get());
227
+ std::unique_ptr<uint8_t[]> flat_codes(new uint8_t[n * code_size]);
228
+ encode_vectors(n, x, coarse_idx, flat_codes.get());
222
229
 
223
230
  DirectMapAdd dm_adder(direct_map, n, xids);
224
231
 
225
- #pragma omp parallel reduction(+: nadd)
232
+ #pragma omp parallel reduction(+ : nadd)
226
233
  {
227
234
  int nt = omp_get_num_threads();
228
235
  int rank = omp_get_thread_num();
229
236
 
230
237
  // each thread takes care of a subset of lists
231
238
  for (size_t i = 0; i < n; i++) {
232
- idx_t list_no = idx [i];
239
+ idx_t list_no = coarse_idx[i];
233
240
  if (list_no >= 0 && list_no % nt == rank) {
234
241
  idx_t id = xids ? xids[i] : ntotal + i;
235
- size_t ofs = invlists->add_entry (
236
- list_no, id,
237
- flat_codes.get() + i * code_size
238
- );
242
+ size_t ofs = invlists->add_entry(
243
+ list_no, id, flat_codes.get() + i * code_size);
239
244
 
240
- dm_adder.add (i, list_no, ofs);
245
+ dm_adder.add(i, list_no, ofs);
241
246
 
242
247
  nadd++;
243
248
  } else if (rank == 0 && list_no == -1) {
244
- dm_adder.add (i, -1, 0);
249
+ dm_adder.add(i, -1, 0);
245
250
  }
246
251
  }
247
252
  }
248
253
 
249
-
250
254
  if (verbose) {
251
- printf(" added %zd / %" PRId64 " vectors (%zd -1s)\n", nadd, n, nminus1);
255
+ printf(" added %zd / %" PRId64 " vectors (%zd -1s)\n",
256
+ nadd,
257
+ n,
258
+ nminus1);
252
259
  }
253
260
 
254
261
  ntotal += n;
255
262
  }
256
263
 
257
- void IndexIVF::make_direct_map (bool b)
258
- {
264
+ void IndexIVF::make_direct_map(bool b) {
259
265
  if (b) {
260
- direct_map.set_type (DirectMap::Array, invlists, ntotal);
266
+ direct_map.set_type(DirectMap::Array, invlists, ntotal);
261
267
  } else {
262
- direct_map.set_type (DirectMap::NoMap, invlists, ntotal);
268
+ direct_map.set_type(DirectMap::NoMap, invlists, ntotal);
263
269
  }
264
270
  }
265
271
 
266
-
267
-
268
- void IndexIVF::set_direct_map_type (DirectMap::Type type)
269
- {
270
- direct_map.set_type (type, invlists, ntotal);
272
+ void IndexIVF::set_direct_map_type(DirectMap::Type type) {
273
+ direct_map.set_type(type, invlists, ntotal);
271
274
  }
272
275
 
273
276
  /** It is a sad fact of software that a conceptually simple function like this
274
277
  * becomes very complex when you factor in several ways of parallelizing +
275
278
  * interrupt/error handling + collecting stats + min/max collection. The
276
279
  * codepath that is used 95% of time is the one for parallel_mode = 0 */
277
- void IndexIVF::search (idx_t n, const float *x, idx_t k,
278
- float *distances, idx_t *labels) const
279
- {
280
+ void IndexIVF::search(
281
+ idx_t n,
282
+ const float* x,
283
+ idx_t k,
284
+ float* distances,
285
+ idx_t* labels) const {
286
+ FAISS_THROW_IF_NOT(k > 0);
280
287
 
288
+ const size_t nprobe = std::min(nlist, this->nprobe);
289
+ FAISS_THROW_IF_NOT(nprobe > 0);
281
290
 
282
291
  // search function for a subset of queries
283
- auto sub_search_func = [this, k]
284
- (idx_t n, const float *x, float *distances, idx_t *labels,
285
- IndexIVFStats *ivf_stats) {
286
-
292
+ auto sub_search_func = [this, k, nprobe](
293
+ idx_t n,
294
+ const float* x,
295
+ float* distances,
296
+ idx_t* labels,
297
+ IndexIVFStats* ivf_stats) {
287
298
  std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
288
299
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
289
300
 
290
301
  double t0 = getmillisecs();
291
- quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
302
+ quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
292
303
 
293
304
  double t1 = getmillisecs();
294
- invlists->prefetch_lists (idx.get(), n * nprobe);
295
-
296
- search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
297
- distances, labels, false, nullptr, ivf_stats);
305
+ invlists->prefetch_lists(idx.get(), n * nprobe);
306
+
307
+ search_preassigned(
308
+ n,
309
+ x,
310
+ k,
311
+ idx.get(),
312
+ coarse_dis.get(),
313
+ distances,
314
+ labels,
315
+ false,
316
+ nullptr,
317
+ ivf_stats);
298
318
  double t2 = getmillisecs();
299
319
  ivf_stats->quantization_time += t1 - t0;
300
320
  ivf_stats->search_time += t2 - t0;
301
321
  };
302
322
 
303
-
304
323
  if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
305
324
  int nt = std::min(omp_get_max_threads(), int(n));
306
325
  std::vector<IndexIVFStats> stats(nt);
@@ -308,18 +327,19 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k,
308
327
  std::string exception_string;
309
328
 
310
329
  #pragma omp parallel for if (nt > 1)
311
- for(idx_t slice = 0; slice < nt; slice++) {
330
+ for (idx_t slice = 0; slice < nt; slice++) {
312
331
  IndexIVFStats local_stats;
313
332
  idx_t i0 = n * slice / nt;
314
333
  idx_t i1 = n * (slice + 1) / nt;
315
334
  if (i1 > i0) {
316
335
  try {
317
336
  sub_search_func(
318
- i1 - i0, x + i0 * d,
319
- distances + i0 * k, labels + i0 * k,
320
- &stats[slice]
321
- );
322
- } catch(const std::exception & e) {
337
+ i1 - i0,
338
+ x + i0 * d,
339
+ distances + i0 * k,
340
+ labels + i0 * k,
341
+ &stats[slice]);
342
+ } catch (const std::exception& e) {
323
343
  std::lock_guard<std::mutex> lock(exception_mutex);
324
344
  exception_string = e.what();
325
345
  }
@@ -327,32 +347,38 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k,
327
347
  }
328
348
 
329
349
  if (!exception_string.empty()) {
330
- FAISS_THROW_MSG (exception_string.c_str());
350
+ FAISS_THROW_MSG(exception_string.c_str());
331
351
  }
332
352
 
333
353
  // collect stats
334
- for(idx_t slice = 0; slice < nt; slice++) {
354
+ for (idx_t slice = 0; slice < nt; slice++) {
335
355
  indexIVF_stats.add(stats[slice]);
336
356
  }
337
357
  } else {
338
- // handle paralellization at level below (or don't run in parallel at all)
358
+ // handle paralellization at level below (or don't run in parallel at
359
+ // all)
339
360
  sub_search_func(n, x, distances, labels, &indexIVF_stats);
340
361
  }
341
-
342
-
343
362
  }
344
363
 
345
-
346
- void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
347
- const idx_t *keys,
348
- const float *coarse_dis ,
349
- float *distances, idx_t *labels,
350
- bool store_pairs,
351
- const IVFSearchParameters *params,
352
- IndexIVFStats *ivf_stats) const
353
- {
354
- long nprobe = params ? params->nprobe : this->nprobe;
355
- long max_codes = params ? params->max_codes : this->max_codes;
364
+ void IndexIVF::search_preassigned(
365
+ idx_t n,
366
+ const float* x,
367
+ idx_t k,
368
+ const idx_t* keys,
369
+ const float* coarse_dis,
370
+ float* distances,
371
+ idx_t* labels,
372
+ bool store_pairs,
373
+ const IVFSearchParameters* params,
374
+ IndexIVFStats* ivf_stats) const {
375
+ FAISS_THROW_IF_NOT(k > 0);
376
+
377
+ idx_t nprobe = params ? params->nprobe : this->nprobe;
378
+ nprobe = std::min((idx_t)nlist, nprobe);
379
+ FAISS_THROW_IF_NOT(nprobe > 0);
380
+
381
+ idx_t max_codes = params ? params->max_codes : this->max_codes;
356
382
 
357
383
  size_t nlistv = 0, ndis = 0, nheap = 0;
358
384
 
@@ -366,15 +392,15 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
366
392
  int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
367
393
  bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
368
394
 
369
- bool do_parallel = omp_get_max_threads() >= 2 && (
370
- pmode == 0 ? false :
371
- pmode == 3 ? n > 1 :
372
- pmode == 1 ? nprobe > 1 :
373
- nprobe * n > 1);
395
+ bool do_parallel = omp_get_max_threads() >= 2 &&
396
+ (pmode == 0 ? false
397
+ : pmode == 3 ? n > 1
398
+ : pmode == 1 ? nprobe > 1
399
+ : nprobe * n > 1);
374
400
 
375
- #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
401
+ #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
376
402
  {
377
- InvertedListScanner *scanner = get_InvertedListScanner(store_pairs);
403
+ InvertedListScanner* scanner = get_InvertedListScanner(store_pairs);
378
404
  ScopeDeleter1<InvertedListScanner> del(scanner);
379
405
 
380
406
  /*****************************************************
@@ -385,49 +411,52 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
385
411
 
386
412
  // intialize + reorder a result heap
387
413
 
388
- auto init_result = [&](float *simi, idx_t *idxi) {
389
- if (!do_heap_init) return;
414
+ auto init_result = [&](float* simi, idx_t* idxi) {
415
+ if (!do_heap_init)
416
+ return;
390
417
  if (metric_type == METRIC_INNER_PRODUCT) {
391
- heap_heapify<HeapForIP> (k, simi, idxi);
418
+ heap_heapify<HeapForIP>(k, simi, idxi);
392
419
  } else {
393
- heap_heapify<HeapForL2> (k, simi, idxi);
420
+ heap_heapify<HeapForL2>(k, simi, idxi);
394
421
  }
395
422
  };
396
423
 
397
- auto add_local_results = [&](
398
- const float * local_dis, const idx_t * local_idx,
399
- float *simi, idx_t *idxi)
400
- {
424
+ auto add_local_results = [&](const float* local_dis,
425
+ const idx_t* local_idx,
426
+ float* simi,
427
+ idx_t* idxi) {
401
428
  if (metric_type == METRIC_INNER_PRODUCT) {
402
- heap_addn<HeapForIP>
403
- (k, simi, idxi, local_dis, local_idx, k);
429
+ heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
404
430
  } else {
405
- heap_addn<HeapForL2>
406
- (k, simi, idxi, local_dis, local_idx, k);
431
+ heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
407
432
  }
408
433
  };
409
434
 
410
- auto reorder_result = [&] (float *simi, idx_t *idxi) {
411
- if (!do_heap_init) return;
435
+ auto reorder_result = [&](float* simi, idx_t* idxi) {
436
+ if (!do_heap_init)
437
+ return;
412
438
  if (metric_type == METRIC_INNER_PRODUCT) {
413
- heap_reorder<HeapForIP> (k, simi, idxi);
439
+ heap_reorder<HeapForIP>(k, simi, idxi);
414
440
  } else {
415
- heap_reorder<HeapForL2> (k, simi, idxi);
441
+ heap_reorder<HeapForL2>(k, simi, idxi);
416
442
  }
417
443
  };
418
444
 
419
445
  // single list scan using the current scanner (with query
420
446
  // set porperly) and storing results in simi and idxi
421
- auto scan_one_list = [&] (idx_t key, float coarse_dis_i,
422
- float *simi, idx_t *idxi) {
423
-
447
+ auto scan_one_list = [&](idx_t key,
448
+ float coarse_dis_i,
449
+ float* simi,
450
+ idx_t* idxi) {
424
451
  if (key < 0) {
425
452
  // not enough centroids for multiprobe
426
453
  return (size_t)0;
427
454
  }
428
- FAISS_THROW_IF_NOT_FMT (key < (idx_t) nlist,
429
- "Invalid key=%" PRId64 " nlist=%zd\n",
430
- key, nlist);
455
+ FAISS_THROW_IF_NOT_FMT(
456
+ key < (idx_t)nlist,
457
+ "Invalid key=%" PRId64 " nlist=%zd\n",
458
+ key,
459
+ nlist);
431
460
 
432
461
  size_t list_size = invlists->list_size(key);
433
462
 
@@ -436,28 +465,28 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
436
465
  return (size_t)0;
437
466
  }
438
467
 
439
- scanner->set_list (key, coarse_dis_i);
468
+ scanner->set_list(key, coarse_dis_i);
440
469
 
441
470
  nlistv++;
442
471
 
443
472
  try {
444
- InvertedLists::ScopedCodes scodes (invlists, key);
473
+ InvertedLists::ScopedCodes scodes(invlists, key);
445
474
 
446
475
  std::unique_ptr<InvertedLists::ScopedIds> sids;
447
- const Index::idx_t * ids = nullptr;
476
+ const Index::idx_t* ids = nullptr;
448
477
 
449
- if (!store_pairs) {
450
- sids.reset (new InvertedLists::ScopedIds (invlists, key));
478
+ if (!store_pairs) {
479
+ sids.reset(new InvertedLists::ScopedIds(invlists, key));
451
480
  ids = sids->get();
452
481
  }
453
482
 
454
- nheap += scanner->scan_codes (list_size, scodes.get(),
455
- ids, simi, idxi, k);
483
+ nheap += scanner->scan_codes(
484
+ list_size, scodes.get(), ids, simi, idxi, k);
456
485
 
457
- } catch(const std::exception & e) {
486
+ } catch (const std::exception& e) {
458
487
  std::lock_guard<std::mutex> lock(exception_mutex);
459
488
  exception_string =
460
- demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
489
+ demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
461
490
  interrupt = true;
462
491
  return size_t(0);
463
492
  }
@@ -470,31 +499,28 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
470
499
  ****************************************************/
471
500
 
472
501
  if (pmode == 0 || pmode == 3) {
473
-
474
502
  #pragma omp for
475
503
  for (idx_t i = 0; i < n; i++) {
476
-
477
504
  if (interrupt) {
478
505
  continue;
479
506
  }
480
507
 
481
508
  // loop over queries
482
- scanner->set_query (x + i * d);
483
- float * simi = distances + i * k;
484
- idx_t * idxi = labels + i * k;
509
+ scanner->set_query(x + i * d);
510
+ float* simi = distances + i * k;
511
+ idx_t* idxi = labels + i * k;
485
512
 
486
- init_result (simi, idxi);
513
+ init_result(simi, idxi);
487
514
 
488
- long nscan = 0;
515
+ idx_t nscan = 0;
489
516
 
490
517
  // loop over probes
491
518
  for (size_t ik = 0; ik < nprobe; ik++) {
492
-
493
- nscan += scan_one_list (
494
- keys [i * nprobe + ik],
495
- coarse_dis[i * nprobe + ik],
496
- simi, idxi
497
- );
519
+ nscan += scan_one_list(
520
+ keys[i * nprobe + ik],
521
+ coarse_dis[i * nprobe + ik],
522
+ simi,
523
+ idxi);
498
524
 
499
525
  if (max_codes && nscan >= max_codes) {
500
526
  break;
@@ -502,54 +528,55 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
502
528
  }
503
529
 
504
530
  ndis += nscan;
505
- reorder_result (simi, idxi);
531
+ reorder_result(simi, idxi);
506
532
 
507
- if (InterruptCallback::is_interrupted ()) {
533
+ if (InterruptCallback::is_interrupted()) {
508
534
  interrupt = true;
509
535
  }
510
536
 
511
537
  } // parallel for
512
538
  } else if (pmode == 1) {
513
- std::vector <idx_t> local_idx (k);
514
- std::vector <float> local_dis (k);
539
+ std::vector<idx_t> local_idx(k);
540
+ std::vector<float> local_dis(k);
515
541
 
516
542
  for (size_t i = 0; i < n; i++) {
517
- scanner->set_query (x + i * d);
518
- init_result (local_dis.data(), local_idx.data());
543
+ scanner->set_query(x + i * d);
544
+ init_result(local_dis.data(), local_idx.data());
519
545
 
520
546
  #pragma omp for schedule(dynamic)
521
- for (long ik = 0; ik < nprobe; ik++) {
522
- ndis += scan_one_list
523
- (keys [i * nprobe + ik],
524
- coarse_dis[i * nprobe + ik],
525
- local_dis.data(), local_idx.data());
547
+ for (idx_t ik = 0; ik < nprobe; ik++) {
548
+ ndis += scan_one_list(
549
+ keys[i * nprobe + ik],
550
+ coarse_dis[i * nprobe + ik],
551
+ local_dis.data(),
552
+ local_idx.data());
526
553
 
527
554
  // can't do the test on max_codes
528
555
  }
529
556
  // merge thread-local results
530
557
 
531
- float * simi = distances + i * k;
532
- idx_t * idxi = labels + i * k;
558
+ float* simi = distances + i * k;
559
+ idx_t* idxi = labels + i * k;
533
560
  #pragma omp single
534
- init_result (simi, idxi);
561
+ init_result(simi, idxi);
535
562
 
536
563
  #pragma omp barrier
537
564
  #pragma omp critical
538
565
  {
539
- add_local_results (local_dis.data(), local_idx.data(),
540
- simi, idxi);
566
+ add_local_results(
567
+ local_dis.data(), local_idx.data(), simi, idxi);
541
568
  }
542
569
  #pragma omp barrier
543
570
  #pragma omp single
544
- reorder_result (simi, idxi);
571
+ reorder_result(simi, idxi);
545
572
  }
546
573
  } else if (pmode == 2) {
547
- std::vector <idx_t> local_idx (k);
548
- std::vector <float> local_dis (k);
574
+ std::vector<idx_t> local_idx(k);
575
+ std::vector<float> local_dis(k);
549
576
 
550
577
  #pragma omp single
551
578
  for (int64_t i = 0; i < n; i++) {
552
- init_result (distances + i * k, labels + i * k);
579
+ init_result(distances + i * k, labels + i * k);
553
580
  }
554
581
 
555
582
  #pragma omp for schedule(dynamic)
@@ -557,33 +584,37 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
557
584
  size_t i = ij / nprobe;
558
585
  size_t j = ij % nprobe;
559
586
 
560
- scanner->set_query (x + i * d);
561
- init_result (local_dis.data(), local_idx.data());
562
- ndis += scan_one_list (
563
- keys [ij], coarse_dis[ij],
564
- local_dis.data(), local_idx.data());
587
+ scanner->set_query(x + i * d);
588
+ init_result(local_dis.data(), local_idx.data());
589
+ ndis += scan_one_list(
590
+ keys[ij],
591
+ coarse_dis[ij],
592
+ local_dis.data(),
593
+ local_idx.data());
565
594
  #pragma omp critical
566
595
  {
567
- add_local_results (local_dis.data(), local_idx.data(),
568
- distances + i * k, labels + i * k);
596
+ add_local_results(
597
+ local_dis.data(),
598
+ local_idx.data(),
599
+ distances + i * k,
600
+ labels + i * k);
569
601
  }
570
602
  }
571
603
  #pragma omp single
572
604
  for (int64_t i = 0; i < n; i++) {
573
- reorder_result (distances + i * k, labels + i * k);
605
+ reorder_result(distances + i * k, labels + i * k);
574
606
  }
575
607
  } else {
576
- FAISS_THROW_FMT ("parallel_mode %d not supported\n",
577
- pmode);
608
+ FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode);
578
609
  }
579
610
  } // parallel section
580
611
 
581
612
  if (interrupt) {
582
613
  if (!exception_string.empty()) {
583
- FAISS_THROW_FMT ("search interrupted with: %s",
584
- exception_string.c_str());
614
+ FAISS_THROW_FMT(
615
+ "search interrupted with: %s", exception_string.c_str());
585
616
  } else {
586
- FAISS_THROW_MSG ("computation interrupted");
617
+ FAISS_THROW_MSG("computation interrupted");
587
618
  }
588
619
  }
589
620
 
@@ -595,38 +626,49 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
595
626
  }
596
627
  }
597
628
 
598
-
599
-
600
-
601
- void IndexIVF::range_search (idx_t nx, const float *x, float radius,
602
- RangeSearchResult *result) const
603
- {
604
- std::unique_ptr<idx_t[]> keys (new idx_t[nx * nprobe]);
605
- std::unique_ptr<float []> coarse_dis (new float[nx * nprobe]);
629
+ void IndexIVF::range_search(
630
+ idx_t nx,
631
+ const float* x,
632
+ float radius,
633
+ RangeSearchResult* result) const {
634
+ const size_t nprobe = std::min(nlist, this->nprobe);
635
+ std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
636
+ std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
606
637
 
607
638
  double t0 = getmillisecs();
608
- quantizer->search (nx, x, nprobe, coarse_dis.get (), keys.get ());
639
+ quantizer->search(nx, x, nprobe, coarse_dis.get(), keys.get());
609
640
  indexIVF_stats.quantization_time += getmillisecs() - t0;
610
641
 
611
642
  t0 = getmillisecs();
612
- invlists->prefetch_lists (keys.get(), nx * nprobe);
613
-
614
- range_search_preassigned (nx, x, radius, keys.get (), coarse_dis.get (),
615
- result, false, nullptr, &indexIVF_stats);
643
+ invlists->prefetch_lists(keys.get(), nx * nprobe);
644
+
645
+ range_search_preassigned(
646
+ nx,
647
+ x,
648
+ radius,
649
+ keys.get(),
650
+ coarse_dis.get(),
651
+ result,
652
+ false,
653
+ nullptr,
654
+ &indexIVF_stats);
616
655
 
617
656
  indexIVF_stats.search_time += getmillisecs() - t0;
618
657
  }
619
658
 
620
- void IndexIVF::range_search_preassigned (
621
- idx_t nx, const float *x, float radius,
622
- const idx_t *keys, const float *coarse_dis,
623
- RangeSearchResult *result,
624
- bool store_pairs,
625
- const IVFSearchParameters *params,
626
- IndexIVFStats *stats) const
627
- {
628
- long nprobe = params ? params->nprobe : this->nprobe;
629
- long max_codes = params ? params->max_codes : this->max_codes;
659
+ void IndexIVF::range_search_preassigned(
660
+ idx_t nx,
661
+ const float* x,
662
+ float radius,
663
+ const idx_t* keys,
664
+ const float* coarse_dis,
665
+ RangeSearchResult* result,
666
+ bool store_pairs,
667
+ const IVFSearchParameters* params,
668
+ IndexIVFStats* stats) const {
669
+ idx_t nprobe = params ? params->nprobe : this->nprobe;
670
+ nprobe = std::min((idx_t)nlist, nprobe);
671
+ idx_t max_codes = params ? params->max_codes : this->max_codes;
630
672
 
631
673
  size_t nlistv = 0, ndis = 0;
632
674
 
@@ -634,119 +676,116 @@ void IndexIVF::range_search_preassigned (
634
676
  std::mutex exception_mutex;
635
677
  std::string exception_string;
636
678
 
637
- std::vector<RangeSearchPartialResult *> all_pres (omp_get_max_threads());
679
+ std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
638
680
 
639
681
  int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
640
682
  // don't start parallel section if single query
641
- bool do_parallel = omp_get_max_threads() >= 2 && (
642
- pmode == 3 ? false :
643
- pmode == 0 ? nx > 1 :
644
- pmode == 1 ? nprobe > 1 :
645
- nprobe * nx > 1);
683
+ bool do_parallel = omp_get_max_threads() >= 2 &&
684
+ (pmode == 3 ? false
685
+ : pmode == 0 ? nx > 1
686
+ : pmode == 1 ? nprobe > 1
687
+ : nprobe * nx > 1);
646
688
 
647
- #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis)
689
+ #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
648
690
  {
649
691
  RangeSearchPartialResult pres(result);
650
- std::unique_ptr<InvertedListScanner> scanner
651
- (get_InvertedListScanner(store_pairs));
652
- FAISS_THROW_IF_NOT (scanner.get ());
692
+ std::unique_ptr<InvertedListScanner> scanner(
693
+ get_InvertedListScanner(store_pairs));
694
+ FAISS_THROW_IF_NOT(scanner.get());
653
695
  all_pres[omp_get_thread_num()] = &pres;
654
696
 
655
697
  // prepare the list scanning function
656
698
 
657
- auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult &qres) {
658
-
659
- idx_t key = keys[i * nprobe + ik]; /* select the list */
660
- if (key < 0) return;
661
- FAISS_THROW_IF_NOT_FMT (
662
- key < (idx_t) nlist,
663
- "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
664
- key, ik, nlist);
699
+ auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
700
+ idx_t key = keys[i * nprobe + ik]; /* select the list */
701
+ if (key < 0)
702
+ return;
703
+ FAISS_THROW_IF_NOT_FMT(
704
+ key < (idx_t)nlist,
705
+ "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
706
+ key,
707
+ ik,
708
+ nlist);
665
709
  const size_t list_size = invlists->list_size(key);
666
710
 
667
- if (list_size == 0) return;
711
+ if (list_size == 0)
712
+ return;
668
713
 
669
714
  try {
715
+ InvertedLists::ScopedCodes scodes(invlists, key);
716
+ InvertedLists::ScopedIds ids(invlists, key);
670
717
 
671
- InvertedLists::ScopedCodes scodes (invlists, key);
672
- InvertedLists::ScopedIds ids (invlists, key);
673
-
674
- scanner->set_list (key, coarse_dis[i * nprobe + ik]);
718
+ scanner->set_list(key, coarse_dis[i * nprobe + ik]);
675
719
  nlistv++;
676
720
  ndis += list_size;
677
- scanner->scan_codes_range (list_size, scodes.get(),
678
- ids.get(), radius, qres);
721
+ scanner->scan_codes_range(
722
+ list_size, scodes.get(), ids.get(), radius, qres);
679
723
 
680
- } catch(const std::exception & e) {
724
+ } catch (const std::exception& e) {
681
725
  std::lock_guard<std::mutex> lock(exception_mutex);
682
726
  exception_string =
683
- demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
727
+ demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
684
728
  interrupt = true;
685
729
  }
686
-
687
730
  };
688
731
 
689
732
  if (parallel_mode == 0) {
690
-
691
733
  #pragma omp for
692
734
  for (idx_t i = 0; i < nx; i++) {
693
- scanner->set_query (x + i * d);
735
+ scanner->set_query(x + i * d);
694
736
 
695
- RangeQueryResult & qres = pres.new_result (i);
737
+ RangeQueryResult& qres = pres.new_result(i);
696
738
 
697
739
  for (size_t ik = 0; ik < nprobe; ik++) {
698
- scan_list_func (i, ik, qres);
740
+ scan_list_func(i, ik, qres);
699
741
  }
700
-
701
742
  }
702
743
 
703
744
  } else if (parallel_mode == 1) {
704
-
705
745
  for (size_t i = 0; i < nx; i++) {
706
- scanner->set_query (x + i * d);
746
+ scanner->set_query(x + i * d);
707
747
 
708
- RangeQueryResult & qres = pres.new_result (i);
748
+ RangeQueryResult& qres = pres.new_result(i);
709
749
 
710
750
  #pragma omp for schedule(dynamic)
711
751
  for (int64_t ik = 0; ik < nprobe; ik++) {
712
- scan_list_func (i, ik, qres);
752
+ scan_list_func(i, ik, qres);
713
753
  }
714
754
  }
715
755
  } else if (parallel_mode == 2) {
716
- std::vector<RangeQueryResult *> all_qres (nx);
717
- RangeQueryResult *qres = nullptr;
756
+ std::vector<RangeQueryResult*> all_qres(nx);
757
+ RangeQueryResult* qres = nullptr;
718
758
 
719
759
  #pragma omp for schedule(dynamic)
720
760
  for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) {
721
761
  idx_t i = iik / (idx_t)nprobe;
722
762
  idx_t ik = iik % (idx_t)nprobe;
723
763
  if (qres == nullptr || qres->qno != i) {
724
- FAISS_ASSERT (!qres || i > qres->qno);
725
- qres = &pres.new_result (i);
726
- scanner->set_query (x + i * d);
764
+ FAISS_ASSERT(!qres || i > qres->qno);
765
+ qres = &pres.new_result(i);
766
+ scanner->set_query(x + i * d);
727
767
  }
728
- scan_list_func (i, ik, *qres);
768
+ scan_list_func(i, ik, *qres);
729
769
  }
730
770
  } else {
731
- FAISS_THROW_FMT ("parallel_mode %d not supported\n", parallel_mode);
771
+ FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode);
732
772
  }
733
773
  if (parallel_mode == 0) {
734
- pres.finalize ();
774
+ pres.finalize();
735
775
  } else {
736
776
  #pragma omp barrier
737
777
  #pragma omp single
738
- RangeSearchPartialResult::merge (all_pres, false);
778
+ RangeSearchPartialResult::merge(all_pres, false);
739
779
  #pragma omp barrier
740
-
741
780
  }
742
781
  }
743
782
 
744
783
  if (interrupt) {
745
784
  if (!exception_string.empty()) {
746
- FAISS_THROW_FMT ("search interrupted with: %s",
747
- exception_string.c_str());
785
+ FAISS_THROW_FMT(
786
+ "search interrupted with: %s", exception_string.c_str());
748
787
  } else {
749
- FAISS_THROW_MSG ("computation interrupted");
788
+ FAISS_THROW_MSG("computation interrupted");
750
789
  }
751
790
  }
752
791
 
@@ -757,27 +796,22 @@ void IndexIVF::range_search_preassigned (
757
796
  }
758
797
  }
759
798
 
760
-
761
- InvertedListScanner *IndexIVF::get_InvertedListScanner (
762
- bool /*store_pairs*/) const
763
- {
799
+ InvertedListScanner* IndexIVF::get_InvertedListScanner(
800
+ bool /*store_pairs*/) const {
764
801
  return nullptr;
765
802
  }
766
803
 
767
- void IndexIVF::reconstruct (idx_t key, float* recons) const
768
- {
769
- idx_t lo = direct_map.get (key);
770
- reconstruct_from_offset (lo_listno(lo), lo_offset(lo), recons);
804
+ void IndexIVF::reconstruct(idx_t key, float* recons) const {
805
+ idx_t lo = direct_map.get(key);
806
+ reconstruct_from_offset(lo_listno(lo), lo_offset(lo), recons);
771
807
  }
772
808
 
773
-
774
- void IndexIVF::reconstruct_n (idx_t i0, idx_t ni, float* recons) const
775
- {
776
- FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
809
+ void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
810
+ FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
777
811
 
778
812
  for (idx_t list_no = 0; list_no < nlist; list_no++) {
779
- size_t list_size = invlists->list_size (list_no);
780
- ScopedIds idlist (invlists, list_no);
813
+ size_t list_size = invlists->list_size(list_no);
814
+ ScopedIds idlist(invlists, list_no);
781
815
 
782
816
  for (idx_t offset = 0; offset < list_size; offset++) {
783
817
  idx_t id = idlist[offset];
@@ -786,46 +820,56 @@ void IndexIVF::reconstruct_n (idx_t i0, idx_t ni, float* recons) const
786
820
  }
787
821
 
788
822
  float* reconstructed = recons + (id - i0) * d;
789
- reconstruct_from_offset (list_no, offset, reconstructed);
823
+ reconstruct_from_offset(list_no, offset, reconstructed);
790
824
  }
791
825
  }
792
826
  }
793
827
 
794
-
795
828
  /* standalone codec interface */
796
- size_t IndexIVF::sa_code_size () const
797
- {
829
+ size_t IndexIVF::sa_code_size() const {
798
830
  size_t coarse_size = coarse_code_size();
799
831
  return code_size + coarse_size;
800
832
  }
801
833
 
802
- void IndexIVF::sa_encode (idx_t n, const float *x,
803
- uint8_t *bytes) const
804
- {
805
- FAISS_THROW_IF_NOT (is_trained);
806
- std::unique_ptr<int64_t []> idx (new int64_t [n]);
807
- quantizer->assign (n, x, idx.get());
808
- encode_vectors (n, x, idx.get(), bytes, true);
834
+ void IndexIVF::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
835
+ FAISS_THROW_IF_NOT(is_trained);
836
+ std::unique_ptr<int64_t[]> idx(new int64_t[n]);
837
+ quantizer->assign(n, x, idx.get());
838
+ encode_vectors(n, x, idx.get(), bytes, true);
809
839
  }
810
840
 
841
+ void IndexIVF::search_and_reconstruct(
842
+ idx_t n,
843
+ const float* x,
844
+ idx_t k,
845
+ float* distances,
846
+ idx_t* labels,
847
+ float* recons) const {
848
+ FAISS_THROW_IF_NOT(k > 0);
811
849
 
812
- void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
813
- float *distances, idx_t *labels,
814
- float *recons) const
815
- {
816
- idx_t * idx = new idx_t [n * nprobe];
817
- ScopeDeleter<idx_t> del (idx);
818
- float * coarse_dis = new float [n * nprobe];
819
- ScopeDeleter<float> del2 (coarse_dis);
850
+ const size_t nprobe = std::min(nlist, this->nprobe);
851
+ FAISS_THROW_IF_NOT(nprobe > 0);
820
852
 
821
- quantizer->search (n, x, nprobe, coarse_dis, idx);
853
+ idx_t* idx = new idx_t[n * nprobe];
854
+ ScopeDeleter<idx_t> del(idx);
855
+ float* coarse_dis = new float[n * nprobe];
856
+ ScopeDeleter<float> del2(coarse_dis);
822
857
 
823
- invlists->prefetch_lists (idx, n * nprobe);
858
+ quantizer->search(n, x, nprobe, coarse_dis, idx);
859
+
860
+ invlists->prefetch_lists(idx, n * nprobe);
824
861
 
825
862
  // search_preassigned() with `store_pairs` enabled to obtain the list_no
826
863
  // and offset into `codes` for reconstruction
827
- search_preassigned (n, x, k, idx, coarse_dis,
828
- distances, labels, true /* store_pairs */);
864
+ search_preassigned(
865
+ n,
866
+ x,
867
+ k,
868
+ idx,
869
+ coarse_dis,
870
+ distances,
871
+ labels,
872
+ true /* store_pairs */);
829
873
  for (idx_t i = 0; i < n; ++i) {
830
874
  for (idx_t j = 0; j < k; ++j) {
831
875
  idx_t ij = i * k + j;
@@ -835,165 +879,151 @@ void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
835
879
  // Fill with NaNs
836
880
  memset(reconstructed, -1, sizeof(*reconstructed) * d);
837
881
  } else {
838
- int list_no = lo_listno (key);
839
- int offset = lo_offset (key);
882
+ int list_no = lo_listno(key);
883
+ int offset = lo_offset(key);
840
884
 
841
885
  // Update label to the actual id
842
- labels[ij] = invlists->get_single_id (list_no, offset);
886
+ labels[ij] = invlists->get_single_id(list_no, offset);
843
887
 
844
- reconstruct_from_offset (list_no, offset, reconstructed);
888
+ reconstruct_from_offset(list_no, offset, reconstructed);
845
889
  }
846
890
  }
847
891
  }
848
892
  }
849
893
 
850
894
  void IndexIVF::reconstruct_from_offset(
851
- int64_t /*list_no*/,
852
- int64_t /*offset*/,
853
- float* /*recons*/) const {
854
- FAISS_THROW_MSG ("reconstruct_from_offset not implemented");
895
+ int64_t /*list_no*/,
896
+ int64_t /*offset*/,
897
+ float* /*recons*/) const {
898
+ FAISS_THROW_MSG("reconstruct_from_offset not implemented");
855
899
  }
856
900
 
857
- void IndexIVF::reset ()
858
- {
859
- direct_map.clear ();
860
- invlists->reset ();
901
+ void IndexIVF::reset() {
902
+ direct_map.clear();
903
+ invlists->reset();
861
904
  ntotal = 0;
862
905
  }
863
906
 
864
-
865
- size_t IndexIVF::remove_ids (const IDSelector & sel)
866
- {
867
- size_t nremove = direct_map.remove_ids (sel, invlists);
907
+ size_t IndexIVF::remove_ids(const IDSelector& sel) {
908
+ size_t nremove = direct_map.remove_ids(sel, invlists);
868
909
  ntotal -= nremove;
869
910
  return nremove;
870
911
  }
871
912
 
872
-
873
- void IndexIVF::update_vectors (int n, const idx_t *new_ids, const float *x)
874
- {
875
-
913
+ void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) {
876
914
  if (direct_map.type == DirectMap::Hashtable) {
877
915
  // just remove then add
878
916
  IDSelectorArray sel(n, new_ids);
879
- size_t nremove = remove_ids (sel);
880
- FAISS_THROW_IF_NOT_MSG (nremove == n,
881
- "did not find all entries to remove");
882
- add_with_ids (n, x, new_ids);
917
+ size_t nremove = remove_ids(sel);
918
+ FAISS_THROW_IF_NOT_MSG(
919
+ nremove == n, "did not find all entries to remove");
920
+ add_with_ids(n, x, new_ids);
883
921
  return;
884
922
  }
885
923
 
886
- FAISS_THROW_IF_NOT (direct_map.type == DirectMap::Array);
924
+ FAISS_THROW_IF_NOT(direct_map.type == DirectMap::Array);
887
925
  // here it is more tricky because we don't want to introduce holes
888
926
  // in continuous range of ids
889
927
 
890
- FAISS_THROW_IF_NOT (is_trained);
891
- std::vector<idx_t> assign (n);
892
- quantizer->assign (n, x, assign.data());
928
+ FAISS_THROW_IF_NOT(is_trained);
929
+ std::vector<idx_t> assign(n);
930
+ quantizer->assign(n, x, assign.data());
893
931
 
894
- std::vector<uint8_t> flat_codes (n * code_size);
895
- encode_vectors (n, x, assign.data(), flat_codes.data());
896
-
897
- direct_map.update_codes (invlists, n, new_ids, assign.data(), flat_codes.data());
932
+ std::vector<uint8_t> flat_codes(n * code_size);
933
+ encode_vectors(n, x, assign.data(), flat_codes.data());
898
934
 
935
+ direct_map.update_codes(
936
+ invlists, n, new_ids, assign.data(), flat_codes.data());
899
937
  }
900
938
 
901
-
902
-
903
-
904
- void IndexIVF::train (idx_t n, const float *x)
905
- {
939
+ void IndexIVF::train(idx_t n, const float* x) {
906
940
  if (verbose)
907
- printf ("Training level-1 quantizer\n");
941
+ printf("Training level-1 quantizer\n");
908
942
 
909
- train_q1 (n, x, verbose, metric_type);
943
+ train_q1(n, x, verbose, metric_type);
910
944
 
911
945
  if (verbose)
912
- printf ("Training IVF residual\n");
946
+ printf("Training IVF residual\n");
913
947
 
914
- train_residual (n, x);
948
+ train_residual(n, x);
915
949
  is_trained = true;
916
-
917
950
  }
918
951
 
919
952
  void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
920
- if (verbose)
921
- printf("IndexIVF: no residual training\n");
922
- // does nothing by default
953
+ if (verbose)
954
+ printf("IndexIVF: no residual training\n");
955
+ // does nothing by default
923
956
  }
924
957
 
925
-
926
- void IndexIVF::check_compatible_for_merge (const IndexIVF &other) const
927
- {
958
+ void IndexIVF::check_compatible_for_merge(const IndexIVF& other) const {
928
959
  // minimal sanity checks
929
- FAISS_THROW_IF_NOT (other.d == d);
930
- FAISS_THROW_IF_NOT (other.nlist == nlist);
931
- FAISS_THROW_IF_NOT (other.code_size == code_size);
932
- FAISS_THROW_IF_NOT_MSG (typeid (*this) == typeid (other),
933
- "can only merge indexes of the same type");
934
- FAISS_THROW_IF_NOT_MSG (this->direct_map.no() && other.direct_map.no(),
935
- "merge direct_map not implemented");
960
+ FAISS_THROW_IF_NOT(other.d == d);
961
+ FAISS_THROW_IF_NOT(other.nlist == nlist);
962
+ FAISS_THROW_IF_NOT(other.code_size == code_size);
963
+ FAISS_THROW_IF_NOT_MSG(
964
+ typeid(*this) == typeid(other),
965
+ "can only merge indexes of the same type");
966
+ FAISS_THROW_IF_NOT_MSG(
967
+ this->direct_map.no() && other.direct_map.no(),
968
+ "merge direct_map not implemented");
936
969
  }
937
970
 
971
+ void IndexIVF::merge_from(IndexIVF& other, idx_t add_id) {
972
+ check_compatible_for_merge(other);
938
973
 
939
- void IndexIVF::merge_from (IndexIVF &other, idx_t add_id)
940
- {
941
- check_compatible_for_merge (other);
942
-
943
- invlists->merge_from (other.invlists, add_id);
974
+ invlists->merge_from(other.invlists, add_id);
944
975
 
945
976
  ntotal += other.ntotal;
946
977
  other.ntotal = 0;
947
978
  }
948
979
 
949
-
950
- void IndexIVF::replace_invlists (InvertedLists *il, bool own)
951
- {
980
+ void IndexIVF::replace_invlists(InvertedLists* il, bool own) {
952
981
  if (own_invlists) {
953
982
  delete invlists;
954
983
  invlists = nullptr;
955
984
  }
956
985
  // FAISS_THROW_IF_NOT (ntotal == 0);
957
986
  if (il) {
958
- FAISS_THROW_IF_NOT (il->nlist == nlist);
959
- FAISS_THROW_IF_NOT (
960
- il->code_size == code_size ||
961
- il->code_size == InvertedLists::INVALID_CODE_SIZE
962
- );
987
+ FAISS_THROW_IF_NOT(il->nlist == nlist);
988
+ FAISS_THROW_IF_NOT(
989
+ il->code_size == code_size ||
990
+ il->code_size == InvertedLists::INVALID_CODE_SIZE);
963
991
  }
964
992
  invlists = il;
965
993
  own_invlists = own;
966
994
  }
967
995
 
968
-
969
- void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
970
- idx_t a1, idx_t a2) const
971
- {
972
-
973
- FAISS_THROW_IF_NOT (nlist == other.nlist);
974
- FAISS_THROW_IF_NOT (code_size == other.code_size);
975
- FAISS_THROW_IF_NOT (other.direct_map.no());
976
- FAISS_THROW_IF_NOT_FMT (
977
- subset_type == 0 || subset_type == 1 || subset_type == 2,
978
- "subset type %d not implemented", subset_type);
996
+ void IndexIVF::copy_subset_to(
997
+ IndexIVF& other,
998
+ int subset_type,
999
+ idx_t a1,
1000
+ idx_t a2) const {
1001
+ FAISS_THROW_IF_NOT(nlist == other.nlist);
1002
+ FAISS_THROW_IF_NOT(code_size == other.code_size);
1003
+ FAISS_THROW_IF_NOT(other.direct_map.no());
1004
+ FAISS_THROW_IF_NOT_FMT(
1005
+ subset_type == 0 || subset_type == 1 || subset_type == 2,
1006
+ "subset type %d not implemented",
1007
+ subset_type);
979
1008
 
980
1009
  size_t accu_n = 0;
981
1010
  size_t accu_a1 = 0;
982
1011
  size_t accu_a2 = 0;
983
1012
 
984
- InvertedLists *oivf = other.invlists;
1013
+ InvertedLists* oivf = other.invlists;
985
1014
 
986
1015
  for (idx_t list_no = 0; list_no < nlist; list_no++) {
987
- size_t n = invlists->list_size (list_no);
988
- ScopedIds ids_in (invlists, list_no);
1016
+ size_t n = invlists->list_size(list_no);
1017
+ ScopedIds ids_in(invlists, list_no);
989
1018
 
990
1019
  if (subset_type == 0) {
991
1020
  for (idx_t i = 0; i < n; i++) {
992
1021
  idx_t id = ids_in[i];
993
1022
  if (a1 <= id && id < a2) {
994
- oivf->add_entry (list_no,
995
- invlists->get_single_id (list_no, i),
996
- ScopedCodes (invlists, list_no, i).get());
1023
+ oivf->add_entry(
1024
+ list_no,
1025
+ invlists->get_single_id(list_no, i),
1026
+ ScopedCodes(invlists, list_no, i).get());
997
1027
  other.ntotal++;
998
1028
  }
999
1029
  }
@@ -1001,9 +1031,10 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
1001
1031
  for (idx_t i = 0; i < n; i++) {
1002
1032
  idx_t id = ids_in[i];
1003
1033
  if (id % a1 == a2) {
1004
- oivf->add_entry (list_no,
1005
- invlists->get_single_id (list_no, i),
1006
- ScopedCodes (invlists, list_no, i).get());
1034
+ oivf->add_entry(
1035
+ list_no,
1036
+ invlists->get_single_id(list_no, i),
1037
+ ScopedCodes(invlists, list_no, i).get());
1007
1038
  other.ntotal++;
1008
1039
  }
1009
1040
  }
@@ -1016,9 +1047,10 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
1016
1047
  size_t i2 = next_accu_a2 - accu_a2;
1017
1048
 
1018
1049
  for (idx_t i = i1; i < i2; i++) {
1019
- oivf->add_entry (list_no,
1020
- invlists->get_single_id (list_no, i),
1021
- ScopedCodes (invlists, list_no, i).get());
1050
+ oivf->add_entry(
1051
+ list_no,
1052
+ invlists->get_single_id(list_no, i),
1053
+ ScopedCodes(invlists, list_no, i).get());
1022
1054
  }
1023
1055
 
1024
1056
  other.ntotal += i2 - i1;
@@ -1028,48 +1060,36 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
1028
1060
  accu_n += n;
1029
1061
  }
1030
1062
  FAISS_ASSERT(accu_n == ntotal);
1031
-
1032
1063
  }
1033
1064
 
1034
-
1035
-
1036
-
1037
- IndexIVF::~IndexIVF()
1038
- {
1065
+ IndexIVF::~IndexIVF() {
1039
1066
  if (own_invlists) {
1040
1067
  delete invlists;
1041
1068
  }
1042
1069
  }
1043
1070
 
1044
-
1045
- void IndexIVFStats::reset()
1046
- {
1047
- memset ((void*)this, 0, sizeof (*this));
1071
+ void IndexIVFStats::reset() {
1072
+ memset((void*)this, 0, sizeof(*this));
1048
1073
  }
1049
1074
 
1050
- void IndexIVFStats::add (const IndexIVFStats & other)
1051
- {
1075
+ void IndexIVFStats::add(const IndexIVFStats& other) {
1052
1076
  nq += other.nq;
1053
1077
  nlist += other.nlist;
1054
1078
  ndis += other.ndis;
1055
1079
  nheap_updates += other.nheap_updates;
1056
1080
  quantization_time += other.quantization_time;
1057
1081
  search_time += other.search_time;
1058
-
1059
1082
  }
1060
1083
 
1061
-
1062
1084
  IndexIVFStats indexIVF_stats;
1063
1085
 
1064
- void InvertedListScanner::scan_codes_range (size_t ,
1065
- const uint8_t *,
1066
- const idx_t *,
1067
- float ,
1068
- RangeQueryResult &) const
1069
- {
1070
- FAISS_THROW_MSG ("scan_codes_range not implemented");
1086
+ void InvertedListScanner::scan_codes_range(
1087
+ size_t,
1088
+ const uint8_t*,
1089
+ const idx_t*,
1090
+ float,
1091
+ RangeQueryResult&) const {
1092
+ FAISS_THROW_MSG("scan_codes_range not implemented");
1071
1093
  }
1072
1094
 
1073
-
1074
-
1075
1095
  } // namespace faiss