faiss 0.1.5 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/README.md +12 -0
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +6 -2
  6. data/ext/faiss/index.cpp +114 -43
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +24 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -10,7 +10,6 @@
10
10
  #ifndef FAISS_INDEX_IVFPQ_H
11
11
  #define FAISS_INDEX_IVFPQ_H
12
12
 
13
-
14
13
  #include <vector>
15
14
 
16
15
  #include <faiss/IndexIVF.h>
@@ -20,32 +19,29 @@
20
19
 
21
20
  namespace faiss {
22
21
 
23
- struct IVFPQSearchParameters: IVFSearchParameters {
24
- size_t scan_table_threshold; ///< use table computation or on-the-fly?
25
- int polysemous_ht; ///< Hamming thresh for polysemous filtering
26
- IVFPQSearchParameters (): scan_table_threshold(0), polysemous_ht(0) {}
27
- ~IVFPQSearchParameters () {}
22
+ struct IVFPQSearchParameters : IVFSearchParameters {
23
+ size_t scan_table_threshold; ///< use table computation or on-the-fly?
24
+ int polysemous_ht; ///< Hamming thresh for polysemous filtering
25
+ IVFPQSearchParameters() : scan_table_threshold(0), polysemous_ht(0) {}
26
+ ~IVFPQSearchParameters() {}
28
27
  };
29
28
 
30
-
31
-
32
29
  FAISS_API extern size_t precomputed_table_max_bytes;
33
30
 
34
-
35
31
  /** Inverted file with Product Quantizer encoding. Each residual
36
32
  * vector is encoded as a product quantizer code.
37
33
  */
38
- struct IndexIVFPQ: IndexIVF {
39
- bool by_residual; ///< Encode residual or plain vector?
34
+ struct IndexIVFPQ : IndexIVF {
35
+ bool by_residual; ///< Encode residual or plain vector?
40
36
 
41
- ProductQuantizer pq; ///< produces the codes
37
+ ProductQuantizer pq; ///< produces the codes
42
38
 
43
- bool do_polysemous_training; ///< reorder PQ centroids after training?
44
- PolysemousTraining *polysemous_training; ///< if NULL, use default
39
+ bool do_polysemous_training; ///< reorder PQ centroids after training?
40
+ PolysemousTraining* polysemous_training; ///< if NULL, use default
45
41
 
46
42
  // search-time parameters
47
- size_t scan_table_threshold; ///< use table computation or on-the-fly?
48
- int polysemous_ht; ///< Hamming thresh for polysemous filtering
43
+ size_t scan_table_threshold; ///< use table computation or on-the-fly?
44
+ int polysemous_ht; ///< Hamming thresh for polysemous filtering
49
45
 
50
46
  /** Precompute table that speed up query preprocessing at some
51
47
  * memory cost (used only for by_residual with L2 metric)
@@ -56,37 +52,47 @@ struct IndexIVFPQ: IndexIVF {
56
52
  /// size nlist * pq.M * pq.ksub
57
53
  AlignedTable<float> precomputed_table;
58
54
 
59
- IndexIVFPQ (
60
- Index * quantizer, size_t d, size_t nlist,
61
- size_t M, size_t nbits_per_idx, MetricType metric = METRIC_L2);
62
-
63
- void add_with_ids(idx_t n, const float* x, const idx_t* xids = nullptr)
64
- override;
65
-
66
- void encode_vectors(idx_t n, const float* x,
67
- const idx_t *list_nos,
68
- uint8_t * codes,
69
- bool include_listnos = false) const override;
70
-
71
- void sa_decode (idx_t n, const uint8_t *bytes,
72
- float *x) const override;
73
-
55
+ IndexIVFPQ(
56
+ Index* quantizer,
57
+ size_t d,
58
+ size_t nlist,
59
+ size_t M,
60
+ size_t nbits_per_idx,
61
+ MetricType metric = METRIC_L2);
62
+
63
+ void encode_vectors(
64
+ idx_t n,
65
+ const float* x,
66
+ const idx_t* list_nos,
67
+ uint8_t* codes,
68
+ bool include_listnos = false) const override;
69
+
70
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
71
+
72
+ void add_core(
73
+ idx_t n,
74
+ const float* x,
75
+ const idx_t* xids,
76
+ const idx_t* precomputed_idx) override;
74
77
 
75
78
  /// same as add_core, also:
76
79
  /// - output 2nd level residuals if residuals_2 != NULL
77
- /// - use precomputed list numbers if precomputed_idx != NULL
78
- void add_core_o (idx_t n, const float *x,
79
- const idx_t *xids, float *residuals_2,
80
- const idx_t *precomputed_idx = nullptr);
80
+ /// - accepts precomputed_idx = nullptr
81
+ void add_core_o(
82
+ idx_t n,
83
+ const float* x,
84
+ const idx_t* xids,
85
+ float* residuals_2,
86
+ const idx_t* precomputed_idx = nullptr);
81
87
 
82
88
  /// trains the product quantizer
83
89
  void train_residual(idx_t n, const float* x) override;
84
90
 
85
91
  /// same as train_residual, also output 2nd level residuals
86
- void train_residual_o (idx_t n, const float *x, float *residuals_2);
92
+ void train_residual_o(idx_t n, const float* x, float* residuals_2);
87
93
 
88
- void reconstruct_from_offset (int64_t list_no, int64_t offset,
89
- float* recons) const override;
94
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
95
+ const override;
90
96
 
91
97
  /** Find exact duplicates in the dataset.
92
98
  *
@@ -99,10 +105,10 @@ struct IndexIVFPQ: IndexIVF {
99
105
  * duplicates (max size ntotal)
100
106
  * @return n number of groups found
101
107
  */
102
- size_t find_duplicates (idx_t *ids, size_t *lims) const;
108
+ size_t find_duplicates(idx_t* ids, size_t* lims) const;
103
109
 
104
110
  // map a vector to a binary code knowning the index
105
- void encode (idx_t key, const float * x, uint8_t * code) const;
111
+ void encode(idx_t key, const float* x, uint8_t* code) const;
106
112
 
107
113
  /** Encode multiple vectors
108
114
  *
@@ -113,22 +119,27 @@ struct IndexIVFPQ: IndexIVF {
113
119
  * @param compute_keys if false, assume keys are precomputed,
114
120
  * otherwise compute them
115
121
  */
116
- void encode_multiple (size_t n, idx_t *keys,
117
- const float * x, uint8_t * codes,
118
- bool compute_keys = false) const;
122
+ void encode_multiple(
123
+ size_t n,
124
+ idx_t* keys,
125
+ const float* x,
126
+ uint8_t* codes,
127
+ bool compute_keys = false) const;
119
128
 
120
129
  /// inverse of encode_multiple
121
- void decode_multiple (size_t n, const idx_t *keys,
122
- const uint8_t * xcodes, float * x) const;
130
+ void decode_multiple(
131
+ size_t n,
132
+ const idx_t* keys,
133
+ const uint8_t* xcodes,
134
+ float* x) const;
123
135
 
124
- InvertedListScanner *get_InvertedListScanner (bool store_pairs)
125
- const override;
136
+ InvertedListScanner* get_InvertedListScanner(
137
+ bool store_pairs) const override;
126
138
 
127
139
  /// build precomputed table
128
- void precompute_table ();
129
-
130
- IndexIVFPQ ();
140
+ void precompute_table();
131
141
 
142
+ IndexIVFPQ();
132
143
  };
133
144
 
134
145
  /** Pre-compute distance tables for IVFPQ with by-residual and METRIC_L2
@@ -136,24 +147,23 @@ struct IndexIVFPQ: IndexIVF {
136
147
  * @param use_precomputed_table (I/O)
137
148
  * =-1: force disable
138
149
  * =0: decide heuristically (default: use tables only if they are
139
- * < precomputed_tables_max_bytes), set use_precomputed_table on output
140
- * =1: tables that work for all quantizers (size 256 * nlist * M)
141
- * =2: specific version for MultiIndexQuantizer (much more compact)
150
+ * < precomputed_tables_max_bytes), set use_precomputed_table on
151
+ * output =1: tables that work for all quantizers (size 256 * nlist * M) =2:
152
+ * specific version for MultiIndexQuantizer (much more compact)
142
153
  * @param precomputed_table precomputed table to intialize
143
154
  */
144
155
 
145
156
  void initialize_IVFPQ_precomputed_table(
146
- int &use_precomputed_table,
147
- const Index *quantizer,
148
- const ProductQuantizer &pq,
149
- AlignedTable<float> & precomputed_table,
150
- bool verbose
151
- );
157
+ int& use_precomputed_table,
158
+ const Index* quantizer,
159
+ const ProductQuantizer& pq,
160
+ AlignedTable<float>& precomputed_table,
161
+ bool verbose);
152
162
 
153
163
  /// statistics are robust to internal threading, but not if
154
164
  /// IndexIVFPQ::search_preassigned is called by multiple threads
155
165
  struct IndexIVFPQStats {
156
- size_t nrefine; ///< nb of refines (IVFPQR)
166
+ size_t nrefine; ///< nb of refines (IVFPQR)
157
167
 
158
168
  size_t n_hamming_pass;
159
169
  ///< nb of passed Hamming distance tests (for polysemous)
@@ -162,17 +172,15 @@ struct IndexIVFPQStats {
162
172
  size_t search_cycles;
163
173
  size_t refine_cycles; ///< only for IVFPQR
164
174
 
165
- IndexIVFPQStats () {reset (); }
166
- void reset ();
175
+ IndexIVFPQStats() {
176
+ reset();
177
+ }
178
+ void reset();
167
179
  };
168
180
 
169
181
  // global var that collects them all
170
182
  FAISS_API extern IndexIVFPQStats indexIVFPQ_stats;
171
183
 
172
-
173
-
174
-
175
184
  } // namespace faiss
176
185
 
177
-
178
186
  #endif
@@ -8,70 +8,68 @@
8
8
  #include <faiss/IndexIVFPQFastScan.h>
9
9
 
10
10
  #include <cassert>
11
+ #include <cinttypes>
11
12
  #include <cstdio>
12
- #include <inttypes.h>
13
13
 
14
14
  #include <omp.h>
15
15
 
16
16
  #include <memory>
17
17
 
18
+ #include <faiss/impl/AuxIndexStructures.h>
18
19
  #include <faiss/impl/FaissAssert.h>
19
- #include <faiss/utils/utils.h>
20
20
  #include <faiss/utils/distances.h>
21
21
  #include <faiss/utils/simdlib.h>
22
- #include <faiss/impl/AuxIndexStructures.h>
22
+ #include <faiss/utils/utils.h>
23
23
 
24
24
  #include <faiss/invlists/BlockInvertedLists.h>
25
25
 
26
+ #include <faiss/impl/pq4_fast_scan.h>
26
27
  #include <faiss/impl/simd_result_handlers.h>
27
28
  #include <faiss/utils/quantize_lut.h>
28
- #include <faiss/impl/pq4_fast_scan.h>
29
29
 
30
30
  namespace faiss {
31
31
 
32
32
  using namespace simd_result_handlers;
33
33
 
34
-
35
34
  inline size_t roundup(size_t a, size_t b) {
36
35
  return (a + b - 1) / b * b;
37
36
  }
38
37
 
39
-
40
- IndexIVFPQFastScan::IndexIVFPQFastScan (
41
- Index * quantizer, size_t d, size_t nlist,
42
- size_t M, size_t nbits_per_idx,
43
- MetricType metric, int bbs):
44
- IndexIVF (quantizer, d, nlist, 0, metric),
45
- pq (d, M, nbits_per_idx),
46
- bbs (bbs)
47
- {
38
+ IndexIVFPQFastScan::IndexIVFPQFastScan(
39
+ Index* quantizer,
40
+ size_t d,
41
+ size_t nlist,
42
+ size_t M,
43
+ size_t nbits_per_idx,
44
+ MetricType metric,
45
+ int bbs)
46
+ : IndexIVF(quantizer, d, nlist, 0, metric),
47
+ pq(d, M, nbits_per_idx),
48
+ bbs(bbs) {
48
49
  FAISS_THROW_IF_NOT(nbits_per_idx == 4);
49
50
  M2 = roundup(pq.M, 2);
50
51
  by_residual = false; // set to false by default because it's much faster
51
52
  is_trained = false;
52
53
  code_size = pq.code_size;
53
54
 
54
- replace_invlists(
55
- new BlockInvertedLists(nlist, bbs, bbs * M2 / 2),
56
- true
57
- );
55
+ replace_invlists(new BlockInvertedLists(nlist, bbs, bbs * M2 / 2), true);
58
56
  }
59
57
 
60
- IndexIVFPQFastScan::IndexIVFPQFastScan ()
61
- {
58
+ IndexIVFPQFastScan::IndexIVFPQFastScan() {
62
59
  by_residual = false;
63
60
  bbs = 0;
64
61
  M2 = 0;
65
62
  }
66
63
 
67
-
68
- IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ & orig, int bbs):
69
- IndexIVF(
70
- orig.quantizer, orig.d, orig.nlist,
71
- orig.pq.code_size, orig.metric_type),
72
- pq(orig.pq),
73
- bbs(bbs)
74
- {
64
+ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
65
+ : IndexIVF(
66
+ orig.quantizer,
67
+ orig.d,
68
+ orig.nlist,
69
+ orig.pq.code_size,
70
+ orig.metric_type),
71
+ pq(orig.pq),
72
+ bbs(bbs) {
75
73
  FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
76
74
 
77
75
  by_residual = orig.by_residual;
@@ -83,69 +81,68 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ & orig, int bbs):
83
81
  M2 = roundup(M, 2);
84
82
 
85
83
  replace_invlists(
86
- new BlockInvertedLists(orig.nlist, bbs, bbs * M2 / 2),
87
- true
88
- );
84
+ new BlockInvertedLists(orig.nlist, bbs, bbs * M2 / 2), true);
89
85
 
90
86
  precomputed_table.resize(orig.precomputed_table.size());
91
87
 
92
88
  if (precomputed_table.nbytes() > 0) {
93
- memcpy(precomputed_table.get(), orig.precomputed_table.data(),
94
- precomputed_table.nbytes()
95
- );
89
+ memcpy(precomputed_table.get(),
90
+ orig.precomputed_table.data(),
91
+ precomputed_table.nbytes());
96
92
  }
97
93
 
98
- for(size_t i = 0; i < nlist; i++) {
94
+ for (size_t i = 0; i < nlist; i++) {
99
95
  size_t nb = orig.invlists->list_size(i);
100
96
  size_t nb2 = roundup(nb, bbs);
101
97
  AlignedTable<uint8_t> tmp(nb2 * M2 / 2);
102
98
  pq4_pack_codes(
103
- InvertedLists::ScopedCodes(orig.invlists, i).get(),
104
- nb, M, nb2, bbs, M2,
105
- tmp.get()
106
- );
99
+ InvertedLists::ScopedCodes(orig.invlists, i).get(),
100
+ nb,
101
+ M,
102
+ nb2,
103
+ bbs,
104
+ M2,
105
+ tmp.get());
107
106
  invlists->add_entries(
108
- i, nb,
109
- InvertedLists::ScopedIds(orig.invlists, i).get(),
110
- tmp.get()
111
- );
107
+ i,
108
+ nb,
109
+ InvertedLists::ScopedIds(orig.invlists, i).get(),
110
+ tmp.get());
112
111
  }
113
112
 
114
113
  orig_invlists = orig.invlists;
115
114
  }
116
115
 
117
-
118
-
119
116
  /*********************************************************
120
117
  * Training
121
118
  *********************************************************/
122
119
 
123
- void IndexIVFPQFastScan::train_residual (idx_t n, const float *x_in)
124
- {
120
+ void IndexIVFPQFastScan::train_residual(idx_t n, const float* x_in) {
121
+ const float* x = fvecs_maybe_subsample(
122
+ d,
123
+ (size_t*)&n,
124
+ pq.cp.max_points_per_centroid * pq.ksub,
125
+ x_in,
126
+ verbose,
127
+ pq.cp.seed);
125
128
 
126
- const float * x = fvecs_maybe_subsample (
127
- d, (size_t*)&n, pq.cp.max_points_per_centroid * pq.ksub,
128
- x_in, verbose, pq.cp.seed);
129
-
130
- std::unique_ptr<float []> del_x;
129
+ std::unique_ptr<float[]> del_x;
131
130
  if (x != x_in) {
132
131
  del_x.reset((float*)x);
133
132
  }
134
133
 
135
- const float *trainset;
134
+ const float* trainset;
136
135
  AlignedTable<float> residuals;
137
136
 
138
137
  if (by_residual) {
139
- if(verbose) printf("computing residuals\n");
138
+ if (verbose)
139
+ printf("computing residuals\n");
140
140
  std::vector<idx_t> assign(n);
141
- quantizer->assign (n, x, assign.data());
141
+ quantizer->assign(n, x, assign.data());
142
142
  residuals.resize(n * d);
143
143
  for (idx_t i = 0; i < n; i++) {
144
- quantizer->compute_residual (
145
- x + i * d,
146
- residuals.data() + i * d,
147
- assign[i]
148
- );
144
+ quantizer->compute_residual(
145
+ x + i * d, residuals.data() + i * d, assign[i]);
149
146
  }
150
147
  trainset = residuals.data();
151
148
  } else {
@@ -153,82 +150,78 @@ void IndexIVFPQFastScan::train_residual (idx_t n, const float *x_in)
153
150
  }
154
151
 
155
152
  if (verbose) {
156
- printf ("training %zdx%zd product quantizer on %zd vectors in %dD\n",
157
- pq.M, pq.ksub, long(n), d);
153
+ printf("training %zdx%zd product quantizer on "
154
+ "%" PRId64 " vectors in %dD\n",
155
+ pq.M,
156
+ pq.ksub,
157
+ n,
158
+ d);
158
159
  }
159
160
  pq.verbose = verbose;
160
- pq.train (n, trainset);
161
+ pq.train(n, trainset);
161
162
 
162
163
  if (by_residual && metric_type == METRIC_L2) {
163
164
  precompute_table();
164
165
  }
165
-
166
166
  }
167
167
 
168
- void IndexIVFPQFastScan::precompute_table ()
169
- {
168
+ void IndexIVFPQFastScan::precompute_table() {
170
169
  initialize_IVFPQ_precomputed_table(
171
- use_precomputed_table,
172
- quantizer, pq, precomputed_table, verbose
173
- );
170
+ use_precomputed_table, quantizer, pq, precomputed_table, verbose);
174
171
  }
175
172
 
176
-
177
173
  /*********************************************************
178
174
  * Code management functions
179
175
  *********************************************************/
180
176
 
181
-
182
-
183
177
  void IndexIVFPQFastScan::encode_vectors(
184
- idx_t n, const float* x, const idx_t *list_nos,
185
- uint8_t * codes, bool include_listnos) const
186
- {
187
-
178
+ idx_t n,
179
+ const float* x,
180
+ const idx_t* list_nos,
181
+ uint8_t* codes,
182
+ bool include_listnos) const {
188
183
  if (by_residual) {
189
- AlignedTable<float> residuals (n * d);
184
+ AlignedTable<float> residuals(n * d);
190
185
  for (size_t i = 0; i < n; i++) {
191
186
  if (list_nos[i] < 0) {
192
- memset (residuals.data() + i * d, 0, sizeof(residuals[0]) * d);
187
+ memset(residuals.data() + i * d, 0, sizeof(residuals[0]) * d);
193
188
  } else {
194
- quantizer->compute_residual (
195
- x + i * d, residuals.data() + i * d, list_nos[i]);
189
+ quantizer->compute_residual(
190
+ x + i * d, residuals.data() + i * d, list_nos[i]);
196
191
  }
197
192
  }
198
- pq.compute_codes (residuals.data(), codes, n);
193
+ pq.compute_codes(residuals.data(), codes, n);
199
194
  } else {
200
- pq.compute_codes (x, codes, n);
195
+ pq.compute_codes(x, codes, n);
201
196
  }
202
197
 
203
198
  if (include_listnos) {
204
199
  size_t coarse_size = coarse_code_size();
205
200
  for (idx_t i = n - 1; i >= 0; i--) {
206
- uint8_t * code = codes + i * (coarse_size + code_size);
207
- memmove (code + coarse_size,
208
- codes + i * code_size, code_size);
209
- encode_listno (list_nos[i], code);
201
+ uint8_t* code = codes + i * (coarse_size + code_size);
202
+ memmove(code + coarse_size, codes + i * code_size, code_size);
203
+ encode_listno(list_nos[i], code);
210
204
  }
211
205
  }
212
206
  }
213
207
 
214
-
215
-
216
- void IndexIVFPQFastScan::add_with_ids (
217
- idx_t n, const float * x, const idx_t *xids) {
218
-
208
+ void IndexIVFPQFastScan::add_with_ids(
209
+ idx_t n,
210
+ const float* x,
211
+ const idx_t* xids) {
219
212
  // copied from IndexIVF::add_with_ids --->
220
213
 
221
214
  // do some blocking to avoid excessive allocs
222
215
  idx_t bs = 65536;
223
216
  if (n > bs) {
224
217
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
225
- idx_t i1 = std::min (n, i0 + bs);
218
+ idx_t i1 = std::min(n, i0 + bs);
226
219
  if (verbose) {
227
220
  printf(" IndexIVFPQFastScan::add_with_ids %zd: %zd",
228
- size_t(i0), size_t(i1));
221
+ size_t(i0),
222
+ size_t(i1));
229
223
  }
230
- add_with_ids (i1 - i0, x + i0 * d,
231
- xids ? xids + i0 : nullptr);
224
+ add_with_ids(i1 - i0, x + i0 * d, xids ? xids + i0 : nullptr);
232
225
  }
233
226
  return;
234
227
  }
@@ -236,37 +229,38 @@ void IndexIVFPQFastScan::add_with_ids (
236
229
 
237
230
  AlignedTable<uint8_t> codes(n * code_size);
238
231
 
239
- FAISS_THROW_IF_NOT (is_trained);
240
- direct_map.check_can_add (xids);
232
+ FAISS_THROW_IF_NOT(is_trained);
233
+ direct_map.check_can_add(xids);
241
234
 
242
- std::unique_ptr<idx_t []> idx(new idx_t[n]);
243
- quantizer->assign (n, x, idx.get());
235
+ std::unique_ptr<idx_t[]> idx(new idx_t[n]);
236
+ quantizer->assign(n, x, idx.get());
244
237
  size_t nadd = 0, nminus1 = 0;
245
238
 
246
239
  for (size_t i = 0; i < n; i++) {
247
- if (idx[i] < 0) nminus1++;
240
+ if (idx[i] < 0)
241
+ nminus1++;
248
242
  }
249
243
 
250
244
  AlignedTable<uint8_t> flat_codes(n * code_size);
251
- encode_vectors (n, x, idx.get(), flat_codes.get());
245
+ encode_vectors(n, x, idx.get(), flat_codes.get());
252
246
 
253
247
  DirectMapAdd dm_adder(direct_map, n, xids);
254
248
 
255
249
  // <---
256
250
 
257
- BlockInvertedLists *bil = dynamic_cast<BlockInvertedLists*>(invlists);
258
- FAISS_THROW_IF_NOT_MSG (bil, "only block inverted lists supported");
251
+ BlockInvertedLists* bil = dynamic_cast<BlockInvertedLists*>(invlists);
252
+ FAISS_THROW_IF_NOT_MSG(bil, "only block inverted lists supported");
259
253
 
260
254
  // prepare batches
261
255
  std::vector<idx_t> order(n);
262
- for(idx_t i = 0; i < n ; i++) { order[i] = i; }
256
+ for (idx_t i = 0; i < n; i++) {
257
+ order[i] = i;
258
+ }
263
259
 
264
260
  // TODO should not need stable
265
- std::stable_sort(order.begin(), order.end(),
266
- [&idx](idx_t a, idx_t b) {
267
- return idx[a] < idx[b];
268
- }
269
- );
261
+ std::stable_sort(order.begin(), order.end(), [&idx](idx_t a, idx_t b) {
262
+ return idx[a] < idx[b];
263
+ });
270
264
 
271
265
  // TODO parallelize
272
266
  idx_t i0 = 0;
@@ -274,7 +268,7 @@ void IndexIVFPQFastScan::add_with_ids (
274
268
  idx_t list_no = idx[order[i0]];
275
269
  idx_t i1 = i0 + 1;
276
270
  while (i1 < n && idx[order[i1]] == list_no) {
277
- i1 ++;
271
+ i1++;
278
272
  }
279
273
 
280
274
  if (list_no == -1) {
@@ -288,58 +282,57 @@ void IndexIVFPQFastScan::add_with_ids (
288
282
 
289
283
  bil->resize(list_no, list_size + i1 - i0);
290
284
 
291
- for(idx_t i = i0; i < i1; i++) {
285
+ for (idx_t i = i0; i < i1; i++) {
292
286
  size_t ofs = list_size + i - i0;
293
287
  idx_t id = xids ? xids[order[i]] : ntotal + order[i];
294
- dm_adder.add (order[i], list_no, ofs);
288
+ dm_adder.add(order[i], list_no, ofs);
295
289
  bil->ids[list_no][ofs] = id;
296
- memcpy(
297
- list_codes.data() + (i - i0) * code_size,
298
- flat_codes.data() + order[i] * code_size,
299
- code_size
300
- );
290
+ memcpy(list_codes.data() + (i - i0) * code_size,
291
+ flat_codes.data() + order[i] * code_size,
292
+ code_size);
301
293
  nadd++;
302
294
  }
303
295
  pq4_pack_codes_range(
304
- list_codes.data(), pq.M,
305
- list_size, list_size + i1 - i0,
306
- bbs, M2, bil->codes[list_no].data()
307
- );
296
+ list_codes.data(),
297
+ pq.M,
298
+ list_size,
299
+ list_size + i1 - i0,
300
+ bbs,
301
+ M2,
302
+ bil->codes[list_no].data());
308
303
 
309
304
  i0 = i1;
310
305
  }
311
306
 
312
307
  ntotal += n;
313
-
314
308
  }
315
309
 
316
-
317
-
318
310
  /*********************************************************
319
311
  * search
320
312
  *********************************************************/
321
313
 
322
-
323
314
  namespace {
324
315
 
325
316
  // from impl/ProductQuantizer.cpp
326
317
  template <class C, typename dis_t>
327
318
  void pq_estimators_from_tables_generic(
328
- const ProductQuantizer& pq, size_t nbits,
329
- const uint8_t *codes, size_t ncodes,
330
- const dis_t *dis_table, const int64_t * ids,
319
+ const ProductQuantizer& pq,
320
+ size_t nbits,
321
+ const uint8_t* codes,
322
+ size_t ncodes,
323
+ const dis_t* dis_table,
324
+ const int64_t* ids,
331
325
  float dis0,
332
- size_t k, typename C::T *heap_dis, int64_t *heap_ids)
333
- {
326
+ size_t k,
327
+ typename C::T* heap_dis,
328
+ int64_t* heap_ids) {
334
329
  using accu_t = typename C::T;
335
330
  const size_t M = pq.M;
336
331
  const size_t ksub = pq.ksub;
337
332
  for (size_t j = 0; j < ncodes; ++j) {
338
- PQDecoderGeneric decoder(
339
- codes + j * pq.code_size, nbits
340
- );
333
+ PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
341
334
  accu_t dis = dis0;
342
- const dis_t * dt = dis_table;
335
+ const dis_t* dt = dis_table;
343
336
  for (size_t m = 0; m < M; m++) {
344
337
  uint64_t c = decoder.decode();
345
338
  dis += dt[c];
@@ -356,17 +349,19 @@ void pq_estimators_from_tables_generic(
356
349
  using idx_t = Index::idx_t;
357
350
  using namespace quantize_lut;
358
351
 
359
- void fvec_madd_avx (
360
- size_t n, const float *a,
361
- float bf, const float *b, float *c)
362
- {
352
+ void fvec_madd_avx(
353
+ size_t n,
354
+ const float* a,
355
+ float bf,
356
+ const float* b,
357
+ float* c) {
363
358
  assert(is_aligned_pointer(a));
364
359
  assert(is_aligned_pointer(b));
365
360
  assert(is_aligned_pointer(c));
366
361
  assert(n % 8 == 0);
367
362
  simd8float32 bf8(bf);
368
363
  n /= 8;
369
- for(size_t i = 0; i < n; i++) {
364
+ for (size_t i = 0; i < n; i++) {
370
365
  simd8float32 ai(a);
371
366
  simd8float32 bi(b);
372
367
 
@@ -376,7 +371,6 @@ void fvec_madd_avx (
376
371
  a += 8;
377
372
  b += 8;
378
373
  }
379
-
380
374
  }
381
375
 
382
376
  } // anonymous namespace
@@ -385,23 +379,20 @@ void fvec_madd_avx (
385
379
  * Look-Up Table functions
386
380
  *********************************************************/
387
381
 
388
-
389
382
  void IndexIVFPQFastScan::compute_LUT(
390
- size_t n, const float *x,
391
- const idx_t *coarse_ids, const float *coarse_dis,
392
- AlignedTable<float> & dis_tables,
393
- AlignedTable<float> & biases
394
- ) const
395
- {
396
- const IndexIVFPQFastScan & ivfpq = *this;
383
+ size_t n,
384
+ const float* x,
385
+ const idx_t* coarse_ids,
386
+ const float* coarse_dis,
387
+ AlignedTable<float>& dis_tables,
388
+ AlignedTable<float>& biases) const {
389
+ const IndexIVFPQFastScan& ivfpq = *this;
397
390
  size_t dim12 = pq.ksub * pq.M;
398
391
  size_t d = pq.d;
399
392
  size_t nprobe = ivfpq.nprobe;
400
393
 
401
394
  if (ivfpq.by_residual) {
402
-
403
395
  if (ivfpq.metric_type == METRIC_L2) {
404
-
405
396
  dis_tables.resize(n * nprobe * dim12);
406
397
 
407
398
  if (ivfpq.use_precomputed_table == 1) {
@@ -409,57 +400,54 @@ void IndexIVFPQFastScan::compute_LUT(
409
400
  memcpy(biases.get(), coarse_dis, sizeof(float) * n * nprobe);
410
401
 
411
402
  AlignedTable<float> ip_table(n * dim12);
412
- pq.compute_inner_prod_tables (n, x, ip_table.get());
403
+ pq.compute_inner_prod_tables(n, x, ip_table.get());
413
404
 
414
405
  #pragma omp parallel for if (n * nprobe > 8000)
415
- for(idx_t ij = 0; ij < n * nprobe; ij++) {
406
+ for (idx_t ij = 0; ij < n * nprobe; ij++) {
416
407
  idx_t i = ij / nprobe;
417
- float *tab = dis_tables.get() + ij * dim12;
408
+ float* tab = dis_tables.get() + ij * dim12;
418
409
  idx_t cij = coarse_ids[ij];
419
410
 
420
411
  if (cij >= 0) {
421
- fvec_madd_avx (
422
- dim12,
423
- precomputed_table.get() + cij * dim12,
424
- -2, ip_table.get() + i * dim12,
425
- tab
426
- );
412
+ fvec_madd_avx(
413
+ dim12,
414
+ precomputed_table.get() + cij * dim12,
415
+ -2,
416
+ ip_table.get() + i * dim12,
417
+ tab);
427
418
  } else {
428
419
  // fill with NaNs so that they are ignored during
429
420
  // LUT quantization
430
- memset (tab, -1, sizeof(float) * dim12);
421
+ memset(tab, -1, sizeof(float) * dim12);
431
422
  }
432
423
  }
433
424
 
434
425
  } else {
435
-
436
426
  std::unique_ptr<float[]> xrel(new float[n * nprobe * d]);
437
427
  biases.resize(n * nprobe);
438
428
  memset(biases.get(), 0, sizeof(float) * n * nprobe);
439
429
 
440
430
  #pragma omp parallel for if (n * nprobe > 8000)
441
- for(idx_t ij = 0; ij < n * nprobe; ij++) {
431
+ for (idx_t ij = 0; ij < n * nprobe; ij++) {
442
432
  idx_t i = ij / nprobe;
443
- float *xij = &xrel[ij * d];
433
+ float* xij = &xrel[ij * d];
444
434
  idx_t cij = coarse_ids[ij];
445
435
 
446
436
  if (cij >= 0) {
447
- ivfpq.quantizer->compute_residual(
448
- x + i * d, xij, cij);
437
+ ivfpq.quantizer->compute_residual(x + i * d, xij, cij);
449
438
  } else {
450
439
  // will fill with NaNs
451
440
  memset(xij, -1, sizeof(float) * d);
452
441
  }
453
442
  }
454
443
 
455
- pq.compute_distance_tables (
444
+ pq.compute_distance_tables(
456
445
  n * nprobe, xrel.get(), dis_tables.get());
457
-
458
446
  }
459
447
 
460
448
  } else if (ivfpq.metric_type == METRIC_INNER_PRODUCT) {
461
449
  dis_tables.resize(n * dim12);
462
- pq.compute_inner_prod_tables (n, x, dis_tables.get());
450
+ pq.compute_inner_prod_tables(n, x, dis_tables.get());
463
451
  // compute_inner_prod_tables(pq, n, x, dis_tables.get());
464
452
 
465
453
  biases.resize(n * nprobe);
@@ -471,33 +459,29 @@ void IndexIVFPQFastScan::compute_LUT(
471
459
  } else {
472
460
  dis_tables.resize(n * dim12);
473
461
  if (ivfpq.metric_type == METRIC_L2) {
474
- pq.compute_distance_tables (n, x, dis_tables.get());
462
+ pq.compute_distance_tables(n, x, dis_tables.get());
475
463
  } else if (ivfpq.metric_type == METRIC_INNER_PRODUCT) {
476
- pq.compute_inner_prod_tables (n, x, dis_tables.get());
464
+ pq.compute_inner_prod_tables(n, x, dis_tables.get());
477
465
  } else {
478
466
  FAISS_THROW_FMT("metric %d not supported", ivfpq.metric_type);
479
467
  }
480
468
  }
481
-
482
469
  }
483
470
 
484
471
  void IndexIVFPQFastScan::compute_LUT_uint8(
485
- size_t n, const float *x,
486
- const idx_t *coarse_ids, const float *coarse_dis,
487
- AlignedTable<uint8_t> & dis_tables,
488
- AlignedTable<uint16_t> & biases,
489
- float * normalizers
490
- ) const {
491
- const IndexIVFPQFastScan & ivfpq = *this;
472
+ size_t n,
473
+ const float* x,
474
+ const idx_t* coarse_ids,
475
+ const float* coarse_dis,
476
+ AlignedTable<uint8_t>& dis_tables,
477
+ AlignedTable<uint16_t>& biases,
478
+ float* normalizers) const {
479
+ const IndexIVFPQFastScan& ivfpq = *this;
492
480
  AlignedTable<float> dis_tables_float;
493
481
  AlignedTable<float> biases_float;
494
482
 
495
483
  uint64_t t0 = get_cy();
496
- compute_LUT(
497
- n, x,
498
- coarse_ids, coarse_dis,
499
- dis_tables_float, biases_float
500
- );
484
+ compute_LUT(n, x, coarse_ids, coarse_dis, dis_tables_float, biases_float);
501
485
  IVFFastScan_stats.t_compute_distance_tables += get_cy() - t0;
502
486
 
503
487
  bool lut_is_3d = ivfpq.by_residual && ivfpq.metric_type == METRIC_L2;
@@ -514,45 +498,52 @@ void IndexIVFPQFastScan::compute_LUT_uint8(
514
498
  uint64_t t1 = get_cy();
515
499
 
516
500
  #pragma omp parallel for if (n > 100)
517
- for(int64_t i = 0; i < n; i++) {
518
- const float *t_in = dis_tables_float.get() + i * dim123;
519
- const float *b_in = nullptr;
520
- uint8_t *t_out = dis_tables.get() + i * dim123_2;
521
- uint16_t *b_out = nullptr;
501
+ for (int64_t i = 0; i < n; i++) {
502
+ const float* t_in = dis_tables_float.get() + i * dim123;
503
+ const float* b_in = nullptr;
504
+ uint8_t* t_out = dis_tables.get() + i * dim123_2;
505
+ uint16_t* b_out = nullptr;
522
506
  if (biases_float.get()) {
523
507
  b_in = biases_float.get() + i * nprobe;
524
508
  b_out = biases.get() + i * nprobe;
525
509
  }
526
510
 
527
511
  quantize_LUT_and_bias(
528
- nprobe, pq.M, pq.ksub, lut_is_3d,
529
- t_in, b_in,
530
- t_out, M2, b_out,
531
- normalizers + 2 * i, normalizers + 2 * i + 1
532
- );
512
+ nprobe,
513
+ pq.M,
514
+ pq.ksub,
515
+ lut_is_3d,
516
+ t_in,
517
+ b_in,
518
+ t_out,
519
+ M2,
520
+ b_out,
521
+ normalizers + 2 * i,
522
+ normalizers + 2 * i + 1);
533
523
  }
534
524
  IVFFastScan_stats.t_round += get_cy() - t1;
535
-
536
525
  }
537
526
 
538
-
539
527
  /*********************************************************
540
528
  * Search functions
541
529
  *********************************************************/
542
530
 
543
- template<bool is_max>
531
+ template <bool is_max>
544
532
  void IndexIVFPQFastScan::search_dispatch_implem(
545
- idx_t n,
546
- const float* x,
547
- idx_t k,
548
- float* distances,
549
- idx_t* labels) const
550
- {
551
- using Cfloat = typename std::conditional<is_max,
552
- CMax<float, int64_t>, CMin<float, int64_t> >::type;
553
-
554
- using C = typename std::conditional<is_max,
555
- CMax<uint16_t, int64_t>, CMin<uint16_t, int64_t> >::type;
533
+ idx_t n,
534
+ const float* x,
535
+ idx_t k,
536
+ float* distances,
537
+ idx_t* labels) const {
538
+ using Cfloat = typename std::conditional<
539
+ is_max,
540
+ CMax<float, int64_t>,
541
+ CMin<float, int64_t>>::type;
542
+
543
+ using C = typename std::conditional<
544
+ is_max,
545
+ CMax<uint16_t, int64_t>,
546
+ CMin<uint16_t, int64_t>>::type;
556
547
 
557
548
  if (n == 0) {
558
549
  return;
@@ -568,7 +559,7 @@ void IndexIVFPQFastScan::search_dispatch_implem(
568
559
  impl = 10;
569
560
  }
570
561
  if (k > 20) {
571
- impl ++;
562
+ impl++;
572
563
  }
573
564
  }
574
565
 
@@ -582,11 +573,25 @@ void IndexIVFPQFastScan::search_dispatch_implem(
582
573
 
583
574
  if (n < 2) {
584
575
  if (impl == 12 || impl == 13) {
585
- search_implem_12<C>
586
- (n, x, k, distances, labels, impl, &ndis, &nlist_visited);
576
+ search_implem_12<C>(
577
+ n,
578
+ x,
579
+ k,
580
+ distances,
581
+ labels,
582
+ impl,
583
+ &ndis,
584
+ &nlist_visited);
587
585
  } else {
588
- search_implem_10<C>
589
- (n, x, k, distances, labels, impl, &ndis, &nlist_visited);
586
+ search_implem_10<C>(
587
+ n,
588
+ x,
589
+ k,
590
+ distances,
591
+ labels,
592
+ impl,
593
+ &ndis,
594
+ &nlist_visited);
590
595
  }
591
596
  } else {
592
597
  // explicitly slice over threads
@@ -595,34 +600,47 @@ void IndexIVFPQFastScan::search_dispatch_implem(
595
600
  nslice = n;
596
601
  } else if (by_residual && metric_type == METRIC_L2) {
597
602
  // make sure we don't make too big LUT tables
598
- size_t lut_size_per_query =
599
- pq.M * pq.ksub * nprobe * (sizeof(float) + sizeof(uint8_t));
603
+ size_t lut_size_per_query = pq.M * pq.ksub * nprobe *
604
+ (sizeof(float) + sizeof(uint8_t));
600
605
 
601
606
  size_t max_lut_size = precomputed_table_max_bytes;
602
607
  // how many queries we can handle within mem budget
603
- size_t nq_ok = std::max(max_lut_size / lut_size_per_query, size_t(1));
604
- nslice = roundup(std::max(size_t(n / nq_ok), size_t(1)), omp_get_max_threads());
608
+ size_t nq_ok =
609
+ std::max(max_lut_size / lut_size_per_query, size_t(1));
610
+ nslice =
611
+ roundup(std::max(size_t(n / nq_ok), size_t(1)),
612
+ omp_get_max_threads());
605
613
  } else {
606
614
  // LUTs unlikely to be a limiting factor
607
615
  nslice = omp_get_max_threads();
608
616
  }
609
617
 
610
- #pragma omp parallel for reduction(+: ndis, nlist_visited)
618
+ #pragma omp parallel for reduction(+ : ndis, nlist_visited)
611
619
  for (int slice = 0; slice < nslice; slice++) {
612
620
  idx_t i0 = n * slice / nslice;
613
621
  idx_t i1 = n * (slice + 1) / nslice;
614
- float *dis_i = distances + i0 * k;
615
- idx_t *lab_i = labels + i0 * k;
622
+ float* dis_i = distances + i0 * k;
623
+ idx_t* lab_i = labels + i0 * k;
616
624
  if (impl == 12 || impl == 13) {
617
625
  search_implem_12<C>(
618
- i1 - i0, x + i0 * d, k, dis_i, lab_i,
619
- impl, &ndis, &nlist_visited
620
- );
626
+ i1 - i0,
627
+ x + i0 * d,
628
+ k,
629
+ dis_i,
630
+ lab_i,
631
+ impl,
632
+ &ndis,
633
+ &nlist_visited);
621
634
  } else {
622
635
  search_implem_10<C>(
623
- i1 - i0, x + i0 * d, k, dis_i, lab_i,
624
- impl, &ndis, &nlist_visited
625
- );
636
+ i1 - i0,
637
+ x + i0 * d,
638
+ k,
639
+ dis_i,
640
+ lab_i,
641
+ impl,
642
+ &ndis,
643
+ &nlist_visited);
626
644
  }
627
645
  }
628
646
  }
@@ -632,14 +650,16 @@ void IndexIVFPQFastScan::search_dispatch_implem(
632
650
  } else {
633
651
  FAISS_THROW_FMT("implem %d does not exist", implem);
634
652
  }
635
-
636
653
  }
637
654
 
638
-
639
655
  void IndexIVFPQFastScan::search(
640
- idx_t n, const float* x, idx_t k,
641
- float* distances, idx_t* labels) const
642
- {
656
+ idx_t n,
657
+ const float* x,
658
+ idx_t k,
659
+ float* distances,
660
+ idx_t* labels) const {
661
+ FAISS_THROW_IF_NOT(k > 0);
662
+
643
663
  if (metric_type == METRIC_L2) {
644
664
  search_dispatch_implem<true>(n, x, k, distances, labels);
645
665
  } else {
@@ -647,133 +667,150 @@ void IndexIVFPQFastScan::search(
647
667
  }
648
668
  }
649
669
 
650
- template<class C>
670
+ template <class C>
651
671
  void IndexIVFPQFastScan::search_implem_1(
652
- idx_t n, const float* x, idx_t k,
653
- float* distances, idx_t* labels) const
654
- {
672
+ idx_t n,
673
+ const float* x,
674
+ idx_t k,
675
+ float* distances,
676
+ idx_t* labels) const {
655
677
  FAISS_THROW_IF_NOT(orig_invlists);
656
678
 
657
679
  std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
658
680
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
659
681
 
660
- quantizer->search (n, x, nprobe, coarse_dis.get(), coarse_ids.get());
682
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
661
683
 
662
684
  size_t dim12 = pq.ksub * pq.M;
663
685
  AlignedTable<float> dis_tables;
664
686
  AlignedTable<float> biases;
665
687
 
666
- compute_LUT (
667
- n, x,
668
- coarse_ids.get(), coarse_dis.get(),
669
- dis_tables, biases
670
- );
688
+ compute_LUT(n, x, coarse_ids.get(), coarse_dis.get(), dis_tables, biases);
671
689
 
672
690
  bool single_LUT = !(by_residual && metric_type == METRIC_L2);
673
691
 
674
692
  size_t ndis = 0, nlist_visited = 0;
675
693
 
676
- #pragma omp parallel for reduction(+: ndis, nlist_visited)
677
- for(idx_t i = 0; i < n; i++) {
678
- int64_t *heap_ids = labels + i * k;
679
- float *heap_dis = distances + i * k;
680
- heap_heapify<C> (k, heap_dis, heap_ids);
681
- float *LUT = nullptr;
694
+ #pragma omp parallel for reduction(+ : ndis, nlist_visited)
695
+ for (idx_t i = 0; i < n; i++) {
696
+ int64_t* heap_ids = labels + i * k;
697
+ float* heap_dis = distances + i * k;
698
+ heap_heapify<C>(k, heap_dis, heap_ids);
699
+ float* LUT = nullptr;
682
700
 
683
701
  if (single_LUT) {
684
702
  LUT = dis_tables.get() + i * dim12;
685
703
  }
686
- for(idx_t j = 0; j < nprobe; j++) {
704
+ for (idx_t j = 0; j < nprobe; j++) {
687
705
  if (!single_LUT) {
688
706
  LUT = dis_tables.get() + (i * nprobe + j) * dim12;
689
707
  }
690
708
  idx_t list_no = coarse_ids[i * nprobe + j];
691
- if (list_no < 0) continue;
709
+ if (list_no < 0)
710
+ continue;
692
711
  size_t ls = orig_invlists->list_size(list_no);
693
- if (ls == 0) continue;
712
+ if (ls == 0)
713
+ continue;
694
714
  InvertedLists::ScopedCodes codes(orig_invlists, list_no);
695
715
  InvertedLists::ScopedIds ids(orig_invlists, list_no);
696
716
 
697
717
  float bias = biases.get() ? biases[i * nprobe + j] : 0;
698
718
 
699
719
  pq_estimators_from_tables_generic<C>(
700
- pq, pq.nbits, codes.get(), ls,
701
- LUT, ids.get(), bias,
702
- k, heap_dis, heap_ids
703
- );
704
- nlist_visited ++;
705
- ndis ++;
720
+ pq,
721
+ pq.nbits,
722
+ codes.get(),
723
+ ls,
724
+ LUT,
725
+ ids.get(),
726
+ bias,
727
+ k,
728
+ heap_dis,
729
+ heap_ids);
730
+ nlist_visited++;
731
+ ndis++;
706
732
  }
707
- heap_reorder<C> (k, heap_dis, heap_ids);
733
+ heap_reorder<C>(k, heap_dis, heap_ids);
708
734
  }
709
735
  indexIVF_stats.nq += n;
710
736
  indexIVF_stats.ndis += ndis;
711
737
  indexIVF_stats.nlist += nlist_visited;
712
738
  }
713
739
 
714
- template<class C>
740
+ template <class C>
715
741
  void IndexIVFPQFastScan::search_implem_2(
716
- idx_t n, const float* x, idx_t k,
717
- float* distances, idx_t* labels) const
718
- {
742
+ idx_t n,
743
+ const float* x,
744
+ idx_t k,
745
+ float* distances,
746
+ idx_t* labels) const {
719
747
  FAISS_THROW_IF_NOT(orig_invlists);
720
748
 
721
749
  std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
722
750
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
723
751
 
724
- quantizer->search (n, x, nprobe, coarse_dis.get(), coarse_ids.get());
752
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
725
753
 
726
754
  size_t dim12 = pq.ksub * M2;
727
755
  AlignedTable<uint8_t> dis_tables;
728
756
  AlignedTable<uint16_t> biases;
729
757
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
730
758
 
731
- compute_LUT_uint8 (
732
- n, x,
733
- coarse_ids.get(), coarse_dis.get(),
734
- dis_tables, biases,
735
- normalizers.get()
736
- );
737
-
759
+ compute_LUT_uint8(
760
+ n,
761
+ x,
762
+ coarse_ids.get(),
763
+ coarse_dis.get(),
764
+ dis_tables,
765
+ biases,
766
+ normalizers.get());
738
767
 
739
768
  bool single_LUT = !(by_residual && metric_type == METRIC_L2);
740
769
 
741
770
  size_t ndis = 0, nlist_visited = 0;
742
771
 
743
- #pragma omp parallel for reduction(+: ndis, nlist_visited)
744
- for(idx_t i = 0; i < n; i++) {
772
+ #pragma omp parallel for reduction(+ : ndis, nlist_visited)
773
+ for (idx_t i = 0; i < n; i++) {
745
774
  std::vector<uint16_t> tmp_dis(k);
746
- int64_t *heap_ids = labels + i * k;
747
- uint16_t *heap_dis = tmp_dis.data();
748
- heap_heapify<C> (k, heap_dis, heap_ids);
749
- const uint8_t *LUT = nullptr;
775
+ int64_t* heap_ids = labels + i * k;
776
+ uint16_t* heap_dis = tmp_dis.data();
777
+ heap_heapify<C>(k, heap_dis, heap_ids);
778
+ const uint8_t* LUT = nullptr;
750
779
 
751
780
  if (single_LUT) {
752
781
  LUT = dis_tables.get() + i * dim12;
753
782
  }
754
- for(idx_t j = 0; j < nprobe; j++) {
783
+ for (idx_t j = 0; j < nprobe; j++) {
755
784
  if (!single_LUT) {
756
785
  LUT = dis_tables.get() + (i * nprobe + j) * dim12;
757
786
  }
758
787
  idx_t list_no = coarse_ids[i * nprobe + j];
759
- if (list_no < 0) continue;
788
+ if (list_no < 0)
789
+ continue;
760
790
  size_t ls = orig_invlists->list_size(list_no);
761
- if (ls == 0) continue;
791
+ if (ls == 0)
792
+ continue;
762
793
  InvertedLists::ScopedCodes codes(orig_invlists, list_no);
763
794
  InvertedLists::ScopedIds ids(orig_invlists, list_no);
764
795
 
765
796
  uint16_t bias = biases.get() ? biases[i * nprobe + j] : 0;
766
797
 
767
798
  pq_estimators_from_tables_generic<C>(
768
- pq, pq.nbits, codes.get(), ls,
769
- LUT, ids.get(), bias,
770
- k, heap_dis, heap_ids
771
- );
799
+ pq,
800
+ pq.nbits,
801
+ codes.get(),
802
+ ls,
803
+ LUT,
804
+ ids.get(),
805
+ bias,
806
+ k,
807
+ heap_dis,
808
+ heap_ids);
772
809
 
773
810
  nlist_visited++;
774
811
  ndis += ls;
775
812
  }
776
- heap_reorder<C> (k, heap_dis, heap_ids);
813
+ heap_reorder<C>(k, heap_dis, heap_ids);
777
814
  // convert distances to float
778
815
  {
779
816
  float one_a = 1 / normalizers[2 * i], b = normalizers[2 * i + 1];
@@ -781,7 +818,7 @@ void IndexIVFPQFastScan::search_implem_2(
781
818
  one_a = 1;
782
819
  b = 0;
783
820
  }
784
- float *heap_dis_float = distances + i * k;
821
+ float* heap_dis_float = distances + i * k;
785
822
  for (int j = 0; j < k; j++) {
786
823
  heap_dis_float[j] = b + heap_dis[j] * one_a;
787
824
  }
@@ -792,14 +829,16 @@ void IndexIVFPQFastScan::search_implem_2(
792
829
  indexIVF_stats.nlist += nlist_visited;
793
830
  }
794
831
 
795
-
796
-
797
- template<class C>
832
+ template <class C>
798
833
  void IndexIVFPQFastScan::search_implem_10(
799
- idx_t n, const float* x, idx_t k,
800
- float* distances, idx_t* labels,
801
- int impl, size_t *ndis_out, size_t *nlist_out) const
802
- {
834
+ idx_t n,
835
+ const float* x,
836
+ idx_t k,
837
+ float* distances,
838
+ idx_t* labels,
839
+ int impl,
840
+ size_t* ndis_out,
841
+ size_t* nlist_out) const {
803
842
  memset(distances, -1, sizeof(float) * k * n);
804
843
  memset(labels, -1, sizeof(idx_t) * k * n);
805
844
 
@@ -807,7 +846,6 @@ void IndexIVFPQFastScan::search_implem_10(
807
846
  using ReservoirHC = ReservoirHandler<C, true>;
808
847
  using SingleResultHC = SingleResultHandler<C, true>;
809
848
 
810
-
811
849
  std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
812
850
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
813
851
 
@@ -817,20 +855,23 @@ void IndexIVFPQFastScan::search_implem_10(
817
855
  #define TIC times[ti++] = get_cy()
818
856
  TIC;
819
857
 
820
- quantizer->search (n, x, nprobe, coarse_dis.get(), coarse_ids.get());
858
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
821
859
 
822
860
  TIC;
823
861
 
824
862
  size_t dim12 = pq.ksub * M2;
825
863
  AlignedTable<uint8_t> dis_tables;
826
864
  AlignedTable<uint16_t> biases;
827
- std::unique_ptr<float[]> normalizers (new float[2 * n]);
865
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
828
866
 
829
- compute_LUT_uint8 (
830
- n, x,
831
- coarse_ids.get(), coarse_dis.get(),
832
- dis_tables, biases, normalizers.get()
833
- );
867
+ compute_LUT_uint8(
868
+ n,
869
+ x,
870
+ coarse_ids.get(),
871
+ coarse_dis.get(),
872
+ dis_tables,
873
+ biases,
874
+ normalizers.get());
834
875
 
835
876
  TIC;
836
877
 
@@ -841,15 +882,16 @@ void IndexIVFPQFastScan::search_implem_10(
841
882
 
842
883
  {
843
884
  AlignedTable<uint16_t> tmp_distances(k);
844
- for(idx_t i = 0; i < n; i++) {
845
- const uint8_t *LUT = nullptr;
885
+ for (idx_t i = 0; i < n; i++) {
886
+ const uint8_t* LUT = nullptr;
846
887
  int qmap1[1] = {0};
847
- std::unique_ptr<SIMDResultHandler<C, true> > handler;
888
+ std::unique_ptr<SIMDResultHandler<C, true>> handler;
848
889
 
849
890
  if (k == 1) {
850
891
  handler.reset(new SingleResultHC(1, 0));
851
892
  } else if (impl == 10) {
852
- handler.reset(new HeapHC(1, tmp_distances.get(), labels + i * k, k, 0));
893
+ handler.reset(new HeapHC(
894
+ 1, tmp_distances.get(), labels + i * k, k, 0));
853
895
  } else if (impl == 11) {
854
896
  handler.reset(new ReservoirHC(1, 0, k, 2 * k));
855
897
  } else {
@@ -861,7 +903,7 @@ void IndexIVFPQFastScan::search_implem_10(
861
903
  if (single_LUT) {
862
904
  LUT = dis_tables.get() + i * dim12;
863
905
  }
864
- for(idx_t j = 0; j < nprobe; j++) {
906
+ for (idx_t j = 0; j < nprobe; j++) {
865
907
  size_t ij = i * nprobe + j;
866
908
  if (!single_LUT) {
867
909
  LUT = dis_tables.get() + ij * dim12;
@@ -871,9 +913,11 @@ void IndexIVFPQFastScan::search_implem_10(
871
913
  }
872
914
 
873
915
  idx_t list_no = coarse_ids[ij];
874
- if (list_no < 0) continue;
916
+ if (list_no < 0)
917
+ continue;
875
918
  size_t ls = invlists->list_size(list_no);
876
- if (ls == 0) continue;
919
+ if (ls == 0)
920
+ continue;
877
921
 
878
922
  InvertedLists::ScopedCodes codes(invlists, list_no);
879
923
  InvertedLists::ScopedIds ids(invlists, list_no);
@@ -881,41 +925,40 @@ void IndexIVFPQFastScan::search_implem_10(
881
925
  handler->ntotal = ls;
882
926
  handler->id_map = ids.get();
883
927
 
884
- #define DISPATCH(classHC) \
885
- if(auto *res = dynamic_cast<classHC* > (handler.get())) { \
886
- pq4_accumulate_loop( \
887
- 1, roundup(ls, bbs), bbs, M2, \
888
- codes.get(), LUT, \
889
- *res \
890
- ); \
891
- }
928
+ #define DISPATCH(classHC) \
929
+ if (dynamic_cast<classHC*>(handler.get())) { \
930
+ auto* res = static_cast<classHC*>(handler.get()); \
931
+ pq4_accumulate_loop( \
932
+ 1, roundup(ls, bbs), bbs, M2, codes.get(), LUT, *res); \
933
+ }
892
934
  DISPATCH(HeapHC)
893
- else DISPATCH(ReservoirHC)
894
- else DISPATCH(SingleResultHC)
935
+ else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
895
936
  #undef DISPATCH
896
937
 
897
- nlist_visited ++;
898
- ndis ++;
938
+ nlist_visited++;
939
+ ndis++;
899
940
  }
900
941
 
901
942
  handler->to_flat_arrays(
902
- distances + i * k, labels + i * k,
903
- skip & 16 ? nullptr : normalizers.get() + i * 2
904
- );
943
+ distances + i * k,
944
+ labels + i * k,
945
+ skip & 16 ? nullptr : normalizers.get() + i * 2);
905
946
  }
906
947
  }
907
948
  *ndis_out = ndis;
908
949
  *nlist_out = nlist;
909
950
  }
910
951
 
911
-
912
-
913
- template<class C>
952
+ template <class C>
914
953
  void IndexIVFPQFastScan::search_implem_12(
915
- idx_t n, const float* x, idx_t k,
916
- float* distances, idx_t* labels,
917
- int impl, size_t *ndis_out, size_t *nlist_out) const
918
- {
954
+ idx_t n,
955
+ const float* x,
956
+ idx_t k,
957
+ float* distances,
958
+ idx_t* labels,
959
+ int impl,
960
+ size_t* ndis_out,
961
+ size_t* nlist_out) const {
919
962
  if (n == 0) { // does not work well with reservoir
920
963
  return;
921
964
  }
@@ -930,53 +973,53 @@ void IndexIVFPQFastScan::search_implem_12(
930
973
  #define TIC times[ti++] = get_cy()
931
974
  TIC;
932
975
 
933
- quantizer->search (n, x, nprobe, coarse_dis.get(), coarse_ids.get());
976
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
934
977
 
935
978
  TIC;
936
979
 
937
980
  size_t dim12 = pq.ksub * M2;
938
981
  AlignedTable<uint8_t> dis_tables;
939
982
  AlignedTable<uint16_t> biases;
940
- std::unique_ptr<float[]> normalizers (new float[2 * n]);
983
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
941
984
 
942
- compute_LUT_uint8 (
943
- n, x,
944
- coarse_ids.get(), coarse_dis.get(),
945
- dis_tables, biases, normalizers.get()
946
- );
985
+ compute_LUT_uint8(
986
+ n,
987
+ x,
988
+ coarse_ids.get(),
989
+ coarse_dis.get(),
990
+ dis_tables,
991
+ biases,
992
+ normalizers.get());
947
993
 
948
994
  TIC;
949
995
 
950
996
  struct QC {
951
- int qno; // sequence number of the query
952
- int list_no; // list to visit
953
- int rank; // this is the rank'th result of the coarse quantizer
997
+ int qno; // sequence number of the query
998
+ int list_no; // list to visit
999
+ int rank; // this is the rank'th result of the coarse quantizer
954
1000
  };
955
1001
  bool single_LUT = !(by_residual && metric_type == METRIC_L2);
956
1002
 
957
1003
  std::vector<QC> qcs;
958
1004
  {
959
1005
  int ij = 0;
960
- for(int i = 0; i < n; i++) {
961
- for(int j = 0; j < nprobe; j++) {
1006
+ for (int i = 0; i < n; i++) {
1007
+ for (int j = 0; j < nprobe; j++) {
962
1008
  if (coarse_ids[ij] >= 0) {
963
1009
  qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
964
1010
  }
965
1011
  ij++;
966
1012
  }
967
1013
  }
968
- std::sort(
969
- qcs.begin(), qcs.end(),
970
- [](const QC &a, const QC & b) {
971
- return a.list_no < b.list_no;
972
- }
973
- );
1014
+ std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
1015
+ return a.list_no < b.list_no;
1016
+ });
974
1017
  }
975
1018
  TIC;
976
1019
 
977
1020
  // prepare the result handlers
978
1021
 
979
- std::unique_ptr<SIMDResultHandler<C, true> > handler;
1022
+ std::unique_ptr<SIMDResultHandler<C, true>> handler;
980
1023
  AlignedTable<uint16_t> tmp_distances;
981
1024
 
982
1025
  using HeapHC = HeapHandler<C, true>;
@@ -1012,7 +1055,7 @@ void IndexIVFPQFastScan::search_implem_12(
1012
1055
  int list_no = qcs[i0].list_no;
1013
1056
  size_t i1 = i0 + 1;
1014
1057
 
1015
- while(i1 < qcs.size() && i1 < i0 + qbs2) {
1058
+ while (i1 < qcs.size() && i1 < i0 + qbs2) {
1016
1059
  if (qcs[i1].list_no != list_no) {
1017
1060
  break;
1018
1061
  }
@@ -1034,8 +1077,8 @@ void IndexIVFPQFastScan::search_implem_12(
1034
1077
  memset(LUT.get(), -1, nc * dim12);
1035
1078
  int qbs = pq4_preferred_qbs(nc);
1036
1079
 
1037
- for(size_t i = i0; i < i1; i++) {
1038
- const QC & qc = qcs[i];
1080
+ for (size_t i = i0; i < i1; i++) {
1081
+ const QC& qc = qcs[i];
1039
1082
  q_map[i - i0] = qc.qno;
1040
1083
  int ij = qc.qno * nprobe + qc.rank;
1041
1084
  lut_entries[i - i0] = single_LUT ? qc.qno : ij;
@@ -1044,9 +1087,7 @@ void IndexIVFPQFastScan::search_implem_12(
1044
1087
  }
1045
1088
  }
1046
1089
  pq4_pack_LUT_qbs_q_map(
1047
- qbs, M2, dis_tables.get(), lut_entries.data(),
1048
- LUT.get()
1049
- );
1090
+ qbs, M2, dis_tables.get(), lut_entries.data(), LUT.get());
1050
1091
 
1051
1092
  // access the inverted list
1052
1093
 
@@ -1062,20 +1103,17 @@ void IndexIVFPQFastScan::search_implem_12(
1062
1103
  handler->id_map = ids.get();
1063
1104
  uint64_t tt1 = get_cy();
1064
1105
 
1065
- #define DISPATCH(classHC) \
1066
- if(auto *res = dynamic_cast<classHC* > (handler.get())) { \
1067
- pq4_accumulate_loop_qbs( \
1068
- qbs, list_size, M2, \
1069
- codes.get(), LUT.get(), \
1070
- *res \
1071
- ); \
1072
- }
1106
+ #define DISPATCH(classHC) \
1107
+ if (dynamic_cast<classHC*>(handler.get())) { \
1108
+ auto* res = static_cast<classHC*>(handler.get()); \
1109
+ pq4_accumulate_loop_qbs( \
1110
+ qbs, list_size, M2, codes.get(), LUT.get(), *res); \
1111
+ }
1073
1112
  DISPATCH(HeapHC)
1074
- else DISPATCH(ReservoirHC)
1075
- else DISPATCH(SingleResultHC)
1113
+ else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
1076
1114
 
1077
- // prepare for next loop
1078
- i0 = i1;
1115
+ // prepare for next loop
1116
+ i0 = i1;
1079
1117
 
1080
1118
  uint64_t tt2 = get_cy();
1081
1119
  t_copy_pack += tt1 - tt0;
@@ -1085,21 +1123,19 @@ void IndexIVFPQFastScan::search_implem_12(
1085
1123
 
1086
1124
  // labels is in-place for HeapHC
1087
1125
  handler->to_flat_arrays(
1088
- distances, labels,
1089
- skip & 16 ? nullptr : normalizers.get()
1090
- );
1126
+ distances, labels, skip & 16 ? nullptr : normalizers.get());
1091
1127
 
1092
1128
  TIC;
1093
1129
 
1094
1130
  // these stats are not thread-safe
1095
1131
 
1096
- for(int i = 1; i < ti; i++) {
1097
- IVFFastScan_stats.times[i] += times[i] - times[i-1];
1132
+ for (int i = 1; i < ti; i++) {
1133
+ IVFFastScan_stats.times[i] += times[i] - times[i - 1];
1098
1134
  }
1099
1135
  IVFFastScan_stats.t_copy_pack += t_copy_pack;
1100
1136
  IVFFastScan_stats.t_scan += t_scan;
1101
1137
 
1102
- if (auto *rh = dynamic_cast<ReservoirHC*> (handler.get())) {
1138
+ if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
1103
1139
  for (int i = 0; i < 4; i++) {
1104
1140
  IVFFastScan_stats.reservoir_times[i] += rh->times[i];
1105
1141
  }
@@ -1107,10 +1143,8 @@ void IndexIVFPQFastScan::search_implem_12(
1107
1143
 
1108
1144
  *ndis_out = ndis;
1109
1145
  *nlist_out = nlist;
1110
-
1111
1146
  }
1112
1147
 
1113
-
1114
1148
  IVFFastScanStats IVFFastScan_stats;
1115
1149
 
1116
1150
  } // namespace faiss