faiss 0.2.6 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -0,0 +1,367 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <algorithm>
11
+ #include <limits>
12
+ #include <utility>
13
+
14
+ #include <faiss/utils/Heap.h>
15
+ #include <faiss/utils/simdlib.h>
16
+
17
+ namespace faiss {
18
+
19
+ // HeapWithBucketsForHamming32 uses simd8uint32 under the hood.
20
+
21
+ template <typename C, uint32_t NBUCKETS, uint32_t N, typename HammingComputerT>
22
+ struct HeapWithBucketsForHamming32 {
23
+ // this case was not implemented yet.
24
+ };
25
+
26
+ template <uint32_t NBUCKETS, uint32_t N, typename HammingComputerT>
27
+ struct HeapWithBucketsForHamming32<
28
+ CMax<int, int64_t>,
29
+ NBUCKETS,
30
+ N,
31
+ HammingComputerT> {
32
+ static constexpr uint32_t NBUCKETS_8 = NBUCKETS / 8;
33
+ static_assert(
34
+ (NBUCKETS) > 0 && ((NBUCKETS % 8) == 0),
35
+ "Number of buckets needs to be 8, 16, 24, ...");
36
+
37
+ static void addn(
38
+ // number of elements
39
+ const uint32_t n,
40
+ // Hamming computer
41
+ const HammingComputerT& hc,
42
+ // n elements that can be used with hc
43
+ const uint8_t* const __restrict binaryVectors,
44
+ // number of best elements to keep
45
+ const uint32_t k,
46
+ // output distances
47
+ int* const __restrict bh_val,
48
+ // output indices, each being within [0, n) range
49
+ int64_t* const __restrict bh_ids) {
50
+ // forward a call to bs_addn with 1 beam
51
+ bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids);
52
+ }
53
+
54
+ static void bs_addn(
55
+ // beam_size parameter of Beam Search algorithm
56
+ const uint32_t beam_size,
57
+ // number of elements per beam
58
+ const uint32_t n_per_beam,
59
+ // Hamming computer
60
+ const HammingComputerT& hc,
61
+ // n elements that can be used against hc
62
+ const uint8_t* const __restrict binary_vectors,
63
+ // number of best elements to keep
64
+ const uint32_t k,
65
+ // output distances
66
+ int* const __restrict bh_val,
67
+ // output indices, each being within [0, n_per_beam * beam_size)
68
+ // range
69
+ int64_t* const __restrict bh_ids) {
70
+ //
71
+ using C = CMax<int, int64_t>;
72
+
73
+ // Hamming code size
74
+ const size_t code_size = hc.get_code_size();
75
+
76
+ // main loop
77
+ for (uint32_t beam_index = 0; beam_index < beam_size; beam_index++) {
78
+ simd8uint32 min_distances_i[NBUCKETS_8][N];
79
+ simd8uint32 min_indices_i[NBUCKETS_8][N];
80
+
81
+ for (uint32_t j = 0; j < NBUCKETS_8; j++) {
82
+ for (uint32_t p = 0; p < N; p++) {
83
+ min_distances_i[j][p] =
84
+ simd8uint32(std::numeric_limits<int32_t>::max());
85
+ min_indices_i[j][p] = simd8uint32(0, 1, 2, 3, 4, 5, 6, 7);
86
+ }
87
+ }
88
+
89
+ simd8uint32 current_indices(0, 1, 2, 3, 4, 5, 6, 7);
90
+ const simd8uint32 indices_delta(NBUCKETS);
91
+
92
+ const uint32_t nb = (n_per_beam / NBUCKETS) * NBUCKETS;
93
+
94
+ // put the data into buckets
95
+ for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
96
+ for (uint32_t j = 0; j < NBUCKETS_8; j++) {
97
+ uint32_t hamming_distances[8];
98
+ for (size_t j8 = 0; j8 < 8; j8++) {
99
+ hamming_distances[j8] = hc.hamming(
100
+ binary_vectors +
101
+ (j8 + j * 8 + ip + n_per_beam * beam_index) *
102
+ code_size);
103
+ }
104
+
105
+ // loop. Compiler should get rid of unneeded ops
106
+ simd8uint32 distance_candidate;
107
+ distance_candidate.loadu(hamming_distances);
108
+ simd8uint32 indices_candidate = current_indices;
109
+
110
+ for (uint32_t p = 0; p < N; p++) {
111
+ simd8uint32 min_distances_new;
112
+ simd8uint32 min_indices_new;
113
+ simd8uint32 max_distances_new;
114
+ simd8uint32 max_indices_new;
115
+
116
+ faiss::cmplt_min_max_fast(
117
+ distance_candidate,
118
+ indices_candidate,
119
+ min_distances_i[j][p],
120
+ min_indices_i[j][p],
121
+ min_distances_new,
122
+ min_indices_new,
123
+ max_distances_new,
124
+ max_indices_new);
125
+
126
+ distance_candidate = max_distances_new;
127
+ indices_candidate = max_indices_new;
128
+
129
+ min_distances_i[j][p] = min_distances_new;
130
+ min_indices_i[j][p] = min_indices_new;
131
+ }
132
+ }
133
+
134
+ current_indices += indices_delta;
135
+ }
136
+
137
+ // fix the indices
138
+ for (uint32_t j = 0; j < NBUCKETS_8; j++) {
139
+ const simd8uint32 offset(n_per_beam * beam_index + j * 8);
140
+ for (uint32_t p = 0; p < N; p++) {
141
+ min_indices_i[j][p] += offset;
142
+ }
143
+ }
144
+
145
+ // merge every bucket into the regular heap
146
+ for (uint32_t p = 0; p < N; p++) {
147
+ for (uint32_t j = 0; j < NBUCKETS_8; j++) {
148
+ uint32_t min_indices_scalar[8];
149
+ uint32_t min_distances_scalar[8];
150
+
151
+ min_indices_i[j][p].storeu(min_indices_scalar);
152
+ min_distances_i[j][p].storeu(min_distances_scalar);
153
+
154
+ // this exact way is needed to maintain the order as if the
155
+ // input elements were pushed to the heap sequentially
156
+ for (size_t j8 = 0; j8 < 8; j8++) {
157
+ const auto value = min_distances_scalar[j8];
158
+ const auto index = min_indices_scalar[j8];
159
+
160
+ if (C::cmp2(bh_val[0], value, bh_ids[0], index)) {
161
+ heap_replace_top<C>(
162
+ k, bh_val, bh_ids, value, index);
163
+ }
164
+ }
165
+ }
166
+ }
167
+
168
+ // process leftovers
169
+ for (uint32_t ip = nb; ip < n_per_beam; ip++) {
170
+ const auto index = ip + n_per_beam * beam_index;
171
+ const auto value =
172
+ hc.hamming(binary_vectors + (index)*code_size);
173
+
174
+ if (C::cmp(bh_val[0], value)) {
175
+ heap_replace_top<C>(k, bh_val, bh_ids, value, index);
176
+ }
177
+ }
178
+ }
179
+ }
180
+ };
181
+
182
+ // HeapWithBucketsForHamming16 uses simd16uint16 under the hood.
183
+ // Less registers needed in total, so higher values of NBUCKETS/N can be used,
184
+ // but somewhat slower.
185
+ // No more than 32K elements currently, but it can be reorganized a bit
186
+ // to be limited to 32K elements per beam.
187
+
188
+ template <typename C, uint32_t NBUCKETS, uint32_t N, typename HammingComputerT>
189
+ struct HeapWithBucketsForHamming16 {
190
+ // this case was not implemented yet.
191
+ };
192
+
193
+ template <uint32_t NBUCKETS, uint32_t N, typename HammingComputerT>
194
+ struct HeapWithBucketsForHamming16<
195
+ CMax<int, int64_t>,
196
+ NBUCKETS,
197
+ N,
198
+ HammingComputerT> {
199
+ static constexpr uint32_t NBUCKETS_16 = NBUCKETS / 16;
200
+ static_assert(
201
+ (NBUCKETS) > 0 && ((NBUCKETS % 16) == 0),
202
+ "Number of buckets needs to be 16, 32, 48...");
203
+
204
+ static void addn(
205
+ // number of elements
206
+ const uint32_t n,
207
+ // Hamming computer
208
+ const HammingComputerT& hc,
209
+ // n elements that can be used with hc
210
+ const uint8_t* const __restrict binaryVectors,
211
+ // number of best elements to keep
212
+ const uint32_t k,
213
+ // output distances
214
+ int* const __restrict bh_val,
215
+ // output indices, each being within [0, n) range
216
+ int64_t* const __restrict bh_ids) {
217
+ // forward a call to bs_addn with 1 beam
218
+ bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids);
219
+ }
220
+
221
+ static void bs_addn(
222
+ // beam_size parameter of Beam Search algorithm
223
+ const uint32_t beam_size,
224
+ // number of elements per beam
225
+ const uint32_t n_per_beam,
226
+ // Hamming computer
227
+ const HammingComputerT& hc,
228
+ // n elements that can be used against hc
229
+ const uint8_t* const __restrict binary_vectors,
230
+ // number of best elements to keep
231
+ const uint32_t k,
232
+ // output distances
233
+ int* const __restrict bh_val,
234
+ // output indices, each being within [0, n_per_beam * beam_size)
235
+ // range
236
+ int64_t* const __restrict bh_ids) {
237
+ //
238
+ using C = CMax<int, int64_t>;
239
+
240
+ // Hamming code size
241
+ const size_t code_size = hc.get_code_size();
242
+
243
+ // main loop
244
+ for (uint32_t beam_index = 0; beam_index < beam_size; beam_index++) {
245
+ simd16uint16 min_distances_i[NBUCKETS_16][N];
246
+ simd16uint16 min_indices_i[NBUCKETS_16][N];
247
+
248
+ for (uint32_t j = 0; j < NBUCKETS_16; j++) {
249
+ for (uint32_t p = 0; p < N; p++) {
250
+ min_distances_i[j][p] =
251
+ simd16uint16(std::numeric_limits<int16_t>::max());
252
+ min_indices_i[j][p] = simd16uint16(
253
+ 0,
254
+ 1,
255
+ 2,
256
+ 3,
257
+ 4,
258
+ 5,
259
+ 6,
260
+ 7,
261
+ 8,
262
+ 9,
263
+ 10,
264
+ 11,
265
+ 12,
266
+ 13,
267
+ 14,
268
+ 15);
269
+ }
270
+ }
271
+
272
+ simd16uint16 current_indices(
273
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
274
+ const simd16uint16 indices_delta((uint16_t)NBUCKETS);
275
+
276
+ const uint32_t nb = (n_per_beam / NBUCKETS) * NBUCKETS;
277
+
278
+ // put the data into buckets
279
+ for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
280
+ for (uint32_t j = 0; j < NBUCKETS_16; j++) {
281
+ uint16_t hamming_distances[16];
282
+ for (size_t j16 = 0; j16 < 16; j16++) {
283
+ hamming_distances[j16] = hc.hamming(
284
+ binary_vectors +
285
+ (j16 + j * 16 + ip + n_per_beam * beam_index) *
286
+ code_size);
287
+ }
288
+
289
+ // loop. Compiler should get rid of unneeded ops
290
+ simd16uint16 distance_candidate;
291
+ distance_candidate.loadu(hamming_distances);
292
+ simd16uint16 indices_candidate = current_indices;
293
+
294
+ for (uint32_t p = 0; p < N; p++) {
295
+ simd16uint16 min_distances_new;
296
+ simd16uint16 min_indices_new;
297
+ simd16uint16 max_distances_new;
298
+ simd16uint16 max_indices_new;
299
+
300
+ faiss::cmplt_min_max_fast(
301
+ distance_candidate,
302
+ indices_candidate,
303
+ min_distances_i[j][p],
304
+ min_indices_i[j][p],
305
+ min_distances_new,
306
+ min_indices_new,
307
+ max_distances_new,
308
+ max_indices_new);
309
+
310
+ distance_candidate = max_distances_new;
311
+ indices_candidate = max_indices_new;
312
+
313
+ min_distances_i[j][p] = min_distances_new;
314
+ min_indices_i[j][p] = min_indices_new;
315
+ }
316
+ }
317
+
318
+ current_indices += indices_delta;
319
+ }
320
+
321
+ // fix the indices
322
+ for (uint32_t j = 0; j < NBUCKETS_16; j++) {
323
+ const simd16uint16 offset(
324
+ (uint16_t)(n_per_beam * beam_index + j * 16));
325
+ for (uint32_t p = 0; p < N; p++) {
326
+ min_indices_i[j][p] += offset;
327
+ }
328
+ }
329
+
330
+ // merge every bucket into the regular heap
331
+ for (uint32_t p = 0; p < N; p++) {
332
+ for (uint32_t j = 0; j < NBUCKETS_16; j++) {
333
+ uint16_t min_indices_scalar[16];
334
+ uint16_t min_distances_scalar[16];
335
+
336
+ min_indices_i[j][p].storeu(min_indices_scalar);
337
+ min_distances_i[j][p].storeu(min_distances_scalar);
338
+
339
+ // this exact way is needed to maintain the order as if the
340
+ // input elements were pushed to the heap sequentially
341
+ for (size_t j16 = 0; j16 < 16; j16++) {
342
+ const auto value = min_distances_scalar[j16];
343
+ const auto index = min_indices_scalar[j16];
344
+
345
+ if (C::cmp2(bh_val[0], value, bh_ids[0], index)) {
346
+ heap_replace_top<C>(
347
+ k, bh_val, bh_ids, value, index);
348
+ }
349
+ }
350
+ }
351
+ }
352
+
353
+ // process leftovers
354
+ for (uint32_t ip = nb; ip < n_per_beam; ip++) {
355
+ const auto index = ip + n_per_beam * beam_index;
356
+ const auto value =
357
+ hc.hamming(binary_vectors + (index)*code_size);
358
+
359
+ if (C::cmp(bh_val[0], value)) {
360
+ heap_replace_top<C>(k, bh_val, bh_ids, value, index);
361
+ }
362
+ }
363
+ }
364
+ }
365
+ };
366
+
367
+ } // namespace faiss
@@ -26,6 +26,8 @@
26
26
  #include <faiss/impl/IDSelector.h>
27
27
  #include <faiss/impl/ResultHandler.h>
28
28
 
29
+ #include <faiss/utils/distances_fused/distances_fused.h>
30
+
29
31
  #ifndef FINTEGER
30
32
  #define FINTEGER long
31
33
  #endif
@@ -229,7 +231,7 @@ void exhaustive_inner_product_blas(
229
231
  // distance correction is an operator that can be applied to transform
230
232
  // the distances
231
233
  template <class ResultHandler>
232
- void exhaustive_L2sqr_blas(
234
+ void exhaustive_L2sqr_blas_default_impl(
233
235
  const float* x,
234
236
  const float* y,
235
237
  size_t d,
@@ -311,10 +313,20 @@ void exhaustive_L2sqr_blas(
311
313
  }
312
314
  }
313
315
 
316
+ template <class ResultHandler>
317
+ void exhaustive_L2sqr_blas(
318
+ const float* x,
319
+ const float* y,
320
+ size_t d,
321
+ size_t nx,
322
+ size_t ny,
323
+ ResultHandler& res,
324
+ const float* y_norms = nullptr) {
325
+ exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
326
+ }
327
+
314
328
  #ifdef __AVX2__
315
- // an override for AVX2 if only a single closest point is needed.
316
- template <>
317
- void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
329
+ void exhaustive_L2sqr_blas_cmax_avx2(
318
330
  const float* x,
319
331
  const float* y,
320
332
  size_t d,
@@ -513,11 +525,53 @@ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
513
525
  res.add_result(i, current_min_distance, current_min_index);
514
526
  }
515
527
  }
528
+ // Does nothing for SingleBestResultHandler, but
529
+ // keeping the call for the consistency.
530
+ res.end_multiple();
516
531
  InterruptCallback::check();
517
532
  }
518
533
  }
519
534
  #endif
520
535
 
536
+ // an override if only a single closest point is needed
537
+ template <>
538
+ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
539
+ const float* x,
540
+ const float* y,
541
+ size_t d,
542
+ size_t nx,
543
+ size_t ny,
544
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
545
+ const float* y_norms) {
546
+ #if defined(__AVX2__)
547
+ // use a faster fused kernel if available
548
+ if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
549
+ // the kernel is available and it is complete, we're done.
550
+ return;
551
+ }
552
+
553
+ // run the specialized AVX2 implementation
554
+ exhaustive_L2sqr_blas_cmax_avx2(x, y, d, nx, ny, res, y_norms);
555
+
556
+ #elif defined(__aarch64__)
557
+ // use a faster fused kernel if available
558
+ if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
559
+ // the kernel is available and it is complete, we're done.
560
+ return;
561
+ }
562
+
563
+ // run the default implementation
564
+ exhaustive_L2sqr_blas_default_impl<
565
+ SingleBestResultHandler<CMax<float, int64_t>>>(
566
+ x, y, d, nx, ny, res, y_norms);
567
+ #else
568
+ // run the default implementation
569
+ exhaustive_L2sqr_blas_default_impl<
570
+ SingleBestResultHandler<CMax<float, int64_t>>>(
571
+ x, y, d, nx, ny, res, y_norms);
572
+ #endif
573
+ }
574
+
521
575
  template <class ResultHandler>
522
576
  void knn_L2sqr_select(
523
577
  const float* x,
@@ -770,7 +824,7 @@ void pairwise_indexed_L2sqr(
770
824
  const float* y,
771
825
  const int64_t* iy,
772
826
  float* dis) {
773
- #pragma omp parallel for
827
+ #pragma omp parallel for if (n > 1)
774
828
  for (int64_t j = 0; j < n; j++) {
775
829
  if (ix[j] >= 0 && iy[j] >= 0) {
776
830
  dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
@@ -786,7 +840,7 @@ void pairwise_indexed_inner_product(
786
840
  const float* y,
787
841
  const int64_t* iy,
788
842
  float* dis) {
789
- #pragma omp parallel for
843
+ #pragma omp parallel for if (n > 1)
790
844
  for (int64_t j = 0; j < n; j++) {
791
845
  if (ix[j] >= 0 && iy[j] >= 0) {
792
846
  dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
@@ -887,7 +941,7 @@ void pairwise_L2sqr(
887
941
  // store in beginning of distance matrix to avoid malloc
888
942
  float* b_norms = dis;
889
943
 
890
- #pragma omp parallel for
944
+ #pragma omp parallel for if (nb > 1)
891
945
  for (int64_t i = 0; i < nb; i++)
892
946
  b_norms[i] = fvec_norm_L2sqr(xb + i * ldb, d);
893
947
 
@@ -73,6 +73,17 @@ void fvec_L2sqr_ny(
73
73
  size_t d,
74
74
  size_t ny);
75
75
 
76
+ /* compute ny square L2 distance between x and a set of transposed contiguous
77
+ y vectors. squared lengths of y should be provided as well */
78
+ void fvec_L2sqr_ny_transposed(
79
+ float* dis,
80
+ const float* x,
81
+ const float* y,
82
+ const float* y_sqlen,
83
+ size_t d,
84
+ size_t d_offset,
85
+ size_t ny);
86
+
76
87
  /* compute ny square L2 distance between x and a set of contiguous y vectors
77
88
  and return the index of the nearest vector.
78
89
  return 0 if ny == 0. */