faiss 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (199) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +16 -4
  5. data/ext/faiss/ext.cpp +12 -308
  6. data/ext/faiss/extconf.rb +6 -3
  7. data/ext/faiss/index.cpp +189 -0
  8. data/ext/faiss/index_binary.cpp +75 -0
  9. data/ext/faiss/kmeans.cpp +40 -0
  10. data/ext/faiss/numo.hpp +867 -0
  11. data/ext/faiss/pca_matrix.cpp +33 -0
  12. data/ext/faiss/product_quantizer.cpp +53 -0
  13. data/ext/faiss/utils.cpp +13 -0
  14. data/ext/faiss/utils.h +5 -0
  15. data/lib/faiss.rb +0 -5
  16. data/lib/faiss/version.rb +1 -1
  17. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  18. data/vendor/faiss/faiss/AutoTune.h +6 -3
  19. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  20. data/vendor/faiss/faiss/Index.cpp +3 -4
  21. data/vendor/faiss/faiss/Index.h +3 -3
  22. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  23. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  26. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  27. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  29. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  30. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  31. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  32. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  33. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  34. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  37. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  38. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  39. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  41. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  42. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  43. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  44. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  45. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  46. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  47. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  48. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  49. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  50. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  51. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  52. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  53. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  54. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  55. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  56. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  57. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  58. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  59. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  60. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  61. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  62. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  63. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  64. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  66. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  67. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  68. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  69. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  70. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  71. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  72. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  73. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  74. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  75. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  76. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  77. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  78. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  79. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  80. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  81. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  82. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  83. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  84. data/vendor/faiss/faiss/impl/io.h +7 -2
  85. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  86. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  87. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  88. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  89. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  90. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  91. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  92. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  93. data/vendor/faiss/faiss/index_io.h +1 -48
  94. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  95. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  96. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  97. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  98. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  99. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  100. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  101. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  102. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  103. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  104. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  105. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  106. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  107. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  108. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  109. data/vendor/faiss/faiss/utils/distances.h +28 -20
  110. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  111. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  112. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  113. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  114. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  115. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  116. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  117. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  118. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  119. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  120. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  121. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  122. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  123. metadata +54 -149
  124. data/lib/faiss/index.rb +0 -20
  125. data/lib/faiss/index_binary.rb +0 -20
  126. data/lib/faiss/kmeans.rb +0 -15
  127. data/lib/faiss/pca_matrix.rb +0 -15
  128. data/lib/faiss/product_quantizer.rb +0 -22
  129. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  130. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  131. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  132. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  133. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  134. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  135. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  136. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  137. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  138. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  139. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  140. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  141. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  142. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  143. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  144. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  145. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  146. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  147. data/vendor/faiss/c_api/Index_c.h +0 -183
  148. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  149. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  150. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  151. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  152. data/vendor/faiss/c_api/error_c.h +0 -42
  153. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  154. data/vendor/faiss/c_api/error_impl.h +0 -16
  155. data/vendor/faiss/c_api/faiss_c.h +0 -58
  156. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  157. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  158. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  159. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  160. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  161. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  162. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  163. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  164. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  165. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  166. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  167. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  168. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  169. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  170. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  171. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  172. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  173. data/vendor/faiss/c_api/index_io_c.h +0 -50
  174. data/vendor/faiss/c_api/macros_impl.h +0 -110
  175. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  176. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  177. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  178. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  179. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  180. data/vendor/faiss/misc/test_blas.cpp +0 -87
  181. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  182. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  183. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  184. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  185. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  186. data/vendor/faiss/tests/test_merge.cpp +0 -260
  187. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  188. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  189. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  190. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  191. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  192. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  193. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  194. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  195. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  196. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  197. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  198. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  199. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -10,7 +10,10 @@
10
10
 
11
11
  #include <faiss/impl/FaissAssert.h>
12
12
  #include <faiss/Index.h>
13
+ #include <faiss/invlists/InvertedLists.h>
13
14
  #include <initializer_list>
15
+ #include <gtest/gtest.h>
16
+ #include <cstring>
14
17
  #include <memory>
15
18
  #include <string>
16
19
  #include <vector>
@@ -90,4 +93,34 @@ void compareLists(const float* refDist,
90
93
  float pctMaxDiff1 = 0.1f,
91
94
  float pctMaxDiffN = 0.005f);
92
95
 
96
+ /// Compare IVF lists between a CPU and GPU index
97
+ template <typename A, typename B>
98
+ void testIVFEquality(A& cpuIndex, B& gpuIndex) {
99
+ // Ensure equality of the inverted lists
100
+ EXPECT_EQ(cpuIndex.nlist, gpuIndex.nlist);
101
+
102
+ for (int i = 0; i < cpuIndex.nlist; ++i) {
103
+ auto cpuLists = cpuIndex.invlists;
104
+
105
+ // Code equality
106
+ EXPECT_EQ(cpuLists->list_size(i), gpuIndex.getListLength(i));
107
+ std::vector<uint8_t> cpuCodes(cpuLists->list_size(i) * cpuLists->code_size);
108
+
109
+ auto sc = faiss::InvertedLists::ScopedCodes(cpuLists, i);
110
+ std::memcpy(cpuCodes.data(), sc.get(),
111
+ cpuLists->list_size(i) * cpuLists->code_size);
112
+
113
+ auto gpuCodes = gpuIndex.getListVectorData(i, false);
114
+ EXPECT_EQ(cpuCodes, gpuCodes);
115
+
116
+ // Index equality
117
+ std::vector<Index::idx_t> cpuIndices(cpuLists->list_size(i));
118
+
119
+ auto si = faiss::InvertedLists::ScopedIds(cpuLists, i);
120
+ std::memcpy(cpuIndices.data(), si.get(),
121
+ cpuLists->list_size(i) * sizeof(faiss::Index::idx_t));
122
+ EXPECT_EQ(cpuIndices, gpuIndex.getListIndices(i));
123
+ }
124
+ }
125
+
93
126
  } }
@@ -10,6 +10,7 @@
10
10
  #include <faiss/gpu/utils/DeviceUtils.h>
11
11
  #include <faiss/gpu/utils/StaticUtils.h>
12
12
  #include <faiss/impl/FaissAssert.h>
13
+ #include <algorithm>
13
14
  #include <sstream>
14
15
 
15
16
  namespace faiss { namespace gpu {
@@ -10,6 +10,12 @@
10
10
 
11
11
  #include <cuda.h>
12
12
 
13
+ // allow usage for non-CUDA files
14
+ #ifndef __host__
15
+ #define __host__
16
+ #define __device__
17
+ #endif
18
+
13
19
  namespace faiss { namespace gpu { namespace utils {
14
20
 
15
21
  template <typename U, typename V>
@@ -9,6 +9,7 @@
9
9
  #include <faiss/gpu/utils/Timer.h>
10
10
  #include <faiss/gpu/utils/DeviceUtils.h>
11
11
  #include <faiss/impl/FaissAssert.h>
12
+ #include <chrono>
12
13
 
13
14
  namespace faiss { namespace gpu {
14
15
 
@@ -43,18 +44,16 @@ KernelTimer::elapsedMilliseconds() {
43
44
  }
44
45
 
45
46
  CpuTimer::CpuTimer() {
46
- clock_gettime(CLOCK_REALTIME, &start_);
47
+ start_ = std::chrono::steady_clock::now();
47
48
  }
48
49
 
49
50
  float
50
51
  CpuTimer::elapsedMilliseconds() {
51
- struct timespec end;
52
- clock_gettime(CLOCK_REALTIME, &end);
52
+ auto end = std::chrono::steady_clock::now();
53
53
 
54
- auto diffS = end.tv_sec - start_.tv_sec;
55
- auto diffNs = end.tv_nsec - start_.tv_nsec;
54
+ std::chrono::duration<float, std::milli> duration = end - start_;
56
55
 
57
- return 1000.0f * (float) diffS + ((float) diffNs) / 1000000.0f;
56
+ return duration.count();
58
57
  }
59
58
 
60
59
  } } // namespace
@@ -9,7 +9,7 @@
9
9
  #pragma once
10
10
 
11
11
  #include <cuda_runtime.h>
12
- #include <time.h>
12
+ #include <chrono>
13
13
 
14
14
  namespace faiss { namespace gpu {
15
15
 
@@ -46,7 +46,7 @@ class CpuTimer {
46
46
  float elapsedMilliseconds();
47
47
 
48
48
  private:
49
- struct timespec start_;
49
+ std::chrono::time_point<std::chrono::steady_clock> start_;
50
50
  };
51
51
 
52
52
  } } // namespace
@@ -199,12 +199,13 @@ struct RangeSearchPartialResult: BufferList {
199
199
  *
200
200
  * The DistanceComputer is not intended to be thread-safe (eg. because
201
201
  * it maintains counters) so the distance functions are not const,
202
- * instanciate one from each thread if needed.
202
+ * instantiate one from each thread if needed.
203
203
  ***********************************************************/
204
204
  struct DistanceComputer {
205
205
  using idx_t = Index::idx_t;
206
206
 
207
- /// called before computing distances
207
+ /// called before computing distances. Pointer x should remain valid
208
+ /// while operator () is called
208
209
  virtual void set_query(const float *x) = 0;
209
210
 
210
211
  /// compute distance of vector i to current query
@@ -233,9 +234,9 @@ struct FAISS_API InterruptCallback {
233
234
 
234
235
  /** check if:
235
236
  * - an interrupt callback is set
236
- * - the callback retuns true
237
+ * - the callback returns true
237
238
  * if this is the case, then throw an exception. Should not be called
238
- * from multiple threds.
239
+ * from multiple threads.
239
240
  */
240
241
  static void check ();
241
242
 
@@ -539,8 +539,7 @@ int HNSW::search_from_candidates(
539
539
  if (nres < k) {
540
540
  faiss::maxheap_push(++nres, D, I, d, v1);
541
541
  } else if (d < D[0]) {
542
- faiss::maxheap_pop(nres--, D, I);
543
- faiss::maxheap_push(++nres, D, I, d, v1);
542
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
544
543
  }
545
544
  vt.set(v1);
546
545
  }
@@ -578,8 +577,7 @@ int HNSW::search_from_candidates(
578
577
  if (nres < k) {
579
578
  faiss::maxheap_push(++nres, D, I, d, v1);
580
579
  } else if (d < D[0]) {
581
- faiss::maxheap_pop(nres--, D, I);
582
- faiss::maxheap_push(++nres, D, I, d, v1);
580
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
583
581
  }
584
582
  candidates.push(v1, d);
585
583
  }
@@ -21,14 +21,14 @@ namespace faiss {
21
21
  struct SimulatedAnnealingParameters {
22
22
 
23
23
  // optimization parameters
24
- double init_temperature; // init probaility of accepting a bad swap
24
+ double init_temperature; // init probability of accepting a bad swap
25
25
  double temperature_decay; // at each iteration the temp is multiplied by this
26
26
  int n_iter; // nb of iterations
27
27
  int n_redo; // nb of runs of the simulation
28
28
  int seed; // random seed
29
29
  int verbose;
30
30
  bool only_bit_flips; // restrict permutation changes to bit flips
31
- bool init_random; // intialize with a random permutation (not identity)
31
+ bool init_random; // initialize with a random permutation (not identity)
32
32
 
33
33
  // set reasonable defaults
34
34
  SimulatedAnnealingParameters ();
@@ -57,7 +57,7 @@ struct ReproduceDistancesObjective : PermutationObjective {
57
57
 
58
58
  static double sqr (double x) { return x * x; }
59
59
 
60
- // weihgting of distances: it is more important to reproduce small
60
+ // weighting of distances: it is more important to reproduce small
61
61
  // distances well
62
62
  double dis_weight (double x) const;
63
63
 
@@ -139,7 +139,7 @@ struct PolysemousTraining: SimulatedAnnealingParameters {
139
139
  // sets default values
140
140
  PolysemousTraining ();
141
141
 
142
- /// reorder the centroids so that the Hamming distace becomes a
142
+ /// reorder the centroids so that the Hamming distance becomes a
143
143
  /// good approximation of the SDC distance (called by train)
144
144
  void optimize_pq_for_hamming (ProductQuantizer & pq,
145
145
  size_t n, const float *x) const;
@@ -63,8 +63,7 @@ void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
63
63
  }
64
64
 
65
65
  if (C::cmp (heap_dis[0], dis)) {
66
- heap_pop<C> (k, heap_dis, heap_ids);
67
- heap_push<C> (k, heap_dis, heap_ids, dis, j);
66
+ heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
68
67
  }
69
68
  }
70
69
  }
@@ -89,8 +88,7 @@ void pq_estimators_from_tables_M4 (const CT * codes,
89
88
  dis += dt[*codes++];
90
89
 
91
90
  if (C::cmp (heap_dis[0], dis)) {
92
- heap_pop<C> (k, heap_dis, heap_ids);
93
- heap_push<C> (k, heap_dis, heap_ids, dis, j);
91
+ heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
94
92
  }
95
93
  }
96
94
  }
@@ -132,8 +130,7 @@ static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
132
130
  dt += ksub;
133
131
  }
134
132
  if (C::cmp (heap_dis[0], dis)) {
135
- heap_pop<C> (k, heap_dis, heap_ids);
136
- heap_push<C> (k, heap_dis, heap_ids, dis, j);
133
+ heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
137
134
  }
138
135
  }
139
136
  }
@@ -163,8 +160,7 @@ static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq,
163
160
  }
164
161
 
165
162
  if (C::cmp(heap_dis[0], dis)) {
166
- heap_pop<C>(k, heap_dis, heap_ids);
167
- heap_push<C>(k, heap_dis, heap_ids, dis, j);
163
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
168
164
  }
169
165
  }
170
166
  }
@@ -186,7 +182,7 @@ ProductQuantizer::ProductQuantizer ()
186
182
 
187
183
  void ProductQuantizer::set_derived_values () {
188
184
  // quite a few derived values
189
- FAISS_THROW_IF_NOT (d % M == 0);
185
+ FAISS_THROW_IF_NOT_MSG (d % M == 0, "The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
190
186
  dsub = d / M;
191
187
  code_size = (nbits * M + 7) / 8;
192
188
  ksub = 1 << nbits;
@@ -549,6 +545,14 @@ void ProductQuantizer::compute_distance_tables (
549
545
  float * dis_tables) const
550
546
  {
551
547
 
548
+ #ifdef __AVX2__
549
+ if (dsub == 2 && nbits < 8) { // interesting for a narrow range of settings
550
+ compute_PQ_dis_tables_dsub2(
551
+ d, ksub, centroids.data(),
552
+ nx, x, false, dis_tables
553
+ );
554
+ } else
555
+ #endif
552
556
  if (dsub < 16) {
553
557
 
554
558
  #pragma omp parallel for
@@ -573,7 +577,14 @@ void ProductQuantizer::compute_inner_prod_tables (
573
577
  const float * x,
574
578
  float * dis_tables) const
575
579
  {
576
-
580
+ #ifdef __AVX2__
581
+ if (dsub == 2 && nbits < 8) {
582
+ compute_PQ_dis_tables_dsub2(
583
+ d, ksub, centroids.data(),
584
+ nx, x, true, dis_tables
585
+ );
586
+ } else
587
+ #endif
577
588
  if (dsub < 16) {
578
589
 
579
590
  #pragma omp parallel for
@@ -747,8 +758,7 @@ void ProductQuantizer::search_sdc (const uint8_t * qcodes,
747
758
  tab += ksub * ksub;
748
759
  }
749
760
  if (dis < heap_dis[0]) {
750
- maxheap_pop (k, heap_dis, heap_ids);
751
- maxheap_push (k, heap_dis, heap_ids, dis, j);
761
+ maxheap_replace_top (k, heap_dis, heap_ids, dis, j);
752
762
  }
753
763
  bcode += code_size;
754
764
  }
@@ -219,12 +219,14 @@ struct PQDecoderGeneric {
219
219
  };
220
220
 
221
221
  struct PQDecoder8 {
222
+ static const int nbits = 8;
222
223
  const uint8_t *code;
223
224
  PQDecoder8(const uint8_t *code, int nbits);
224
225
  uint64_t decode();
225
226
  };
226
227
 
227
228
  struct PQDecoder16 {
229
+ static const int nbits = 16;
228
230
  const uint16_t *code;
229
231
  PQDecoder16(const uint8_t *code, int nbits);
230
232
  uint64_t decode();
@@ -0,0 +1,452 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+ /*
10
+ * Structures that collect search results from distance computations
11
+ */
12
+
13
+ #pragma once
14
+
15
+
16
+ #include <faiss/utils/Heap.h>
17
+ #include <faiss/utils/partitioning.h>
18
+ #include <faiss/impl/AuxIndexStructures.h>
19
+
20
+
21
+ namespace faiss {
22
+
23
+
24
+
25
+ /*****************************************************************
26
+ * Heap based result handler
27
+ *****************************************************************/
28
+
29
+
30
+ template<class C>
31
+ struct HeapResultHandler {
32
+
33
+ using T = typename C::T;
34
+ using TI = typename C::TI;
35
+
36
+ int nq;
37
+ T *heap_dis_tab;
38
+ TI *heap_ids_tab;
39
+
40
+ int64_t k; // number of results to keep
41
+
42
+ HeapResultHandler(
43
+ size_t nq,
44
+ T * heap_dis_tab, TI * heap_ids_tab,
45
+ size_t k):
46
+ nq(nq),
47
+ heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
48
+ {
49
+
50
+ }
51
+
52
+ /******************************************************
53
+ * API for 1 result at a time (each SingleResultHandler is
54
+ * called from 1 thread)
55
+ */
56
+
57
+ struct SingleResultHandler {
58
+ HeapResultHandler & hr;
59
+ size_t k;
60
+
61
+ T *heap_dis;
62
+ TI *heap_ids;
63
+ T thresh;
64
+
65
+ SingleResultHandler(HeapResultHandler &hr): hr(hr), k(hr.k) {}
66
+
67
+ /// begin results for query # i
68
+ void begin(size_t i) {
69
+ heap_dis = hr.heap_dis_tab + i * k;
70
+ heap_ids = hr.heap_ids_tab + i * k;
71
+ heap_heapify<C> (k, heap_dis, heap_ids);
72
+ thresh = heap_dis[0];
73
+ }
74
+
75
+ /// add one result for query i
76
+ void add_result(T dis, TI idx) {
77
+ if (C::cmp(heap_dis[0], dis)) {
78
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
79
+ thresh = heap_dis[0];
80
+ }
81
+ }
82
+
83
+ /// series of results for query i is done
84
+ void end() {
85
+ heap_reorder<C> (k, heap_dis, heap_ids);
86
+ }
87
+ };
88
+
89
+
90
+ /******************************************************
91
+ * API for multiple results (called from 1 thread)
92
+ */
93
+
94
+ size_t i0, i1;
95
+
96
+ /// begin
97
+ void begin_multiple(size_t i0, size_t i1) {
98
+ this->i0 = i0;
99
+ this->i1 = i1;
100
+ for(size_t i = i0; i < i1; i++) {
101
+ heap_heapify<C> (k, heap_dis_tab + i * k, heap_ids_tab + i * k);
102
+ }
103
+ }
104
+
105
+ /// add results for query i0..i1 and j0..j1
106
+ void add_results(size_t j0, size_t j1, const T *dis_tab) {
107
+ // maybe parallel for
108
+ for (size_t i = i0; i < i1; i++) {
109
+ T * heap_dis = heap_dis_tab + i * k;
110
+ TI * heap_ids = heap_ids_tab + i * k;
111
+ T thresh = heap_dis[0];
112
+ for (size_t j = j0; j < j1; j++) {
113
+ T dis = *dis_tab++;
114
+ if (C::cmp(thresh, dis)) {
115
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
116
+ thresh = heap_dis[0];
117
+ }
118
+ }
119
+ }
120
+ }
121
+
122
+ /// series of results for queries i0..i1 is done
123
+ void end_multiple() {
124
+ // maybe parallel for
125
+ for(size_t i = i0; i < i1; i++) {
126
+ heap_reorder<C> (k, heap_dis_tab + i * k, heap_ids_tab + i * k);
127
+ }
128
+ }
129
+
130
+ };
131
+
132
+ /*****************************************************************
133
+ * Reservoir result handler
134
+ *
135
+ * A reservoir is a result array of size capacity > n (number of requested
136
+ * results) all results below a threshold are stored in an arbitrary order. When
137
+ * the capacity is reached, a new threshold is chosen by partitionning the
138
+ * distance array.
139
+ *****************************************************************/
140
+
141
+
142
+
143
+ /// Reservoir for a single query
144
+ template<class C>
145
+ struct ReservoirTopN {
146
+ using T = typename C::T;
147
+ using TI = typename C::TI;
148
+
149
+ T *vals;
150
+ TI *ids;
151
+
152
+ size_t i; // number of stored elements
153
+ size_t n; // number of requested elements
154
+ size_t capacity; // size of storage
155
+
156
+ T threshold; // current threshold
157
+
158
+ ReservoirTopN() {}
159
+
160
+ ReservoirTopN(
161
+ size_t n, size_t capacity,
162
+ T *vals, TI *ids
163
+ ):
164
+ vals(vals), ids(ids),
165
+ i(0), n(n), capacity(capacity) {
166
+ assert(n < capacity);
167
+ threshold = C::neutral();
168
+ }
169
+
170
+ void add(T val, TI id) {
171
+ if (C::cmp(threshold, val)) {
172
+ if (i == capacity) {
173
+ shrink_fuzzy();
174
+ }
175
+ vals[i] = val;
176
+ ids[i] = id;
177
+ i++;
178
+ }
179
+ }
180
+
181
+ // reduce storage from capacity to anything
182
+ // between n and (capacity + n) / 2
183
+ void shrink_fuzzy() {
184
+ assert(i == capacity);
185
+
186
+ threshold = partition_fuzzy<C>(
187
+ vals, ids, capacity, n, (capacity + n) / 2,
188
+ &i);
189
+ }
190
+
191
+ void to_result(T *heap_dis, TI *heap_ids) const {
192
+
193
+ for (int j = 0; j < std::min(i, n); j++) {
194
+ heap_push<C>(
195
+ j + 1, heap_dis, heap_ids,
196
+ vals[j], ids[j]
197
+ );
198
+ }
199
+
200
+ if (i < n) {
201
+ heap_reorder<C> (i, heap_dis, heap_ids);
202
+ // add empty results
203
+ heap_heapify<C> (n - i, heap_dis + i, heap_ids + i);
204
+ } else {
205
+ // add remaining elements
206
+ heap_addn<C> (n, heap_dis, heap_ids, vals + n, ids + n, i - n);
207
+ heap_reorder<C> (n, heap_dis, heap_ids);
208
+ }
209
+
210
+ }
211
+
212
+ };
213
+
214
+
215
+
216
+ template<class C>
217
+ struct ReservoirResultHandler {
218
+
219
+ using T = typename C::T;
220
+ using TI = typename C::TI;
221
+
222
+ int nq;
223
+ T *heap_dis_tab;
224
+ TI *heap_ids_tab;
225
+
226
+ int64_t k; // number of results to keep
227
+ size_t capacity; // capacity of the reservoirs
228
+
229
+ ReservoirResultHandler(
230
+ size_t nq,
231
+ T * heap_dis_tab, TI * heap_ids_tab,
232
+ size_t k):
233
+ nq(nq),
234
+ heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
235
+ {
236
+ // double then round up to multiple of 16 (for SIMD alignment)
237
+ capacity = (2 * k + 15) & ~15;
238
+ }
239
+
240
+ /******************************************************
241
+ * API for 1 result at a time (each SingleResultHandler is
242
+ * called from 1 thread)
243
+ */
244
+
245
+ struct SingleResultHandler {
246
+ ReservoirResultHandler & hr;
247
+
248
+ std::vector<T> reservoir_dis;
249
+ std::vector<TI> reservoir_ids;
250
+ ReservoirTopN<C> res1;
251
+
252
+ SingleResultHandler(ReservoirResultHandler &hr):
253
+ hr(hr), reservoir_dis(hr.capacity), reservoir_ids(hr.capacity)
254
+ {
255
+ }
256
+
257
+ size_t i;
258
+
259
+ /// begin results for query # i
260
+ void begin(size_t i) {
261
+ res1 = ReservoirTopN<C>(
262
+ hr.k, hr.capacity, reservoir_dis.data(), reservoir_ids.data());
263
+ this->i = i;
264
+ }
265
+
266
+ /// add one result for query i
267
+ void add_result(T dis, TI idx) {
268
+ res1.add(dis, idx);
269
+ }
270
+
271
+ /// series of results for query i is done
272
+ void end() {
273
+ T * heap_dis = hr.heap_dis_tab + i * hr.k;
274
+ TI * heap_ids = hr.heap_ids_tab + i * hr.k;
275
+ res1.to_result(heap_dis, heap_ids);
276
+ }
277
+ };
278
+
279
+ /******************************************************
280
+ * API for multiple results (called from 1 thread)
281
+ */
282
+
283
+ size_t i0, i1;
284
+
285
+ std::vector<T> reservoir_dis;
286
+ std::vector<TI> reservoir_ids;
287
+ std::vector<ReservoirTopN<C>> reservoirs;
288
+
289
+ /// begin
290
+ void begin_multiple(size_t i0, size_t i1) {
291
+ this->i0 = i0;
292
+ this->i1 = i1;
293
+ reservoir_dis.resize((i1 - i0) * capacity);
294
+ reservoir_ids.resize((i1 - i0) * capacity);
295
+ reservoirs.clear();
296
+ for (size_t i = i0; i < i1; i++) {
297
+ reservoirs.emplace_back(
298
+ k, capacity,
299
+ reservoir_dis.data() + (i - i0) * capacity,
300
+ reservoir_ids.data() + (i - i0) * capacity
301
+ );
302
+ }
303
+ }
304
+
305
+ /// add results for query i0..i1 and j0..j1
306
+ void add_results(size_t j0, size_t j1, const T *dis_tab) {
307
+ // maybe parallel for
308
+ for (size_t i = i0; i < i1; i++) {
309
+ ReservoirTopN<C> & reservoir = reservoirs[i - i0];
310
+ for (size_t j = j0; j < j1; j++) {
311
+ T dis = *dis_tab++;
312
+ reservoir.add(dis, j);
313
+ }
314
+ }
315
+ }
316
+
317
+ /// series of results for queries i0..i1 is done
318
+ void end_multiple() {
319
+ // maybe parallel for
320
+ for(size_t i = i0; i < i1; i++) {
321
+ reservoirs[i - i0].to_result(
322
+ heap_dis_tab + i * k, heap_ids_tab + i * k);
323
+ }
324
+ }
325
+
326
+ };
327
+
328
+
329
+ /*****************************************************************
330
+ * Result handler for range searches
331
+ *****************************************************************/
332
+
333
+
334
+
335
+ template<class C>
336
+ struct RangeSearchResultHandler {
337
+ using T = typename C::T;
338
+ using TI = typename C::TI;
339
+
340
+ RangeSearchResult *res;
341
+ float radius;
342
+
343
+ RangeSearchResultHandler(RangeSearchResult *res, float radius):
344
+ res(res), radius(radius)
345
+ {}
346
+
347
+ /******************************************************
348
+ * API for 1 result at a time (each SingleResultHandler is
349
+ * called from 1 thread)
350
+ ******************************************************/
351
+
352
+ struct SingleResultHandler {
353
+ // almost the same interface as RangeSearchResultHandler
354
+ RangeSearchPartialResult pres;
355
+ float radius;
356
+ RangeQueryResult *qr = nullptr;
357
+
358
+ SingleResultHandler(RangeSearchResultHandler &rh):
359
+ pres(rh.res), radius(rh.radius)
360
+ {}
361
+
362
+ /// begin results for query # i
363
+ void begin(size_t i) {
364
+ qr = &pres.new_result(i);
365
+ }
366
+
367
+ /// add one result for query i
368
+ void add_result(T dis, TI idx) {
369
+
370
+ if (C::cmp(radius, dis)) {
371
+ qr->add(dis, idx);
372
+ }
373
+ }
374
+
375
+ /// series of results for query i is done
376
+ void end() {
377
+ }
378
+
379
+ ~SingleResultHandler() {
380
+ pres.finalize();
381
+ }
382
+ };
383
+
384
+ /******************************************************
385
+ * API for multiple results (called from 1 thread)
386
+ ******************************************************/
387
+
388
+ size_t i0, i1;
389
+
390
+ std::vector <RangeSearchPartialResult *> partial_results;
391
+ std::vector <size_t> j0s;
392
+ int pr = 0;
393
+
394
+ /// begin
395
+ void begin_multiple(size_t i0, size_t i1) {
396
+ this->i0 = i0;
397
+ this->i1 = i1;
398
+ }
399
+
400
+ /// add results for query i0..i1 and j0..j1
401
+
402
+ void add_results(size_t j0, size_t j1, const T *dis_tab) {
403
+ RangeSearchPartialResult *pres;
404
+ // there is one RangeSearchPartialResult structure per j0
405
+ // (= block of columns of the large distance matrix)
406
+ // it is a bit tricky to find the poper PartialResult structure
407
+ // because the inner loop is on db not on queries.
408
+
409
+ if (pr < j0s.size() && j0 == j0s[pr]) {
410
+ pres = partial_results[pr];
411
+ pr++;
412
+ } else if (j0 == 0 && j0s.size() > 0) {
413
+ pr = 0;
414
+ pres = partial_results[pr];
415
+ pr++;
416
+ } else { // did not find this j0
417
+ pres = new RangeSearchPartialResult (res);
418
+ partial_results.push_back(pres);
419
+ j0s.push_back(j0);
420
+ pr = partial_results.size();
421
+ }
422
+
423
+ for (size_t i = i0; i < i1; i++) {
424
+ const float *ip_line = dis_tab + (i - i0) * (j1 - j0);
425
+ RangeQueryResult & qres = pres->new_result (i);
426
+
427
+ for (size_t j = j0; j < j1; j++) {
428
+ float dis = *ip_line++;
429
+ if (C::cmp(radius, dis)) {
430
+ qres.add (dis, j);
431
+ }
432
+ }
433
+ }
434
+ }
435
+
436
+ void end_multiple() {
437
+
438
+ }
439
+
440
+ ~RangeSearchResultHandler() {
441
+ if (partial_results.size() > 0) {
442
+ RangeSearchPartialResult::merge (partial_results);
443
+ }
444
+ }
445
+
446
+ };
447
+
448
+
449
+
450
+
451
+ } // namespace faiss
452
+