faiss 0.1.4 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +26 -1
  3. data/README.md +15 -3
  4. data/ext/faiss/ext.cpp +12 -308
  5. data/ext/faiss/extconf.rb +5 -2
  6. data/ext/faiss/index.cpp +189 -0
  7. data/ext/faiss/index_binary.cpp +75 -0
  8. data/ext/faiss/kmeans.cpp +40 -0
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +33 -0
  11. data/ext/faiss/product_quantizer.cpp +53 -0
  12. data/ext/faiss/utils.cpp +13 -0
  13. data/ext/faiss/utils.h +5 -0
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +31 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -10,7 +10,6 @@
10
10
  #ifndef FAISS_INDEX_IVFPQ_H
11
11
  #define FAISS_INDEX_IVFPQ_H
12
12
 
13
-
14
13
  #include <vector>
15
14
 
16
15
  #include <faiss/IndexIVF.h>
@@ -20,32 +19,29 @@
20
19
 
21
20
  namespace faiss {
22
21
 
23
- struct IVFPQSearchParameters: IVFSearchParameters {
24
- size_t scan_table_threshold; ///< use table computation or on-the-fly?
25
- int polysemous_ht; ///< Hamming thresh for polysemous filtering
26
- IVFPQSearchParameters (): scan_table_threshold(0), polysemous_ht(0) {}
27
- ~IVFPQSearchParameters () {}
22
+ struct IVFPQSearchParameters : IVFSearchParameters {
23
+ size_t scan_table_threshold; ///< use table computation or on-the-fly?
24
+ int polysemous_ht; ///< Hamming thresh for polysemous filtering
25
+ IVFPQSearchParameters() : scan_table_threshold(0), polysemous_ht(0) {}
26
+ ~IVFPQSearchParameters() {}
28
27
  };
29
28
 
30
-
31
-
32
29
  FAISS_API extern size_t precomputed_table_max_bytes;
33
30
 
34
-
35
31
  /** Inverted file with Product Quantizer encoding. Each residual
36
32
  * vector is encoded as a product quantizer code.
37
33
  */
38
- struct IndexIVFPQ: IndexIVF {
39
- bool by_residual; ///< Encode residual or plain vector?
34
+ struct IndexIVFPQ : IndexIVF {
35
+ bool by_residual; ///< Encode residual or plain vector?
40
36
 
41
- ProductQuantizer pq; ///< produces the codes
37
+ ProductQuantizer pq; ///< produces the codes
42
38
 
43
- bool do_polysemous_training; ///< reorder PQ centroids after training?
44
- PolysemousTraining *polysemous_training; ///< if NULL, use default
39
+ bool do_polysemous_training; ///< reorder PQ centroids after training?
40
+ PolysemousTraining* polysemous_training; ///< if NULL, use default
45
41
 
46
42
  // search-time parameters
47
- size_t scan_table_threshold; ///< use table computation or on-the-fly?
48
- int polysemous_ht; ///< Hamming thresh for polysemous filtering
43
+ size_t scan_table_threshold; ///< use table computation or on-the-fly?
44
+ int polysemous_ht; ///< Hamming thresh for polysemous filtering
49
45
 
50
46
  /** Precompute table that speed up query preprocessing at some
51
47
  * memory cost (used only for by_residual with L2 metric)
@@ -56,37 +52,47 @@ struct IndexIVFPQ: IndexIVF {
56
52
  /// size nlist * pq.M * pq.ksub
57
53
  AlignedTable<float> precomputed_table;
58
54
 
59
- IndexIVFPQ (
60
- Index * quantizer, size_t d, size_t nlist,
61
- size_t M, size_t nbits_per_idx, MetricType metric = METRIC_L2);
62
-
63
- void add_with_ids(idx_t n, const float* x, const idx_t* xids = nullptr)
64
- override;
65
-
66
- void encode_vectors(idx_t n, const float* x,
67
- const idx_t *list_nos,
68
- uint8_t * codes,
69
- bool include_listnos = false) const override;
70
-
71
- void sa_decode (idx_t n, const uint8_t *bytes,
72
- float *x) const override;
73
-
55
+ IndexIVFPQ(
56
+ Index* quantizer,
57
+ size_t d,
58
+ size_t nlist,
59
+ size_t M,
60
+ size_t nbits_per_idx,
61
+ MetricType metric = METRIC_L2);
62
+
63
+ void encode_vectors(
64
+ idx_t n,
65
+ const float* x,
66
+ const idx_t* list_nos,
67
+ uint8_t* codes,
68
+ bool include_listnos = false) const override;
69
+
70
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
71
+
72
+ void add_core(
73
+ idx_t n,
74
+ const float* x,
75
+ const idx_t* xids,
76
+ const idx_t* precomputed_idx) override;
74
77
 
75
78
  /// same as add_core, also:
76
79
  /// - output 2nd level residuals if residuals_2 != NULL
77
- /// - use precomputed list numbers if precomputed_idx != NULL
78
- void add_core_o (idx_t n, const float *x,
79
- const idx_t *xids, float *residuals_2,
80
- const idx_t *precomputed_idx = nullptr);
80
+ /// - accepts precomputed_idx = nullptr
81
+ void add_core_o(
82
+ idx_t n,
83
+ const float* x,
84
+ const idx_t* xids,
85
+ float* residuals_2,
86
+ const idx_t* precomputed_idx = nullptr);
81
87
 
82
88
  /// trains the product quantizer
83
89
  void train_residual(idx_t n, const float* x) override;
84
90
 
85
91
  /// same as train_residual, also output 2nd level residuals
86
- void train_residual_o (idx_t n, const float *x, float *residuals_2);
92
+ void train_residual_o(idx_t n, const float* x, float* residuals_2);
87
93
 
88
- void reconstruct_from_offset (int64_t list_no, int64_t offset,
89
- float* recons) const override;
94
+ void reconstruct_from_offset(int64_t list_no, int64_t offset, float* recons)
95
+ const override;
90
96
 
91
97
  /** Find exact duplicates in the dataset.
92
98
  *
@@ -99,10 +105,10 @@ struct IndexIVFPQ: IndexIVF {
99
105
  * duplicates (max size ntotal)
100
106
  * @return n number of groups found
101
107
  */
102
- size_t find_duplicates (idx_t *ids, size_t *lims) const;
108
+ size_t find_duplicates(idx_t* ids, size_t* lims) const;
103
109
 
104
110
  // map a vector to a binary code knowning the index
105
- void encode (idx_t key, const float * x, uint8_t * code) const;
111
+ void encode(idx_t key, const float* x, uint8_t* code) const;
106
112
 
107
113
  /** Encode multiple vectors
108
114
  *
@@ -113,22 +119,27 @@ struct IndexIVFPQ: IndexIVF {
113
119
  * @param compute_keys if false, assume keys are precomputed,
114
120
  * otherwise compute them
115
121
  */
116
- void encode_multiple (size_t n, idx_t *keys,
117
- const float * x, uint8_t * codes,
118
- bool compute_keys = false) const;
122
+ void encode_multiple(
123
+ size_t n,
124
+ idx_t* keys,
125
+ const float* x,
126
+ uint8_t* codes,
127
+ bool compute_keys = false) const;
119
128
 
120
129
  /// inverse of encode_multiple
121
- void decode_multiple (size_t n, const idx_t *keys,
122
- const uint8_t * xcodes, float * x) const;
130
+ void decode_multiple(
131
+ size_t n,
132
+ const idx_t* keys,
133
+ const uint8_t* xcodes,
134
+ float* x) const;
123
135
 
124
- InvertedListScanner *get_InvertedListScanner (bool store_pairs)
125
- const override;
136
+ InvertedListScanner* get_InvertedListScanner(
137
+ bool store_pairs) const override;
126
138
 
127
139
  /// build precomputed table
128
- void precompute_table ();
129
-
130
- IndexIVFPQ ();
140
+ void precompute_table();
131
141
 
142
+ IndexIVFPQ();
132
143
  };
133
144
 
134
145
  /** Pre-compute distance tables for IVFPQ with by-residual and METRIC_L2
@@ -136,24 +147,23 @@ struct IndexIVFPQ: IndexIVF {
136
147
  * @param use_precomputed_table (I/O)
137
148
  * =-1: force disable
138
149
  * =0: decide heuristically (default: use tables only if they are
139
- * < precomputed_tables_max_bytes), set use_precomputed_table on output
140
- * =1: tables that work for all quantizers (size 256 * nlist * M)
141
- * =2: specific version for MultiIndexQuantizer (much more compact)
150
+ * < precomputed_tables_max_bytes), set use_precomputed_table on
151
+ * output =1: tables that work for all quantizers (size 256 * nlist * M) =2:
152
+ * specific version for MultiIndexQuantizer (much more compact)
142
153
  * @param precomputed_table precomputed table to intialize
143
154
  */
144
155
 
145
156
  void initialize_IVFPQ_precomputed_table(
146
- int &use_precomputed_table,
147
- const Index *quantizer,
148
- const ProductQuantizer &pq,
149
- AlignedTable<float> & precomputed_table,
150
- bool verbose
151
- );
157
+ int& use_precomputed_table,
158
+ const Index* quantizer,
159
+ const ProductQuantizer& pq,
160
+ AlignedTable<float>& precomputed_table,
161
+ bool verbose);
152
162
 
153
163
  /// statistics are robust to internal threading, but not if
154
164
  /// IndexIVFPQ::search_preassigned is called by multiple threads
155
165
  struct IndexIVFPQStats {
156
- size_t nrefine; ///< nb of refines (IVFPQR)
166
+ size_t nrefine; ///< nb of refines (IVFPQR)
157
167
 
158
168
  size_t n_hamming_pass;
159
169
  ///< nb of passed Hamming distance tests (for polysemous)
@@ -162,17 +172,15 @@ struct IndexIVFPQStats {
162
172
  size_t search_cycles;
163
173
  size_t refine_cycles; ///< only for IVFPQR
164
174
 
165
- IndexIVFPQStats () {reset (); }
166
- void reset ();
175
+ IndexIVFPQStats() {
176
+ reset();
177
+ }
178
+ void reset();
167
179
  };
168
180
 
169
181
  // global var that collects them all
170
182
  FAISS_API extern IndexIVFPQStats indexIVFPQ_stats;
171
183
 
172
-
173
-
174
-
175
184
  } // namespace faiss
176
185
 
177
-
178
186
  #endif
@@ -8,70 +8,68 @@
8
8
  #include <faiss/IndexIVFPQFastScan.h>
9
9
 
10
10
  #include <cassert>
11
+ #include <cinttypes>
11
12
  #include <cstdio>
12
- #include <inttypes.h>
13
13
 
14
14
  #include <omp.h>
15
15
 
16
16
  #include <memory>
17
17
 
18
+ #include <faiss/impl/AuxIndexStructures.h>
18
19
  #include <faiss/impl/FaissAssert.h>
19
- #include <faiss/utils/utils.h>
20
20
  #include <faiss/utils/distances.h>
21
21
  #include <faiss/utils/simdlib.h>
22
- #include <faiss/impl/AuxIndexStructures.h>
22
+ #include <faiss/utils/utils.h>
23
23
 
24
24
  #include <faiss/invlists/BlockInvertedLists.h>
25
25
 
26
+ #include <faiss/impl/pq4_fast_scan.h>
26
27
  #include <faiss/impl/simd_result_handlers.h>
27
28
  #include <faiss/utils/quantize_lut.h>
28
- #include <faiss/impl/pq4_fast_scan.h>
29
29
 
30
30
  namespace faiss {
31
31
 
32
32
  using namespace simd_result_handlers;
33
33
 
34
-
35
34
  inline size_t roundup(size_t a, size_t b) {
36
35
  return (a + b - 1) / b * b;
37
36
  }
38
37
 
39
-
40
- IndexIVFPQFastScan::IndexIVFPQFastScan (
41
- Index * quantizer, size_t d, size_t nlist,
42
- size_t M, size_t nbits_per_idx,
43
- MetricType metric, int bbs):
44
- IndexIVF (quantizer, d, nlist, 0, metric),
45
- pq (d, M, nbits_per_idx),
46
- bbs (bbs)
47
- {
38
+ IndexIVFPQFastScan::IndexIVFPQFastScan(
39
+ Index* quantizer,
40
+ size_t d,
41
+ size_t nlist,
42
+ size_t M,
43
+ size_t nbits_per_idx,
44
+ MetricType metric,
45
+ int bbs)
46
+ : IndexIVF(quantizer, d, nlist, 0, metric),
47
+ pq(d, M, nbits_per_idx),
48
+ bbs(bbs) {
48
49
  FAISS_THROW_IF_NOT(nbits_per_idx == 4);
49
50
  M2 = roundup(pq.M, 2);
50
51
  by_residual = false; // set to false by default because it's much faster
51
52
  is_trained = false;
52
53
  code_size = pq.code_size;
53
54
 
54
- replace_invlists(
55
- new BlockInvertedLists(nlist, bbs, bbs * M2 / 2),
56
- true
57
- );
55
+ replace_invlists(new BlockInvertedLists(nlist, bbs, bbs * M2 / 2), true);
58
56
  }
59
57
 
60
- IndexIVFPQFastScan::IndexIVFPQFastScan ()
61
- {
58
+ IndexIVFPQFastScan::IndexIVFPQFastScan() {
62
59
  by_residual = false;
63
60
  bbs = 0;
64
61
  M2 = 0;
65
62
  }
66
63
 
67
-
68
- IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ & orig, int bbs):
69
- IndexIVF(
70
- orig.quantizer, orig.d, orig.nlist,
71
- orig.pq.code_size, orig.metric_type),
72
- pq(orig.pq),
73
- bbs(bbs)
74
- {
64
+ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ& orig, int bbs)
65
+ : IndexIVF(
66
+ orig.quantizer,
67
+ orig.d,
68
+ orig.nlist,
69
+ orig.pq.code_size,
70
+ orig.metric_type),
71
+ pq(orig.pq),
72
+ bbs(bbs) {
75
73
  FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
76
74
 
77
75
  by_residual = orig.by_residual;
@@ -83,69 +81,68 @@ IndexIVFPQFastScan::IndexIVFPQFastScan(const IndexIVFPQ & orig, int bbs):
83
81
  M2 = roundup(M, 2);
84
82
 
85
83
  replace_invlists(
86
- new BlockInvertedLists(orig.nlist, bbs, bbs * M2 / 2),
87
- true
88
- );
84
+ new BlockInvertedLists(orig.nlist, bbs, bbs * M2 / 2), true);
89
85
 
90
86
  precomputed_table.resize(orig.precomputed_table.size());
91
87
 
92
88
  if (precomputed_table.nbytes() > 0) {
93
- memcpy(precomputed_table.get(), orig.precomputed_table.data(),
94
- precomputed_table.nbytes()
95
- );
89
+ memcpy(precomputed_table.get(),
90
+ orig.precomputed_table.data(),
91
+ precomputed_table.nbytes());
96
92
  }
97
93
 
98
- for(size_t i = 0; i < nlist; i++) {
94
+ for (size_t i = 0; i < nlist; i++) {
99
95
  size_t nb = orig.invlists->list_size(i);
100
96
  size_t nb2 = roundup(nb, bbs);
101
97
  AlignedTable<uint8_t> tmp(nb2 * M2 / 2);
102
98
  pq4_pack_codes(
103
- InvertedLists::ScopedCodes(orig.invlists, i).get(),
104
- nb, M, nb2, bbs, M2,
105
- tmp.get()
106
- );
99
+ InvertedLists::ScopedCodes(orig.invlists, i).get(),
100
+ nb,
101
+ M,
102
+ nb2,
103
+ bbs,
104
+ M2,
105
+ tmp.get());
107
106
  invlists->add_entries(
108
- i, nb,
109
- InvertedLists::ScopedIds(orig.invlists, i).get(),
110
- tmp.get()
111
- );
107
+ i,
108
+ nb,
109
+ InvertedLists::ScopedIds(orig.invlists, i).get(),
110
+ tmp.get());
112
111
  }
113
112
 
114
113
  orig_invlists = orig.invlists;
115
114
  }
116
115
 
117
-
118
-
119
116
  /*********************************************************
120
117
  * Training
121
118
  *********************************************************/
122
119
 
123
- void IndexIVFPQFastScan::train_residual (idx_t n, const float *x_in)
124
- {
120
+ void IndexIVFPQFastScan::train_residual(idx_t n, const float* x_in) {
121
+ const float* x = fvecs_maybe_subsample(
122
+ d,
123
+ (size_t*)&n,
124
+ pq.cp.max_points_per_centroid * pq.ksub,
125
+ x_in,
126
+ verbose,
127
+ pq.cp.seed);
125
128
 
126
- const float * x = fvecs_maybe_subsample (
127
- d, (size_t*)&n, pq.cp.max_points_per_centroid * pq.ksub,
128
- x_in, verbose, pq.cp.seed);
129
-
130
- std::unique_ptr<float []> del_x;
129
+ std::unique_ptr<float[]> del_x;
131
130
  if (x != x_in) {
132
131
  del_x.reset((float*)x);
133
132
  }
134
133
 
135
- const float *trainset;
134
+ const float* trainset;
136
135
  AlignedTable<float> residuals;
137
136
 
138
137
  if (by_residual) {
139
- if(verbose) printf("computing residuals\n");
138
+ if (verbose)
139
+ printf("computing residuals\n");
140
140
  std::vector<idx_t> assign(n);
141
- quantizer->assign (n, x, assign.data());
141
+ quantizer->assign(n, x, assign.data());
142
142
  residuals.resize(n * d);
143
143
  for (idx_t i = 0; i < n; i++) {
144
- quantizer->compute_residual (
145
- x + i * d,
146
- residuals.data() + i * d,
147
- assign[i]
148
- );
144
+ quantizer->compute_residual(
145
+ x + i * d, residuals.data() + i * d, assign[i]);
149
146
  }
150
147
  trainset = residuals.data();
151
148
  } else {
@@ -153,82 +150,78 @@ void IndexIVFPQFastScan::train_residual (idx_t n, const float *x_in)
153
150
  }
154
151
 
155
152
  if (verbose) {
156
- printf ("training %zdx%zd product quantizer on %zd vectors in %dD\n",
157
- pq.M, pq.ksub, long(n), d);
153
+ printf("training %zdx%zd product quantizer on "
154
+ "%" PRId64 " vectors in %dD\n",
155
+ pq.M,
156
+ pq.ksub,
157
+ n,
158
+ d);
158
159
  }
159
160
  pq.verbose = verbose;
160
- pq.train (n, trainset);
161
+ pq.train(n, trainset);
161
162
 
162
163
  if (by_residual && metric_type == METRIC_L2) {
163
164
  precompute_table();
164
165
  }
165
-
166
166
  }
167
167
 
168
- void IndexIVFPQFastScan::precompute_table ()
169
- {
168
+ void IndexIVFPQFastScan::precompute_table() {
170
169
  initialize_IVFPQ_precomputed_table(
171
- use_precomputed_table,
172
- quantizer, pq, precomputed_table, verbose
173
- );
170
+ use_precomputed_table, quantizer, pq, precomputed_table, verbose);
174
171
  }
175
172
 
176
-
177
173
  /*********************************************************
178
174
  * Code management functions
179
175
  *********************************************************/
180
176
 
181
-
182
-
183
177
  void IndexIVFPQFastScan::encode_vectors(
184
- idx_t n, const float* x, const idx_t *list_nos,
185
- uint8_t * codes, bool include_listnos) const
186
- {
187
-
178
+ idx_t n,
179
+ const float* x,
180
+ const idx_t* list_nos,
181
+ uint8_t* codes,
182
+ bool include_listnos) const {
188
183
  if (by_residual) {
189
- AlignedTable<float> residuals (n * d);
184
+ AlignedTable<float> residuals(n * d);
190
185
  for (size_t i = 0; i < n; i++) {
191
186
  if (list_nos[i] < 0) {
192
- memset (residuals.data() + i * d, 0, sizeof(residuals[0]) * d);
187
+ memset(residuals.data() + i * d, 0, sizeof(residuals[0]) * d);
193
188
  } else {
194
- quantizer->compute_residual (
195
- x + i * d, residuals.data() + i * d, list_nos[i]);
189
+ quantizer->compute_residual(
190
+ x + i * d, residuals.data() + i * d, list_nos[i]);
196
191
  }
197
192
  }
198
- pq.compute_codes (residuals.data(), codes, n);
193
+ pq.compute_codes(residuals.data(), codes, n);
199
194
  } else {
200
- pq.compute_codes (x, codes, n);
195
+ pq.compute_codes(x, codes, n);
201
196
  }
202
197
 
203
198
  if (include_listnos) {
204
199
  size_t coarse_size = coarse_code_size();
205
200
  for (idx_t i = n - 1; i >= 0; i--) {
206
- uint8_t * code = codes + i * (coarse_size + code_size);
207
- memmove (code + coarse_size,
208
- codes + i * code_size, code_size);
209
- encode_listno (list_nos[i], code);
201
+ uint8_t* code = codes + i * (coarse_size + code_size);
202
+ memmove(code + coarse_size, codes + i * code_size, code_size);
203
+ encode_listno(list_nos[i], code);
210
204
  }
211
205
  }
212
206
  }
213
207
 
214
-
215
-
216
- void IndexIVFPQFastScan::add_with_ids (
217
- idx_t n, const float * x, const idx_t *xids) {
218
-
208
+ void IndexIVFPQFastScan::add_with_ids(
209
+ idx_t n,
210
+ const float* x,
211
+ const idx_t* xids) {
219
212
  // copied from IndexIVF::add_with_ids --->
220
213
 
221
214
  // do some blocking to avoid excessive allocs
222
215
  idx_t bs = 65536;
223
216
  if (n > bs) {
224
217
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
225
- idx_t i1 = std::min (n, i0 + bs);
218
+ idx_t i1 = std::min(n, i0 + bs);
226
219
  if (verbose) {
227
220
  printf(" IndexIVFPQFastScan::add_with_ids %zd: %zd",
228
- size_t(i0), size_t(i1));
221
+ size_t(i0),
222
+ size_t(i1));
229
223
  }
230
- add_with_ids (i1 - i0, x + i0 * d,
231
- xids ? xids + i0 : nullptr);
224
+ add_with_ids(i1 - i0, x + i0 * d, xids ? xids + i0 : nullptr);
232
225
  }
233
226
  return;
234
227
  }
@@ -236,37 +229,38 @@ void IndexIVFPQFastScan::add_with_ids (
236
229
 
237
230
  AlignedTable<uint8_t> codes(n * code_size);
238
231
 
239
- FAISS_THROW_IF_NOT (is_trained);
240
- direct_map.check_can_add (xids);
232
+ FAISS_THROW_IF_NOT(is_trained);
233
+ direct_map.check_can_add(xids);
241
234
 
242
- std::unique_ptr<idx_t []> idx(new idx_t[n]);
243
- quantizer->assign (n, x, idx.get());
235
+ std::unique_ptr<idx_t[]> idx(new idx_t[n]);
236
+ quantizer->assign(n, x, idx.get());
244
237
  size_t nadd = 0, nminus1 = 0;
245
238
 
246
239
  for (size_t i = 0; i < n; i++) {
247
- if (idx[i] < 0) nminus1++;
240
+ if (idx[i] < 0)
241
+ nminus1++;
248
242
  }
249
243
 
250
244
  AlignedTable<uint8_t> flat_codes(n * code_size);
251
- encode_vectors (n, x, idx.get(), flat_codes.get());
245
+ encode_vectors(n, x, idx.get(), flat_codes.get());
252
246
 
253
247
  DirectMapAdd dm_adder(direct_map, n, xids);
254
248
 
255
249
  // <---
256
250
 
257
- BlockInvertedLists *bil = dynamic_cast<BlockInvertedLists*>(invlists);
258
- FAISS_THROW_IF_NOT_MSG (bil, "only block inverted lists supported");
251
+ BlockInvertedLists* bil = dynamic_cast<BlockInvertedLists*>(invlists);
252
+ FAISS_THROW_IF_NOT_MSG(bil, "only block inverted lists supported");
259
253
 
260
254
  // prepare batches
261
255
  std::vector<idx_t> order(n);
262
- for(idx_t i = 0; i < n ; i++) { order[i] = i; }
256
+ for (idx_t i = 0; i < n; i++) {
257
+ order[i] = i;
258
+ }
263
259
 
264
260
  // TODO should not need stable
265
- std::stable_sort(order.begin(), order.end(),
266
- [&idx](idx_t a, idx_t b) {
267
- return idx[a] < idx[b];
268
- }
269
- );
261
+ std::stable_sort(order.begin(), order.end(), [&idx](idx_t a, idx_t b) {
262
+ return idx[a] < idx[b];
263
+ });
270
264
 
271
265
  // TODO parallelize
272
266
  idx_t i0 = 0;
@@ -274,7 +268,7 @@ void IndexIVFPQFastScan::add_with_ids (
274
268
  idx_t list_no = idx[order[i0]];
275
269
  idx_t i1 = i0 + 1;
276
270
  while (i1 < n && idx[order[i1]] == list_no) {
277
- i1 ++;
271
+ i1++;
278
272
  }
279
273
 
280
274
  if (list_no == -1) {
@@ -288,58 +282,57 @@ void IndexIVFPQFastScan::add_with_ids (
288
282
 
289
283
  bil->resize(list_no, list_size + i1 - i0);
290
284
 
291
- for(idx_t i = i0; i < i1; i++) {
285
+ for (idx_t i = i0; i < i1; i++) {
292
286
  size_t ofs = list_size + i - i0;
293
287
  idx_t id = xids ? xids[order[i]] : ntotal + order[i];
294
- dm_adder.add (order[i], list_no, ofs);
288
+ dm_adder.add(order[i], list_no, ofs);
295
289
  bil->ids[list_no][ofs] = id;
296
- memcpy(
297
- list_codes.data() + (i - i0) * code_size,
298
- flat_codes.data() + order[i] * code_size,
299
- code_size
300
- );
290
+ memcpy(list_codes.data() + (i - i0) * code_size,
291
+ flat_codes.data() + order[i] * code_size,
292
+ code_size);
301
293
  nadd++;
302
294
  }
303
295
  pq4_pack_codes_range(
304
- list_codes.data(), pq.M,
305
- list_size, list_size + i1 - i0,
306
- bbs, M2, bil->codes[list_no].data()
307
- );
296
+ list_codes.data(),
297
+ pq.M,
298
+ list_size,
299
+ list_size + i1 - i0,
300
+ bbs,
301
+ M2,
302
+ bil->codes[list_no].data());
308
303
 
309
304
  i0 = i1;
310
305
  }
311
306
 
312
307
  ntotal += n;
313
-
314
308
  }
315
309
 
316
-
317
-
318
310
  /*********************************************************
319
311
  * search
320
312
  *********************************************************/
321
313
 
322
-
323
314
  namespace {
324
315
 
325
316
  // from impl/ProductQuantizer.cpp
326
317
  template <class C, typename dis_t>
327
318
  void pq_estimators_from_tables_generic(
328
- const ProductQuantizer& pq, size_t nbits,
329
- const uint8_t *codes, size_t ncodes,
330
- const dis_t *dis_table, const int64_t * ids,
319
+ const ProductQuantizer& pq,
320
+ size_t nbits,
321
+ const uint8_t* codes,
322
+ size_t ncodes,
323
+ const dis_t* dis_table,
324
+ const int64_t* ids,
331
325
  float dis0,
332
- size_t k, typename C::T *heap_dis, int64_t *heap_ids)
333
- {
326
+ size_t k,
327
+ typename C::T* heap_dis,
328
+ int64_t* heap_ids) {
334
329
  using accu_t = typename C::T;
335
330
  const size_t M = pq.M;
336
331
  const size_t ksub = pq.ksub;
337
332
  for (size_t j = 0; j < ncodes; ++j) {
338
- PQDecoderGeneric decoder(
339
- codes + j * pq.code_size, nbits
340
- );
333
+ PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
341
334
  accu_t dis = dis0;
342
- const dis_t * dt = dis_table;
335
+ const dis_t* dt = dis_table;
343
336
  for (size_t m = 0; m < M; m++) {
344
337
  uint64_t c = decoder.decode();
345
338
  dis += dt[c];
@@ -356,17 +349,19 @@ void pq_estimators_from_tables_generic(
356
349
  using idx_t = Index::idx_t;
357
350
  using namespace quantize_lut;
358
351
 
359
- void fvec_madd_avx (
360
- size_t n, const float *a,
361
- float bf, const float *b, float *c)
362
- {
352
+ void fvec_madd_avx(
353
+ size_t n,
354
+ const float* a,
355
+ float bf,
356
+ const float* b,
357
+ float* c) {
363
358
  assert(is_aligned_pointer(a));
364
359
  assert(is_aligned_pointer(b));
365
360
  assert(is_aligned_pointer(c));
366
361
  assert(n % 8 == 0);
367
362
  simd8float32 bf8(bf);
368
363
  n /= 8;
369
- for(size_t i = 0; i < n; i++) {
364
+ for (size_t i = 0; i < n; i++) {
370
365
  simd8float32 ai(a);
371
366
  simd8float32 bi(b);
372
367
 
@@ -376,7 +371,6 @@ void fvec_madd_avx (
376
371
  a += 8;
377
372
  b += 8;
378
373
  }
379
-
380
374
  }
381
375
 
382
376
  } // anonymous namespace
@@ -385,23 +379,20 @@ void fvec_madd_avx (
385
379
  * Look-Up Table functions
386
380
  *********************************************************/
387
381
 
388
-
389
382
  void IndexIVFPQFastScan::compute_LUT(
390
- size_t n, const float *x,
391
- const idx_t *coarse_ids, const float *coarse_dis,
392
- AlignedTable<float> & dis_tables,
393
- AlignedTable<float> & biases
394
- ) const
395
- {
396
- const IndexIVFPQFastScan & ivfpq = *this;
383
+ size_t n,
384
+ const float* x,
385
+ const idx_t* coarse_ids,
386
+ const float* coarse_dis,
387
+ AlignedTable<float>& dis_tables,
388
+ AlignedTable<float>& biases) const {
389
+ const IndexIVFPQFastScan& ivfpq = *this;
397
390
  size_t dim12 = pq.ksub * pq.M;
398
391
  size_t d = pq.d;
399
392
  size_t nprobe = ivfpq.nprobe;
400
393
 
401
394
  if (ivfpq.by_residual) {
402
-
403
395
  if (ivfpq.metric_type == METRIC_L2) {
404
-
405
396
  dis_tables.resize(n * nprobe * dim12);
406
397
 
407
398
  if (ivfpq.use_precomputed_table == 1) {
@@ -409,57 +400,54 @@ void IndexIVFPQFastScan::compute_LUT(
409
400
  memcpy(biases.get(), coarse_dis, sizeof(float) * n * nprobe);
410
401
 
411
402
  AlignedTable<float> ip_table(n * dim12);
412
- pq.compute_inner_prod_tables (n, x, ip_table.get());
403
+ pq.compute_inner_prod_tables(n, x, ip_table.get());
413
404
 
414
405
  #pragma omp parallel for if (n * nprobe > 8000)
415
- for(idx_t ij = 0; ij < n * nprobe; ij++) {
406
+ for (idx_t ij = 0; ij < n * nprobe; ij++) {
416
407
  idx_t i = ij / nprobe;
417
- float *tab = dis_tables.get() + ij * dim12;
408
+ float* tab = dis_tables.get() + ij * dim12;
418
409
  idx_t cij = coarse_ids[ij];
419
410
 
420
411
  if (cij >= 0) {
421
- fvec_madd_avx (
422
- dim12,
423
- precomputed_table.get() + cij * dim12,
424
- -2, ip_table.get() + i * dim12,
425
- tab
426
- );
412
+ fvec_madd_avx(
413
+ dim12,
414
+ precomputed_table.get() + cij * dim12,
415
+ -2,
416
+ ip_table.get() + i * dim12,
417
+ tab);
427
418
  } else {
428
419
  // fill with NaNs so that they are ignored during
429
420
  // LUT quantization
430
- memset (tab, -1, sizeof(float) * dim12);
421
+ memset(tab, -1, sizeof(float) * dim12);
431
422
  }
432
423
  }
433
424
 
434
425
  } else {
435
-
436
426
  std::unique_ptr<float[]> xrel(new float[n * nprobe * d]);
437
427
  biases.resize(n * nprobe);
438
428
  memset(biases.get(), 0, sizeof(float) * n * nprobe);
439
429
 
440
430
  #pragma omp parallel for if (n * nprobe > 8000)
441
- for(idx_t ij = 0; ij < n * nprobe; ij++) {
431
+ for (idx_t ij = 0; ij < n * nprobe; ij++) {
442
432
  idx_t i = ij / nprobe;
443
- float *xij = &xrel[ij * d];
433
+ float* xij = &xrel[ij * d];
444
434
  idx_t cij = coarse_ids[ij];
445
435
 
446
436
  if (cij >= 0) {
447
- ivfpq.quantizer->compute_residual(
448
- x + i * d, xij, cij);
437
+ ivfpq.quantizer->compute_residual(x + i * d, xij, cij);
449
438
  } else {
450
439
  // will fill with NaNs
451
440
  memset(xij, -1, sizeof(float) * d);
452
441
  }
453
442
  }
454
443
 
455
- pq.compute_distance_tables (
444
+ pq.compute_distance_tables(
456
445
  n * nprobe, xrel.get(), dis_tables.get());
457
-
458
446
  }
459
447
 
460
448
  } else if (ivfpq.metric_type == METRIC_INNER_PRODUCT) {
461
449
  dis_tables.resize(n * dim12);
462
- pq.compute_inner_prod_tables (n, x, dis_tables.get());
450
+ pq.compute_inner_prod_tables(n, x, dis_tables.get());
463
451
  // compute_inner_prod_tables(pq, n, x, dis_tables.get());
464
452
 
465
453
  biases.resize(n * nprobe);
@@ -471,33 +459,29 @@ void IndexIVFPQFastScan::compute_LUT(
471
459
  } else {
472
460
  dis_tables.resize(n * dim12);
473
461
  if (ivfpq.metric_type == METRIC_L2) {
474
- pq.compute_distance_tables (n, x, dis_tables.get());
462
+ pq.compute_distance_tables(n, x, dis_tables.get());
475
463
  } else if (ivfpq.metric_type == METRIC_INNER_PRODUCT) {
476
- pq.compute_inner_prod_tables (n, x, dis_tables.get());
464
+ pq.compute_inner_prod_tables(n, x, dis_tables.get());
477
465
  } else {
478
466
  FAISS_THROW_FMT("metric %d not supported", ivfpq.metric_type);
479
467
  }
480
468
  }
481
-
482
469
  }
483
470
 
484
471
  void IndexIVFPQFastScan::compute_LUT_uint8(
485
- size_t n, const float *x,
486
- const idx_t *coarse_ids, const float *coarse_dis,
487
- AlignedTable<uint8_t> & dis_tables,
488
- AlignedTable<uint16_t> & biases,
489
- float * normalizers
490
- ) const {
491
- const IndexIVFPQFastScan & ivfpq = *this;
472
+ size_t n,
473
+ const float* x,
474
+ const idx_t* coarse_ids,
475
+ const float* coarse_dis,
476
+ AlignedTable<uint8_t>& dis_tables,
477
+ AlignedTable<uint16_t>& biases,
478
+ float* normalizers) const {
479
+ const IndexIVFPQFastScan& ivfpq = *this;
492
480
  AlignedTable<float> dis_tables_float;
493
481
  AlignedTable<float> biases_float;
494
482
 
495
483
  uint64_t t0 = get_cy();
496
- compute_LUT(
497
- n, x,
498
- coarse_ids, coarse_dis,
499
- dis_tables_float, biases_float
500
- );
484
+ compute_LUT(n, x, coarse_ids, coarse_dis, dis_tables_float, biases_float);
501
485
  IVFFastScan_stats.t_compute_distance_tables += get_cy() - t0;
502
486
 
503
487
  bool lut_is_3d = ivfpq.by_residual && ivfpq.metric_type == METRIC_L2;
@@ -514,45 +498,52 @@ void IndexIVFPQFastScan::compute_LUT_uint8(
514
498
  uint64_t t1 = get_cy();
515
499
 
516
500
  #pragma omp parallel for if (n > 100)
517
- for(int64_t i = 0; i < n; i++) {
518
- const float *t_in = dis_tables_float.get() + i * dim123;
519
- const float *b_in = nullptr;
520
- uint8_t *t_out = dis_tables.get() + i * dim123_2;
521
- uint16_t *b_out = nullptr;
501
+ for (int64_t i = 0; i < n; i++) {
502
+ const float* t_in = dis_tables_float.get() + i * dim123;
503
+ const float* b_in = nullptr;
504
+ uint8_t* t_out = dis_tables.get() + i * dim123_2;
505
+ uint16_t* b_out = nullptr;
522
506
  if (biases_float.get()) {
523
507
  b_in = biases_float.get() + i * nprobe;
524
508
  b_out = biases.get() + i * nprobe;
525
509
  }
526
510
 
527
511
  quantize_LUT_and_bias(
528
- nprobe, pq.M, pq.ksub, lut_is_3d,
529
- t_in, b_in,
530
- t_out, M2, b_out,
531
- normalizers + 2 * i, normalizers + 2 * i + 1
532
- );
512
+ nprobe,
513
+ pq.M,
514
+ pq.ksub,
515
+ lut_is_3d,
516
+ t_in,
517
+ b_in,
518
+ t_out,
519
+ M2,
520
+ b_out,
521
+ normalizers + 2 * i,
522
+ normalizers + 2 * i + 1);
533
523
  }
534
524
  IVFFastScan_stats.t_round += get_cy() - t1;
535
-
536
525
  }
537
526
 
538
-
539
527
  /*********************************************************
540
528
  * Search functions
541
529
  *********************************************************/
542
530
 
543
- template<bool is_max>
531
+ template <bool is_max>
544
532
  void IndexIVFPQFastScan::search_dispatch_implem(
545
- idx_t n,
546
- const float* x,
547
- idx_t k,
548
- float* distances,
549
- idx_t* labels) const
550
- {
551
- using Cfloat = typename std::conditional<is_max,
552
- CMax<float, int64_t>, CMin<float, int64_t> >::type;
553
-
554
- using C = typename std::conditional<is_max,
555
- CMax<uint16_t, int64_t>, CMin<uint16_t, int64_t> >::type;
533
+ idx_t n,
534
+ const float* x,
535
+ idx_t k,
536
+ float* distances,
537
+ idx_t* labels) const {
538
+ using Cfloat = typename std::conditional<
539
+ is_max,
540
+ CMax<float, int64_t>,
541
+ CMin<float, int64_t>>::type;
542
+
543
+ using C = typename std::conditional<
544
+ is_max,
545
+ CMax<uint16_t, int64_t>,
546
+ CMin<uint16_t, int64_t>>::type;
556
547
 
557
548
  if (n == 0) {
558
549
  return;
@@ -568,7 +559,7 @@ void IndexIVFPQFastScan::search_dispatch_implem(
568
559
  impl = 10;
569
560
  }
570
561
  if (k > 20) {
571
- impl ++;
562
+ impl++;
572
563
  }
573
564
  }
574
565
 
@@ -582,11 +573,25 @@ void IndexIVFPQFastScan::search_dispatch_implem(
582
573
 
583
574
  if (n < 2) {
584
575
  if (impl == 12 || impl == 13) {
585
- search_implem_12<C>
586
- (n, x, k, distances, labels, impl, &ndis, &nlist_visited);
576
+ search_implem_12<C>(
577
+ n,
578
+ x,
579
+ k,
580
+ distances,
581
+ labels,
582
+ impl,
583
+ &ndis,
584
+ &nlist_visited);
587
585
  } else {
588
- search_implem_10<C>
589
- (n, x, k, distances, labels, impl, &ndis, &nlist_visited);
586
+ search_implem_10<C>(
587
+ n,
588
+ x,
589
+ k,
590
+ distances,
591
+ labels,
592
+ impl,
593
+ &ndis,
594
+ &nlist_visited);
590
595
  }
591
596
  } else {
592
597
  // explicitly slice over threads
@@ -595,34 +600,47 @@ void IndexIVFPQFastScan::search_dispatch_implem(
595
600
  nslice = n;
596
601
  } else if (by_residual && metric_type == METRIC_L2) {
597
602
  // make sure we don't make too big LUT tables
598
- size_t lut_size_per_query =
599
- pq.M * pq.ksub * nprobe * (sizeof(float) + sizeof(uint8_t));
603
+ size_t lut_size_per_query = pq.M * pq.ksub * nprobe *
604
+ (sizeof(float) + sizeof(uint8_t));
600
605
 
601
606
  size_t max_lut_size = precomputed_table_max_bytes;
602
607
  // how many queries we can handle within mem budget
603
- size_t nq_ok = std::max(max_lut_size / lut_size_per_query, size_t(1));
604
- nslice = roundup(std::max(size_t(n / nq_ok), size_t(1)), omp_get_max_threads());
608
+ size_t nq_ok =
609
+ std::max(max_lut_size / lut_size_per_query, size_t(1));
610
+ nslice =
611
+ roundup(std::max(size_t(n / nq_ok), size_t(1)),
612
+ omp_get_max_threads());
605
613
  } else {
606
614
  // LUTs unlikely to be a limiting factor
607
615
  nslice = omp_get_max_threads();
608
616
  }
609
617
 
610
- #pragma omp parallel for reduction(+: ndis, nlist_visited)
618
+ #pragma omp parallel for reduction(+ : ndis, nlist_visited)
611
619
  for (int slice = 0; slice < nslice; slice++) {
612
620
  idx_t i0 = n * slice / nslice;
613
621
  idx_t i1 = n * (slice + 1) / nslice;
614
- float *dis_i = distances + i0 * k;
615
- idx_t *lab_i = labels + i0 * k;
622
+ float* dis_i = distances + i0 * k;
623
+ idx_t* lab_i = labels + i0 * k;
616
624
  if (impl == 12 || impl == 13) {
617
625
  search_implem_12<C>(
618
- i1 - i0, x + i0 * d, k, dis_i, lab_i,
619
- impl, &ndis, &nlist_visited
620
- );
626
+ i1 - i0,
627
+ x + i0 * d,
628
+ k,
629
+ dis_i,
630
+ lab_i,
631
+ impl,
632
+ &ndis,
633
+ &nlist_visited);
621
634
  } else {
622
635
  search_implem_10<C>(
623
- i1 - i0, x + i0 * d, k, dis_i, lab_i,
624
- impl, &ndis, &nlist_visited
625
- );
636
+ i1 - i0,
637
+ x + i0 * d,
638
+ k,
639
+ dis_i,
640
+ lab_i,
641
+ impl,
642
+ &ndis,
643
+ &nlist_visited);
626
644
  }
627
645
  }
628
646
  }
@@ -632,14 +650,16 @@ void IndexIVFPQFastScan::search_dispatch_implem(
632
650
  } else {
633
651
  FAISS_THROW_FMT("implem %d does not exist", implem);
634
652
  }
635
-
636
653
  }
637
654
 
638
-
639
655
  void IndexIVFPQFastScan::search(
640
- idx_t n, const float* x, idx_t k,
641
- float* distances, idx_t* labels) const
642
- {
656
+ idx_t n,
657
+ const float* x,
658
+ idx_t k,
659
+ float* distances,
660
+ idx_t* labels) const {
661
+ FAISS_THROW_IF_NOT(k > 0);
662
+
643
663
  if (metric_type == METRIC_L2) {
644
664
  search_dispatch_implem<true>(n, x, k, distances, labels);
645
665
  } else {
@@ -647,133 +667,150 @@ void IndexIVFPQFastScan::search(
647
667
  }
648
668
  }
649
669
 
650
- template<class C>
670
+ template <class C>
651
671
  void IndexIVFPQFastScan::search_implem_1(
652
- idx_t n, const float* x, idx_t k,
653
- float* distances, idx_t* labels) const
654
- {
672
+ idx_t n,
673
+ const float* x,
674
+ idx_t k,
675
+ float* distances,
676
+ idx_t* labels) const {
655
677
  FAISS_THROW_IF_NOT(orig_invlists);
656
678
 
657
679
  std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
658
680
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
659
681
 
660
- quantizer->search (n, x, nprobe, coarse_dis.get(), coarse_ids.get());
682
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
661
683
 
662
684
  size_t dim12 = pq.ksub * pq.M;
663
685
  AlignedTable<float> dis_tables;
664
686
  AlignedTable<float> biases;
665
687
 
666
- compute_LUT (
667
- n, x,
668
- coarse_ids.get(), coarse_dis.get(),
669
- dis_tables, biases
670
- );
688
+ compute_LUT(n, x, coarse_ids.get(), coarse_dis.get(), dis_tables, biases);
671
689
 
672
690
  bool single_LUT = !(by_residual && metric_type == METRIC_L2);
673
691
 
674
692
  size_t ndis = 0, nlist_visited = 0;
675
693
 
676
- #pragma omp parallel for reduction(+: ndis, nlist_visited)
677
- for(idx_t i = 0; i < n; i++) {
678
- int64_t *heap_ids = labels + i * k;
679
- float *heap_dis = distances + i * k;
680
- heap_heapify<C> (k, heap_dis, heap_ids);
681
- float *LUT = nullptr;
694
+ #pragma omp parallel for reduction(+ : ndis, nlist_visited)
695
+ for (idx_t i = 0; i < n; i++) {
696
+ int64_t* heap_ids = labels + i * k;
697
+ float* heap_dis = distances + i * k;
698
+ heap_heapify<C>(k, heap_dis, heap_ids);
699
+ float* LUT = nullptr;
682
700
 
683
701
  if (single_LUT) {
684
702
  LUT = dis_tables.get() + i * dim12;
685
703
  }
686
- for(idx_t j = 0; j < nprobe; j++) {
704
+ for (idx_t j = 0; j < nprobe; j++) {
687
705
  if (!single_LUT) {
688
706
  LUT = dis_tables.get() + (i * nprobe + j) * dim12;
689
707
  }
690
708
  idx_t list_no = coarse_ids[i * nprobe + j];
691
- if (list_no < 0) continue;
709
+ if (list_no < 0)
710
+ continue;
692
711
  size_t ls = orig_invlists->list_size(list_no);
693
- if (ls == 0) continue;
712
+ if (ls == 0)
713
+ continue;
694
714
  InvertedLists::ScopedCodes codes(orig_invlists, list_no);
695
715
  InvertedLists::ScopedIds ids(orig_invlists, list_no);
696
716
 
697
717
  float bias = biases.get() ? biases[i * nprobe + j] : 0;
698
718
 
699
719
  pq_estimators_from_tables_generic<C>(
700
- pq, pq.nbits, codes.get(), ls,
701
- LUT, ids.get(), bias,
702
- k, heap_dis, heap_ids
703
- );
704
- nlist_visited ++;
705
- ndis ++;
720
+ pq,
721
+ pq.nbits,
722
+ codes.get(),
723
+ ls,
724
+ LUT,
725
+ ids.get(),
726
+ bias,
727
+ k,
728
+ heap_dis,
729
+ heap_ids);
730
+ nlist_visited++;
731
+ ndis++;
706
732
  }
707
- heap_reorder<C> (k, heap_dis, heap_ids);
733
+ heap_reorder<C>(k, heap_dis, heap_ids);
708
734
  }
709
735
  indexIVF_stats.nq += n;
710
736
  indexIVF_stats.ndis += ndis;
711
737
  indexIVF_stats.nlist += nlist_visited;
712
738
  }
713
739
 
714
- template<class C>
740
+ template <class C>
715
741
  void IndexIVFPQFastScan::search_implem_2(
716
- idx_t n, const float* x, idx_t k,
717
- float* distances, idx_t* labels) const
718
- {
742
+ idx_t n,
743
+ const float* x,
744
+ idx_t k,
745
+ float* distances,
746
+ idx_t* labels) const {
719
747
  FAISS_THROW_IF_NOT(orig_invlists);
720
748
 
721
749
  std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
722
750
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
723
751
 
724
- quantizer->search (n, x, nprobe, coarse_dis.get(), coarse_ids.get());
752
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
725
753
 
726
754
  size_t dim12 = pq.ksub * M2;
727
755
  AlignedTable<uint8_t> dis_tables;
728
756
  AlignedTable<uint16_t> biases;
729
757
  std::unique_ptr<float[]> normalizers(new float[2 * n]);
730
758
 
731
- compute_LUT_uint8 (
732
- n, x,
733
- coarse_ids.get(), coarse_dis.get(),
734
- dis_tables, biases,
735
- normalizers.get()
736
- );
737
-
759
+ compute_LUT_uint8(
760
+ n,
761
+ x,
762
+ coarse_ids.get(),
763
+ coarse_dis.get(),
764
+ dis_tables,
765
+ biases,
766
+ normalizers.get());
738
767
 
739
768
  bool single_LUT = !(by_residual && metric_type == METRIC_L2);
740
769
 
741
770
  size_t ndis = 0, nlist_visited = 0;
742
771
 
743
- #pragma omp parallel for reduction(+: ndis, nlist_visited)
744
- for(idx_t i = 0; i < n; i++) {
772
+ #pragma omp parallel for reduction(+ : ndis, nlist_visited)
773
+ for (idx_t i = 0; i < n; i++) {
745
774
  std::vector<uint16_t> tmp_dis(k);
746
- int64_t *heap_ids = labels + i * k;
747
- uint16_t *heap_dis = tmp_dis.data();
748
- heap_heapify<C> (k, heap_dis, heap_ids);
749
- const uint8_t *LUT = nullptr;
775
+ int64_t* heap_ids = labels + i * k;
776
+ uint16_t* heap_dis = tmp_dis.data();
777
+ heap_heapify<C>(k, heap_dis, heap_ids);
778
+ const uint8_t* LUT = nullptr;
750
779
 
751
780
  if (single_LUT) {
752
781
  LUT = dis_tables.get() + i * dim12;
753
782
  }
754
- for(idx_t j = 0; j < nprobe; j++) {
783
+ for (idx_t j = 0; j < nprobe; j++) {
755
784
  if (!single_LUT) {
756
785
  LUT = dis_tables.get() + (i * nprobe + j) * dim12;
757
786
  }
758
787
  idx_t list_no = coarse_ids[i * nprobe + j];
759
- if (list_no < 0) continue;
788
+ if (list_no < 0)
789
+ continue;
760
790
  size_t ls = orig_invlists->list_size(list_no);
761
- if (ls == 0) continue;
791
+ if (ls == 0)
792
+ continue;
762
793
  InvertedLists::ScopedCodes codes(orig_invlists, list_no);
763
794
  InvertedLists::ScopedIds ids(orig_invlists, list_no);
764
795
 
765
796
  uint16_t bias = biases.get() ? biases[i * nprobe + j] : 0;
766
797
 
767
798
  pq_estimators_from_tables_generic<C>(
768
- pq, pq.nbits, codes.get(), ls,
769
- LUT, ids.get(), bias,
770
- k, heap_dis, heap_ids
771
- );
799
+ pq,
800
+ pq.nbits,
801
+ codes.get(),
802
+ ls,
803
+ LUT,
804
+ ids.get(),
805
+ bias,
806
+ k,
807
+ heap_dis,
808
+ heap_ids);
772
809
 
773
810
  nlist_visited++;
774
811
  ndis += ls;
775
812
  }
776
- heap_reorder<C> (k, heap_dis, heap_ids);
813
+ heap_reorder<C>(k, heap_dis, heap_ids);
777
814
  // convert distances to float
778
815
  {
779
816
  float one_a = 1 / normalizers[2 * i], b = normalizers[2 * i + 1];
@@ -781,7 +818,7 @@ void IndexIVFPQFastScan::search_implem_2(
781
818
  one_a = 1;
782
819
  b = 0;
783
820
  }
784
- float *heap_dis_float = distances + i * k;
821
+ float* heap_dis_float = distances + i * k;
785
822
  for (int j = 0; j < k; j++) {
786
823
  heap_dis_float[j] = b + heap_dis[j] * one_a;
787
824
  }
@@ -792,14 +829,16 @@ void IndexIVFPQFastScan::search_implem_2(
792
829
  indexIVF_stats.nlist += nlist_visited;
793
830
  }
794
831
 
795
-
796
-
797
- template<class C>
832
+ template <class C>
798
833
  void IndexIVFPQFastScan::search_implem_10(
799
- idx_t n, const float* x, idx_t k,
800
- float* distances, idx_t* labels,
801
- int impl, size_t *ndis_out, size_t *nlist_out) const
802
- {
834
+ idx_t n,
835
+ const float* x,
836
+ idx_t k,
837
+ float* distances,
838
+ idx_t* labels,
839
+ int impl,
840
+ size_t* ndis_out,
841
+ size_t* nlist_out) const {
803
842
  memset(distances, -1, sizeof(float) * k * n);
804
843
  memset(labels, -1, sizeof(idx_t) * k * n);
805
844
 
@@ -807,7 +846,6 @@ void IndexIVFPQFastScan::search_implem_10(
807
846
  using ReservoirHC = ReservoirHandler<C, true>;
808
847
  using SingleResultHC = SingleResultHandler<C, true>;
809
848
 
810
-
811
849
  std::unique_ptr<idx_t[]> coarse_ids(new idx_t[n * nprobe]);
812
850
  std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
813
851
 
@@ -817,20 +855,23 @@ void IndexIVFPQFastScan::search_implem_10(
817
855
  #define TIC times[ti++] = get_cy()
818
856
  TIC;
819
857
 
820
- quantizer->search (n, x, nprobe, coarse_dis.get(), coarse_ids.get());
858
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
821
859
 
822
860
  TIC;
823
861
 
824
862
  size_t dim12 = pq.ksub * M2;
825
863
  AlignedTable<uint8_t> dis_tables;
826
864
  AlignedTable<uint16_t> biases;
827
- std::unique_ptr<float[]> normalizers (new float[2 * n]);
865
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
828
866
 
829
- compute_LUT_uint8 (
830
- n, x,
831
- coarse_ids.get(), coarse_dis.get(),
832
- dis_tables, biases, normalizers.get()
833
- );
867
+ compute_LUT_uint8(
868
+ n,
869
+ x,
870
+ coarse_ids.get(),
871
+ coarse_dis.get(),
872
+ dis_tables,
873
+ biases,
874
+ normalizers.get());
834
875
 
835
876
  TIC;
836
877
 
@@ -841,15 +882,16 @@ void IndexIVFPQFastScan::search_implem_10(
841
882
 
842
883
  {
843
884
  AlignedTable<uint16_t> tmp_distances(k);
844
- for(idx_t i = 0; i < n; i++) {
845
- const uint8_t *LUT = nullptr;
885
+ for (idx_t i = 0; i < n; i++) {
886
+ const uint8_t* LUT = nullptr;
846
887
  int qmap1[1] = {0};
847
- std::unique_ptr<SIMDResultHandler<C, true> > handler;
888
+ std::unique_ptr<SIMDResultHandler<C, true>> handler;
848
889
 
849
890
  if (k == 1) {
850
891
  handler.reset(new SingleResultHC(1, 0));
851
892
  } else if (impl == 10) {
852
- handler.reset(new HeapHC(1, tmp_distances.get(), labels + i * k, k, 0));
893
+ handler.reset(new HeapHC(
894
+ 1, tmp_distances.get(), labels + i * k, k, 0));
853
895
  } else if (impl == 11) {
854
896
  handler.reset(new ReservoirHC(1, 0, k, 2 * k));
855
897
  } else {
@@ -861,7 +903,7 @@ void IndexIVFPQFastScan::search_implem_10(
861
903
  if (single_LUT) {
862
904
  LUT = dis_tables.get() + i * dim12;
863
905
  }
864
- for(idx_t j = 0; j < nprobe; j++) {
906
+ for (idx_t j = 0; j < nprobe; j++) {
865
907
  size_t ij = i * nprobe + j;
866
908
  if (!single_LUT) {
867
909
  LUT = dis_tables.get() + ij * dim12;
@@ -871,9 +913,11 @@ void IndexIVFPQFastScan::search_implem_10(
871
913
  }
872
914
 
873
915
  idx_t list_no = coarse_ids[ij];
874
- if (list_no < 0) continue;
916
+ if (list_no < 0)
917
+ continue;
875
918
  size_t ls = invlists->list_size(list_no);
876
- if (ls == 0) continue;
919
+ if (ls == 0)
920
+ continue;
877
921
 
878
922
  InvertedLists::ScopedCodes codes(invlists, list_no);
879
923
  InvertedLists::ScopedIds ids(invlists, list_no);
@@ -881,41 +925,40 @@ void IndexIVFPQFastScan::search_implem_10(
881
925
  handler->ntotal = ls;
882
926
  handler->id_map = ids.get();
883
927
 
884
- #define DISPATCH(classHC) \
885
- if(auto *res = dynamic_cast<classHC* > (handler.get())) { \
886
- pq4_accumulate_loop( \
887
- 1, roundup(ls, bbs), bbs, M2, \
888
- codes.get(), LUT, \
889
- *res \
890
- ); \
891
- }
928
+ #define DISPATCH(classHC) \
929
+ if (dynamic_cast<classHC*>(handler.get())) { \
930
+ auto* res = static_cast<classHC*>(handler.get()); \
931
+ pq4_accumulate_loop( \
932
+ 1, roundup(ls, bbs), bbs, M2, codes.get(), LUT, *res); \
933
+ }
892
934
  DISPATCH(HeapHC)
893
- else DISPATCH(ReservoirHC)
894
- else DISPATCH(SingleResultHC)
935
+ else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
895
936
  #undef DISPATCH
896
937
 
897
- nlist_visited ++;
898
- ndis ++;
938
+ nlist_visited++;
939
+ ndis++;
899
940
  }
900
941
 
901
942
  handler->to_flat_arrays(
902
- distances + i * k, labels + i * k,
903
- skip & 16 ? nullptr : normalizers.get() + i * 2
904
- );
943
+ distances + i * k,
944
+ labels + i * k,
945
+ skip & 16 ? nullptr : normalizers.get() + i * 2);
905
946
  }
906
947
  }
907
948
  *ndis_out = ndis;
908
949
  *nlist_out = nlist;
909
950
  }
910
951
 
911
-
912
-
913
- template<class C>
952
+ template <class C>
914
953
  void IndexIVFPQFastScan::search_implem_12(
915
- idx_t n, const float* x, idx_t k,
916
- float* distances, idx_t* labels,
917
- int impl, size_t *ndis_out, size_t *nlist_out) const
918
- {
954
+ idx_t n,
955
+ const float* x,
956
+ idx_t k,
957
+ float* distances,
958
+ idx_t* labels,
959
+ int impl,
960
+ size_t* ndis_out,
961
+ size_t* nlist_out) const {
919
962
  if (n == 0) { // does not work well with reservoir
920
963
  return;
921
964
  }
@@ -930,53 +973,53 @@ void IndexIVFPQFastScan::search_implem_12(
930
973
  #define TIC times[ti++] = get_cy()
931
974
  TIC;
932
975
 
933
- quantizer->search (n, x, nprobe, coarse_dis.get(), coarse_ids.get());
976
+ quantizer->search(n, x, nprobe, coarse_dis.get(), coarse_ids.get());
934
977
 
935
978
  TIC;
936
979
 
937
980
  size_t dim12 = pq.ksub * M2;
938
981
  AlignedTable<uint8_t> dis_tables;
939
982
  AlignedTable<uint16_t> biases;
940
- std::unique_ptr<float[]> normalizers (new float[2 * n]);
983
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
941
984
 
942
- compute_LUT_uint8 (
943
- n, x,
944
- coarse_ids.get(), coarse_dis.get(),
945
- dis_tables, biases, normalizers.get()
946
- );
985
+ compute_LUT_uint8(
986
+ n,
987
+ x,
988
+ coarse_ids.get(),
989
+ coarse_dis.get(),
990
+ dis_tables,
991
+ biases,
992
+ normalizers.get());
947
993
 
948
994
  TIC;
949
995
 
950
996
  struct QC {
951
- int qno; // sequence number of the query
952
- int list_no; // list to visit
953
- int rank; // this is the rank'th result of the coarse quantizer
997
+ int qno; // sequence number of the query
998
+ int list_no; // list to visit
999
+ int rank; // this is the rank'th result of the coarse quantizer
954
1000
  };
955
1001
  bool single_LUT = !(by_residual && metric_type == METRIC_L2);
956
1002
 
957
1003
  std::vector<QC> qcs;
958
1004
  {
959
1005
  int ij = 0;
960
- for(int i = 0; i < n; i++) {
961
- for(int j = 0; j < nprobe; j++) {
1006
+ for (int i = 0; i < n; i++) {
1007
+ for (int j = 0; j < nprobe; j++) {
962
1008
  if (coarse_ids[ij] >= 0) {
963
1009
  qcs.push_back(QC{i, int(coarse_ids[ij]), int(j)});
964
1010
  }
965
1011
  ij++;
966
1012
  }
967
1013
  }
968
- std::sort(
969
- qcs.begin(), qcs.end(),
970
- [](const QC &a, const QC & b) {
971
- return a.list_no < b.list_no;
972
- }
973
- );
1014
+ std::sort(qcs.begin(), qcs.end(), [](const QC& a, const QC& b) {
1015
+ return a.list_no < b.list_no;
1016
+ });
974
1017
  }
975
1018
  TIC;
976
1019
 
977
1020
  // prepare the result handlers
978
1021
 
979
- std::unique_ptr<SIMDResultHandler<C, true> > handler;
1022
+ std::unique_ptr<SIMDResultHandler<C, true>> handler;
980
1023
  AlignedTable<uint16_t> tmp_distances;
981
1024
 
982
1025
  using HeapHC = HeapHandler<C, true>;
@@ -1012,7 +1055,7 @@ void IndexIVFPQFastScan::search_implem_12(
1012
1055
  int list_no = qcs[i0].list_no;
1013
1056
  size_t i1 = i0 + 1;
1014
1057
 
1015
- while(i1 < qcs.size() && i1 < i0 + qbs2) {
1058
+ while (i1 < qcs.size() && i1 < i0 + qbs2) {
1016
1059
  if (qcs[i1].list_no != list_no) {
1017
1060
  break;
1018
1061
  }
@@ -1034,8 +1077,8 @@ void IndexIVFPQFastScan::search_implem_12(
1034
1077
  memset(LUT.get(), -1, nc * dim12);
1035
1078
  int qbs = pq4_preferred_qbs(nc);
1036
1079
 
1037
- for(size_t i = i0; i < i1; i++) {
1038
- const QC & qc = qcs[i];
1080
+ for (size_t i = i0; i < i1; i++) {
1081
+ const QC& qc = qcs[i];
1039
1082
  q_map[i - i0] = qc.qno;
1040
1083
  int ij = qc.qno * nprobe + qc.rank;
1041
1084
  lut_entries[i - i0] = single_LUT ? qc.qno : ij;
@@ -1044,9 +1087,7 @@ void IndexIVFPQFastScan::search_implem_12(
1044
1087
  }
1045
1088
  }
1046
1089
  pq4_pack_LUT_qbs_q_map(
1047
- qbs, M2, dis_tables.get(), lut_entries.data(),
1048
- LUT.get()
1049
- );
1090
+ qbs, M2, dis_tables.get(), lut_entries.data(), LUT.get());
1050
1091
 
1051
1092
  // access the inverted list
1052
1093
 
@@ -1062,20 +1103,17 @@ void IndexIVFPQFastScan::search_implem_12(
1062
1103
  handler->id_map = ids.get();
1063
1104
  uint64_t tt1 = get_cy();
1064
1105
 
1065
- #define DISPATCH(classHC) \
1066
- if(auto *res = dynamic_cast<classHC* > (handler.get())) { \
1067
- pq4_accumulate_loop_qbs( \
1068
- qbs, list_size, M2, \
1069
- codes.get(), LUT.get(), \
1070
- *res \
1071
- ); \
1072
- }
1106
+ #define DISPATCH(classHC) \
1107
+ if (dynamic_cast<classHC*>(handler.get())) { \
1108
+ auto* res = static_cast<classHC*>(handler.get()); \
1109
+ pq4_accumulate_loop_qbs( \
1110
+ qbs, list_size, M2, codes.get(), LUT.get(), *res); \
1111
+ }
1073
1112
  DISPATCH(HeapHC)
1074
- else DISPATCH(ReservoirHC)
1075
- else DISPATCH(SingleResultHC)
1113
+ else DISPATCH(ReservoirHC) else DISPATCH(SingleResultHC)
1076
1114
 
1077
- // prepare for next loop
1078
- i0 = i1;
1115
+ // prepare for next loop
1116
+ i0 = i1;
1079
1117
 
1080
1118
  uint64_t tt2 = get_cy();
1081
1119
  t_copy_pack += tt1 - tt0;
@@ -1085,21 +1123,19 @@ void IndexIVFPQFastScan::search_implem_12(
1085
1123
 
1086
1124
  // labels is in-place for HeapHC
1087
1125
  handler->to_flat_arrays(
1088
- distances, labels,
1089
- skip & 16 ? nullptr : normalizers.get()
1090
- );
1126
+ distances, labels, skip & 16 ? nullptr : normalizers.get());
1091
1127
 
1092
1128
  TIC;
1093
1129
 
1094
1130
  // these stats are not thread-safe
1095
1131
 
1096
- for(int i = 1; i < ti; i++) {
1097
- IVFFastScan_stats.times[i] += times[i] - times[i-1];
1132
+ for (int i = 1; i < ti; i++) {
1133
+ IVFFastScan_stats.times[i] += times[i] - times[i - 1];
1098
1134
  }
1099
1135
  IVFFastScan_stats.t_copy_pack += t_copy_pack;
1100
1136
  IVFFastScan_stats.t_scan += t_scan;
1101
1137
 
1102
- if (auto *rh = dynamic_cast<ReservoirHC*> (handler.get())) {
1138
+ if (auto* rh = dynamic_cast<ReservoirHC*>(handler.get())) {
1103
1139
  for (int i = 0; i < 4; i++) {
1104
1140
  IVFFastScan_stats.reservoir_times[i] += rh->times[i];
1105
1141
  }
@@ -1107,10 +1143,8 @@ void IndexIVFPQFastScan::search_implem_12(
1107
1143
 
1108
1144
  *ndis_out = ndis;
1109
1145
  *nlist_out = nlist;
1110
-
1111
1146
  }
1112
1147
 
1113
-
1114
1148
  IVFFastScanStats IVFFastScan_stats;
1115
1149
 
1116
1150
  } // namespace faiss