faiss 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +1 -1
  6. data/lib/faiss/version.rb +1 -1
  7. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  8. data/vendor/faiss/faiss/AutoTune.h +6 -3
  9. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  10. data/vendor/faiss/faiss/Index.cpp +3 -4
  11. data/vendor/faiss/faiss/Index.h +3 -3
  12. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  13. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  16. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  17. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  19. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  20. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  21. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  22. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  24. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  25. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  26. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  27. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  28. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  29. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  30. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  31. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  32. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  33. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  34. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  35. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  36. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  37. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  38. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  39. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  40. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  41. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  42. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  43. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  44. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  47. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  48. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  49. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  50. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  51. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  52. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  53. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  54. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  55. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  56. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  57. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  58. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  59. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  60. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  61. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  62. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  63. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  64. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  65. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  71. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  72. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  73. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  74. data/vendor/faiss/faiss/impl/io.h +7 -2
  75. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  76. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  77. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  78. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  79. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  81. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  82. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  83. data/vendor/faiss/faiss/index_io.h +1 -48
  84. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  85. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  86. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  87. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  88. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  89. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  90. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  91. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  92. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  93. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  94. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  95. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  96. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  97. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  98. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  99. data/vendor/faiss/faiss/utils/distances.h +28 -20
  100. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  101. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  102. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  103. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  104. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  105. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  106. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  107. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  108. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  109. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  110. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  111. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  112. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  113. metadata +43 -141
  114. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  115. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  116. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  117. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  118. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  119. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  120. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  121. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  122. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  123. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  124. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  125. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  126. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  127. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  128. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  129. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  130. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  131. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  132. data/vendor/faiss/c_api/Index_c.h +0 -183
  133. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  134. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  135. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  136. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  137. data/vendor/faiss/c_api/error_c.h +0 -42
  138. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  139. data/vendor/faiss/c_api/error_impl.h +0 -16
  140. data/vendor/faiss/c_api/faiss_c.h +0 -58
  141. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  142. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  143. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  144. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  145. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  146. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  147. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  148. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  149. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  150. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  151. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  152. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  153. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  154. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  155. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  156. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  157. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  158. data/vendor/faiss/c_api/index_io_c.h +0 -50
  159. data/vendor/faiss/c_api/macros_impl.h +0 -110
  160. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  161. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  162. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  163. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  164. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  165. data/vendor/faiss/misc/test_blas.cpp +0 -87
  166. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  167. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  168. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  169. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  170. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  171. data/vendor/faiss/tests/test_merge.cpp +0 -260
  172. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  173. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  174. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  175. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  176. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  177. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  178. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  179. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  180. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  181. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  182. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  183. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  184. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -0,0 +1,166 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <memory>
11
+
12
+ #include <faiss/IndexIVFPQ.h>
13
+ #include <faiss/impl/ProductQuantizer.h>
14
+ #include <faiss/utils/AlignedTable.h>
15
+
16
+ namespace faiss {
17
+
18
+
19
+ /** Fast scan version of IVFPQ. Works for 4-bit PQ for now.
20
+ *
21
+ * The codes in the inverted lists are not stored sequentially but
22
+ * grouped in blocks of size bbs. This makes it possible to very quickly
23
+ * compute distances with SIMD instructions.
24
+ *
25
+ * Implementations (implem):
26
+ * 0: auto-select implementation (default)
27
+ * 1: orig's search, re-implemented
28
+ * 2: orig's search, re-ordered by invlist
29
+ * 10: optimizer int16 search, collect results in heap, no qbs
30
+ * 11: idem, collect results in reservoir
31
+ * 12: optimizer int16 search, collect results in heap, uses qbs
32
+ * 13: idem, collect results in reservoir
33
+ */
34
+
35
+ struct IndexIVFPQFastScan: IndexIVF {
36
+
37
+ bool by_residual; ///< Encode residual or plain vector?
38
+ ProductQuantizer pq; ///< produces the codes
39
+
40
+ // size of the kernel
41
+ int bbs; // set at build time
42
+
43
+ // M rounded up to a multiple of 2
44
+ size_t M2;
45
+
46
+ /// precomputed tables management
47
+ int use_precomputed_table = 0;
48
+ /// if use_precompute_table size (nlist, pq.M, pq.ksub)
49
+ AlignedTable<float> precomputed_table;
50
+
51
+ // search-time implementation
52
+ int implem = 0;
53
+ // skip some parts of the computation (for timing)
54
+ int skip = 0;
55
+
56
+ // batching factors at search time (0 = default)
57
+ int qbs = 0;
58
+ size_t qbs2 = 0;
59
+
60
+ IndexIVFPQFastScan (
61
+ Index * quantizer, size_t d, size_t nlist,
62
+ size_t M, size_t nbits_per_idx,
63
+ MetricType metric = METRIC_L2, int bbs = 32);
64
+
65
+ IndexIVFPQFastScan ();
66
+
67
+ // built from an IndexIVFPQ
68
+ explicit IndexIVFPQFastScan(const IndexIVFPQ & orig, int bbs = 32);
69
+
70
+ /// orig's inverted lists (for debugging)
71
+ InvertedLists * orig_invlists = nullptr;
72
+
73
+ void train_residual (idx_t n, const float *x) override;
74
+
75
+ /// build precomputed table, possibly updating use_precomputed_table
76
+ void precompute_table ();
77
+
78
+ /// same as the regular IVFPQ encoder. The codes are not reorganized by
79
+ /// blocks a that point
80
+ void encode_vectors(
81
+ idx_t n, const float* x,
82
+ const idx_t *list_nos, uint8_t * codes,
83
+ bool include_listno = false) const override;
84
+
85
+ void add_with_ids (
86
+ idx_t n, const float * x, const idx_t *xids) override;
87
+
88
+ void search(
89
+ idx_t n, const float* x, idx_t k,
90
+ float* distances, idx_t* labels) const override;
91
+
92
+ // prepare look-up tables
93
+
94
+ void compute_LUT(
95
+ size_t n, const float *x,
96
+ const idx_t *coarse_ids, const float *coarse_dis,
97
+ AlignedTable<float> & dis_tables,
98
+ AlignedTable<float> & biases
99
+ ) const;
100
+
101
+ void compute_LUT_uint8(
102
+ size_t n, const float *x,
103
+ const idx_t *coarse_ids, const float *coarse_dis,
104
+ AlignedTable<uint8_t> & dis_tables,
105
+ AlignedTable<uint16_t> & biases,
106
+ float * normalizers
107
+ ) const;
108
+
109
+ // internal search funcs
110
+
111
+ template<bool is_max>
112
+ void search_dispatch_implem(
113
+ idx_t n, const float* x, idx_t k,
114
+ float* distances, idx_t* labels) const;
115
+
116
+ template<class C>
117
+ void search_implem_1(
118
+ idx_t n, const float* x, idx_t k,
119
+ float* distances, idx_t* labels) const;
120
+
121
+ template<class C>
122
+ void search_implem_2(
123
+ idx_t n, const float* x, idx_t k,
124
+ float* distances, idx_t* labels) const;
125
+
126
+ // implem 10 and 12 are not multithreaded internally, so
127
+ // export search stats
128
+ template<class C>
129
+ void search_implem_10(
130
+ idx_t n, const float* x, idx_t k,
131
+ float* distances, idx_t* labels,
132
+ int impl, size_t *ndis_out, size_t *nlist_out) const;
133
+
134
+ template<class C>
135
+ void search_implem_12(
136
+ idx_t n, const float* x, idx_t k,
137
+ float* distances, idx_t* labels,
138
+ int impl, size_t *ndis_out, size_t *nlist_out) const;
139
+
140
+
141
+
142
+ };
143
+
144
+ struct IVFFastScanStats {
145
+ uint64_t times[10];
146
+ uint64_t t_compute_distance_tables, t_round;
147
+ uint64_t t_copy_pack, t_scan, t_to_flat;
148
+ uint64_t reservoir_times[4];
149
+
150
+ double Mcy_at(int i) {
151
+ return times[i] / (1000*1000.0);
152
+ }
153
+
154
+ double Mcy_reservoir_at(int i) {
155
+ return reservoir_times[i] / (1000*1000.0);
156
+ }
157
+ IVFFastScanStats() {reset();}
158
+ void reset() {
159
+ memset(this, 0, sizeof(*this));
160
+ }
161
+ };
162
+
163
+ FAISS_API extern IVFFastScanStats IVFFastScan_stats;
164
+
165
+
166
+ } // namespace faiss
@@ -97,13 +97,13 @@ void IndexIVFPQR::add_core (idx_t n, const float *x, const idx_t *xids,
97
97
  #define TOC get_cycles () - t0
98
98
 
99
99
 
100
- void IndexIVFPQR::search_preassigned (idx_t n, const float *x, idx_t k,
101
- const idx_t *idx,
102
- const float *L1_dis,
103
- float *distances, idx_t *labels,
104
- bool store_pairs,
105
- const IVFSearchParameters *params
106
- ) const
100
+ void IndexIVFPQR::search_preassigned (
101
+ idx_t n, const float *x, idx_t k,
102
+ const idx_t *idx, const float *L1_dis,
103
+ float *distances, idx_t *labels,
104
+ bool store_pairs,
105
+ const IVFSearchParameters *params, IndexIVFStats *stats
106
+ ) const
107
107
  {
108
108
  uint64_t t0;
109
109
  TIC;
@@ -172,9 +172,8 @@ void IndexIVFPQR::search_preassigned (idx_t n, const float *x, idx_t k,
172
172
  float dis = fvec_L2sqr (residual_1, residual_2, d);
173
173
 
174
174
  if (dis < heap_sim[0]) {
175
- maxheap_pop (k, heap_sim, heap_ids);
176
175
  idx_t id_or_pair = store_pairs ? sl : id;
177
- maxheap_push (k, heap_sim, heap_ids, dis, id_or_pair);
176
+ maxheap_replace_top (k, heap_sim, heap_ids, dis, id_or_pair);
178
177
  }
179
178
  n_refine ++;
180
179
  }
@@ -55,7 +55,8 @@ struct IndexIVFPQR: IndexIVFPQ {
55
55
  const float *centroid_dis,
56
56
  float *distances, idx_t *labels,
57
57
  bool store_pairs,
58
- const IVFSearchParameters *params=nullptr
58
+ const IVFSearchParameters *params=nullptr,
59
+ IndexIVFStats *stats=nullptr
59
60
  ) const override;
60
61
 
61
62
  IndexIVFPQR();
@@ -269,9 +269,8 @@ struct IVFScanner: InvertedListScanner {
269
269
  float dis = hc.hamming (codes);
270
270
 
271
271
  if (dis < simi [0]) {
272
- maxheap_pop (k, simi, idxi);
273
272
  int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
274
- maxheap_push (k, simi, idxi, dis, id);
273
+ maxheap_replace_top (k, simi, idxi, dis, id);
275
274
  nup++;
276
275
  }
277
276
  codes += code_size;
@@ -129,9 +129,10 @@ void IndexPQ::reconstruct (idx_t key, float * recons) const
129
129
 
130
130
  namespace {
131
131
 
132
-
133
- struct PQDis: DistanceComputer {
132
+ template<class PQDecoder>
133
+ struct PQDistanceComputer: DistanceComputer {
134
134
  size_t d;
135
+ MetricType metric;
135
136
  Index::idx_t nb;
136
137
  const uint8_t *codes;
137
138
  size_t code_size;
@@ -144,10 +145,11 @@ struct PQDis: DistanceComputer {
144
145
  {
145
146
  const uint8_t *code = codes + i * code_size;
146
147
  const float *dt = precomputed_table.data();
148
+ PQDecoder decoder(code, pq.nbits);
147
149
  float accu = 0;
148
150
  for (int j = 0; j < pq.M; j++) {
149
- accu += dt[*code++];
150
- dt += 256;
151
+ accu += dt[decoder.decode()];
152
+ dt += 1 << decoder.nbits;
151
153
  }
152
154
  ndis++;
153
155
  return accu;
@@ -155,33 +157,43 @@ struct PQDis: DistanceComputer {
155
157
 
156
158
  float symmetric_dis(idx_t i, idx_t j) override
157
159
  {
160
+ FAISS_THROW_IF_NOT(sdc);
158
161
  const float * sdci = sdc;
159
162
  float accu = 0;
160
- const uint8_t *codei = codes + i * code_size;
161
- const uint8_t *codej = codes + j * code_size;
163
+ PQDecoder codei (codes + i * code_size, pq.nbits);
164
+ PQDecoder codej (codes + j * code_size, pq.nbits);
162
165
 
163
166
  for (int l = 0; l < pq.M; l++) {
164
- accu += sdci[(*codei++) + (*codej++) * 256];
165
- sdci += 256 * 256;
167
+ accu += sdci[codei.decode() + (codej.decode() << codei.nbits)];
168
+ sdci += uint64_t(1) << (2 * codei.nbits);
166
169
  }
170
+ ndis++;
167
171
  return accu;
168
172
  }
169
173
 
170
- explicit PQDis(const IndexPQ& storage, const float* /*q*/ = nullptr)
171
- : pq(storage.pq) {
174
+ explicit PQDistanceComputer(const IndexPQ& storage)
175
+ : pq(storage.pq) {
172
176
  precomputed_table.resize(pq.M * pq.ksub);
173
177
  nb = storage.ntotal;
174
178
  d = storage.d;
179
+ metric = storage.metric_type;
175
180
  codes = storage.codes.data();
176
181
  code_size = pq.code_size;
177
- FAISS_ASSERT(pq.ksub == 256);
178
- FAISS_ASSERT(pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M);
179
- sdc = pq.sdc_table.data();
182
+ if (pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M) {
183
+ sdc = pq.sdc_table.data();
184
+ } else {
185
+ sdc = nullptr;
186
+ }
180
187
  ndis = 0;
181
188
  }
182
189
 
183
190
  void set_query(const float *x) override {
184
- pq.compute_distance_table(x, precomputed_table.data());
191
+ if (metric == METRIC_L2) {
192
+ pq.compute_distance_table(x, precomputed_table.data());
193
+ } else {
194
+ pq.compute_inner_prod_table(x, precomputed_table.data());
195
+ }
196
+
185
197
  }
186
198
  };
187
199
 
@@ -190,8 +202,13 @@ struct PQDis: DistanceComputer {
190
202
 
191
203
 
192
204
  DistanceComputer * IndexPQ::get_distance_computer() const {
193
- FAISS_THROW_IF_NOT(pq.nbits == 8);
194
- return new PQDis(*this);
205
+ if (pq.nbits == 8) {
206
+ return new PQDistanceComputer<PQDecoder8>(*this);
207
+ } else if (pq.nbits == 16) {
208
+ return new PQDistanceComputer<PQDecoder16>(*this);
209
+ } else {
210
+ return new PQDistanceComputer<PQDecoderGeneric>(*this);
211
+ }
195
212
  }
196
213
 
197
214
 
@@ -329,8 +346,7 @@ static size_t polysemous_inner_loop (
329
346
  }
330
347
 
331
348
  if (dis < heap_dis[0]) {
332
- maxheap_pop (k, heap_dis, heap_ids);
333
- maxheap_push (k, heap_dis, heap_ids, dis, bi);
349
+ maxheap_replace_top (k, heap_dis, heap_ids, dis, bi);
334
350
  }
335
351
  }
336
352
  b_code += code_size;
@@ -0,0 +1,536 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+ #include <faiss/IndexPQFastScan.h>
10
+
11
+ #include <cassert>
12
+ #include <memory>
13
+ #include <limits.h>
14
+
15
+ #include <omp.h>
16
+
17
+
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/utils/utils.h>
20
+ #include <faiss/utils/random.h>
21
+
22
+ #include <faiss/impl/simd_result_handlers.h>
23
+ #include <faiss/utils/quantize_lut.h>
24
+ #include <faiss/impl/pq4_fast_scan.h>
25
+
26
+
27
+ namespace faiss {
28
+
29
+ using namespace simd_result_handlers;
30
+
31
+ inline size_t roundup(size_t a, size_t b) {
32
+ return (a + b - 1) / b * b;
33
+ }
34
+
35
+ IndexPQFastScan::IndexPQFastScan(
36
+ int d, size_t M, size_t nbits,
37
+ MetricType metric,
38
+ int bbs):
39
+ Index(d, metric), pq(d, M, nbits),
40
+ bbs(bbs), ntotal2(0), M2(roundup(M, 2))
41
+ {
42
+ FAISS_THROW_IF_NOT(nbits == 4);
43
+ is_trained = false;
44
+ }
45
+
46
+ IndexPQFastScan::IndexPQFastScan():
47
+ bbs(0), ntotal2(0), M2(0)
48
+ {}
49
+
50
+ IndexPQFastScan::IndexPQFastScan(const IndexPQ & orig, int bbs):
51
+ Index(orig.d, orig.metric_type),
52
+ pq(orig.pq),
53
+ bbs(bbs)
54
+ {
55
+ FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
56
+ ntotal = orig.ntotal;
57
+ is_trained = orig.is_trained;
58
+ orig_codes = orig.codes.data();
59
+
60
+ qbs = 0; // means use default
61
+
62
+ // pack the codes
63
+
64
+ size_t M = pq.M;
65
+
66
+ FAISS_THROW_IF_NOT(bbs % 32 == 0);
67
+ M2 = roundup(M, 2);
68
+ ntotal2 = roundup(ntotal, bbs);
69
+
70
+ codes.resize(ntotal2 * M2 / 2);
71
+
72
+ // printf("M=%d M2=%d code_size=%d\n", M, M2, pq.code_size);
73
+ pq4_pack_codes(
74
+ orig.codes.data(),
75
+ ntotal, M,
76
+ ntotal2, bbs, M2,
77
+ codes.get()
78
+ );
79
+ }
80
+
81
+ void IndexPQFastScan::train (idx_t n, const float *x)
82
+ {
83
+ if (is_trained) {
84
+ return;
85
+ }
86
+ pq.train(n, x);
87
+ is_trained = true;
88
+ }
89
+
90
+
91
+ void IndexPQFastScan::add (idx_t n, const float *x) {
92
+ FAISS_THROW_IF_NOT (is_trained);
93
+ AlignedTable<uint8_t> tmp_codes(n * pq.code_size);
94
+ pq.compute_codes (x, tmp_codes.get(), n);
95
+ ntotal2 = roundup(ntotal + n, bbs);
96
+ size_t new_size = ntotal2 * M2 / 2;
97
+ size_t old_size = codes.size();
98
+ if (new_size > old_size) {
99
+ codes.resize(new_size);
100
+ memset(codes.get() + old_size, 0, new_size - old_size);
101
+ }
102
+ pq4_pack_codes_range(
103
+ tmp_codes.get(), pq.M, ntotal, ntotal + n,
104
+ bbs, M2, codes.get()
105
+ );
106
+ ntotal += n;
107
+ }
108
+
109
+ void IndexPQFastScan::reset()
110
+ {
111
+ codes.resize(0);
112
+ ntotal = 0;
113
+ }
114
+
115
+
116
+
117
+ namespace {
118
+
119
+ // from impl/ProductQuantizer.cpp
120
+ template <class C, typename dis_t>
121
+ void pq_estimators_from_tables_generic(
122
+ const ProductQuantizer& pq, size_t nbits,
123
+ const uint8_t *codes, size_t ncodes,
124
+ const dis_t *dis_table, size_t k,
125
+ typename C::T *heap_dis, int64_t *heap_ids)
126
+ {
127
+ using accu_t = typename C::T;
128
+ const size_t M = pq.M;
129
+ const size_t ksub = pq.ksub;
130
+ for (size_t j = 0; j < ncodes; ++j) {
131
+ PQDecoderGeneric decoder(
132
+ codes + j * pq.code_size, nbits
133
+ );
134
+ accu_t dis = 0;
135
+ const dis_t * __restrict dt = dis_table;
136
+ for (size_t m = 0; m < M; m++) {
137
+ uint64_t c = decoder.decode();
138
+ dis += dt[c];
139
+ dt += ksub;
140
+ }
141
+
142
+ if (C::cmp(heap_dis[0], dis)) {
143
+ heap_pop<C>(k, heap_dis, heap_ids);
144
+ heap_push<C>(k, heap_dis, heap_ids, dis, j);
145
+ }
146
+ }
147
+ }
148
+
149
+
150
+ } // anonymous namespace
151
+
152
+
153
+ using namespace quantize_lut;
154
+
155
+ void IndexPQFastScan::compute_quantized_LUT(
156
+ idx_t n, const float* x,
157
+ uint8_t *lut, float *normalizers) const
158
+ {
159
+ size_t dim12 = pq.ksub * pq.M;
160
+ std::unique_ptr<float[]> dis_tables(new float [n * dim12]);
161
+ if (metric_type == METRIC_L2) {
162
+ pq.compute_distance_tables (n, x, dis_tables.get());
163
+ } else {
164
+ pq.compute_inner_prod_tables (n, x, dis_tables.get());
165
+ }
166
+
167
+ for(uint64_t i = 0; i < n; i++) {
168
+ round_uint8_per_column(
169
+ dis_tables.get() + i * dim12, pq.M, pq.ksub,
170
+ &normalizers[2 * i], &normalizers[2 * i + 1]
171
+ );
172
+ }
173
+
174
+ for(uint64_t i = 0; i < n; i++) {
175
+ const float *t_in = dis_tables.get() + i * dim12;
176
+ uint8_t *t_out = lut + i * M2 * pq.ksub;
177
+
178
+ for(int j = 0; j < dim12; j++) {
179
+ t_out[j] = int(t_in[j]);
180
+ }
181
+ memset(t_out + dim12, 0, (M2 - pq.M) * pq.ksub);
182
+ }
183
+ }
184
+
185
+
186
+
187
+ /******************************************************************************
188
+ * Search driver routine
189
+ ******************************************************************************/
190
+
191
+
192
+ void IndexPQFastScan::search(
193
+ idx_t n, const float* x, idx_t k,
194
+ float* distances, idx_t* labels) const
195
+ {
196
+ if (metric_type == METRIC_L2) {
197
+ search_dispatch_implem<true>(n, x, k, distances, labels);
198
+ } else {
199
+ search_dispatch_implem<false>(n, x, k, distances, labels);
200
+ }
201
+ }
202
+
203
+
204
+ template<bool is_max>
205
+ void IndexPQFastScan::search_dispatch_implem(
206
+ idx_t n,
207
+ const float* x,
208
+ idx_t k,
209
+ float* distances,
210
+ idx_t* labels) const
211
+ {
212
+ using Cfloat = typename std::conditional<is_max,
213
+ CMax<float, int64_t>, CMin<float, int64_t> >::type;
214
+
215
+ using C = typename std::conditional<is_max,
216
+ CMax<uint16_t, int>, CMin<uint16_t, int> >::type;
217
+
218
+ if (n == 0) {
219
+ return;
220
+ }
221
+
222
+ // actual implementation used
223
+ int impl = implem;
224
+
225
+ if (impl == 0) {
226
+ if (bbs == 32) {
227
+ impl = 12;
228
+ } else {
229
+ impl = 14;
230
+ }
231
+ if (k > 20) {
232
+ impl ++;
233
+ }
234
+ }
235
+
236
+
237
+ if (implem == 1) {
238
+ FAISS_THROW_IF_NOT(orig_codes);
239
+ FAISS_THROW_IF_NOT(is_max);
240
+ float_maxheap_array_t res = {
241
+ size_t(n), size_t(k), labels, distances };
242
+ pq.search (x, n, orig_codes, ntotal, &res, true);
243
+ } else if (implem == 2 || implem == 3 || implem == 4) {
244
+ FAISS_THROW_IF_NOT(orig_codes);
245
+
246
+ size_t dim12 = pq.ksub * pq.M;
247
+ std::unique_ptr<float[]> dis_tables(new float [n * dim12]);
248
+ if (is_max) {
249
+ pq.compute_distance_tables (n, x, dis_tables.get());
250
+ } else {
251
+ pq.compute_inner_prod_tables (n, x, dis_tables.get());
252
+ }
253
+
254
+ std::vector<float> normalizers(n * 2);
255
+
256
+ if (implem == 2) {
257
+ // default float
258
+ } else if (implem == 3 || implem == 4) {
259
+ for(uint64_t i = 0; i < n; i++) {
260
+ round_uint8_per_column(
261
+ dis_tables.get() + i * dim12, pq.M,
262
+ pq.ksub,
263
+ &normalizers[2 * i], &normalizers[2 * i + 1]
264
+ );
265
+ }
266
+ }
267
+
268
+ for (int64_t i = 0; i < n; i++) {
269
+ int64_t *heap_ids = labels + i * k;
270
+ float *heap_dis = distances + i * k;
271
+
272
+ heap_heapify<Cfloat> (k, heap_dis, heap_ids);
273
+
274
+ pq_estimators_from_tables_generic<Cfloat>(
275
+ pq, pq.nbits, orig_codes, ntotal,
276
+ dis_tables.get() + i * dim12,
277
+ k, heap_dis, heap_ids
278
+ );
279
+
280
+ heap_reorder<Cfloat> (k, heap_dis, heap_ids);
281
+
282
+ if (implem == 4) {
283
+ float a = normalizers[2 * i];
284
+ float b = normalizers[2 * i + 1];
285
+
286
+ for(int j = 0; j < k; j++) {
287
+ heap_dis[j] = heap_dis[j] / a + b;
288
+ }
289
+ }
290
+ }
291
+ } else if (impl >= 12 && impl <= 15) {
292
+ FAISS_THROW_IF_NOT(ntotal < INT_MAX);
293
+ int nt = std::min(omp_get_max_threads(), int(n));
294
+ if (nt < 2) {
295
+ if (impl == 12 || impl == 13) {
296
+ search_implem_12<C>(n, x, k, distances, labels, impl);
297
+ } else {
298
+ search_implem_14<C>(n, x, k, distances, labels, impl);
299
+ }
300
+ } else {
301
+ // explicitly slice over threads
302
+ #pragma omp parallel for num_threads(nt)
303
+ for (int slice = 0; slice < nt; slice++) {
304
+ idx_t i0 = n * slice / nt;
305
+ idx_t i1 = n * (slice + 1) / nt;
306
+ float *dis_i = distances + i0 * k;
307
+ idx_t *lab_i = labels + i0 * k;
308
+ if (impl == 12 || impl == 13) {
309
+ search_implem_12<C>(
310
+ i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
311
+ } else {
312
+ search_implem_14<C>(
313
+ i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
314
+ }
315
+ }
316
+ }
317
+ } else {
318
+ FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl);
319
+ }
320
+
321
+ }
322
+
323
+ template<class C>
324
+ void IndexPQFastScan::search_implem_12(
325
+ idx_t n, const float* x, idx_t k,
326
+ float* distances, idx_t* labels,
327
+ int impl) const
328
+ {
329
+
330
+ FAISS_THROW_IF_NOT(bbs == 32);
331
+
332
+ // handle qbs2 blocking by recursive call
333
+ int64_t qbs2 = this->qbs == 0 ? 11 : pq4_qbs_to_nq(this->qbs);
334
+ if (n > qbs2) {
335
+ for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
336
+ int64_t i1 = std::min(i0 + qbs2, n);
337
+ search_implem_12<C>(
338
+ i1 - i0, x + d * i0, k,
339
+ distances + i0 * k, labels + i0 * k, impl
340
+ );
341
+ }
342
+ return;
343
+ }
344
+
345
+ size_t dim12 = pq.ksub * M2;
346
+ AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
347
+ std::unique_ptr<float []> normalizers(new float[2 * n]);
348
+
349
+ if (skip & 1) {
350
+ quantized_dis_tables.clear();
351
+ } else {
352
+ compute_quantized_LUT(
353
+ n, x, quantized_dis_tables.get(), normalizers.get()
354
+ );
355
+ }
356
+
357
+ AlignedTable<uint8_t> LUT(n * dim12);
358
+
359
+ // block sizes are encoded in qbs, 4 bits at a time
360
+
361
+ // caution: we override an object field
362
+ int qbs = this->qbs;
363
+
364
+ if (n != pq4_qbs_to_nq(qbs)) {
365
+ qbs = pq4_preferred_qbs(n);
366
+ }
367
+
368
+ int LUT_nq = pq4_pack_LUT_qbs(
369
+ qbs, M2, quantized_dis_tables.get(), LUT.get()
370
+ );
371
+ FAISS_THROW_IF_NOT(LUT_nq == n);
372
+
373
+ if (k == 1) {
374
+ SingleResultHandler<C> handler(n, ntotal);
375
+ if (skip & 4) {
376
+ // pass
377
+ } else {
378
+ handler.disable = bool(skip & 2);
379
+ pq4_accumulate_loop_qbs(
380
+ qbs, ntotal2, M2,
381
+ codes.get(), LUT.get(),
382
+ handler
383
+ );
384
+ }
385
+
386
+ handler.to_flat_arrays(distances, labels, normalizers.get());
387
+
388
+ } else if (impl == 12) {
389
+
390
+ std::vector<uint16_t> tmp_dis(n * k);
391
+ std::vector<int32_t> tmp_ids(n * k);
392
+
393
+ if (skip & 4) {
394
+ // skip
395
+ } else {
396
+ HeapHandler<C> handler(n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
397
+ handler.disable = bool(skip & 2);
398
+
399
+ pq4_accumulate_loop_qbs(
400
+ qbs, ntotal2, M2,
401
+ codes.get(), LUT.get(),
402
+ handler
403
+ );
404
+
405
+ if (!(skip & 8)) {
406
+ handler.to_flat_arrays(distances, labels, normalizers.get());
407
+ }
408
+ }
409
+
410
+
411
+ } else { // impl == 13
412
+
413
+ ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
414
+ handler.disable = bool(skip & 2);
415
+
416
+ if (skip & 4) {
417
+ // skip
418
+ } else {
419
+ pq4_accumulate_loop_qbs(
420
+ qbs, ntotal2, M2,
421
+ codes.get(), LUT.get(),
422
+ handler
423
+ );
424
+ }
425
+
426
+ if (!(skip & 8)) {
427
+ handler.to_flat_arrays(distances, labels, normalizers.get());
428
+ }
429
+
430
+ FastScan_stats.t0 += handler.times[0];
431
+ FastScan_stats.t1 += handler.times[1];
432
+ FastScan_stats.t2 += handler.times[2];
433
+ FastScan_stats.t3 += handler.times[3];
434
+
435
+ }
436
+ }
437
+
438
+ FastScanStats FastScan_stats;
439
+
440
+ template<class C>
441
+ void IndexPQFastScan::search_implem_14(
442
+ idx_t n, const float* x, idx_t k,
443
+ float* distances, idx_t* labels, int impl) const
444
+ {
445
+
446
+ FAISS_THROW_IF_NOT(bbs % 32 == 0);
447
+
448
+ int qbs2 = qbs == 0 ? 4 : qbs;
449
+
450
+ // handle qbs2 blocking by recursive call
451
+ if (n > qbs2) {
452
+ for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
453
+ int64_t i1 = std::min(i0 + qbs2, n);
454
+ search_implem_14<C>(
455
+ i1 - i0, x + d * i0, k,
456
+ distances + i0 * k, labels + i0 * k, impl
457
+ );
458
+ }
459
+ return;
460
+ }
461
+
462
+ size_t dim12 = pq.ksub * M2;
463
+ AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
464
+ std::unique_ptr<float []> normalizers(new float[2 * n]);
465
+
466
+ if (skip & 1) {
467
+ quantized_dis_tables.clear();
468
+ } else {
469
+ compute_quantized_LUT(
470
+ n, x, quantized_dis_tables.get(), normalizers.get()
471
+ );
472
+ }
473
+
474
+ AlignedTable<uint8_t> LUT(n * dim12);
475
+ pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
476
+
477
+ if (k == 1) {
478
+ SingleResultHandler<C> handler(n, ntotal);
479
+ if (skip & 4) {
480
+ // pass
481
+ } else {
482
+ handler.disable = bool(skip & 2);
483
+ pq4_accumulate_loop (
484
+ n, ntotal2, bbs, M2,
485
+ codes.get(), LUT.get(),
486
+ handler
487
+ );
488
+ }
489
+ handler.to_flat_arrays(distances, labels, normalizers.get());
490
+
491
+ } else if (impl == 14) {
492
+
493
+ std::vector<uint16_t> tmp_dis(n * k);
494
+ std::vector<int32_t> tmp_ids(n * k);
495
+
496
+ if (skip & 4) {
497
+ // skip
498
+ } else if (k > 1) {
499
+ HeapHandler<C> handler(n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
500
+ handler.disable = bool(skip & 2);
501
+
502
+ pq4_accumulate_loop (
503
+ n, ntotal2, bbs, M2,
504
+ codes.get(), LUT.get(),
505
+ handler
506
+ );
507
+
508
+ if (!(skip & 8)) {
509
+ handler.to_flat_arrays(distances, labels, normalizers.get());
510
+ }
511
+ }
512
+
513
+
514
+ } else { // impl == 15
515
+
516
+ ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
517
+ handler.disable = bool(skip & 2);
518
+
519
+ if (skip & 4) {
520
+ // skip
521
+ } else {
522
+ pq4_accumulate_loop (
523
+ n, ntotal2, bbs, M2,
524
+ codes.get(), LUT.get(),
525
+ handler
526
+ );
527
+ }
528
+
529
+ if (!(skip & 8)) {
530
+ handler.to_flat_arrays(distances, labels, normalizers.get());
531
+ }
532
+ }
533
+ }
534
+
535
+
536
+ } // namespace faiss