faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -11,13 +11,12 @@
11
11
 
12
12
  #include <vector>
13
13
 
14
- #include <faiss/impl/HNSW.h>
15
14
  #include <faiss/IndexFlat.h>
16
15
  #include <faiss/IndexPQ.h>
17
16
  #include <faiss/IndexScalarQuantizer.h>
17
+ #include <faiss/impl/HNSW.h>
18
18
  #include <faiss/utils/utils.h>
19
19
 
20
-
21
20
  namespace faiss {
22
21
 
23
22
  struct IndexHNSW;
@@ -26,9 +25,9 @@ struct ReconstructFromNeighbors {
26
25
  typedef Index::idx_t idx_t;
27
26
  typedef HNSW::storage_idx_t storage_idx_t;
28
27
 
29
- const IndexHNSW & index;
30
- size_t M; // number of neighbors
31
- size_t k; // number of codebook entries
28
+ const IndexHNSW& index;
29
+ size_t M; // number of neighbors
30
+ size_t k; // number of codebook entries
32
31
  size_t nsq; // number of subvectors
33
32
  size_t code_size;
34
33
  int k_reorder; // nb to reorder. -1 = all
@@ -39,35 +38,37 @@ struct ReconstructFromNeighbors {
39
38
  size_t ntotal;
40
39
  size_t d, dsub; // derived values
41
40
 
42
- explicit ReconstructFromNeighbors(const IndexHNSW& index,
43
- size_t k=256, size_t nsq=1);
41
+ explicit ReconstructFromNeighbors(
42
+ const IndexHNSW& index,
43
+ size_t k = 256,
44
+ size_t nsq = 1);
44
45
 
45
46
  /// codes must be added in the correct order and the IndexHNSW
46
47
  /// must be populated and sorted
47
- void add_codes(size_t n, const float *x);
48
+ void add_codes(size_t n, const float* x);
48
49
 
49
- size_t compute_distances(size_t n, const idx_t *shortlist,
50
- const float *query, float *distances) const;
50
+ size_t compute_distances(
51
+ size_t n,
52
+ const idx_t* shortlist,
53
+ const float* query,
54
+ float* distances) const;
51
55
 
52
56
  /// called by add_codes
53
- void estimate_code(const float *x, storage_idx_t i, uint8_t *code) const;
57
+ void estimate_code(const float* x, storage_idx_t i, uint8_t* code) const;
54
58
 
55
59
  /// called by compute_distances
56
- void reconstruct(storage_idx_t i, float *x, float *tmp) const;
60
+ void reconstruct(storage_idx_t i, float* x, float* tmp) const;
57
61
 
58
- void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float *x) const;
62
+ void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float* x) const;
59
63
 
60
64
  /// get the M+1 -by-d table for neighbor coordinates for vector i
61
- void get_neighbor_table(storage_idx_t i, float *out) const;
62
-
65
+ void get_neighbor_table(storage_idx_t i, float* out) const;
63
66
  };
64
67
 
65
-
66
68
  /** The HNSW index is a normal random-access index with a HNSW
67
69
  * link structure built on top */
68
70
 
69
71
  struct IndexHNSW : Index {
70
-
71
72
  typedef HNSW::storage_idx_t storage_idx_t;
72
73
 
73
74
  // the link strcuture
@@ -75,27 +76,31 @@ struct IndexHNSW : Index {
75
76
 
76
77
  // the sequential storage
77
78
  bool own_fields;
78
- Index *storage;
79
+ Index* storage;
79
80
 
80
- ReconstructFromNeighbors *reconstruct_from_neighbors;
81
+ ReconstructFromNeighbors* reconstruct_from_neighbors;
81
82
 
82
- explicit IndexHNSW (int d = 0, int M = 32, MetricType metric = METRIC_L2);
83
- explicit IndexHNSW (Index *storage, int M = 32);
83
+ explicit IndexHNSW(int d = 0, int M = 32, MetricType metric = METRIC_L2);
84
+ explicit IndexHNSW(Index* storage, int M = 32);
84
85
 
85
86
  ~IndexHNSW() override;
86
87
 
87
- void add(idx_t n, const float *x) override;
88
+ void add(idx_t n, const float* x) override;
88
89
 
89
90
  /// Trains the storage if needed
90
91
  void train(idx_t n, const float* x) override;
91
92
 
92
93
  /// entry point for search
93
- void search (idx_t n, const float *x, idx_t k,
94
- float *distances, idx_t *labels) const override;
94
+ void search(
95
+ idx_t n,
96
+ const float* x,
97
+ idx_t k,
98
+ float* distances,
99
+ idx_t* labels) const override;
95
100
 
96
101
  void reconstruct(idx_t key, float* recons) const override;
97
102
 
98
- void reset () override;
103
+ void reset() override;
99
104
 
100
105
  void shrink_level_0_neighbors(int size);
101
106
 
@@ -105,19 +110,25 @@ struct IndexHNSW : Index {
105
110
  * @param search_type 1:perform one search per nprobe, 2: enqueue
106
111
  * all entry points
107
112
  */
108
- void search_level_0(idx_t n, const float *x, idx_t k,
109
- const storage_idx_t *nearest, const float *nearest_d,
110
- float *distances, idx_t *labels, int nprobe = 1,
111
- int search_type = 1) const;
113
+ void search_level_0(
114
+ idx_t n,
115
+ const float* x,
116
+ idx_t k,
117
+ const storage_idx_t* nearest,
118
+ const float* nearest_d,
119
+ float* distances,
120
+ idx_t* labels,
121
+ int nprobe = 1,
122
+ int search_type = 1) const;
112
123
 
113
124
  /// alternative graph building
114
- void init_level_0_from_knngraph(
115
- int k, const float *D, const idx_t *I);
125
+ void init_level_0_from_knngraph(int k, const float* D, const idx_t* I);
116
126
 
117
127
  /// alternative graph building
118
128
  void init_level_0_from_entry_points(
119
- int npt, const storage_idx_t *points,
120
- const storage_idx_t *nearests);
129
+ int npt,
130
+ const storage_idx_t* points,
131
+ const storage_idx_t* nearests);
121
132
 
122
133
  // reorder links from nearest to farthest
123
134
  void reorder_links();
@@ -125,7 +136,6 @@ struct IndexHNSW : Index {
125
136
  void link_singletons();
126
137
  };
127
138
 
128
-
129
139
  /** Flat index topped with with a HNSW structure to access elements
130
140
  * more efficiently.
131
141
  */
@@ -149,22 +159,28 @@ struct IndexHNSWPQ : IndexHNSW {
149
159
  */
150
160
  struct IndexHNSWSQ : IndexHNSW {
151
161
  IndexHNSWSQ();
152
- IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M, MetricType metric = METRIC_L2);
162
+ IndexHNSWSQ(
163
+ int d,
164
+ ScalarQuantizer::QuantizerType qtype,
165
+ int M,
166
+ MetricType metric = METRIC_L2);
153
167
  };
154
168
 
155
169
  /** 2-level code structure with fast random access
156
170
  */
157
171
  struct IndexHNSW2Level : IndexHNSW {
158
172
  IndexHNSW2Level();
159
- IndexHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M);
173
+ IndexHNSW2Level(Index* quantizer, size_t nlist, int m_pq, int M);
160
174
 
161
175
  void flip_to_ivf();
162
176
 
163
177
  /// entry point for search
164
- void search (idx_t n, const float *x, idx_t k,
165
- float *distances, idx_t *labels) const override;
166
-
178
+ void search(
179
+ idx_t n,
180
+ const float* x,
181
+ idx_t k,
182
+ float* distances,
183
+ idx_t* labels) const override;
167
184
  };
168
185
 
169
-
170
- } // namespace faiss
186
+ } // namespace faiss
@@ -9,7 +9,6 @@
9
9
 
10
10
  #include <faiss/IndexIVF.h>
11
11
 
12
-
13
12
  #include <omp.h>
14
13
  #include <mutex>
15
14
 
@@ -18,12 +17,12 @@
18
17
  #include <cstdio>
19
18
  #include <memory>
20
19
 
21
- #include <faiss/utils/utils.h>
22
20
  #include <faiss/utils/hamming.h>
21
+ #include <faiss/utils/utils.h>
23
22
 
24
- #include <faiss/impl/FaissAssert.h>
25
23
  #include <faiss/IndexFlat.h>
26
24
  #include <faiss/impl/AuxIndexStructures.h>
25
+ #include <faiss/impl/FaissAssert.h>
27
26
 
28
27
  namespace faiss {
29
28
 
@@ -34,99 +33,97 @@ using ScopedCodes = InvertedLists::ScopedCodes;
34
33
  * Level1Quantizer implementation
35
34
  ******************************************/
36
35
 
37
-
38
- Level1Quantizer::Level1Quantizer (Index * quantizer, size_t nlist):
39
- quantizer (quantizer),
40
- nlist (nlist),
41
- quantizer_trains_alone (0),
42
- own_fields (false),
43
- clustering_index (nullptr)
44
- {
36
+ Level1Quantizer::Level1Quantizer(Index* quantizer, size_t nlist)
37
+ : quantizer(quantizer),
38
+ nlist(nlist),
39
+ quantizer_trains_alone(0),
40
+ own_fields(false),
41
+ clustering_index(nullptr) {
45
42
  // here we set a low # iterations because this is typically used
46
43
  // for large clusterings (nb this is not used for the MultiIndex,
47
44
  // for which quantizer_trains_alone = true)
48
45
  cp.niter = 10;
49
46
  }
50
47
 
51
- Level1Quantizer::Level1Quantizer ():
52
- quantizer (nullptr),
53
- nlist (0),
54
- quantizer_trains_alone (0), own_fields (false),
55
- clustering_index (nullptr)
56
- {}
48
+ Level1Quantizer::Level1Quantizer()
49
+ : quantizer(nullptr),
50
+ nlist(0),
51
+ quantizer_trains_alone(0),
52
+ own_fields(false),
53
+ clustering_index(nullptr) {}
57
54
 
58
- Level1Quantizer::~Level1Quantizer ()
59
- {
60
- if (own_fields) delete quantizer;
55
+ Level1Quantizer::~Level1Quantizer() {
56
+ if (own_fields)
57
+ delete quantizer;
61
58
  }
62
59
 
63
- void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricType metric_type)
64
- {
60
+ void Level1Quantizer::train_q1(
61
+ size_t n,
62
+ const float* x,
63
+ bool verbose,
64
+ MetricType metric_type) {
65
65
  size_t d = quantizer->d;
66
66
  if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
67
67
  if (verbose)
68
- printf ("IVF quantizer does not need training.\n");
68
+ printf("IVF quantizer does not need training.\n");
69
69
  } else if (quantizer_trains_alone == 1) {
70
70
  if (verbose)
71
- printf ("IVF quantizer trains alone...\n");
72
- quantizer->train (n, x);
71
+ printf("IVF quantizer trains alone...\n");
72
+ quantizer->train(n, x);
73
73
  quantizer->verbose = verbose;
74
- FAISS_THROW_IF_NOT_MSG (quantizer->ntotal == nlist,
75
- "nlist not consistent with quantizer size");
74
+ FAISS_THROW_IF_NOT_MSG(
75
+ quantizer->ntotal == nlist,
76
+ "nlist not consistent with quantizer size");
76
77
  } else if (quantizer_trains_alone == 0) {
77
78
  if (verbose)
78
- printf ("Training level-1 quantizer on %zd vectors in %zdD\n",
79
- n, d);
79
+ printf("Training level-1 quantizer on %zd vectors in %zdD\n", n, d);
80
80
 
81
- Clustering clus (d, nlist, cp);
81
+ Clustering clus(d, nlist, cp);
82
82
  quantizer->reset();
83
83
  if (clustering_index) {
84
- clus.train (n, x, *clustering_index);
85
- quantizer->add (nlist, clus.centroids.data());
84
+ clus.train(n, x, *clustering_index);
85
+ quantizer->add(nlist, clus.centroids.data());
86
86
  } else {
87
- clus.train (n, x, *quantizer);
87
+ clus.train(n, x, *quantizer);
88
88
  }
89
89
  quantizer->is_trained = true;
90
90
  } else if (quantizer_trains_alone == 2) {
91
91
  if (verbose) {
92
- printf (
93
- "Training L2 quantizer on %zd vectors in %zdD%s\n",
94
- n, d,
95
- clustering_index ? "(user provided index)" : "");
92
+ printf("Training L2 quantizer on %zd vectors in %zdD%s\n",
93
+ n,
94
+ d,
95
+ clustering_index ? "(user provided index)" : "");
96
96
  }
97
97
  // also accept spherical centroids because in that case
98
98
  // L2 and IP are equivalent
99
- FAISS_THROW_IF_NOT (
100
- metric_type == METRIC_L2 ||
101
- (metric_type == METRIC_INNER_PRODUCT && cp.spherical)
102
- );
99
+ FAISS_THROW_IF_NOT(
100
+ metric_type == METRIC_L2 ||
101
+ (metric_type == METRIC_INNER_PRODUCT && cp.spherical));
103
102
 
104
- Clustering clus (d, nlist, cp);
103
+ Clustering clus(d, nlist, cp);
105
104
  if (!clustering_index) {
106
- IndexFlatL2 assigner (d);
105
+ IndexFlatL2 assigner(d);
107
106
  clus.train(n, x, assigner);
108
107
  } else {
109
108
  clus.train(n, x, *clustering_index);
110
109
  }
111
110
  if (verbose)
112
- printf ("Adding centroids to quantizer\n");
113
- quantizer->add (nlist, clus.centroids.data());
111
+ printf("Adding centroids to quantizer\n");
112
+ quantizer->add(nlist, clus.centroids.data());
114
113
  }
115
114
  }
116
115
 
117
- size_t Level1Quantizer::coarse_code_size () const
118
- {
116
+ size_t Level1Quantizer::coarse_code_size() const {
119
117
  size_t nl = nlist - 1;
120
118
  size_t nbyte = 0;
121
119
  while (nl > 0) {
122
- nbyte ++;
120
+ nbyte++;
123
121
  nl >>= 8;
124
122
  }
125
123
  return nbyte;
126
124
  }
127
125
 
128
- void Level1Quantizer::encode_listno (Index::idx_t list_no, uint8_t *code) const
129
- {
126
+ void Level1Quantizer::encode_listno(Index::idx_t list_no, uint8_t* code) const {
130
127
  // little endian
131
128
  size_t nl = nlist - 1;
132
129
  while (nl > 0) {
@@ -136,8 +133,7 @@ void Level1Quantizer::encode_listno (Index::idx_t list_no, uint8_t *code) const
136
133
  }
137
134
  }
138
135
 
139
- Index::idx_t Level1Quantizer::decode_listno (const uint8_t *code) const
140
- {
136
+ Index::idx_t Level1Quantizer::decode_listno(const uint8_t* code) const {
141
137
  size_t nl = nlist - 1;
142
138
  int64_t list_no = 0;
143
139
  int nbit = 0;
@@ -146,161 +142,184 @@ Index::idx_t Level1Quantizer::decode_listno (const uint8_t *code) const
146
142
  nbit += 8;
147
143
  nl >>= 8;
148
144
  }
149
- FAISS_THROW_IF_NOT (list_no >= 0 && list_no < nlist);
145
+ FAISS_THROW_IF_NOT(list_no >= 0 && list_no < nlist);
150
146
  return list_no;
151
147
  }
152
148
 
153
-
154
-
155
149
  /*****************************************
156
150
  * IndexIVF implementation
157
151
  ******************************************/
158
152
 
159
-
160
- IndexIVF::IndexIVF (Index * quantizer, size_t d,
161
- size_t nlist, size_t code_size,
162
- MetricType metric):
163
- Index (d, metric),
164
- Level1Quantizer (quantizer, nlist),
165
- invlists (new ArrayInvertedLists (nlist, code_size)),
166
- own_invlists (true),
167
- code_size (code_size),
168
- nprobe (1),
169
- max_codes (0),
170
- parallel_mode (0)
171
- {
172
- FAISS_THROW_IF_NOT (d == quantizer->d);
153
+ IndexIVF::IndexIVF(
154
+ Index* quantizer,
155
+ size_t d,
156
+ size_t nlist,
157
+ size_t code_size,
158
+ MetricType metric)
159
+ : Index(d, metric),
160
+ Level1Quantizer(quantizer, nlist),
161
+ invlists(new ArrayInvertedLists(nlist, code_size)),
162
+ own_invlists(true),
163
+ code_size(code_size),
164
+ nprobe(1),
165
+ max_codes(0),
166
+ parallel_mode(0) {
167
+ FAISS_THROW_IF_NOT(d == quantizer->d);
173
168
  is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
174
169
  // Spherical by default if the metric is inner_product
175
170
  if (metric_type == METRIC_INNER_PRODUCT) {
176
171
  cp.spherical = true;
177
172
  }
178
-
179
173
  }
180
174
 
181
- IndexIVF::IndexIVF ():
182
- invlists (nullptr), own_invlists (false),
183
- code_size (0),
184
- nprobe (1), max_codes (0), parallel_mode (0)
185
- {}
175
+ IndexIVF::IndexIVF()
176
+ : invlists(nullptr),
177
+ own_invlists(false),
178
+ code_size(0),
179
+ nprobe(1),
180
+ max_codes(0),
181
+ parallel_mode(0) {}
186
182
 
187
- void IndexIVF::add (idx_t n, const float * x)
188
- {
189
- add_with_ids (n, x, nullptr);
183
+ void IndexIVF::add(idx_t n, const float* x) {
184
+ add_with_ids(n, x, nullptr);
190
185
  }
191
186
 
187
+ void IndexIVF::add_with_ids(idx_t n, const float* x, const idx_t* xids) {
188
+ std::unique_ptr<idx_t[]> coarse_idx(new idx_t[n]);
189
+ quantizer->assign(n, x, coarse_idx.get());
190
+ add_core(n, x, xids, coarse_idx.get());
191
+ }
192
192
 
193
- void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
194
- {
193
+ void IndexIVF::add_core(
194
+ idx_t n,
195
+ const float* x,
196
+ const idx_t* xids,
197
+ const idx_t* coarse_idx) {
195
198
  // do some blocking to avoid excessive allocs
196
199
  idx_t bs = 65536;
197
200
  if (n > bs) {
198
201
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
199
- idx_t i1 = std::min (n, i0 + bs);
202
+ idx_t i1 = std::min(n, i0 + bs);
200
203
  if (verbose) {
201
- printf(" IndexIVF::add_with_ids %" PRId64 ":%" PRId64 "\n", i0, i1);
204
+ printf(" IndexIVF::add_with_ids %" PRId64 ":%" PRId64 "\n",
205
+ i0,
206
+ i1);
202
207
  }
203
- add_with_ids (i1 - i0, x + i0 * d,
204
- xids ? xids + i0 : nullptr);
208
+ add_core(
209
+ i1 - i0,
210
+ x + i0 * d,
211
+ xids ? xids + i0 : nullptr,
212
+ coarse_idx + i0);
205
213
  }
206
214
  return;
207
215
  }
216
+ FAISS_THROW_IF_NOT(coarse_idx);
217
+ FAISS_THROW_IF_NOT(is_trained);
218
+ direct_map.check_can_add(xids);
208
219
 
209
- FAISS_THROW_IF_NOT (is_trained);
210
- direct_map.check_can_add (xids);
211
-
212
- std::unique_ptr<idx_t []> idx(new idx_t[n]);
213
- quantizer->assign (n, x, idx.get());
214
220
  size_t nadd = 0, nminus1 = 0;
215
221
 
216
222
  for (size_t i = 0; i < n; i++) {
217
- if (idx[i] < 0) nminus1++;
223
+ if (coarse_idx[i] < 0)
224
+ nminus1++;
218
225
  }
219
226
 
220
- std::unique_ptr<uint8_t []> flat_codes(new uint8_t [n * code_size]);
221
- encode_vectors (n, x, idx.get(), flat_codes.get());
227
+ std::unique_ptr<uint8_t[]> flat_codes(new uint8_t[n * code_size]);
228
+ encode_vectors(n, x, coarse_idx, flat_codes.get());
222
229
 
223
230
  DirectMapAdd dm_adder(direct_map, n, xids);
224
231
 
225
- #pragma omp parallel reduction(+: nadd)
232
+ #pragma omp parallel reduction(+ : nadd)
226
233
  {
227
234
  int nt = omp_get_num_threads();
228
235
  int rank = omp_get_thread_num();
229
236
 
230
237
  // each thread takes care of a subset of lists
231
238
  for (size_t i = 0; i < n; i++) {
232
- idx_t list_no = idx [i];
239
+ idx_t list_no = coarse_idx[i];
233
240
  if (list_no >= 0 && list_no % nt == rank) {
234
241
  idx_t id = xids ? xids[i] : ntotal + i;
235
- size_t ofs = invlists->add_entry (
236
- list_no, id,
237
- flat_codes.get() + i * code_size
238
- );
242
+ size_t ofs = invlists->add_entry(
243
+ list_no, id, flat_codes.get() + i * code_size);
239
244
 
240
- dm_adder.add (i, list_no, ofs);
245
+ dm_adder.add(i, list_no, ofs);
241
246
 
242
247
  nadd++;
243
248
  } else if (rank == 0 && list_no == -1) {
244
- dm_adder.add (i, -1, 0);
249
+ dm_adder.add(i, -1, 0);
245
250
  }
246
251
  }
247
252
  }
248
253
 
249
-
250
254
  if (verbose) {
251
- printf(" added %zd / %" PRId64 " vectors (%zd -1s)\n", nadd, n, nminus1);
255
+ printf(" added %zd / %" PRId64 " vectors (%zd -1s)\n",
256
+ nadd,
257
+ n,
258
+ nminus1);
252
259
  }
253
260
 
254
261
  ntotal += n;
255
262
  }
256
263
 
257
- void IndexIVF::make_direct_map (bool b)
258
- {
264
+ void IndexIVF::make_direct_map(bool b) {
259
265
  if (b) {
260
- direct_map.set_type (DirectMap::Array, invlists, ntotal);
266
+ direct_map.set_type(DirectMap::Array, invlists, ntotal);
261
267
  } else {
262
- direct_map.set_type (DirectMap::NoMap, invlists, ntotal);
268
+ direct_map.set_type(DirectMap::NoMap, invlists, ntotal);
263
269
  }
264
270
  }
265
271
 
266
-
267
-
268
- void IndexIVF::set_direct_map_type (DirectMap::Type type)
269
- {
270
- direct_map.set_type (type, invlists, ntotal);
272
+ void IndexIVF::set_direct_map_type(DirectMap::Type type) {
273
+ direct_map.set_type(type, invlists, ntotal);
271
274
  }
272
275
 
273
276
  /** It is a sad fact of software that a conceptually simple function like this
274
277
  * becomes very complex when you factor in several ways of parallelizing +
275
278
  * interrupt/error handling + collecting stats + min/max collection. The
276
279
  * codepath that is used 95% of time is the one for parallel_mode = 0 */
277
- void IndexIVF::search (idx_t n, const float *x, idx_t k,
278
- float *distances, idx_t *labels) const
279
- {
280
+ void IndexIVF::search(
281
+ idx_t n,
282
+ const float* x,
283
+ idx_t k,
284
+ float* distances,
285
+ idx_t* labels) const {
286
+ FAISS_THROW_IF_NOT(k > 0);
280
287
 
288
+ const size_t nprobe = std::min(nlist, this->nprobe);
289
+ FAISS_THROW_IF_NOT(nprobe > 0);
281
290
 
282
291
  // search function for a subset of queries
283
- auto sub_search_func = [this, k]
284
- (idx_t n, const float *x, float *distances, idx_t *labels,
285
- IndexIVFStats *ivf_stats) {
286
-
292
+ auto sub_search_func = [this, k, nprobe](
293
+ idx_t n,
294
+ const float* x,
295
+ float* distances,
296
+ idx_t* labels,
297
+ IndexIVFStats* ivf_stats) {
287
298
  std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
288
299
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
289
300
 
290
301
  double t0 = getmillisecs();
291
- quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
302
+ quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());
292
303
 
293
304
  double t1 = getmillisecs();
294
- invlists->prefetch_lists (idx.get(), n * nprobe);
295
-
296
- search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
297
- distances, labels, false, nullptr, ivf_stats);
305
+ invlists->prefetch_lists(idx.get(), n * nprobe);
306
+
307
+ search_preassigned(
308
+ n,
309
+ x,
310
+ k,
311
+ idx.get(),
312
+ coarse_dis.get(),
313
+ distances,
314
+ labels,
315
+ false,
316
+ nullptr,
317
+ ivf_stats);
298
318
  double t2 = getmillisecs();
299
319
  ivf_stats->quantization_time += t1 - t0;
300
320
  ivf_stats->search_time += t2 - t0;
301
321
  };
302
322
 
303
-
304
323
  if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
305
324
  int nt = std::min(omp_get_max_threads(), int(n));
306
325
  std::vector<IndexIVFStats> stats(nt);
@@ -308,18 +327,19 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k,
308
327
  std::string exception_string;
309
328
 
310
329
  #pragma omp parallel for if (nt > 1)
311
- for(idx_t slice = 0; slice < nt; slice++) {
330
+ for (idx_t slice = 0; slice < nt; slice++) {
312
331
  IndexIVFStats local_stats;
313
332
  idx_t i0 = n * slice / nt;
314
333
  idx_t i1 = n * (slice + 1) / nt;
315
334
  if (i1 > i0) {
316
335
  try {
317
336
  sub_search_func(
318
- i1 - i0, x + i0 * d,
319
- distances + i0 * k, labels + i0 * k,
320
- &stats[slice]
321
- );
322
- } catch(const std::exception & e) {
337
+ i1 - i0,
338
+ x + i0 * d,
339
+ distances + i0 * k,
340
+ labels + i0 * k,
341
+ &stats[slice]);
342
+ } catch (const std::exception& e) {
323
343
  std::lock_guard<std::mutex> lock(exception_mutex);
324
344
  exception_string = e.what();
325
345
  }
@@ -327,32 +347,38 @@ void IndexIVF::search (idx_t n, const float *x, idx_t k,
327
347
  }
328
348
 
329
349
  if (!exception_string.empty()) {
330
- FAISS_THROW_MSG (exception_string.c_str());
350
+ FAISS_THROW_MSG(exception_string.c_str());
331
351
  }
332
352
 
333
353
  // collect stats
334
- for(idx_t slice = 0; slice < nt; slice++) {
354
+ for (idx_t slice = 0; slice < nt; slice++) {
335
355
  indexIVF_stats.add(stats[slice]);
336
356
  }
337
357
  } else {
338
- // handle paralellization at level below (or don't run in parallel at all)
358
+ // handle paralellization at level below (or don't run in parallel at
359
+ // all)
339
360
  sub_search_func(n, x, distances, labels, &indexIVF_stats);
340
361
  }
341
-
342
-
343
362
  }
344
363
 
345
-
346
- void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
347
- const idx_t *keys,
348
- const float *coarse_dis ,
349
- float *distances, idx_t *labels,
350
- bool store_pairs,
351
- const IVFSearchParameters *params,
352
- IndexIVFStats *ivf_stats) const
353
- {
354
- long nprobe = params ? params->nprobe : this->nprobe;
355
- long max_codes = params ? params->max_codes : this->max_codes;
364
+ void IndexIVF::search_preassigned(
365
+ idx_t n,
366
+ const float* x,
367
+ idx_t k,
368
+ const idx_t* keys,
369
+ const float* coarse_dis,
370
+ float* distances,
371
+ idx_t* labels,
372
+ bool store_pairs,
373
+ const IVFSearchParameters* params,
374
+ IndexIVFStats* ivf_stats) const {
375
+ FAISS_THROW_IF_NOT(k > 0);
376
+
377
+ idx_t nprobe = params ? params->nprobe : this->nprobe;
378
+ nprobe = std::min((idx_t)nlist, nprobe);
379
+ FAISS_THROW_IF_NOT(nprobe > 0);
380
+
381
+ idx_t max_codes = params ? params->max_codes : this->max_codes;
356
382
 
357
383
  size_t nlistv = 0, ndis = 0, nheap = 0;
358
384
 
@@ -366,15 +392,15 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
366
392
  int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
367
393
  bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
368
394
 
369
- bool do_parallel = omp_get_max_threads() >= 2 && (
370
- pmode == 0 ? false :
371
- pmode == 3 ? n > 1 :
372
- pmode == 1 ? nprobe > 1 :
373
- nprobe * n > 1);
395
+ bool do_parallel = omp_get_max_threads() >= 2 &&
396
+ (pmode == 0 ? false
397
+ : pmode == 3 ? n > 1
398
+ : pmode == 1 ? nprobe > 1
399
+ : nprobe * n > 1);
374
400
 
375
- #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
401
+ #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis, nheap)
376
402
  {
377
- InvertedListScanner *scanner = get_InvertedListScanner(store_pairs);
403
+ InvertedListScanner* scanner = get_InvertedListScanner(store_pairs);
378
404
  ScopeDeleter1<InvertedListScanner> del(scanner);
379
405
 
380
406
  /*****************************************************
@@ -385,49 +411,52 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
385
411
 
386
412
  // intialize + reorder a result heap
387
413
 
388
- auto init_result = [&](float *simi, idx_t *idxi) {
389
- if (!do_heap_init) return;
414
+ auto init_result = [&](float* simi, idx_t* idxi) {
415
+ if (!do_heap_init)
416
+ return;
390
417
  if (metric_type == METRIC_INNER_PRODUCT) {
391
- heap_heapify<HeapForIP> (k, simi, idxi);
418
+ heap_heapify<HeapForIP>(k, simi, idxi);
392
419
  } else {
393
- heap_heapify<HeapForL2> (k, simi, idxi);
420
+ heap_heapify<HeapForL2>(k, simi, idxi);
394
421
  }
395
422
  };
396
423
 
397
- auto add_local_results = [&](
398
- const float * local_dis, const idx_t * local_idx,
399
- float *simi, idx_t *idxi)
400
- {
424
+ auto add_local_results = [&](const float* local_dis,
425
+ const idx_t* local_idx,
426
+ float* simi,
427
+ idx_t* idxi) {
401
428
  if (metric_type == METRIC_INNER_PRODUCT) {
402
- heap_addn<HeapForIP>
403
- (k, simi, idxi, local_dis, local_idx, k);
429
+ heap_addn<HeapForIP>(k, simi, idxi, local_dis, local_idx, k);
404
430
  } else {
405
- heap_addn<HeapForL2>
406
- (k, simi, idxi, local_dis, local_idx, k);
431
+ heap_addn<HeapForL2>(k, simi, idxi, local_dis, local_idx, k);
407
432
  }
408
433
  };
409
434
 
410
- auto reorder_result = [&] (float *simi, idx_t *idxi) {
411
- if (!do_heap_init) return;
435
+ auto reorder_result = [&](float* simi, idx_t* idxi) {
436
+ if (!do_heap_init)
437
+ return;
412
438
  if (metric_type == METRIC_INNER_PRODUCT) {
413
- heap_reorder<HeapForIP> (k, simi, idxi);
439
+ heap_reorder<HeapForIP>(k, simi, idxi);
414
440
  } else {
415
- heap_reorder<HeapForL2> (k, simi, idxi);
441
+ heap_reorder<HeapForL2>(k, simi, idxi);
416
442
  }
417
443
  };
418
444
 
419
445
  // single list scan using the current scanner (with query
420
446
  // set porperly) and storing results in simi and idxi
421
- auto scan_one_list = [&] (idx_t key, float coarse_dis_i,
422
- float *simi, idx_t *idxi) {
423
-
447
+ auto scan_one_list = [&](idx_t key,
448
+ float coarse_dis_i,
449
+ float* simi,
450
+ idx_t* idxi) {
424
451
  if (key < 0) {
425
452
  // not enough centroids for multiprobe
426
453
  return (size_t)0;
427
454
  }
428
- FAISS_THROW_IF_NOT_FMT (key < (idx_t) nlist,
429
- "Invalid key=%" PRId64 " nlist=%zd\n",
430
- key, nlist);
455
+ FAISS_THROW_IF_NOT_FMT(
456
+ key < (idx_t)nlist,
457
+ "Invalid key=%" PRId64 " nlist=%zd\n",
458
+ key,
459
+ nlist);
431
460
 
432
461
  size_t list_size = invlists->list_size(key);
433
462
 
@@ -436,28 +465,28 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
436
465
  return (size_t)0;
437
466
  }
438
467
 
439
- scanner->set_list (key, coarse_dis_i);
468
+ scanner->set_list(key, coarse_dis_i);
440
469
 
441
470
  nlistv++;
442
471
 
443
472
  try {
444
- InvertedLists::ScopedCodes scodes (invlists, key);
473
+ InvertedLists::ScopedCodes scodes(invlists, key);
445
474
 
446
475
  std::unique_ptr<InvertedLists::ScopedIds> sids;
447
- const Index::idx_t * ids = nullptr;
476
+ const Index::idx_t* ids = nullptr;
448
477
 
449
- if (!store_pairs) {
450
- sids.reset (new InvertedLists::ScopedIds (invlists, key));
478
+ if (!store_pairs) {
479
+ sids.reset(new InvertedLists::ScopedIds(invlists, key));
451
480
  ids = sids->get();
452
481
  }
453
482
 
454
- nheap += scanner->scan_codes (list_size, scodes.get(),
455
- ids, simi, idxi, k);
483
+ nheap += scanner->scan_codes(
484
+ list_size, scodes.get(), ids, simi, idxi, k);
456
485
 
457
- } catch(const std::exception & e) {
486
+ } catch (const std::exception& e) {
458
487
  std::lock_guard<std::mutex> lock(exception_mutex);
459
488
  exception_string =
460
- demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
489
+ demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
461
490
  interrupt = true;
462
491
  return size_t(0);
463
492
  }
@@ -470,31 +499,28 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
470
499
  ****************************************************/
471
500
 
472
501
  if (pmode == 0 || pmode == 3) {
473
-
474
502
  #pragma omp for
475
503
  for (idx_t i = 0; i < n; i++) {
476
-
477
504
  if (interrupt) {
478
505
  continue;
479
506
  }
480
507
 
481
508
  // loop over queries
482
- scanner->set_query (x + i * d);
483
- float * simi = distances + i * k;
484
- idx_t * idxi = labels + i * k;
509
+ scanner->set_query(x + i * d);
510
+ float* simi = distances + i * k;
511
+ idx_t* idxi = labels + i * k;
485
512
 
486
- init_result (simi, idxi);
513
+ init_result(simi, idxi);
487
514
 
488
- long nscan = 0;
515
+ idx_t nscan = 0;
489
516
 
490
517
  // loop over probes
491
518
  for (size_t ik = 0; ik < nprobe; ik++) {
492
-
493
- nscan += scan_one_list (
494
- keys [i * nprobe + ik],
495
- coarse_dis[i * nprobe + ik],
496
- simi, idxi
497
- );
519
+ nscan += scan_one_list(
520
+ keys[i * nprobe + ik],
521
+ coarse_dis[i * nprobe + ik],
522
+ simi,
523
+ idxi);
498
524
 
499
525
  if (max_codes && nscan >= max_codes) {
500
526
  break;
@@ -502,54 +528,55 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
502
528
  }
503
529
 
504
530
  ndis += nscan;
505
- reorder_result (simi, idxi);
531
+ reorder_result(simi, idxi);
506
532
 
507
- if (InterruptCallback::is_interrupted ()) {
533
+ if (InterruptCallback::is_interrupted()) {
508
534
  interrupt = true;
509
535
  }
510
536
 
511
537
  } // parallel for
512
538
  } else if (pmode == 1) {
513
- std::vector <idx_t> local_idx (k);
514
- std::vector <float> local_dis (k);
539
+ std::vector<idx_t> local_idx(k);
540
+ std::vector<float> local_dis(k);
515
541
 
516
542
  for (size_t i = 0; i < n; i++) {
517
- scanner->set_query (x + i * d);
518
- init_result (local_dis.data(), local_idx.data());
543
+ scanner->set_query(x + i * d);
544
+ init_result(local_dis.data(), local_idx.data());
519
545
 
520
546
  #pragma omp for schedule(dynamic)
521
- for (long ik = 0; ik < nprobe; ik++) {
522
- ndis += scan_one_list
523
- (keys [i * nprobe + ik],
524
- coarse_dis[i * nprobe + ik],
525
- local_dis.data(), local_idx.data());
547
+ for (idx_t ik = 0; ik < nprobe; ik++) {
548
+ ndis += scan_one_list(
549
+ keys[i * nprobe + ik],
550
+ coarse_dis[i * nprobe + ik],
551
+ local_dis.data(),
552
+ local_idx.data());
526
553
 
527
554
  // can't do the test on max_codes
528
555
  }
529
556
  // merge thread-local results
530
557
 
531
- float * simi = distances + i * k;
532
- idx_t * idxi = labels + i * k;
558
+ float* simi = distances + i * k;
559
+ idx_t* idxi = labels + i * k;
533
560
  #pragma omp single
534
- init_result (simi, idxi);
561
+ init_result(simi, idxi);
535
562
 
536
563
  #pragma omp barrier
537
564
  #pragma omp critical
538
565
  {
539
- add_local_results (local_dis.data(), local_idx.data(),
540
- simi, idxi);
566
+ add_local_results(
567
+ local_dis.data(), local_idx.data(), simi, idxi);
541
568
  }
542
569
  #pragma omp barrier
543
570
  #pragma omp single
544
- reorder_result (simi, idxi);
571
+ reorder_result(simi, idxi);
545
572
  }
546
573
  } else if (pmode == 2) {
547
- std::vector <idx_t> local_idx (k);
548
- std::vector <float> local_dis (k);
574
+ std::vector<idx_t> local_idx(k);
575
+ std::vector<float> local_dis(k);
549
576
 
550
577
  #pragma omp single
551
578
  for (int64_t i = 0; i < n; i++) {
552
- init_result (distances + i * k, labels + i * k);
579
+ init_result(distances + i * k, labels + i * k);
553
580
  }
554
581
 
555
582
  #pragma omp for schedule(dynamic)
@@ -557,33 +584,37 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
557
584
  size_t i = ij / nprobe;
558
585
  size_t j = ij % nprobe;
559
586
 
560
- scanner->set_query (x + i * d);
561
- init_result (local_dis.data(), local_idx.data());
562
- ndis += scan_one_list (
563
- keys [ij], coarse_dis[ij],
564
- local_dis.data(), local_idx.data());
587
+ scanner->set_query(x + i * d);
588
+ init_result(local_dis.data(), local_idx.data());
589
+ ndis += scan_one_list(
590
+ keys[ij],
591
+ coarse_dis[ij],
592
+ local_dis.data(),
593
+ local_idx.data());
565
594
  #pragma omp critical
566
595
  {
567
- add_local_results (local_dis.data(), local_idx.data(),
568
- distances + i * k, labels + i * k);
596
+ add_local_results(
597
+ local_dis.data(),
598
+ local_idx.data(),
599
+ distances + i * k,
600
+ labels + i * k);
569
601
  }
570
602
  }
571
603
  #pragma omp single
572
604
  for (int64_t i = 0; i < n; i++) {
573
- reorder_result (distances + i * k, labels + i * k);
605
+ reorder_result(distances + i * k, labels + i * k);
574
606
  }
575
607
  } else {
576
- FAISS_THROW_FMT ("parallel_mode %d not supported\n",
577
- pmode);
608
+ FAISS_THROW_FMT("parallel_mode %d not supported\n", pmode);
578
609
  }
579
610
  } // parallel section
580
611
 
581
612
  if (interrupt) {
582
613
  if (!exception_string.empty()) {
583
- FAISS_THROW_FMT ("search interrupted with: %s",
584
- exception_string.c_str());
614
+ FAISS_THROW_FMT(
615
+ "search interrupted with: %s", exception_string.c_str());
585
616
  } else {
586
- FAISS_THROW_MSG ("computation interrupted");
617
+ FAISS_THROW_MSG("computation interrupted");
587
618
  }
588
619
  }
589
620
 
@@ -595,38 +626,49 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
595
626
  }
596
627
  }
597
628
 
598
-
599
-
600
-
601
- void IndexIVF::range_search (idx_t nx, const float *x, float radius,
602
- RangeSearchResult *result) const
603
- {
604
- std::unique_ptr<idx_t[]> keys (new idx_t[nx * nprobe]);
605
- std::unique_ptr<float []> coarse_dis (new float[nx * nprobe]);
629
+ void IndexIVF::range_search(
630
+ idx_t nx,
631
+ const float* x,
632
+ float radius,
633
+ RangeSearchResult* result) const {
634
+ const size_t nprobe = std::min(nlist, this->nprobe);
635
+ std::unique_ptr<idx_t[]> keys(new idx_t[nx * nprobe]);
636
+ std::unique_ptr<float[]> coarse_dis(new float[nx * nprobe]);
606
637
 
607
638
  double t0 = getmillisecs();
608
- quantizer->search (nx, x, nprobe, coarse_dis.get (), keys.get ());
639
+ quantizer->search(nx, x, nprobe, coarse_dis.get(), keys.get());
609
640
  indexIVF_stats.quantization_time += getmillisecs() - t0;
610
641
 
611
642
  t0 = getmillisecs();
612
- invlists->prefetch_lists (keys.get(), nx * nprobe);
613
-
614
- range_search_preassigned (nx, x, radius, keys.get (), coarse_dis.get (),
615
- result, false, nullptr, &indexIVF_stats);
643
+ invlists->prefetch_lists(keys.get(), nx * nprobe);
644
+
645
+ range_search_preassigned(
646
+ nx,
647
+ x,
648
+ radius,
649
+ keys.get(),
650
+ coarse_dis.get(),
651
+ result,
652
+ false,
653
+ nullptr,
654
+ &indexIVF_stats);
616
655
 
617
656
  indexIVF_stats.search_time += getmillisecs() - t0;
618
657
  }
619
658
 
620
- void IndexIVF::range_search_preassigned (
621
- idx_t nx, const float *x, float radius,
622
- const idx_t *keys, const float *coarse_dis,
623
- RangeSearchResult *result,
624
- bool store_pairs,
625
- const IVFSearchParameters *params,
626
- IndexIVFStats *stats) const
627
- {
628
- long nprobe = params ? params->nprobe : this->nprobe;
629
- long max_codes = params ? params->max_codes : this->max_codes;
659
+ void IndexIVF::range_search_preassigned(
660
+ idx_t nx,
661
+ const float* x,
662
+ float radius,
663
+ const idx_t* keys,
664
+ const float* coarse_dis,
665
+ RangeSearchResult* result,
666
+ bool store_pairs,
667
+ const IVFSearchParameters* params,
668
+ IndexIVFStats* stats) const {
669
+ idx_t nprobe = params ? params->nprobe : this->nprobe;
670
+ nprobe = std::min((idx_t)nlist, nprobe);
671
+ idx_t max_codes = params ? params->max_codes : this->max_codes;
630
672
 
631
673
  size_t nlistv = 0, ndis = 0;
632
674
 
@@ -634,119 +676,116 @@ void IndexIVF::range_search_preassigned (
634
676
  std::mutex exception_mutex;
635
677
  std::string exception_string;
636
678
 
637
- std::vector<RangeSearchPartialResult *> all_pres (omp_get_max_threads());
679
+ std::vector<RangeSearchPartialResult*> all_pres(omp_get_max_threads());
638
680
 
639
681
  int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
640
682
  // don't start parallel section if single query
641
- bool do_parallel = omp_get_max_threads() >= 2 && (
642
- pmode == 3 ? false :
643
- pmode == 0 ? nx > 1 :
644
- pmode == 1 ? nprobe > 1 :
645
- nprobe * nx > 1);
683
+ bool do_parallel = omp_get_max_threads() >= 2 &&
684
+ (pmode == 3 ? false
685
+ : pmode == 0 ? nx > 1
686
+ : pmode == 1 ? nprobe > 1
687
+ : nprobe * nx > 1);
646
688
 
647
- #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis)
689
+ #pragma omp parallel if (do_parallel) reduction(+ : nlistv, ndis)
648
690
  {
649
691
  RangeSearchPartialResult pres(result);
650
- std::unique_ptr<InvertedListScanner> scanner
651
- (get_InvertedListScanner(store_pairs));
652
- FAISS_THROW_IF_NOT (scanner.get ());
692
+ std::unique_ptr<InvertedListScanner> scanner(
693
+ get_InvertedListScanner(store_pairs));
694
+ FAISS_THROW_IF_NOT(scanner.get());
653
695
  all_pres[omp_get_thread_num()] = &pres;
654
696
 
655
697
  // prepare the list scanning function
656
698
 
657
- auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult &qres) {
658
-
659
- idx_t key = keys[i * nprobe + ik]; /* select the list */
660
- if (key < 0) return;
661
- FAISS_THROW_IF_NOT_FMT (
662
- key < (idx_t) nlist,
663
- "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
664
- key, ik, nlist);
699
+ auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult& qres) {
700
+ idx_t key = keys[i * nprobe + ik]; /* select the list */
701
+ if (key < 0)
702
+ return;
703
+ FAISS_THROW_IF_NOT_FMT(
704
+ key < (idx_t)nlist,
705
+ "Invalid key=%" PRId64 " at ik=%zd nlist=%zd\n",
706
+ key,
707
+ ik,
708
+ nlist);
665
709
  const size_t list_size = invlists->list_size(key);
666
710
 
667
- if (list_size == 0) return;
711
+ if (list_size == 0)
712
+ return;
668
713
 
669
714
  try {
715
+ InvertedLists::ScopedCodes scodes(invlists, key);
716
+ InvertedLists::ScopedIds ids(invlists, key);
670
717
 
671
- InvertedLists::ScopedCodes scodes (invlists, key);
672
- InvertedLists::ScopedIds ids (invlists, key);
673
-
674
- scanner->set_list (key, coarse_dis[i * nprobe + ik]);
718
+ scanner->set_list(key, coarse_dis[i * nprobe + ik]);
675
719
  nlistv++;
676
720
  ndis += list_size;
677
- scanner->scan_codes_range (list_size, scodes.get(),
678
- ids.get(), radius, qres);
721
+ scanner->scan_codes_range(
722
+ list_size, scodes.get(), ids.get(), radius, qres);
679
723
 
680
- } catch(const std::exception & e) {
724
+ } catch (const std::exception& e) {
681
725
  std::lock_guard<std::mutex> lock(exception_mutex);
682
726
  exception_string =
683
- demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
727
+ demangle_cpp_symbol(typeid(e).name()) + " " + e.what();
684
728
  interrupt = true;
685
729
  }
686
-
687
730
  };
688
731
 
689
732
  if (parallel_mode == 0) {
690
-
691
733
  #pragma omp for
692
734
  for (idx_t i = 0; i < nx; i++) {
693
- scanner->set_query (x + i * d);
735
+ scanner->set_query(x + i * d);
694
736
 
695
- RangeQueryResult & qres = pres.new_result (i);
737
+ RangeQueryResult& qres = pres.new_result(i);
696
738
 
697
739
  for (size_t ik = 0; ik < nprobe; ik++) {
698
- scan_list_func (i, ik, qres);
740
+ scan_list_func(i, ik, qres);
699
741
  }
700
-
701
742
  }
702
743
 
703
744
  } else if (parallel_mode == 1) {
704
-
705
745
  for (size_t i = 0; i < nx; i++) {
706
- scanner->set_query (x + i * d);
746
+ scanner->set_query(x + i * d);
707
747
 
708
- RangeQueryResult & qres = pres.new_result (i);
748
+ RangeQueryResult& qres = pres.new_result(i);
709
749
 
710
750
  #pragma omp for schedule(dynamic)
711
751
  for (int64_t ik = 0; ik < nprobe; ik++) {
712
- scan_list_func (i, ik, qres);
752
+ scan_list_func(i, ik, qres);
713
753
  }
714
754
  }
715
755
  } else if (parallel_mode == 2) {
716
- std::vector<RangeQueryResult *> all_qres (nx);
717
- RangeQueryResult *qres = nullptr;
756
+ std::vector<RangeQueryResult*> all_qres(nx);
757
+ RangeQueryResult* qres = nullptr;
718
758
 
719
759
  #pragma omp for schedule(dynamic)
720
760
  for (idx_t iik = 0; iik < nx * (idx_t)nprobe; iik++) {
721
761
  idx_t i = iik / (idx_t)nprobe;
722
762
  idx_t ik = iik % (idx_t)nprobe;
723
763
  if (qres == nullptr || qres->qno != i) {
724
- FAISS_ASSERT (!qres || i > qres->qno);
725
- qres = &pres.new_result (i);
726
- scanner->set_query (x + i * d);
764
+ FAISS_ASSERT(!qres || i > qres->qno);
765
+ qres = &pres.new_result(i);
766
+ scanner->set_query(x + i * d);
727
767
  }
728
- scan_list_func (i, ik, *qres);
768
+ scan_list_func(i, ik, *qres);
729
769
  }
730
770
  } else {
731
- FAISS_THROW_FMT ("parallel_mode %d not supported\n", parallel_mode);
771
+ FAISS_THROW_FMT("parallel_mode %d not supported\n", parallel_mode);
732
772
  }
733
773
  if (parallel_mode == 0) {
734
- pres.finalize ();
774
+ pres.finalize();
735
775
  } else {
736
776
  #pragma omp barrier
737
777
  #pragma omp single
738
- RangeSearchPartialResult::merge (all_pres, false);
778
+ RangeSearchPartialResult::merge(all_pres, false);
739
779
  #pragma omp barrier
740
-
741
780
  }
742
781
  }
743
782
 
744
783
  if (interrupt) {
745
784
  if (!exception_string.empty()) {
746
- FAISS_THROW_FMT ("search interrupted with: %s",
747
- exception_string.c_str());
785
+ FAISS_THROW_FMT(
786
+ "search interrupted with: %s", exception_string.c_str());
748
787
  } else {
749
- FAISS_THROW_MSG ("computation interrupted");
788
+ FAISS_THROW_MSG("computation interrupted");
750
789
  }
751
790
  }
752
791
 
@@ -757,27 +796,22 @@ void IndexIVF::range_search_preassigned (
757
796
  }
758
797
  }
759
798
 
760
-
761
- InvertedListScanner *IndexIVF::get_InvertedListScanner (
762
- bool /*store_pairs*/) const
763
- {
799
+ InvertedListScanner* IndexIVF::get_InvertedListScanner(
800
+ bool /*store_pairs*/) const {
764
801
  return nullptr;
765
802
  }
766
803
 
767
- void IndexIVF::reconstruct (idx_t key, float* recons) const
768
- {
769
- idx_t lo = direct_map.get (key);
770
- reconstruct_from_offset (lo_listno(lo), lo_offset(lo), recons);
804
+ void IndexIVF::reconstruct(idx_t key, float* recons) const {
805
+ idx_t lo = direct_map.get(key);
806
+ reconstruct_from_offset(lo_listno(lo), lo_offset(lo), recons);
771
807
  }
772
808
 
773
-
774
- void IndexIVF::reconstruct_n (idx_t i0, idx_t ni, float* recons) const
775
- {
776
- FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
809
+ void IndexIVF::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
810
+ FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
777
811
 
778
812
  for (idx_t list_no = 0; list_no < nlist; list_no++) {
779
- size_t list_size = invlists->list_size (list_no);
780
- ScopedIds idlist (invlists, list_no);
813
+ size_t list_size = invlists->list_size(list_no);
814
+ ScopedIds idlist(invlists, list_no);
781
815
 
782
816
  for (idx_t offset = 0; offset < list_size; offset++) {
783
817
  idx_t id = idlist[offset];
@@ -786,46 +820,56 @@ void IndexIVF::reconstruct_n (idx_t i0, idx_t ni, float* recons) const
786
820
  }
787
821
 
788
822
  float* reconstructed = recons + (id - i0) * d;
789
- reconstruct_from_offset (list_no, offset, reconstructed);
823
+ reconstruct_from_offset(list_no, offset, reconstructed);
790
824
  }
791
825
  }
792
826
  }
793
827
 
794
-
795
828
  /* standalone codec interface */
796
- size_t IndexIVF::sa_code_size () const
797
- {
829
+ size_t IndexIVF::sa_code_size() const {
798
830
  size_t coarse_size = coarse_code_size();
799
831
  return code_size + coarse_size;
800
832
  }
801
833
 
802
- void IndexIVF::sa_encode (idx_t n, const float *x,
803
- uint8_t *bytes) const
804
- {
805
- FAISS_THROW_IF_NOT (is_trained);
806
- std::unique_ptr<int64_t []> idx (new int64_t [n]);
807
- quantizer->assign (n, x, idx.get());
808
- encode_vectors (n, x, idx.get(), bytes, true);
834
+ void IndexIVF::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
835
+ FAISS_THROW_IF_NOT(is_trained);
836
+ std::unique_ptr<int64_t[]> idx(new int64_t[n]);
837
+ quantizer->assign(n, x, idx.get());
838
+ encode_vectors(n, x, idx.get(), bytes, true);
809
839
  }
810
840
 
841
+ void IndexIVF::search_and_reconstruct(
842
+ idx_t n,
843
+ const float* x,
844
+ idx_t k,
845
+ float* distances,
846
+ idx_t* labels,
847
+ float* recons) const {
848
+ FAISS_THROW_IF_NOT(k > 0);
811
849
 
812
- void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
813
- float *distances, idx_t *labels,
814
- float *recons) const
815
- {
816
- idx_t * idx = new idx_t [n * nprobe];
817
- ScopeDeleter<idx_t> del (idx);
818
- float * coarse_dis = new float [n * nprobe];
819
- ScopeDeleter<float> del2 (coarse_dis);
850
+ const size_t nprobe = std::min(nlist, this->nprobe);
851
+ FAISS_THROW_IF_NOT(nprobe > 0);
820
852
 
821
- quantizer->search (n, x, nprobe, coarse_dis, idx);
853
+ idx_t* idx = new idx_t[n * nprobe];
854
+ ScopeDeleter<idx_t> del(idx);
855
+ float* coarse_dis = new float[n * nprobe];
856
+ ScopeDeleter<float> del2(coarse_dis);
822
857
 
823
- invlists->prefetch_lists (idx, n * nprobe);
858
+ quantizer->search(n, x, nprobe, coarse_dis, idx);
859
+
860
+ invlists->prefetch_lists(idx, n * nprobe);
824
861
 
825
862
  // search_preassigned() with `store_pairs` enabled to obtain the list_no
826
863
  // and offset into `codes` for reconstruction
827
- search_preassigned (n, x, k, idx, coarse_dis,
828
- distances, labels, true /* store_pairs */);
864
+ search_preassigned(
865
+ n,
866
+ x,
867
+ k,
868
+ idx,
869
+ coarse_dis,
870
+ distances,
871
+ labels,
872
+ true /* store_pairs */);
829
873
  for (idx_t i = 0; i < n; ++i) {
830
874
  for (idx_t j = 0; j < k; ++j) {
831
875
  idx_t ij = i * k + j;
@@ -835,165 +879,151 @@ void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
835
879
  // Fill with NaNs
836
880
  memset(reconstructed, -1, sizeof(*reconstructed) * d);
837
881
  } else {
838
- int list_no = lo_listno (key);
839
- int offset = lo_offset (key);
882
+ int list_no = lo_listno(key);
883
+ int offset = lo_offset(key);
840
884
 
841
885
  // Update label to the actual id
842
- labels[ij] = invlists->get_single_id (list_no, offset);
886
+ labels[ij] = invlists->get_single_id(list_no, offset);
843
887
 
844
- reconstruct_from_offset (list_no, offset, reconstructed);
888
+ reconstruct_from_offset(list_no, offset, reconstructed);
845
889
  }
846
890
  }
847
891
  }
848
892
  }
849
893
 
850
894
  void IndexIVF::reconstruct_from_offset(
851
- int64_t /*list_no*/,
852
- int64_t /*offset*/,
853
- float* /*recons*/) const {
854
- FAISS_THROW_MSG ("reconstruct_from_offset not implemented");
895
+ int64_t /*list_no*/,
896
+ int64_t /*offset*/,
897
+ float* /*recons*/) const {
898
+ FAISS_THROW_MSG("reconstruct_from_offset not implemented");
855
899
  }
856
900
 
857
- void IndexIVF::reset ()
858
- {
859
- direct_map.clear ();
860
- invlists->reset ();
901
+ void IndexIVF::reset() {
902
+ direct_map.clear();
903
+ invlists->reset();
861
904
  ntotal = 0;
862
905
  }
863
906
 
864
-
865
- size_t IndexIVF::remove_ids (const IDSelector & sel)
866
- {
867
- size_t nremove = direct_map.remove_ids (sel, invlists);
907
+ size_t IndexIVF::remove_ids(const IDSelector& sel) {
908
+ size_t nremove = direct_map.remove_ids(sel, invlists);
868
909
  ntotal -= nremove;
869
910
  return nremove;
870
911
  }
871
912
 
872
-
873
- void IndexIVF::update_vectors (int n, const idx_t *new_ids, const float *x)
874
- {
875
-
913
+ void IndexIVF::update_vectors(int n, const idx_t* new_ids, const float* x) {
876
914
  if (direct_map.type == DirectMap::Hashtable) {
877
915
  // just remove then add
878
916
  IDSelectorArray sel(n, new_ids);
879
- size_t nremove = remove_ids (sel);
880
- FAISS_THROW_IF_NOT_MSG (nremove == n,
881
- "did not find all entries to remove");
882
- add_with_ids (n, x, new_ids);
917
+ size_t nremove = remove_ids(sel);
918
+ FAISS_THROW_IF_NOT_MSG(
919
+ nremove == n, "did not find all entries to remove");
920
+ add_with_ids(n, x, new_ids);
883
921
  return;
884
922
  }
885
923
 
886
- FAISS_THROW_IF_NOT (direct_map.type == DirectMap::Array);
924
+ FAISS_THROW_IF_NOT(direct_map.type == DirectMap::Array);
887
925
  // here it is more tricky because we don't want to introduce holes
888
926
  // in continuous range of ids
889
927
 
890
- FAISS_THROW_IF_NOT (is_trained);
891
- std::vector<idx_t> assign (n);
892
- quantizer->assign (n, x, assign.data());
928
+ FAISS_THROW_IF_NOT(is_trained);
929
+ std::vector<idx_t> assign(n);
930
+ quantizer->assign(n, x, assign.data());
893
931
 
894
- std::vector<uint8_t> flat_codes (n * code_size);
895
- encode_vectors (n, x, assign.data(), flat_codes.data());
896
-
897
- direct_map.update_codes (invlists, n, new_ids, assign.data(), flat_codes.data());
932
+ std::vector<uint8_t> flat_codes(n * code_size);
933
+ encode_vectors(n, x, assign.data(), flat_codes.data());
898
934
 
935
+ direct_map.update_codes(
936
+ invlists, n, new_ids, assign.data(), flat_codes.data());
899
937
  }
900
938
 
901
-
902
-
903
-
904
- void IndexIVF::train (idx_t n, const float *x)
905
- {
939
+ void IndexIVF::train(idx_t n, const float* x) {
906
940
  if (verbose)
907
- printf ("Training level-1 quantizer\n");
941
+ printf("Training level-1 quantizer\n");
908
942
 
909
- train_q1 (n, x, verbose, metric_type);
943
+ train_q1(n, x, verbose, metric_type);
910
944
 
911
945
  if (verbose)
912
- printf ("Training IVF residual\n");
946
+ printf("Training IVF residual\n");
913
947
 
914
- train_residual (n, x);
948
+ train_residual(n, x);
915
949
  is_trained = true;
916
-
917
950
  }
918
951
 
919
952
  void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
920
- if (verbose)
921
- printf("IndexIVF: no residual training\n");
922
- // does nothing by default
953
+ if (verbose)
954
+ printf("IndexIVF: no residual training\n");
955
+ // does nothing by default
923
956
  }
924
957
 
925
-
926
- void IndexIVF::check_compatible_for_merge (const IndexIVF &other) const
927
- {
958
+ void IndexIVF::check_compatible_for_merge(const IndexIVF& other) const {
928
959
  // minimal sanity checks
929
- FAISS_THROW_IF_NOT (other.d == d);
930
- FAISS_THROW_IF_NOT (other.nlist == nlist);
931
- FAISS_THROW_IF_NOT (other.code_size == code_size);
932
- FAISS_THROW_IF_NOT_MSG (typeid (*this) == typeid (other),
933
- "can only merge indexes of the same type");
934
- FAISS_THROW_IF_NOT_MSG (this->direct_map.no() && other.direct_map.no(),
935
- "merge direct_map not implemented");
960
+ FAISS_THROW_IF_NOT(other.d == d);
961
+ FAISS_THROW_IF_NOT(other.nlist == nlist);
962
+ FAISS_THROW_IF_NOT(other.code_size == code_size);
963
+ FAISS_THROW_IF_NOT_MSG(
964
+ typeid(*this) == typeid(other),
965
+ "can only merge indexes of the same type");
966
+ FAISS_THROW_IF_NOT_MSG(
967
+ this->direct_map.no() && other.direct_map.no(),
968
+ "merge direct_map not implemented");
936
969
  }
937
970
 
971
+ void IndexIVF::merge_from(IndexIVF& other, idx_t add_id) {
972
+ check_compatible_for_merge(other);
938
973
 
939
- void IndexIVF::merge_from (IndexIVF &other, idx_t add_id)
940
- {
941
- check_compatible_for_merge (other);
942
-
943
- invlists->merge_from (other.invlists, add_id);
974
+ invlists->merge_from(other.invlists, add_id);
944
975
 
945
976
  ntotal += other.ntotal;
946
977
  other.ntotal = 0;
947
978
  }
948
979
 
949
-
950
- void IndexIVF::replace_invlists (InvertedLists *il, bool own)
951
- {
980
+ void IndexIVF::replace_invlists(InvertedLists* il, bool own) {
952
981
  if (own_invlists) {
953
982
  delete invlists;
954
983
  invlists = nullptr;
955
984
  }
956
985
  // FAISS_THROW_IF_NOT (ntotal == 0);
957
986
  if (il) {
958
- FAISS_THROW_IF_NOT (il->nlist == nlist);
959
- FAISS_THROW_IF_NOT (
960
- il->code_size == code_size ||
961
- il->code_size == InvertedLists::INVALID_CODE_SIZE
962
- );
987
+ FAISS_THROW_IF_NOT(il->nlist == nlist);
988
+ FAISS_THROW_IF_NOT(
989
+ il->code_size == code_size ||
990
+ il->code_size == InvertedLists::INVALID_CODE_SIZE);
963
991
  }
964
992
  invlists = il;
965
993
  own_invlists = own;
966
994
  }
967
995
 
968
-
969
- void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
970
- idx_t a1, idx_t a2) const
971
- {
972
-
973
- FAISS_THROW_IF_NOT (nlist == other.nlist);
974
- FAISS_THROW_IF_NOT (code_size == other.code_size);
975
- FAISS_THROW_IF_NOT (other.direct_map.no());
976
- FAISS_THROW_IF_NOT_FMT (
977
- subset_type == 0 || subset_type == 1 || subset_type == 2,
978
- "subset type %d not implemented", subset_type);
996
+ void IndexIVF::copy_subset_to(
997
+ IndexIVF& other,
998
+ int subset_type,
999
+ idx_t a1,
1000
+ idx_t a2) const {
1001
+ FAISS_THROW_IF_NOT(nlist == other.nlist);
1002
+ FAISS_THROW_IF_NOT(code_size == other.code_size);
1003
+ FAISS_THROW_IF_NOT(other.direct_map.no());
1004
+ FAISS_THROW_IF_NOT_FMT(
1005
+ subset_type == 0 || subset_type == 1 || subset_type == 2,
1006
+ "subset type %d not implemented",
1007
+ subset_type);
979
1008
 
980
1009
  size_t accu_n = 0;
981
1010
  size_t accu_a1 = 0;
982
1011
  size_t accu_a2 = 0;
983
1012
 
984
- InvertedLists *oivf = other.invlists;
1013
+ InvertedLists* oivf = other.invlists;
985
1014
 
986
1015
  for (idx_t list_no = 0; list_no < nlist; list_no++) {
987
- size_t n = invlists->list_size (list_no);
988
- ScopedIds ids_in (invlists, list_no);
1016
+ size_t n = invlists->list_size(list_no);
1017
+ ScopedIds ids_in(invlists, list_no);
989
1018
 
990
1019
  if (subset_type == 0) {
991
1020
  for (idx_t i = 0; i < n; i++) {
992
1021
  idx_t id = ids_in[i];
993
1022
  if (a1 <= id && id < a2) {
994
- oivf->add_entry (list_no,
995
- invlists->get_single_id (list_no, i),
996
- ScopedCodes (invlists, list_no, i).get());
1023
+ oivf->add_entry(
1024
+ list_no,
1025
+ invlists->get_single_id(list_no, i),
1026
+ ScopedCodes(invlists, list_no, i).get());
997
1027
  other.ntotal++;
998
1028
  }
999
1029
  }
@@ -1001,9 +1031,10 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
1001
1031
  for (idx_t i = 0; i < n; i++) {
1002
1032
  idx_t id = ids_in[i];
1003
1033
  if (id % a1 == a2) {
1004
- oivf->add_entry (list_no,
1005
- invlists->get_single_id (list_no, i),
1006
- ScopedCodes (invlists, list_no, i).get());
1034
+ oivf->add_entry(
1035
+ list_no,
1036
+ invlists->get_single_id(list_no, i),
1037
+ ScopedCodes(invlists, list_no, i).get());
1007
1038
  other.ntotal++;
1008
1039
  }
1009
1040
  }
@@ -1016,9 +1047,10 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
1016
1047
  size_t i2 = next_accu_a2 - accu_a2;
1017
1048
 
1018
1049
  for (idx_t i = i1; i < i2; i++) {
1019
- oivf->add_entry (list_no,
1020
- invlists->get_single_id (list_no, i),
1021
- ScopedCodes (invlists, list_no, i).get());
1050
+ oivf->add_entry(
1051
+ list_no,
1052
+ invlists->get_single_id(list_no, i),
1053
+ ScopedCodes(invlists, list_no, i).get());
1022
1054
  }
1023
1055
 
1024
1056
  other.ntotal += i2 - i1;
@@ -1028,48 +1060,36 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
1028
1060
  accu_n += n;
1029
1061
  }
1030
1062
  FAISS_ASSERT(accu_n == ntotal);
1031
-
1032
1063
  }
1033
1064
 
1034
-
1035
-
1036
-
1037
- IndexIVF::~IndexIVF()
1038
- {
1065
+ IndexIVF::~IndexIVF() {
1039
1066
  if (own_invlists) {
1040
1067
  delete invlists;
1041
1068
  }
1042
1069
  }
1043
1070
 
1044
-
1045
- void IndexIVFStats::reset()
1046
- {
1047
- memset ((void*)this, 0, sizeof (*this));
1071
+ void IndexIVFStats::reset() {
1072
+ memset((void*)this, 0, sizeof(*this));
1048
1073
  }
1049
1074
 
1050
- void IndexIVFStats::add (const IndexIVFStats & other)
1051
- {
1075
+ void IndexIVFStats::add(const IndexIVFStats& other) {
1052
1076
  nq += other.nq;
1053
1077
  nlist += other.nlist;
1054
1078
  ndis += other.ndis;
1055
1079
  nheap_updates += other.nheap_updates;
1056
1080
  quantization_time += other.quantization_time;
1057
1081
  search_time += other.search_time;
1058
-
1059
1082
  }
1060
1083
 
1061
-
1062
1084
  IndexIVFStats indexIVF_stats;
1063
1085
 
1064
- void InvertedListScanner::scan_codes_range (size_t ,
1065
- const uint8_t *,
1066
- const idx_t *,
1067
- float ,
1068
- RangeQueryResult &) const
1069
- {
1070
- FAISS_THROW_MSG ("scan_codes_range not implemented");
1086
+ void InvertedListScanner::scan_codes_range(
1087
+ size_t,
1088
+ const uint8_t*,
1089
+ const idx_t*,
1090
+ float,
1091
+ RangeQueryResult&) const {
1092
+ FAISS_THROW_MSG("scan_codes_range not implemented");
1071
1093
  }
1072
1094
 
1073
-
1074
-
1075
1095
  } // namespace faiss