faiss 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +1 -1
  6. data/lib/faiss/version.rb +1 -1
  7. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  8. data/vendor/faiss/faiss/AutoTune.h +6 -3
  9. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  10. data/vendor/faiss/faiss/Index.cpp +3 -4
  11. data/vendor/faiss/faiss/Index.h +3 -3
  12. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  13. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  16. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  17. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  19. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  20. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  21. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  22. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  24. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  25. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  26. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  27. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  28. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  29. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  30. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  31. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  32. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  33. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  34. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  35. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  36. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  37. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  38. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  39. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  40. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  41. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  42. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  43. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  44. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  47. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  48. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  49. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  50. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  51. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  52. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  53. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  54. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  55. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  56. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  57. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  58. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  59. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  60. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  61. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  62. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  63. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  64. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  65. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  71. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  72. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  73. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  74. data/vendor/faiss/faiss/impl/io.h +7 -2
  75. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  76. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  77. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  78. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  79. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  81. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  82. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  83. data/vendor/faiss/faiss/index_io.h +1 -48
  84. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  85. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  86. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  87. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  88. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  89. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  90. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  91. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  92. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  93. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  94. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  95. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  96. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  97. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  98. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  99. data/vendor/faiss/faiss/utils/distances.h +28 -20
  100. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  101. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  102. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  103. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  104. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  105. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  106. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  107. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  108. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  109. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  110. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  111. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  112. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  113. metadata +43 -141
  114. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  115. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  116. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  117. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  118. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  119. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  120. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  121. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  122. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  123. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  124. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  125. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  126. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  127. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  128. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  129. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  130. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  131. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  132. data/vendor/faiss/c_api/Index_c.h +0 -183
  133. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  134. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  135. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  136. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  137. data/vendor/faiss/c_api/error_c.h +0 -42
  138. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  139. data/vendor/faiss/c_api/error_impl.h +0 -16
  140. data/vendor/faiss/c_api/faiss_c.h +0 -58
  141. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  142. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  143. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  144. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  145. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  146. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  147. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  148. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  149. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  150. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  151. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  152. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  153. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  154. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  155. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  156. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  157. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  158. data/vendor/faiss/c_api/index_io_c.h +0 -50
  159. data/vendor/faiss/c_api/macros_impl.h +0 -110
  160. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  161. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  162. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  163. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  164. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  165. data/vendor/faiss/misc/test_blas.cpp +0 -87
  166. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  167. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  168. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  169. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  170. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  171. data/vendor/faiss/tests/test_merge.cpp +0 -260
  172. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  173. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  174. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  175. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  176. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  177. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  178. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  179. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  180. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  181. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  182. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  183. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  184. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -10,7 +10,10 @@
10
10
 
11
11
  #include <faiss/impl/FaissAssert.h>
12
12
  #include <faiss/Index.h>
13
+ #include <faiss/invlists/InvertedLists.h>
13
14
  #include <initializer_list>
15
+ #include <gtest/gtest.h>
16
+ #include <cstring>
14
17
  #include <memory>
15
18
  #include <string>
16
19
  #include <vector>
@@ -90,4 +93,34 @@ void compareLists(const float* refDist,
90
93
  float pctMaxDiff1 = 0.1f,
91
94
  float pctMaxDiffN = 0.005f);
92
95
 
96
+ /// Compare IVF lists between a CPU and GPU index
97
+ template <typename A, typename B>
98
+ void testIVFEquality(A& cpuIndex, B& gpuIndex) {
99
+ // Ensure equality of the inverted lists
100
+ EXPECT_EQ(cpuIndex.nlist, gpuIndex.nlist);
101
+
102
+ for (int i = 0; i < cpuIndex.nlist; ++i) {
103
+ auto cpuLists = cpuIndex.invlists;
104
+
105
+ // Code equality
106
+ EXPECT_EQ(cpuLists->list_size(i), gpuIndex.getListLength(i));
107
+ std::vector<uint8_t> cpuCodes(cpuLists->list_size(i) * cpuLists->code_size);
108
+
109
+ auto sc = faiss::InvertedLists::ScopedCodes(cpuLists, i);
110
+ std::memcpy(cpuCodes.data(), sc.get(),
111
+ cpuLists->list_size(i) * cpuLists->code_size);
112
+
113
+ auto gpuCodes = gpuIndex.getListVectorData(i, false);
114
+ EXPECT_EQ(cpuCodes, gpuCodes);
115
+
116
+ // Index equality
117
+ std::vector<Index::idx_t> cpuIndices(cpuLists->list_size(i));
118
+
119
+ auto si = faiss::InvertedLists::ScopedIds(cpuLists, i);
120
+ std::memcpy(cpuIndices.data(), si.get(),
121
+ cpuLists->list_size(i) * sizeof(faiss::Index::idx_t));
122
+ EXPECT_EQ(cpuIndices, gpuIndex.getListIndices(i));
123
+ }
124
+ }
125
+
93
126
  } }
@@ -10,6 +10,7 @@
10
10
  #include <faiss/gpu/utils/DeviceUtils.h>
11
11
  #include <faiss/gpu/utils/StaticUtils.h>
12
12
  #include <faiss/impl/FaissAssert.h>
13
+ #include <algorithm>
13
14
  #include <sstream>
14
15
 
15
16
  namespace faiss { namespace gpu {
@@ -10,6 +10,12 @@
10
10
 
11
11
  #include <cuda.h>
12
12
 
13
+ // allow usage for non-CUDA files
14
+ #ifndef __host__
15
+ #define __host__
16
+ #define __device__
17
+ #endif
18
+
13
19
  namespace faiss { namespace gpu { namespace utils {
14
20
 
15
21
  template <typename U, typename V>
@@ -9,6 +9,7 @@
9
9
  #include <faiss/gpu/utils/Timer.h>
10
10
  #include <faiss/gpu/utils/DeviceUtils.h>
11
11
  #include <faiss/impl/FaissAssert.h>
12
+ #include <chrono>
12
13
 
13
14
  namespace faiss { namespace gpu {
14
15
 
@@ -43,18 +44,16 @@ KernelTimer::elapsedMilliseconds() {
43
44
  }
44
45
 
45
46
  CpuTimer::CpuTimer() {
46
- clock_gettime(CLOCK_REALTIME, &start_);
47
+ start_ = std::chrono::steady_clock::now();
47
48
  }
48
49
 
49
50
  float
50
51
  CpuTimer::elapsedMilliseconds() {
51
- struct timespec end;
52
- clock_gettime(CLOCK_REALTIME, &end);
52
+ auto end = std::chrono::steady_clock::now();
53
53
 
54
- auto diffS = end.tv_sec - start_.tv_sec;
55
- auto diffNs = end.tv_nsec - start_.tv_nsec;
54
+ std::chrono::duration<float, std::milli> duration = end - start_;
56
55
 
57
- return 1000.0f * (float) diffS + ((float) diffNs) / 1000000.0f;
56
+ return duration.count();
58
57
  }
59
58
 
60
59
  } } // namespace
@@ -9,7 +9,7 @@
9
9
  #pragma once
10
10
 
11
11
  #include <cuda_runtime.h>
12
- #include <time.h>
12
+ #include <chrono>
13
13
 
14
14
  namespace faiss { namespace gpu {
15
15
 
@@ -46,7 +46,7 @@ class CpuTimer {
46
46
  float elapsedMilliseconds();
47
47
 
48
48
  private:
49
- struct timespec start_;
49
+ std::chrono::time_point<std::chrono::steady_clock> start_;
50
50
  };
51
51
 
52
52
  } } // namespace
@@ -199,12 +199,13 @@ struct RangeSearchPartialResult: BufferList {
199
199
  *
200
200
  * The DistanceComputer is not intended to be thread-safe (eg. because
201
201
  * it maintains counters) so the distance functions are not const,
202
- * instanciate one from each thread if needed.
202
+ * instantiate one from each thread if needed.
203
203
  ***********************************************************/
204
204
  struct DistanceComputer {
205
205
  using idx_t = Index::idx_t;
206
206
 
207
- /// called before computing distances
207
+ /// called before computing distances. Pointer x should remain valid
208
+ /// while operator () is called
208
209
  virtual void set_query(const float *x) = 0;
209
210
 
210
211
  /// compute distance of vector i to current query
@@ -233,9 +234,9 @@ struct FAISS_API InterruptCallback {
233
234
 
234
235
  /** check if:
235
236
  * - an interrupt callback is set
236
- * - the callback retuns true
237
+ * - the callback returns true
237
238
  * if this is the case, then throw an exception. Should not be called
238
- * from multiple threds.
239
+ * from multiple threads.
239
240
  */
240
241
  static void check ();
241
242
 
@@ -539,8 +539,7 @@ int HNSW::search_from_candidates(
539
539
  if (nres < k) {
540
540
  faiss::maxheap_push(++nres, D, I, d, v1);
541
541
  } else if (d < D[0]) {
542
- faiss::maxheap_pop(nres--, D, I);
543
- faiss::maxheap_push(++nres, D, I, d, v1);
542
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
544
543
  }
545
544
  vt.set(v1);
546
545
  }
@@ -578,8 +577,7 @@ int HNSW::search_from_candidates(
578
577
  if (nres < k) {
579
578
  faiss::maxheap_push(++nres, D, I, d, v1);
580
579
  } else if (d < D[0]) {
581
- faiss::maxheap_pop(nres--, D, I);
582
- faiss::maxheap_push(++nres, D, I, d, v1);
580
+ faiss::maxheap_replace_top(nres, D, I, d, v1);
583
581
  }
584
582
  candidates.push(v1, d);
585
583
  }
@@ -21,14 +21,14 @@ namespace faiss {
21
21
  struct SimulatedAnnealingParameters {
22
22
 
23
23
  // optimization parameters
24
- double init_temperature; // init probaility of accepting a bad swap
24
+ double init_temperature; // init probability of accepting a bad swap
25
25
  double temperature_decay; // at each iteration the temp is multiplied by this
26
26
  int n_iter; // nb of iterations
27
27
  int n_redo; // nb of runs of the simulation
28
28
  int seed; // random seed
29
29
  int verbose;
30
30
  bool only_bit_flips; // restrict permutation changes to bit flips
31
- bool init_random; // intialize with a random permutation (not identity)
31
+ bool init_random; // initialize with a random permutation (not identity)
32
32
 
33
33
  // set reasonable defaults
34
34
  SimulatedAnnealingParameters ();
@@ -57,7 +57,7 @@ struct ReproduceDistancesObjective : PermutationObjective {
57
57
 
58
58
  static double sqr (double x) { return x * x; }
59
59
 
60
- // weihgting of distances: it is more important to reproduce small
60
+ // weighting of distances: it is more important to reproduce small
61
61
  // distances well
62
62
  double dis_weight (double x) const;
63
63
 
@@ -139,7 +139,7 @@ struct PolysemousTraining: SimulatedAnnealingParameters {
139
139
  // sets default values
140
140
  PolysemousTraining ();
141
141
 
142
- /// reorder the centroids so that the Hamming distace becomes a
142
+ /// reorder the centroids so that the Hamming distance becomes a
143
143
  /// good approximation of the SDC distance (called by train)
144
144
  void optimize_pq_for_hamming (ProductQuantizer & pq,
145
145
  size_t n, const float *x) const;
@@ -63,8 +63,7 @@ void pq_estimators_from_tables_Mmul4 (int M, const CT * codes,
63
63
  }
64
64
 
65
65
  if (C::cmp (heap_dis[0], dis)) {
66
- heap_pop<C> (k, heap_dis, heap_ids);
67
- heap_push<C> (k, heap_dis, heap_ids, dis, j);
66
+ heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
68
67
  }
69
68
  }
70
69
  }
@@ -89,8 +88,7 @@ void pq_estimators_from_tables_M4 (const CT * codes,
89
88
  dis += dt[*codes++];
90
89
 
91
90
  if (C::cmp (heap_dis[0], dis)) {
92
- heap_pop<C> (k, heap_dis, heap_ids);
93
- heap_push<C> (k, heap_dis, heap_ids, dis, j);
91
+ heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
94
92
  }
95
93
  }
96
94
  }
@@ -132,8 +130,7 @@ static inline void pq_estimators_from_tables (const ProductQuantizer& pq,
132
130
  dt += ksub;
133
131
  }
134
132
  if (C::cmp (heap_dis[0], dis)) {
135
- heap_pop<C> (k, heap_dis, heap_ids);
136
- heap_push<C> (k, heap_dis, heap_ids, dis, j);
133
+ heap_replace_top<C> (k, heap_dis, heap_ids, dis, j);
137
134
  }
138
135
  }
139
136
  }
@@ -163,8 +160,7 @@ static inline void pq_estimators_from_tables_generic(const ProductQuantizer& pq,
163
160
  }
164
161
 
165
162
  if (C::cmp(heap_dis[0], dis)) {
166
- heap_pop<C>(k, heap_dis, heap_ids);
167
- heap_push<C>(k, heap_dis, heap_ids, dis, j);
163
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
168
164
  }
169
165
  }
170
166
  }
@@ -186,7 +182,7 @@ ProductQuantizer::ProductQuantizer ()
186
182
 
187
183
  void ProductQuantizer::set_derived_values () {
188
184
  // quite a few derived values
189
- FAISS_THROW_IF_NOT (d % M == 0);
185
+ FAISS_THROW_IF_NOT_MSG (d % M == 0, "The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
190
186
  dsub = d / M;
191
187
  code_size = (nbits * M + 7) / 8;
192
188
  ksub = 1 << nbits;
@@ -549,6 +545,14 @@ void ProductQuantizer::compute_distance_tables (
549
545
  float * dis_tables) const
550
546
  {
551
547
 
548
+ #ifdef __AVX2__
549
+ if (dsub == 2 && nbits < 8) { // interesting for a narrow range of settings
550
+ compute_PQ_dis_tables_dsub2(
551
+ d, ksub, centroids.data(),
552
+ nx, x, false, dis_tables
553
+ );
554
+ } else
555
+ #endif
552
556
  if (dsub < 16) {
553
557
 
554
558
  #pragma omp parallel for
@@ -573,7 +577,14 @@ void ProductQuantizer::compute_inner_prod_tables (
573
577
  const float * x,
574
578
  float * dis_tables) const
575
579
  {
576
-
580
+ #ifdef __AVX2__
581
+ if (dsub == 2 && nbits < 8) {
582
+ compute_PQ_dis_tables_dsub2(
583
+ d, ksub, centroids.data(),
584
+ nx, x, true, dis_tables
585
+ );
586
+ } else
587
+ #endif
577
588
  if (dsub < 16) {
578
589
 
579
590
  #pragma omp parallel for
@@ -747,8 +758,7 @@ void ProductQuantizer::search_sdc (const uint8_t * qcodes,
747
758
  tab += ksub * ksub;
748
759
  }
749
760
  if (dis < heap_dis[0]) {
750
- maxheap_pop (k, heap_dis, heap_ids);
751
- maxheap_push (k, heap_dis, heap_ids, dis, j);
761
+ maxheap_replace_top (k, heap_dis, heap_ids, dis, j);
752
762
  }
753
763
  bcode += code_size;
754
764
  }
@@ -219,12 +219,14 @@ struct PQDecoderGeneric {
219
219
  };
220
220
 
221
221
  struct PQDecoder8 {
222
+ static const int nbits = 8;
222
223
  const uint8_t *code;
223
224
  PQDecoder8(const uint8_t *code, int nbits);
224
225
  uint64_t decode();
225
226
  };
226
227
 
227
228
  struct PQDecoder16 {
229
+ static const int nbits = 16;
228
230
  const uint16_t *code;
229
231
  PQDecoder16(const uint8_t *code, int nbits);
230
232
  uint64_t decode();
@@ -0,0 +1,452 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+ /*
10
+ * Structures that collect search results from distance computations
11
+ */
12
+
13
+ #pragma once
14
+
15
+
16
+ #include <faiss/utils/Heap.h>
17
+ #include <faiss/utils/partitioning.h>
18
+ #include <faiss/impl/AuxIndexStructures.h>
19
+
20
+
21
+ namespace faiss {
22
+
23
+
24
+
25
+ /*****************************************************************
26
+ * Heap based result handler
27
+ *****************************************************************/
28
+
29
+
30
+ template<class C>
31
+ struct HeapResultHandler {
32
+
33
+ using T = typename C::T;
34
+ using TI = typename C::TI;
35
+
36
+ int nq;
37
+ T *heap_dis_tab;
38
+ TI *heap_ids_tab;
39
+
40
+ int64_t k; // number of results to keep
41
+
42
+ HeapResultHandler(
43
+ size_t nq,
44
+ T * heap_dis_tab, TI * heap_ids_tab,
45
+ size_t k):
46
+ nq(nq),
47
+ heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
48
+ {
49
+
50
+ }
51
+
52
+ /******************************************************
53
+ * API for 1 result at a time (each SingleResultHandler is
54
+ * called from 1 thread)
55
+ */
56
+
57
+ struct SingleResultHandler {
58
+ HeapResultHandler & hr;
59
+ size_t k;
60
+
61
+ T *heap_dis;
62
+ TI *heap_ids;
63
+ T thresh;
64
+
65
+ SingleResultHandler(HeapResultHandler &hr): hr(hr), k(hr.k) {}
66
+
67
+ /// begin results for query # i
68
+ void begin(size_t i) {
69
+ heap_dis = hr.heap_dis_tab + i * k;
70
+ heap_ids = hr.heap_ids_tab + i * k;
71
+ heap_heapify<C> (k, heap_dis, heap_ids);
72
+ thresh = heap_dis[0];
73
+ }
74
+
75
+ /// add one result for query i
76
+ void add_result(T dis, TI idx) {
77
+ if (C::cmp(heap_dis[0], dis)) {
78
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
79
+ thresh = heap_dis[0];
80
+ }
81
+ }
82
+
83
+ /// series of results for query i is done
84
+ void end() {
85
+ heap_reorder<C> (k, heap_dis, heap_ids);
86
+ }
87
+ };
88
+
89
+
90
+ /******************************************************
91
+ * API for multiple results (called from 1 thread)
92
+ */
93
+
94
+ size_t i0, i1;
95
+
96
+ /// begin
97
+ void begin_multiple(size_t i0, size_t i1) {
98
+ this->i0 = i0;
99
+ this->i1 = i1;
100
+ for(size_t i = i0; i < i1; i++) {
101
+ heap_heapify<C> (k, heap_dis_tab + i * k, heap_ids_tab + i * k);
102
+ }
103
+ }
104
+
105
+ /// add results for query i0..i1 and j0..j1
106
+ void add_results(size_t j0, size_t j1, const T *dis_tab) {
107
+ // maybe parallel for
108
+ for (size_t i = i0; i < i1; i++) {
109
+ T * heap_dis = heap_dis_tab + i * k;
110
+ TI * heap_ids = heap_ids_tab + i * k;
111
+ T thresh = heap_dis[0];
112
+ for (size_t j = j0; j < j1; j++) {
113
+ T dis = *dis_tab++;
114
+ if (C::cmp(thresh, dis)) {
115
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
116
+ thresh = heap_dis[0];
117
+ }
118
+ }
119
+ }
120
+ }
121
+
122
+ /// series of results for queries i0..i1 is done
123
+ void end_multiple() {
124
+ // maybe parallel for
125
+ for(size_t i = i0; i < i1; i++) {
126
+ heap_reorder<C> (k, heap_dis_tab + i * k, heap_ids_tab + i * k);
127
+ }
128
+ }
129
+
130
+ };
131
+
132
+ /*****************************************************************
133
+ * Reservoir result handler
134
+ *
135
+ * A reservoir is a result array of size capacity > n (number of requested
136
+ * results) all results below a threshold are stored in an arbitrary order. When
137
+ * the capacity is reached, a new threshold is chosen by partitionning the
138
+ * distance array.
139
+ *****************************************************************/
140
+
141
+
142
+
143
+ /// Reservoir for a single query
144
+ template<class C>
145
+ struct ReservoirTopN {
146
+ using T = typename C::T;
147
+ using TI = typename C::TI;
148
+
149
+ T *vals;
150
+ TI *ids;
151
+
152
+ size_t i; // number of stored elements
153
+ size_t n; // number of requested elements
154
+ size_t capacity; // size of storage
155
+
156
+ T threshold; // current threshold
157
+
158
+ ReservoirTopN() {}
159
+
160
+ ReservoirTopN(
161
+ size_t n, size_t capacity,
162
+ T *vals, TI *ids
163
+ ):
164
+ vals(vals), ids(ids),
165
+ i(0), n(n), capacity(capacity) {
166
+ assert(n < capacity);
167
+ threshold = C::neutral();
168
+ }
169
+
170
+ void add(T val, TI id) {
171
+ if (C::cmp(threshold, val)) {
172
+ if (i == capacity) {
173
+ shrink_fuzzy();
174
+ }
175
+ vals[i] = val;
176
+ ids[i] = id;
177
+ i++;
178
+ }
179
+ }
180
+
181
+ // reduce storage from capacity to anything
182
+ // between n and (capacity + n) / 2
183
+ void shrink_fuzzy() {
184
+ assert(i == capacity);
185
+
186
+ threshold = partition_fuzzy<C>(
187
+ vals, ids, capacity, n, (capacity + n) / 2,
188
+ &i);
189
+ }
190
+
191
+ void to_result(T *heap_dis, TI *heap_ids) const {
192
+
193
+ for (int j = 0; j < std::min(i, n); j++) {
194
+ heap_push<C>(
195
+ j + 1, heap_dis, heap_ids,
196
+ vals[j], ids[j]
197
+ );
198
+ }
199
+
200
+ if (i < n) {
201
+ heap_reorder<C> (i, heap_dis, heap_ids);
202
+ // add empty results
203
+ heap_heapify<C> (n - i, heap_dis + i, heap_ids + i);
204
+ } else {
205
+ // add remaining elements
206
+ heap_addn<C> (n, heap_dis, heap_ids, vals + n, ids + n, i - n);
207
+ heap_reorder<C> (n, heap_dis, heap_ids);
208
+ }
209
+
210
+ }
211
+
212
+ };
213
+
214
+
215
+
216
+ template<class C>
217
+ struct ReservoirResultHandler {
218
+
219
+ using T = typename C::T;
220
+ using TI = typename C::TI;
221
+
222
+ int nq;
223
+ T *heap_dis_tab;
224
+ TI *heap_ids_tab;
225
+
226
+ int64_t k; // number of results to keep
227
+ size_t capacity; // capacity of the reservoirs
228
+
229
+ ReservoirResultHandler(
230
+ size_t nq,
231
+ T * heap_dis_tab, TI * heap_ids_tab,
232
+ size_t k):
233
+ nq(nq),
234
+ heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
235
+ {
236
+ // double then round up to multiple of 16 (for SIMD alignment)
237
+ capacity = (2 * k + 15) & ~15;
238
+ }
239
+
240
+ /******************************************************
241
+ * API for 1 result at a time (each SingleResultHandler is
242
+ * called from 1 thread)
243
+ */
244
+
245
+ struct SingleResultHandler {
246
+ ReservoirResultHandler & hr;
247
+
248
+ std::vector<T> reservoir_dis;
249
+ std::vector<TI> reservoir_ids;
250
+ ReservoirTopN<C> res1;
251
+
252
+ SingleResultHandler(ReservoirResultHandler &hr):
253
+ hr(hr), reservoir_dis(hr.capacity), reservoir_ids(hr.capacity)
254
+ {
255
+ }
256
+
257
+ size_t i;
258
+
259
+ /// begin results for query # i
260
+ void begin(size_t i) {
261
+ res1 = ReservoirTopN<C>(
262
+ hr.k, hr.capacity, reservoir_dis.data(), reservoir_ids.data());
263
+ this->i = i;
264
+ }
265
+
266
+ /// add one result for query i
267
+ void add_result(T dis, TI idx) {
268
+ res1.add(dis, idx);
269
+ }
270
+
271
+ /// series of results for query i is done
272
+ void end() {
273
+ T * heap_dis = hr.heap_dis_tab + i * hr.k;
274
+ TI * heap_ids = hr.heap_ids_tab + i * hr.k;
275
+ res1.to_result(heap_dis, heap_ids);
276
+ }
277
+ };
278
+
279
+ /******************************************************
280
+ * API for multiple results (called from 1 thread)
281
+ */
282
+
283
+ size_t i0, i1;
284
+
285
+ std::vector<T> reservoir_dis;
286
+ std::vector<TI> reservoir_ids;
287
+ std::vector<ReservoirTopN<C>> reservoirs;
288
+
289
+ /// begin
290
+ void begin_multiple(size_t i0, size_t i1) {
291
+ this->i0 = i0;
292
+ this->i1 = i1;
293
+ reservoir_dis.resize((i1 - i0) * capacity);
294
+ reservoir_ids.resize((i1 - i0) * capacity);
295
+ reservoirs.clear();
296
+ for (size_t i = i0; i < i1; i++) {
297
+ reservoirs.emplace_back(
298
+ k, capacity,
299
+ reservoir_dis.data() + (i - i0) * capacity,
300
+ reservoir_ids.data() + (i - i0) * capacity
301
+ );
302
+ }
303
+ }
304
+
305
+ /// add results for query i0..i1 and j0..j1
306
+ void add_results(size_t j0, size_t j1, const T *dis_tab) {
307
+ // maybe parallel for
308
+ for (size_t i = i0; i < i1; i++) {
309
+ ReservoirTopN<C> & reservoir = reservoirs[i - i0];
310
+ for (size_t j = j0; j < j1; j++) {
311
+ T dis = *dis_tab++;
312
+ reservoir.add(dis, j);
313
+ }
314
+ }
315
+ }
316
+
317
+ /// series of results for queries i0..i1 is done
318
+ void end_multiple() {
319
+ // maybe parallel for
320
+ for(size_t i = i0; i < i1; i++) {
321
+ reservoirs[i - i0].to_result(
322
+ heap_dis_tab + i * k, heap_ids_tab + i * k);
323
+ }
324
+ }
325
+
326
+ };
327
+
328
+
329
+ /*****************************************************************
330
+ * Result handler for range searches
331
+ *****************************************************************/
332
+
333
+
334
+
335
+ template<class C>
336
+ struct RangeSearchResultHandler {
337
+ using T = typename C::T;
338
+ using TI = typename C::TI;
339
+
340
+ RangeSearchResult *res;
341
+ float radius;
342
+
343
+ RangeSearchResultHandler(RangeSearchResult *res, float radius):
344
+ res(res), radius(radius)
345
+ {}
346
+
347
+ /******************************************************
348
+ * API for 1 result at a time (each SingleResultHandler is
349
+ * called from 1 thread)
350
+ ******************************************************/
351
+
352
+ struct SingleResultHandler {
353
+ // almost the same interface as RangeSearchResultHandler
354
+ RangeSearchPartialResult pres;
355
+ float radius;
356
+ RangeQueryResult *qr = nullptr;
357
+
358
+ SingleResultHandler(RangeSearchResultHandler &rh):
359
+ pres(rh.res), radius(rh.radius)
360
+ {}
361
+
362
+ /// begin results for query # i
363
+ void begin(size_t i) {
364
+ qr = &pres.new_result(i);
365
+ }
366
+
367
+ /// add one result for query i
368
+ void add_result(T dis, TI idx) {
369
+
370
+ if (C::cmp(radius, dis)) {
371
+ qr->add(dis, idx);
372
+ }
373
+ }
374
+
375
+ /// series of results for query i is done
376
+ void end() {
377
+ }
378
+
379
+ ~SingleResultHandler() {
380
+ pres.finalize();
381
+ }
382
+ };
383
+
384
+ /******************************************************
385
+ * API for multiple results (called from 1 thread)
386
+ ******************************************************/
387
+
388
+ size_t i0, i1;
389
+
390
+ std::vector <RangeSearchPartialResult *> partial_results;
391
+ std::vector <size_t> j0s;
392
+ int pr = 0;
393
+
394
+ /// begin
395
+ void begin_multiple(size_t i0, size_t i1) {
396
+ this->i0 = i0;
397
+ this->i1 = i1;
398
+ }
399
+
400
+ /// add results for query i0..i1 and j0..j1
401
+
402
+ void add_results(size_t j0, size_t j1, const T *dis_tab) {
403
+ RangeSearchPartialResult *pres;
404
+ // there is one RangeSearchPartialResult structure per j0
405
+ // (= block of columns of the large distance matrix)
406
+ // it is a bit tricky to find the poper PartialResult structure
407
+ // because the inner loop is on db not on queries.
408
+
409
+ if (pr < j0s.size() && j0 == j0s[pr]) {
410
+ pres = partial_results[pr];
411
+ pr++;
412
+ } else if (j0 == 0 && j0s.size() > 0) {
413
+ pr = 0;
414
+ pres = partial_results[pr];
415
+ pr++;
416
+ } else { // did not find this j0
417
+ pres = new RangeSearchPartialResult (res);
418
+ partial_results.push_back(pres);
419
+ j0s.push_back(j0);
420
+ pr = partial_results.size();
421
+ }
422
+
423
+ for (size_t i = i0; i < i1; i++) {
424
+ const float *ip_line = dis_tab + (i - i0) * (j1 - j0);
425
+ RangeQueryResult & qres = pres->new_result (i);
426
+
427
+ for (size_t j = j0; j < j1; j++) {
428
+ float dis = *ip_line++;
429
+ if (C::cmp(radius, dis)) {
430
+ qres.add (dis, j);
431
+ }
432
+ }
433
+ }
434
+ }
435
+
436
+ void end_multiple() {
437
+
438
+ }
439
+
440
+ ~RangeSearchResultHandler() {
441
+ if (partial_results.size() > 0) {
442
+ RangeSearchPartialResult::merge (partial_results);
443
+ }
444
+ }
445
+
446
+ };
447
+
448
+
449
+
450
+
451
+ } // namespace faiss
452
+