faiss 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (199) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +16 -4
  5. data/ext/faiss/ext.cpp +12 -308
  6. data/ext/faiss/extconf.rb +6 -3
  7. data/ext/faiss/index.cpp +189 -0
  8. data/ext/faiss/index_binary.cpp +75 -0
  9. data/ext/faiss/kmeans.cpp +40 -0
  10. data/ext/faiss/numo.hpp +867 -0
  11. data/ext/faiss/pca_matrix.cpp +33 -0
  12. data/ext/faiss/product_quantizer.cpp +53 -0
  13. data/ext/faiss/utils.cpp +13 -0
  14. data/ext/faiss/utils.h +5 -0
  15. data/lib/faiss.rb +0 -5
  16. data/lib/faiss/version.rb +1 -1
  17. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  18. data/vendor/faiss/faiss/AutoTune.h +6 -3
  19. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  20. data/vendor/faiss/faiss/Index.cpp +3 -4
  21. data/vendor/faiss/faiss/Index.h +3 -3
  22. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  23. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  26. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  27. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  29. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  30. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  31. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  32. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  33. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  34. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  37. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  38. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  39. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  41. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  42. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  43. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  44. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  45. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  46. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  47. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  48. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  49. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  50. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  51. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  52. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  53. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  54. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  55. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  56. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  57. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  58. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  59. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  60. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  61. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  62. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  63. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  64. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  66. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  67. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  68. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  69. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  70. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  71. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  72. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  73. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  74. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  75. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  76. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  77. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  78. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  79. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  80. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  81. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  82. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  83. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  84. data/vendor/faiss/faiss/impl/io.h +7 -2
  85. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  86. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  87. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  88. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  89. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  90. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  91. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  92. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  93. data/vendor/faiss/faiss/index_io.h +1 -48
  94. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  95. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  96. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  97. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  98. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  99. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  100. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  101. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  102. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  103. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  104. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  105. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  106. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  107. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  108. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  109. data/vendor/faiss/faiss/utils/distances.h +28 -20
  110. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  111. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  112. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  113. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  114. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  115. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  116. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  117. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  118. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  119. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  120. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  121. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  122. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  123. metadata +54 -149
  124. data/lib/faiss/index.rb +0 -20
  125. data/lib/faiss/index_binary.rb +0 -20
  126. data/lib/faiss/kmeans.rb +0 -15
  127. data/lib/faiss/pca_matrix.rb +0 -15
  128. data/lib/faiss/product_quantizer.rb +0 -22
  129. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  130. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  131. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  132. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  133. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  134. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  135. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  136. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  137. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  138. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  139. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  140. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  141. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  142. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  143. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  144. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  145. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  146. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  147. data/vendor/faiss/c_api/Index_c.h +0 -183
  148. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  149. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  150. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  151. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  152. data/vendor/faiss/c_api/error_c.h +0 -42
  153. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  154. data/vendor/faiss/c_api/error_impl.h +0 -16
  155. data/vendor/faiss/c_api/faiss_c.h +0 -58
  156. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  157. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  158. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  159. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  160. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  161. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  162. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  163. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  164. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  165. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  166. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  167. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  168. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  169. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  170. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  171. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  172. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  173. data/vendor/faiss/c_api/index_io_c.h +0 -50
  174. data/vendor/faiss/c_api/macros_impl.h +0 -110
  175. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  176. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  177. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  178. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  179. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  180. data/vendor/faiss/misc/test_blas.cpp +0 -87
  181. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  182. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  183. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  184. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  185. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  186. data/vendor/faiss/tests/test_merge.cpp +0 -260
  187. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  188. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  189. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  190. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  191. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  192. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  193. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  194. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  195. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  196. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  197. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  198. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  199. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,166 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <memory>
11
+
12
+ #include <faiss/IndexIVFPQ.h>
13
+ #include <faiss/impl/ProductQuantizer.h>
14
+ #include <faiss/utils/AlignedTable.h>
15
+
16
+ namespace faiss {
17
+
18
+
19
+ /** Fast scan version of IVFPQ. Works for 4-bit PQ for now.
20
+ *
21
+ * The codes in the inverted lists are not stored sequentially but
22
+ * grouped in blocks of size bbs. This makes it possible to very quickly
23
+ * compute distances with SIMD instructions.
24
+ *
25
+ * Implementations (implem):
26
+ * 0: auto-select implementation (default)
27
+ * 1: orig's search, re-implemented
28
+ * 2: orig's search, re-ordered by invlist
29
+ * 10: optimizer int16 search, collect results in heap, no qbs
30
+ * 11: idem, collect results in reservoir
31
+ * 12: optimizer int16 search, collect results in heap, uses qbs
32
+ * 13: idem, collect results in reservoir
33
+ */
34
+
35
+ struct IndexIVFPQFastScan: IndexIVF {
36
+
37
+ bool by_residual; ///< Encode residual or plain vector?
38
+ ProductQuantizer pq; ///< produces the codes
39
+
40
+ // size of the kernel
41
+ int bbs; // set at build time
42
+
43
+ // M rounded up to a multiple of 2
44
+ size_t M2;
45
+
46
+ /// precomputed tables management
47
+ int use_precomputed_table = 0;
48
+ /// if use_precompute_table size (nlist, pq.M, pq.ksub)
49
+ AlignedTable<float> precomputed_table;
50
+
51
+ // search-time implementation
52
+ int implem = 0;
53
+ // skip some parts of the computation (for timing)
54
+ int skip = 0;
55
+
56
+ // batching factors at search time (0 = default)
57
+ int qbs = 0;
58
+ size_t qbs2 = 0;
59
+
60
+ IndexIVFPQFastScan (
61
+ Index * quantizer, size_t d, size_t nlist,
62
+ size_t M, size_t nbits_per_idx,
63
+ MetricType metric = METRIC_L2, int bbs = 32);
64
+
65
+ IndexIVFPQFastScan ();
66
+
67
+ // built from an IndexIVFPQ
68
+ explicit IndexIVFPQFastScan(const IndexIVFPQ & orig, int bbs = 32);
69
+
70
+ /// orig's inverted lists (for debugging)
71
+ InvertedLists * orig_invlists = nullptr;
72
+
73
+ void train_residual (idx_t n, const float *x) override;
74
+
75
+ /// build precomputed table, possibly updating use_precomputed_table
76
+ void precompute_table ();
77
+
78
+ /// same as the regular IVFPQ encoder. The codes are not reorganized by
79
+ /// blocks a that point
80
+ void encode_vectors(
81
+ idx_t n, const float* x,
82
+ const idx_t *list_nos, uint8_t * codes,
83
+ bool include_listno = false) const override;
84
+
85
+ void add_with_ids (
86
+ idx_t n, const float * x, const idx_t *xids) override;
87
+
88
+ void search(
89
+ idx_t n, const float* x, idx_t k,
90
+ float* distances, idx_t* labels) const override;
91
+
92
+ // prepare look-up tables
93
+
94
+ void compute_LUT(
95
+ size_t n, const float *x,
96
+ const idx_t *coarse_ids, const float *coarse_dis,
97
+ AlignedTable<float> & dis_tables,
98
+ AlignedTable<float> & biases
99
+ ) const;
100
+
101
+ void compute_LUT_uint8(
102
+ size_t n, const float *x,
103
+ const idx_t *coarse_ids, const float *coarse_dis,
104
+ AlignedTable<uint8_t> & dis_tables,
105
+ AlignedTable<uint16_t> & biases,
106
+ float * normalizers
107
+ ) const;
108
+
109
+ // internal search funcs
110
+
111
+ template<bool is_max>
112
+ void search_dispatch_implem(
113
+ idx_t n, const float* x, idx_t k,
114
+ float* distances, idx_t* labels) const;
115
+
116
+ template<class C>
117
+ void search_implem_1(
118
+ idx_t n, const float* x, idx_t k,
119
+ float* distances, idx_t* labels) const;
120
+
121
+ template<class C>
122
+ void search_implem_2(
123
+ idx_t n, const float* x, idx_t k,
124
+ float* distances, idx_t* labels) const;
125
+
126
+ // implem 10 and 12 are not multithreaded internally, so
127
+ // export search stats
128
+ template<class C>
129
+ void search_implem_10(
130
+ idx_t n, const float* x, idx_t k,
131
+ float* distances, idx_t* labels,
132
+ int impl, size_t *ndis_out, size_t *nlist_out) const;
133
+
134
+ template<class C>
135
+ void search_implem_12(
136
+ idx_t n, const float* x, idx_t k,
137
+ float* distances, idx_t* labels,
138
+ int impl, size_t *ndis_out, size_t *nlist_out) const;
139
+
140
+
141
+
142
+ };
143
+
144
+ struct IVFFastScanStats {
145
+ uint64_t times[10];
146
+ uint64_t t_compute_distance_tables, t_round;
147
+ uint64_t t_copy_pack, t_scan, t_to_flat;
148
+ uint64_t reservoir_times[4];
149
+
150
+ double Mcy_at(int i) {
151
+ return times[i] / (1000*1000.0);
152
+ }
153
+
154
+ double Mcy_reservoir_at(int i) {
155
+ return reservoir_times[i] / (1000*1000.0);
156
+ }
157
+ IVFFastScanStats() {reset();}
158
+ void reset() {
159
+ memset(this, 0, sizeof(*this));
160
+ }
161
+ };
162
+
163
+ FAISS_API extern IVFFastScanStats IVFFastScan_stats;
164
+
165
+
166
+ } // namespace faiss
@@ -97,13 +97,13 @@ void IndexIVFPQR::add_core (idx_t n, const float *x, const idx_t *xids,
97
97
  #define TOC get_cycles () - t0
98
98
 
99
99
 
100
- void IndexIVFPQR::search_preassigned (idx_t n, const float *x, idx_t k,
101
- const idx_t *idx,
102
- const float *L1_dis,
103
- float *distances, idx_t *labels,
104
- bool store_pairs,
105
- const IVFSearchParameters *params
106
- ) const
100
+ void IndexIVFPQR::search_preassigned (
101
+ idx_t n, const float *x, idx_t k,
102
+ const idx_t *idx, const float *L1_dis,
103
+ float *distances, idx_t *labels,
104
+ bool store_pairs,
105
+ const IVFSearchParameters *params, IndexIVFStats *stats
106
+ ) const
107
107
  {
108
108
  uint64_t t0;
109
109
  TIC;
@@ -172,9 +172,8 @@ void IndexIVFPQR::search_preassigned (idx_t n, const float *x, idx_t k,
172
172
  float dis = fvec_L2sqr (residual_1, residual_2, d);
173
173
 
174
174
  if (dis < heap_sim[0]) {
175
- maxheap_pop (k, heap_sim, heap_ids);
176
175
  idx_t id_or_pair = store_pairs ? sl : id;
177
- maxheap_push (k, heap_sim, heap_ids, dis, id_or_pair);
176
+ maxheap_replace_top (k, heap_sim, heap_ids, dis, id_or_pair);
178
177
  }
179
178
  n_refine ++;
180
179
  }
@@ -55,7 +55,8 @@ struct IndexIVFPQR: IndexIVFPQ {
55
55
  const float *centroid_dis,
56
56
  float *distances, idx_t *labels,
57
57
  bool store_pairs,
58
- const IVFSearchParameters *params=nullptr
58
+ const IVFSearchParameters *params=nullptr,
59
+ IndexIVFStats *stats=nullptr
59
60
  ) const override;
60
61
 
61
62
  IndexIVFPQR();
@@ -269,9 +269,8 @@ struct IVFScanner: InvertedListScanner {
269
269
  float dis = hc.hamming (codes);
270
270
 
271
271
  if (dis < simi [0]) {
272
- maxheap_pop (k, simi, idxi);
273
272
  int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
274
- maxheap_push (k, simi, idxi, dis, id);
273
+ maxheap_replace_top (k, simi, idxi, dis, id);
275
274
  nup++;
276
275
  }
277
276
  codes += code_size;
@@ -129,9 +129,10 @@ void IndexPQ::reconstruct (idx_t key, float * recons) const
129
129
 
130
130
  namespace {
131
131
 
132
-
133
- struct PQDis: DistanceComputer {
132
+ template<class PQDecoder>
133
+ struct PQDistanceComputer: DistanceComputer {
134
134
  size_t d;
135
+ MetricType metric;
135
136
  Index::idx_t nb;
136
137
  const uint8_t *codes;
137
138
  size_t code_size;
@@ -144,10 +145,11 @@ struct PQDis: DistanceComputer {
144
145
  {
145
146
  const uint8_t *code = codes + i * code_size;
146
147
  const float *dt = precomputed_table.data();
148
+ PQDecoder decoder(code, pq.nbits);
147
149
  float accu = 0;
148
150
  for (int j = 0; j < pq.M; j++) {
149
- accu += dt[*code++];
150
- dt += 256;
151
+ accu += dt[decoder.decode()];
152
+ dt += 1 << decoder.nbits;
151
153
  }
152
154
  ndis++;
153
155
  return accu;
@@ -155,33 +157,43 @@ struct PQDis: DistanceComputer {
155
157
 
156
158
  float symmetric_dis(idx_t i, idx_t j) override
157
159
  {
160
+ FAISS_THROW_IF_NOT(sdc);
158
161
  const float * sdci = sdc;
159
162
  float accu = 0;
160
- const uint8_t *codei = codes + i * code_size;
161
- const uint8_t *codej = codes + j * code_size;
163
+ PQDecoder codei (codes + i * code_size, pq.nbits);
164
+ PQDecoder codej (codes + j * code_size, pq.nbits);
162
165
 
163
166
  for (int l = 0; l < pq.M; l++) {
164
- accu += sdci[(*codei++) + (*codej++) * 256];
165
- sdci += 256 * 256;
167
+ accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
168
+ sdci += uint64_t(1) << (2 * codei.nbits);
166
169
  }
170
+ ndis++;
167
171
  return accu;
168
172
  }
169
173
 
170
- explicit PQDis(const IndexPQ& storage, const float* /*q*/ = nullptr)
171
- : pq(storage.pq) {
174
+ explicit PQDistanceComputer(const IndexPQ& storage)
175
+ : pq(storage.pq) {
172
176
  precomputed_table.resize(pq.M * pq.ksub);
173
177
  nb = storage.ntotal;
174
178
  d = storage.d;
179
+ metric = storage.metric_type;
175
180
  codes = storage.codes.data();
176
181
  code_size = pq.code_size;
177
- FAISS_ASSERT(pq.ksub == 256);
178
- FAISS_ASSERT(pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M);
179
- sdc = pq.sdc_table.data();
182
+ if (pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M) {
183
+ sdc = pq.sdc_table.data();
184
+ } else {
185
+ sdc = nullptr;
186
+ }
180
187
  ndis = 0;
181
188
  }
182
189
 
183
190
  void set_query(const float *x) override {
184
- pq.compute_distance_table(x, precomputed_table.data());
191
+ if (metric == METRIC_L2) {
192
+ pq.compute_distance_table(x, precomputed_table.data());
193
+ } else {
194
+ pq.compute_inner_prod_table(x, precomputed_table.data());
195
+ }
196
+
185
197
  }
186
198
  };
187
199
 
@@ -190,8 +202,13 @@ struct PQDis: DistanceComputer {
190
202
 
191
203
 
192
204
  DistanceComputer * IndexPQ::get_distance_computer() const {
193
- FAISS_THROW_IF_NOT(pq.nbits == 8);
194
- return new PQDis(*this);
205
+ if (pq.nbits == 8) {
206
+ return new PQDistanceComputer<PQDecoder8>(*this);
207
+ } else if (pq.nbits == 16) {
208
+ return new PQDistanceComputer<PQDecoder16>(*this);
209
+ } else {
210
+ return new PQDistanceComputer<PQDecoderGeneric>(*this);
211
+ }
195
212
  }
196
213
 
197
214
 
@@ -329,8 +346,7 @@ static size_t polysemous_inner_loop (
329
346
  }
330
347
 
331
348
  if (dis < heap_dis[0]) {
332
- maxheap_pop (k, heap_dis, heap_ids);
333
- maxheap_push (k, heap_dis, heap_ids, dis, bi);
349
+ maxheap_replace_top (k, heap_dis, heap_ids, dis, bi);
334
350
  }
335
351
  }
336
352
  b_code += code_size;
@@ -0,0 +1,536 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+ #include <faiss/IndexPQFastScan.h>
10
+
11
+ #include <cassert>
12
+ #include <memory>
13
+ #include <limits.h>
14
+
15
+ #include <omp.h>
16
+
17
+
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/utils/utils.h>
20
+ #include <faiss/utils/random.h>
21
+
22
+ #include <faiss/impl/simd_result_handlers.h>
23
+ #include <faiss/utils/quantize_lut.h>
24
+ #include <faiss/impl/pq4_fast_scan.h>
25
+
26
+
27
+ namespace faiss {
28
+
29
+ using namespace simd_result_handlers;
30
+
31
+ inline size_t roundup(size_t a, size_t b) {
32
+ return (a + b - 1) / b * b;
33
+ }
34
+
35
+ IndexPQFastScan::IndexPQFastScan(
36
+ int d, size_t M, size_t nbits,
37
+ MetricType metric,
38
+ int bbs):
39
+ Index(d, metric), pq(d, M, nbits),
40
+ bbs(bbs), ntotal2(0), M2(roundup(M, 2))
41
+ {
42
+ FAISS_THROW_IF_NOT(nbits == 4);
43
+ is_trained = false;
44
+ }
45
+
46
+ IndexPQFastScan::IndexPQFastScan():
47
+ bbs(0), ntotal2(0), M2(0)
48
+ {}
49
+
50
+ IndexPQFastScan::IndexPQFastScan(const IndexPQ & orig, int bbs):
51
+ Index(orig.d, orig.metric_type),
52
+ pq(orig.pq),
53
+ bbs(bbs)
54
+ {
55
+ FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
56
+ ntotal = orig.ntotal;
57
+ is_trained = orig.is_trained;
58
+ orig_codes = orig.codes.data();
59
+
60
+ qbs = 0; // means use default
61
+
62
+ // pack the codes
63
+
64
+ size_t M = pq.M;
65
+
66
+ FAISS_THROW_IF_NOT(bbs % 32 == 0);
67
+ M2 = roundup(M, 2);
68
+ ntotal2 = roundup(ntotal, bbs);
69
+
70
+ codes.resize(ntotal2 * M2 / 2);
71
+
72
+ // printf("M=%d M2=%d code_size=%d\n", M, M2, pq.code_size);
73
+ pq4_pack_codes(
74
+ orig.codes.data(),
75
+ ntotal, M,
76
+ ntotal2, bbs, M2,
77
+ codes.get()
78
+ );
79
+ }
80
+
81
+ void IndexPQFastScan::train (idx_t n, const float *x)
82
+ {
83
+ if (is_trained) {
84
+ return;
85
+ }
86
+ pq.train(n, x);
87
+ is_trained = true;
88
+ }
89
+
90
+
91
+ void IndexPQFastScan::add (idx_t n, const float *x) {
92
+ FAISS_THROW_IF_NOT (is_trained);
93
+ AlignedTable<uint8_t> tmp_codes(n * pq.code_size);
94
+ pq.compute_codes (x, tmp_codes.get(), n);
95
+ ntotal2 = roundup(ntotal + n, bbs);
96
+ size_t new_size = ntotal2 * M2 / 2;
97
+ size_t old_size = codes.size();
98
+ if (new_size > old_size) {
99
+ codes.resize(new_size);
100
+ memset(codes.get() + old_size, 0, new_size - old_size);
101
+ }
102
+ pq4_pack_codes_range(
103
+ tmp_codes.get(), pq.M, ntotal, ntotal + n,
104
+ bbs, M2, codes.get()
105
+ );
106
+ ntotal += n;
107
+ }
108
+
109
+ void IndexPQFastScan::reset()
110
+ {
111
+ codes.resize(0);
112
+ ntotal = 0;
113
+ }
114
+
115
+
116
+
117
+ namespace {
118
+
119
+ // from impl/ProductQuantizer.cpp
120
+ template <class C, typename dis_t>
121
+ void pq_estimators_from_tables_generic(
122
+ const ProductQuantizer& pq, size_t nbits,
123
+ const uint8_t *codes, size_t ncodes,
124
+ const dis_t *dis_table, size_t k,
125
+ typename C::T *heap_dis, int64_t *heap_ids)
126
+ {
127
+ using accu_t = typename C::T;
128
+ const size_t M = pq.M;
129
+ const size_t ksub = pq.ksub;
130
+ for (size_t j = 0; j < ncodes; ++j) {
131
+ PQDecoderGeneric decoder(
132
+ codes + j * pq.code_size, nbits
133
+ );
134
+ accu_t dis = 0;
135
+ const dis_t * __restrict dt = dis_table;
136
+ for (size_t m = 0; m < M; m++) {
137
+ uint64_t c = decoder.decode();
138
+ dis += dt[c];
139
+ dt += ksub;
140
+ }
141
+
142
+ if (C::cmp(heap_dis[0], dis)) {
143
+ heap_pop<C>(k, heap_dis, heap_ids);
144
+ heap_push<C>(k, heap_dis, heap_ids, dis, j);
145
+ }
146
+ }
147
+ }
148
+
149
+
150
+ } // anonymous namespace
151
+
152
+
153
+ using namespace quantize_lut;
154
+
155
+ void IndexPQFastScan::compute_quantized_LUT(
156
+ idx_t n, const float* x,
157
+ uint8_t *lut, float *normalizers) const
158
+ {
159
+ size_t dim12 = pq.ksub * pq.M;
160
+ std::unique_ptr<float[]> dis_tables(new float [n * dim12]);
161
+ if (metric_type == METRIC_L2) {
162
+ pq.compute_distance_tables (n, x, dis_tables.get());
163
+ } else {
164
+ pq.compute_inner_prod_tables (n, x, dis_tables.get());
165
+ }
166
+
167
+ for(uint64_t i = 0; i < n; i++) {
168
+ round_uint8_per_column(
169
+ dis_tables.get() + i * dim12, pq.M, pq.ksub,
170
+ &normalizers[2 * i], &normalizers[2 * i + 1]
171
+ );
172
+ }
173
+
174
+ for(uint64_t i = 0; i < n; i++) {
175
+ const float *t_in = dis_tables.get() + i * dim12;
176
+ uint8_t *t_out = lut + i * M2 * pq.ksub;
177
+
178
+ for(int j = 0; j < dim12; j++) {
179
+ t_out[j] = int(t_in[j]);
180
+ }
181
+ memset(t_out + dim12, 0, (M2 - pq.M) * pq.ksub);
182
+ }
183
+ }
184
+
185
+
186
+
187
+ /******************************************************************************
188
+ * Search driver routine
189
+ ******************************************************************************/
190
+
191
+
192
+ void IndexPQFastScan::search(
193
+ idx_t n, const float* x, idx_t k,
194
+ float* distances, idx_t* labels) const
195
+ {
196
+ if (metric_type == METRIC_L2) {
197
+ search_dispatch_implem<true>(n, x, k, distances, labels);
198
+ } else {
199
+ search_dispatch_implem<false>(n, x, k, distances, labels);
200
+ }
201
+ }
202
+
203
+
204
+ template<bool is_max>
205
+ void IndexPQFastScan::search_dispatch_implem(
206
+ idx_t n,
207
+ const float* x,
208
+ idx_t k,
209
+ float* distances,
210
+ idx_t* labels) const
211
+ {
212
+ using Cfloat = typename std::conditional<is_max,
213
+ CMax<float, int64_t>, CMin<float, int64_t> >::type;
214
+
215
+ using C = typename std::conditional<is_max,
216
+ CMax<uint16_t, int>, CMin<uint16_t, int> >::type;
217
+
218
+ if (n == 0) {
219
+ return;
220
+ }
221
+
222
+ // actual implementation used
223
+ int impl = implem;
224
+
225
+ if (impl == 0) {
226
+ if (bbs == 32) {
227
+ impl = 12;
228
+ } else {
229
+ impl = 14;
230
+ }
231
+ if (k > 20) {
232
+ impl ++;
233
+ }
234
+ }
235
+
236
+
237
+ if (implem == 1) {
238
+ FAISS_THROW_IF_NOT(orig_codes);
239
+ FAISS_THROW_IF_NOT(is_max);
240
+ float_maxheap_array_t res = {
241
+ size_t(n), size_t(k), labels, distances };
242
+ pq.search (x, n, orig_codes, ntotal, &res, true);
243
+ } else if (implem == 2 || implem == 3 || implem == 4) {
244
+ FAISS_THROW_IF_NOT(orig_codes);
245
+
246
+ size_t dim12 = pq.ksub * pq.M;
247
+ std::unique_ptr<float[]> dis_tables(new float [n * dim12]);
248
+ if (is_max) {
249
+ pq.compute_distance_tables (n, x, dis_tables.get());
250
+ } else {
251
+ pq.compute_inner_prod_tables (n, x, dis_tables.get());
252
+ }
253
+
254
+ std::vector<float> normalizers(n * 2);
255
+
256
+ if (implem == 2) {
257
+ // default float
258
+ } else if (implem == 3 || implem == 4) {
259
+ for(uint64_t i = 0; i < n; i++) {
260
+ round_uint8_per_column(
261
+ dis_tables.get() + i * dim12, pq.M,
262
+ pq.ksub,
263
+ &normalizers[2 * i], &normalizers[2 * i + 1]
264
+ );
265
+ }
266
+ }
267
+
268
+ for (int64_t i = 0; i < n; i++) {
269
+ int64_t *heap_ids = labels + i * k;
270
+ float *heap_dis = distances + i * k;
271
+
272
+ heap_heapify<Cfloat> (k, heap_dis, heap_ids);
273
+
274
+ pq_estimators_from_tables_generic<Cfloat>(
275
+ pq, pq.nbits, orig_codes, ntotal,
276
+ dis_tables.get() + i * dim12,
277
+ k, heap_dis, heap_ids
278
+ );
279
+
280
+ heap_reorder<Cfloat> (k, heap_dis, heap_ids);
281
+
282
+ if (implem == 4) {
283
+ float a = normalizers[2 * i];
284
+ float b = normalizers[2 * i + 1];
285
+
286
+ for(int j = 0; j < k; j++) {
287
+ heap_dis[j] = heap_dis[j] / a + b;
288
+ }
289
+ }
290
+ }
291
+ } else if (impl >= 12 && impl <= 15) {
292
+ FAISS_THROW_IF_NOT(ntotal < INT_MAX);
293
+ int nt = std::min(omp_get_max_threads(), int(n));
294
+ if (nt < 2) {
295
+ if (impl == 12 || impl == 13) {
296
+ search_implem_12<C>(n, x, k, distances, labels, impl);
297
+ } else {
298
+ search_implem_14<C>(n, x, k, distances, labels, impl);
299
+ }
300
+ } else {
301
+ // explicitly slice over threads
302
+ #pragma omp parallel for num_threads(nt)
303
+ for (int slice = 0; slice < nt; slice++) {
304
+ idx_t i0 = n * slice / nt;
305
+ idx_t i1 = n * (slice + 1) / nt;
306
+ float *dis_i = distances + i0 * k;
307
+ idx_t *lab_i = labels + i0 * k;
308
+ if (impl == 12 || impl == 13) {
309
+ search_implem_12<C>(
310
+ i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
311
+ } else {
312
+ search_implem_14<C>(
313
+ i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
314
+ }
315
+ }
316
+ }
317
+ } else {
318
+ FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl);
319
+ }
320
+
321
+ }
322
+
323
+ template<class C>
324
+ void IndexPQFastScan::search_implem_12(
325
+ idx_t n, const float* x, idx_t k,
326
+ float* distances, idx_t* labels,
327
+ int impl) const
328
+ {
329
+
330
+ FAISS_THROW_IF_NOT(bbs == 32);
331
+
332
+ // handle qbs2 blocking by recursive call
333
+ int64_t qbs2 = this->qbs == 0 ? 11 : pq4_qbs_to_nq(this->qbs);
334
+ if (n > qbs2) {
335
+ for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
336
+ int64_t i1 = std::min(i0 + qbs2, n);
337
+ search_implem_12<C>(
338
+ i1 - i0, x + d * i0, k,
339
+ distances + i0 * k, labels + i0 * k, impl
340
+ );
341
+ }
342
+ return;
343
+ }
344
+
345
+ size_t dim12 = pq.ksub * M2;
346
+ AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
347
+ std::unique_ptr<float []> normalizers(new float[2 * n]);
348
+
349
+ if (skip & 1) {
350
+ quantized_dis_tables.clear();
351
+ } else {
352
+ compute_quantized_LUT(
353
+ n, x, quantized_dis_tables.get(), normalizers.get()
354
+ );
355
+ }
356
+
357
+ AlignedTable<uint8_t> LUT(n * dim12);
358
+
359
+ // block sizes are encoded in qbs, 4 bits at a time
360
+
361
+ // caution: we override an object field
362
+ int qbs = this->qbs;
363
+
364
+ if (n != pq4_qbs_to_nq(qbs)) {
365
+ qbs = pq4_preferred_qbs(n);
366
+ }
367
+
368
+ int LUT_nq = pq4_pack_LUT_qbs(
369
+ qbs, M2, quantized_dis_tables.get(), LUT.get()
370
+ );
371
+ FAISS_THROW_IF_NOT(LUT_nq == n);
372
+
373
+ if (k == 1) {
374
+ SingleResultHandler<C> handler(n, ntotal);
375
+ if (skip & 4) {
376
+ // pass
377
+ } else {
378
+ handler.disable = bool(skip & 2);
379
+ pq4_accumulate_loop_qbs(
380
+ qbs, ntotal2, M2,
381
+ codes.get(), LUT.get(),
382
+ handler
383
+ );
384
+ }
385
+
386
+ handler.to_flat_arrays(distances, labels, normalizers.get());
387
+
388
+ } else if (impl == 12) {
389
+
390
+ std::vector<uint16_t> tmp_dis(n * k);
391
+ std::vector<int32_t> tmp_ids(n * k);
392
+
393
+ if (skip & 4) {
394
+ // skip
395
+ } else {
396
+ HeapHandler<C> handler(n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
397
+ handler.disable = bool(skip & 2);
398
+
399
+ pq4_accumulate_loop_qbs(
400
+ qbs, ntotal2, M2,
401
+ codes.get(), LUT.get(),
402
+ handler
403
+ );
404
+
405
+ if (!(skip & 8)) {
406
+ handler.to_flat_arrays(distances, labels, normalizers.get());
407
+ }
408
+ }
409
+
410
+
411
+ } else { // impl == 13
412
+
413
+ ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
414
+ handler.disable = bool(skip & 2);
415
+
416
+ if (skip & 4) {
417
+ // skip
418
+ } else {
419
+ pq4_accumulate_loop_qbs(
420
+ qbs, ntotal2, M2,
421
+ codes.get(), LUT.get(),
422
+ handler
423
+ );
424
+ }
425
+
426
+ if (!(skip & 8)) {
427
+ handler.to_flat_arrays(distances, labels, normalizers.get());
428
+ }
429
+
430
+ FastScan_stats.t0 += handler.times[0];
431
+ FastScan_stats.t1 += handler.times[1];
432
+ FastScan_stats.t2 += handler.times[2];
433
+ FastScan_stats.t3 += handler.times[3];
434
+
435
+ }
436
+ }
437
+
438
+ FastScanStats FastScan_stats;
439
+
440
+ template<class C>
441
+ void IndexPQFastScan::search_implem_14(
442
+ idx_t n, const float* x, idx_t k,
443
+ float* distances, idx_t* labels, int impl) const
444
+ {
445
+
446
+ FAISS_THROW_IF_NOT(bbs % 32 == 0);
447
+
448
+ int qbs2 = qbs == 0 ? 4 : qbs;
449
+
450
+ // handle qbs2 blocking by recursive call
451
+ if (n > qbs2) {
452
+ for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
453
+ int64_t i1 = std::min(i0 + qbs2, n);
454
+ search_implem_14<C>(
455
+ i1 - i0, x + d * i0, k,
456
+ distances + i0 * k, labels + i0 * k, impl
457
+ );
458
+ }
459
+ return;
460
+ }
461
+
462
+ size_t dim12 = pq.ksub * M2;
463
+ AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
464
+ std::unique_ptr<float []> normalizers(new float[2 * n]);
465
+
466
+ if (skip & 1) {
467
+ quantized_dis_tables.clear();
468
+ } else {
469
+ compute_quantized_LUT(
470
+ n, x, quantized_dis_tables.get(), normalizers.get()
471
+ );
472
+ }
473
+
474
+ AlignedTable<uint8_t> LUT(n * dim12);
475
+ pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
476
+
477
+ if (k == 1) {
478
+ SingleResultHandler<C> handler(n, ntotal);
479
+ if (skip & 4) {
480
+ // pass
481
+ } else {
482
+ handler.disable = bool(skip & 2);
483
+ pq4_accumulate_loop (
484
+ n, ntotal2, bbs, M2,
485
+ codes.get(), LUT.get(),
486
+ handler
487
+ );
488
+ }
489
+ handler.to_flat_arrays(distances, labels, normalizers.get());
490
+
491
+ } else if (impl == 14) {
492
+
493
+ std::vector<uint16_t> tmp_dis(n * k);
494
+ std::vector<int32_t> tmp_ids(n * k);
495
+
496
+ if (skip & 4) {
497
+ // skip
498
+ } else if (k > 1) {
499
+ HeapHandler<C> handler(n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
500
+ handler.disable = bool(skip & 2);
501
+
502
+ pq4_accumulate_loop (
503
+ n, ntotal2, bbs, M2,
504
+ codes.get(), LUT.get(),
505
+ handler
506
+ );
507
+
508
+ if (!(skip & 8)) {
509
+ handler.to_flat_arrays(distances, labels, normalizers.get());
510
+ }
511
+ }
512
+
513
+
514
+ } else { // impl == 15
515
+
516
+ ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
517
+ handler.disable = bool(skip & 2);
518
+
519
+ if (skip & 4) {
520
+ // skip
521
+ } else {
522
+ pq4_accumulate_loop (
523
+ n, ntotal2, bbs, M2,
524
+ codes.get(), LUT.get(),
525
+ handler
526
+ );
527
+ }
528
+
529
+ if (!(skip & 8)) {
530
+ handler.to_flat_arrays(distances, labels, normalizers.get());
531
+ }
532
+ }
533
+ }
534
+
535
+
536
+ } // namespace faiss