faiss 0.2.6 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -0,0 +1,84 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // This file contains an implementation of approximate top-k search
9
+ // using heap. It was initially created for a beam search.
10
+ //
11
+ // The core idea is the following.
12
+ // Say we need to find beam_size indices with the minimal distance
13
+ // values. It is done via heap (priority_queue) using the following
14
+ // pseudocode:
15
+ //
16
+ // def baseline():
17
+ // distances = np.empty([beam_size * n], dtype=float)
18
+ // indices = np.empty([beam_size * n], dtype=int)
19
+ //
20
+ // heap = Heap(max_heap_size=beam_size)
21
+ //
22
+ // for i in range(0, beam_size * n):
23
+ // heap.push(distances[i], indices[i])
24
+ //
25
+ // Basically, this is what heap_addn() function from utils/Heap.h does.
26
+ //
27
+ // The following scheme can be used for approximate beam search.
28
+ // Say, we need to find elements with min distance.
29
+ // Basically, we split n elements of every beam into NBUCKETS buckets
30
+ // and track the index with the minimal distance for every bucket.
31
+ // This can be effectively SIMD-ed and significantly lowers the number
32
+ // of operations, but yields approximate results for beam_size >= 2.
33
+ //
34
+ // def approximate_v1():
35
+ // distances = np.empty([beam_size * n], dtype=float)
36
+ // indices = np.empty([beam_size * n], dtype=int)
37
+ //
38
+ // heap = Heap(max_heap_size=beam_size)
39
+ //
40
+ // for beam in range(0, beam_size):
41
+ // # The value of 32 is just an example.
42
+ // # The value may be varied: the larger the value is,
43
+ // # the slower and the more precise vs baseline beam search is
44
+ // NBUCKETS = 32
45
+ //
46
+ // local_min_distances = [HUGE_VALF] * NBUCKETS
47
+ // local_min_indices = [0] * NBUCKETS
48
+ //
49
+ // for i in range(0, n / NBUCKETS):
50
+ // for j in range(0, NBUCKETS):
51
+ // idx = beam * n + i * NBUCKETS + j
52
+ // if distances[idx] < local_min_distances[j]:
53
+ // local_min_distances[i] = distances[idx]
54
+ // local_min_indices[i] = indices[idx]
55
+ //
56
+ // for j in range(0, NBUCKETS):
57
+ // heap.push(local_min_distances[j], local_min_indices[j])
58
+ //
59
+ // The accuracy can be improved by tracking min-2 elements for every
60
+ // bucket. Such a min-2 implementation with NBUCKETS buckets provides
61
+ // better accuracy than top-1 implementation with 2 * NBUCKETS buckets.
62
+ // Min-3 is also doable. One can use min-N approach, but I'm not sure
63
+ // whether min-4 and above are practical, because of the lack of SIMD
64
+ // registers (unless AVX-512 version is used).
65
+ //
66
+ // C++ template for top-N implementation is provided. The code
67
+ // assumes that indices[idx] == idx. One can write a code that lifts
68
+ // such an assumption easily.
69
+ //
70
+ // Currently, the code that tracks elements with min distances is implemented
71
+ // (Max Heap). Min Heap option can be added easily.
72
+
73
+ #pragma once
74
+
75
+ #include <faiss/impl/platform_macros.h>
76
+
77
+ // the list of available modes is in the following file
78
+ #include <faiss/utils/approx_topk/mode.h>
79
+
80
+ #ifdef __AVX2__
81
+ #include <faiss/utils/approx_topk/avx2-inl.h>
82
+ #else
83
+ #include <faiss/utils/approx_topk/generic.h>
84
+ #endif
@@ -0,0 +1,196 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <immintrin.h>
11
+
12
+ #include <limits>
13
+
14
+ #include <faiss/impl/FaissAssert.h>
15
+ #include <faiss/utils/Heap.h>
16
+
17
+ namespace faiss {
18
+
19
+ template <typename C, uint32_t NBUCKETS, uint32_t N>
20
+ struct HeapWithBuckets {
21
+ // this case was not implemented yet.
22
+ };
23
+
24
+ template <uint32_t NBUCKETS, uint32_t N>
25
+ struct HeapWithBuckets<CMax<float, int>, NBUCKETS, N> {
26
+ static constexpr uint32_t NBUCKETS_8 = NBUCKETS / 8;
27
+ static_assert(
28
+ (NBUCKETS) > 0 && ((NBUCKETS % 8) == 0),
29
+ "Number of buckets needs to be 8, 16, 24, ...");
30
+
31
+ static void addn(
32
+ // number of elements
33
+ const uint32_t n,
34
+ // distances. It is assumed to have n elements.
35
+ const float* const __restrict distances,
36
+ // number of best elements to keep
37
+ const uint32_t k,
38
+ // output distances
39
+ float* const __restrict bh_val,
40
+ // output indices, each being within [0, n) range
41
+ int32_t* const __restrict bh_ids) {
42
+ // forward a call to bs_addn with 1 beam
43
+ bs_addn(1, n, distances, k, bh_val, bh_ids);
44
+ }
45
+
46
+ static void bs_addn(
47
+ // beam_size parameter of Beam Search algorithm
48
+ const uint32_t beam_size,
49
+ // number of elements per beam
50
+ const uint32_t n_per_beam,
51
+ // distances. It is assumed to have (n_per_beam * beam_size)
52
+ // elements.
53
+ const float* const __restrict distances,
54
+ // number of best elements to keep
55
+ const uint32_t k,
56
+ // output distances
57
+ float* const __restrict bh_val,
58
+ // output indices, each being within [0, n_per_beam * beam_size)
59
+ // range
60
+ int32_t* const __restrict bh_ids) {
61
+ // // Basically, the function runs beam_size iterations.
62
+ // // Every iteration NBUCKETS * N elements are added to a regular heap.
63
+ // // So, maximum number of added elements is beam_size * NBUCKETS * N.
64
+ // // This number is expected to be less or equal than k.
65
+ // FAISS_THROW_IF_NOT_FMT(
66
+ // beam_size * NBUCKETS * N >= k,
67
+ // "Cannot pick %d elements, only %d. "
68
+ // "Check the function and template arguments values.",
69
+ // k,
70
+ // beam_size * NBUCKETS * N);
71
+
72
+ using C = CMax<float, int>;
73
+
74
+ // main loop
75
+ for (uint32_t beam_index = 0; beam_index < beam_size; beam_index++) {
76
+ __m256 min_distances_i[NBUCKETS_8][N];
77
+ __m256i min_indices_i[NBUCKETS_8][N];
78
+
79
+ for (uint32_t j = 0; j < NBUCKETS_8; j++) {
80
+ for (uint32_t p = 0; p < N; p++) {
81
+ min_distances_i[j][p] =
82
+ _mm256_set1_ps(std::numeric_limits<float>::max());
83
+ min_indices_i[j][p] =
84
+ _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
85
+ }
86
+ }
87
+
88
+ __m256i current_indices = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
89
+ __m256i indices_delta = _mm256_set1_epi32(NBUCKETS);
90
+
91
+ const uint32_t nb = (n_per_beam / NBUCKETS) * NBUCKETS;
92
+
93
+ // put the data into buckets
94
+ for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
95
+ for (uint32_t j = 0; j < NBUCKETS_8; j++) {
96
+ const __m256 distances_reg = _mm256_loadu_ps(
97
+ distances + j * 8 + ip + n_per_beam * beam_index);
98
+
99
+ // loop. Compiler should get rid of unneeded ops
100
+ __m256 distance_candidate = distances_reg;
101
+ __m256i indices_candidate = current_indices;
102
+
103
+ for (uint32_t p = 0; p < N; p++) {
104
+ const __m256 comparison = _mm256_cmp_ps(
105
+ min_distances_i[j][p],
106
+ distance_candidate,
107
+ _CMP_LE_OS);
108
+
109
+ // // blend seems to be slower that min
110
+ // const __m256 min_distances_new = _mm256_blendv_ps(
111
+ // distance_candidate,
112
+ // min_distances_i[j][p],
113
+ // comparison);
114
+ const __m256 min_distances_new = _mm256_min_ps(
115
+ distance_candidate, min_distances_i[j][p]);
116
+ const __m256i min_indices_new =
117
+ _mm256_castps_si256(_mm256_blendv_ps(
118
+ _mm256_castsi256_ps(indices_candidate),
119
+ _mm256_castsi256_ps(
120
+ min_indices_i[j][p]),
121
+ comparison));
122
+
123
+ // // blend seems to be slower that min
124
+ // const __m256 max_distances_new = _mm256_blendv_ps(
125
+ // min_distances_i[j][p],
126
+ // distance_candidate,
127
+ // comparison);
128
+ const __m256 max_distances_new = _mm256_max_ps(
129
+ min_distances_i[j][p], distances_reg);
130
+ const __m256i max_indices_new =
131
+ _mm256_castps_si256(_mm256_blendv_ps(
132
+ _mm256_castsi256_ps(
133
+ min_indices_i[j][p]),
134
+ _mm256_castsi256_ps(indices_candidate),
135
+ comparison));
136
+
137
+ distance_candidate = max_distances_new;
138
+ indices_candidate = max_indices_new;
139
+
140
+ min_distances_i[j][p] = min_distances_new;
141
+ min_indices_i[j][p] = min_indices_new;
142
+ }
143
+ }
144
+
145
+ current_indices =
146
+ _mm256_add_epi32(current_indices, indices_delta);
147
+ }
148
+
149
+ // fix the indices
150
+ for (uint32_t j = 0; j < NBUCKETS_8; j++) {
151
+ const __m256i offset =
152
+ _mm256_set1_epi32(n_per_beam * beam_index + j * 8);
153
+ for (uint32_t p = 0; p < N; p++) {
154
+ min_indices_i[j][p] =
155
+ _mm256_add_epi32(min_indices_i[j][p], offset);
156
+ }
157
+ }
158
+
159
+ // merge every bucket into the regular heap
160
+ for (uint32_t p = 0; p < N; p++) {
161
+ for (uint32_t j = 0; j < NBUCKETS_8; j++) {
162
+ int32_t min_indices_scalar[8];
163
+ float min_distances_scalar[8];
164
+
165
+ _mm256_storeu_si256(
166
+ (__m256i*)min_indices_scalar, min_indices_i[j][p]);
167
+ _mm256_storeu_ps(
168
+ min_distances_scalar, min_distances_i[j][p]);
169
+
170
+ // this exact way is needed to maintain the order as if the
171
+ // input elements were pushed to the heap sequentially
172
+ for (size_t j8 = 0; j8 < 8; j8++) {
173
+ const auto value = min_distances_scalar[j8];
174
+ const auto index = min_indices_scalar[j8];
175
+ if (C::cmp2(bh_val[0], value, bh_ids[0], index)) {
176
+ heap_replace_top<C>(
177
+ k, bh_val, bh_ids, value, index);
178
+ }
179
+ }
180
+ }
181
+ }
182
+
183
+ // process leftovers
184
+ for (uint32_t ip = nb; ip < n_per_beam; ip++) {
185
+ const int32_t index = ip + n_per_beam * beam_index;
186
+ const float value = distances[index];
187
+
188
+ if (C::cmp(bh_val[0], value)) {
189
+ heap_replace_top<C>(k, bh_val, bh_ids, value, index);
190
+ }
191
+ }
192
+ }
193
+ }
194
+ };
195
+
196
+ } // namespace faiss
@@ -0,0 +1,138 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <algorithm>
11
+ #include <limits>
12
+ #include <utility>
13
+
14
+ #include <faiss/impl/FaissAssert.h>
15
+ #include <faiss/utils/Heap.h>
16
+
17
+ namespace faiss {
18
+
19
+ // This is the implementation of the idea and it is very slow,
20
+ // because a compiler is unable to vectorize it properly.
21
+
22
+ template <typename C, uint32_t NBUCKETS, uint32_t N>
23
+ struct HeapWithBuckets {
24
+ // this case was not implemented yet.
25
+ };
26
+
27
+ template <uint32_t NBUCKETS, uint32_t N>
28
+ struct HeapWithBuckets<CMax<float, int>, NBUCKETS, N> {
29
+ static void addn(
30
+ // number of elements
31
+ const uint32_t n,
32
+ // distances. It is assumed to have n elements.
33
+ const float* const __restrict distances,
34
+ // number of best elements to keep
35
+ const uint32_t k,
36
+ // output distances
37
+ float* const __restrict bh_val,
38
+ // output indices, each being within [0, n) range
39
+ int32_t* const __restrict bh_ids) {
40
+ // forward a call to bs_addn with 1 beam
41
+ bs_addn(1, n, distances, k, bh_val, bh_ids);
42
+ }
43
+
44
+ static void bs_addn(
45
+ // beam_size parameter of Beam Search algorithm
46
+ const uint32_t beam_size,
47
+ // number of elements per beam
48
+ const uint32_t n_per_beam,
49
+ // distances. It is assumed to have (n_per_beam * beam_size)
50
+ // elements.
51
+ const float* const __restrict distances,
52
+ // number of best elements to keep
53
+ const uint32_t k,
54
+ // output distances
55
+ float* const __restrict bh_val,
56
+ // output indices, each being within [0, n_per_beam * beam_size)
57
+ // range
58
+ int32_t* const __restrict bh_ids) {
59
+ // // Basically, the function runs beam_size iterations.
60
+ // // Every iteration NBUCKETS * N elements are added to a regular heap.
61
+ // // So, maximum number of added elements is beam_size * NBUCKETS * N.
62
+ // // This number is expected to be less or equal than k.
63
+ // FAISS_THROW_IF_NOT_FMT(
64
+ // beam_size * NBUCKETS * N >= k,
65
+ // "Cannot pick %d elements, only %d. "
66
+ // "Check the function and template arguments values.",
67
+ // k,
68
+ // beam_size * NBUCKETS * N);
69
+
70
+ using C = CMax<float, int>;
71
+
72
+ // main loop
73
+ for (uint32_t beam_index = 0; beam_index < beam_size; beam_index++) {
74
+ float min_distances_i[N][NBUCKETS];
75
+ int min_indices_i[N][NBUCKETS];
76
+
77
+ for (uint32_t p = 0; p < N; p++) {
78
+ for (uint32_t j = 0; j < NBUCKETS; j++) {
79
+ min_distances_i[p][j] = std::numeric_limits<float>::max();
80
+ min_indices_i[p][j] = 0;
81
+ }
82
+ }
83
+
84
+ const uint32_t nb = (n_per_beam / NBUCKETS) * NBUCKETS;
85
+
86
+ // put the data into buckets
87
+ for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
88
+ for (uint32_t j = 0; j < NBUCKETS; j++) {
89
+ const int index = j + ip + n_per_beam * beam_index;
90
+ const float distance = distances[index];
91
+
92
+ int index_candidate = index;
93
+ float distance_candidate = distance;
94
+
95
+ for (uint32_t p = 0; p < N; p++) {
96
+ if (distance_candidate < min_distances_i[p][j]) {
97
+ std::swap(
98
+ distance_candidate, min_distances_i[p][j]);
99
+ std::swap(index_candidate, min_indices_i[p][j]);
100
+ }
101
+ }
102
+ }
103
+ }
104
+
105
+ // merge every bucket into the regular heap
106
+ for (uint32_t p = 0; p < N; p++) {
107
+ for (uint32_t j = 0; j < NBUCKETS; j++) {
108
+ // this exact way is needed to maintain the order as if the
109
+ // input elements were pushed to the heap sequentially
110
+
111
+ if (C::cmp2(bh_val[0],
112
+ min_distances_i[p][j],
113
+ bh_ids[0],
114
+ min_indices_i[p][j])) {
115
+ heap_replace_top<C>(
116
+ k,
117
+ bh_val,
118
+ bh_ids,
119
+ min_distances_i[p][j],
120
+ min_indices_i[p][j]);
121
+ }
122
+ }
123
+ }
124
+
125
+ // process leftovers
126
+ for (uint32_t ip = nb; ip < n_per_beam; ip++) {
127
+ const int32_t index = ip + n_per_beam * beam_index;
128
+ const float value = distances[index];
129
+
130
+ if (C::cmp(bh_val[0], value)) {
131
+ heap_replace_top<C>(k, bh_val, bh_ids, value, index);
132
+ }
133
+ }
134
+ }
135
+ }
136
+ };
137
+
138
+ } // namespace faiss
@@ -0,0 +1,34 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ /// Represents the mode of use of approximate top-k computations
11
+ /// that allows to trade accuracy vs speed. So, every options
12
+ /// besides EXACT_TOPK increases the speed.
13
+ ///
14
+ /// B represents the number of buckets.
15
+ /// D is the number of min-k elements to track within every bucket.
16
+ ///
17
+ /// Default option is EXACT_TOPK.
18
+ /// APPROX_TOPK_BUCKETS_B16_D2 is worth starting from, if you'd like
19
+ /// to experiment a bit.
20
+ ///
21
+ /// It seems that only the limited number of combinations are
22
+ /// meaningful, because of the limited supply of SIMD registers.
23
+ /// Also, certain combinations, such as B32_D1 and B16_D1, were concluded
24
+ /// to be not very precise in benchmarks, so ones were not introduced.
25
+ ///
26
+ /// TODO: Consider d-ary SIMD heap.
27
+
28
+ enum ApproxTopK_mode_t : int {
29
+ EXACT_TOPK = 0,
30
+ APPROX_TOPK_BUCKETS_B32_D2 = 1,
31
+ APPROX_TOPK_BUCKETS_B8_D3 = 2,
32
+ APPROX_TOPK_BUCKETS_B16_D2 = 3,
33
+ APPROX_TOPK_BUCKETS_B8_D2 = 4,
34
+ };