faiss 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -0,0 +1,367 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <algorithm>
11
+ #include <limits>
12
+ #include <utility>
13
+
14
+ #include <faiss/utils/Heap.h>
15
+ #include <faiss/utils/simdlib.h>
16
+
17
+ namespace faiss {
18
+
19
+ // HeapWithBucketsForHamming32 uses simd8uint32 under the hood.
20
+
21
+ template <typename C, uint32_t NBUCKETS, uint32_t N, typename HammingComputerT>
22
+ struct HeapWithBucketsForHamming32 {
23
+ // this case was not implemented yet.
24
+ };
25
+
26
+ template <uint32_t NBUCKETS, uint32_t N, typename HammingComputerT>
27
+ struct HeapWithBucketsForHamming32<
28
+ CMax<int, int64_t>,
29
+ NBUCKETS,
30
+ N,
31
+ HammingComputerT> {
32
+ static constexpr uint32_t NBUCKETS_8 = NBUCKETS / 8;
33
+ static_assert(
34
+ (NBUCKETS) > 0 && ((NBUCKETS % 8) == 0),
35
+ "Number of buckets needs to be 8, 16, 24, ...");
36
+
37
+ static void addn(
38
+ // number of elements
39
+ const uint32_t n,
40
+ // Hamming computer
41
+ const HammingComputerT& hc,
42
+ // n elements that can be used with hc
43
+ const uint8_t* const __restrict binaryVectors,
44
+ // number of best elements to keep
45
+ const uint32_t k,
46
+ // output distances
47
+ int* const __restrict bh_val,
48
+ // output indices, each being within [0, n) range
49
+ int64_t* const __restrict bh_ids) {
50
+ // forward a call to bs_addn with 1 beam
51
+ bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids);
52
+ }
53
+
54
+ static void bs_addn(
55
+ // beam_size parameter of Beam Search algorithm
56
+ const uint32_t beam_size,
57
+ // number of elements per beam
58
+ const uint32_t n_per_beam,
59
+ // Hamming computer
60
+ const HammingComputerT& hc,
61
+ // n elements that can be used against hc
62
+ const uint8_t* const __restrict binary_vectors,
63
+ // number of best elements to keep
64
+ const uint32_t k,
65
+ // output distances
66
+ int* const __restrict bh_val,
67
+ // output indices, each being within [0, n_per_beam * beam_size)
68
+ // range
69
+ int64_t* const __restrict bh_ids) {
70
+ //
71
+ using C = CMax<int, int64_t>;
72
+
73
+ // Hamming code size
74
+ const size_t code_size = hc.get_code_size();
75
+
76
+ // main loop
77
+ for (uint32_t beam_index = 0; beam_index < beam_size; beam_index++) {
78
+ simd8uint32 min_distances_i[NBUCKETS_8][N];
79
+ simd8uint32 min_indices_i[NBUCKETS_8][N];
80
+
81
+ for (uint32_t j = 0; j < NBUCKETS_8; j++) {
82
+ for (uint32_t p = 0; p < N; p++) {
83
+ min_distances_i[j][p] =
84
+ simd8uint32(std::numeric_limits<int32_t>::max());
85
+ min_indices_i[j][p] = simd8uint32(0, 1, 2, 3, 4, 5, 6, 7);
86
+ }
87
+ }
88
+
89
+ simd8uint32 current_indices(0, 1, 2, 3, 4, 5, 6, 7);
90
+ const simd8uint32 indices_delta(NBUCKETS);
91
+
92
+ const uint32_t nb = (n_per_beam / NBUCKETS) * NBUCKETS;
93
+
94
+ // put the data into buckets
95
+ for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
96
+ for (uint32_t j = 0; j < NBUCKETS_8; j++) {
97
+ uint32_t hamming_distances[8];
98
+ for (size_t j8 = 0; j8 < 8; j8++) {
99
+ hamming_distances[j8] = hc.hamming(
100
+ binary_vectors +
101
+ (j8 + j * 8 + ip + n_per_beam * beam_index) *
102
+ code_size);
103
+ }
104
+
105
+ // loop. Compiler should get rid of unneeded ops
106
+ simd8uint32 distance_candidate;
107
+ distance_candidate.loadu(hamming_distances);
108
+ simd8uint32 indices_candidate = current_indices;
109
+
110
+ for (uint32_t p = 0; p < N; p++) {
111
+ simd8uint32 min_distances_new;
112
+ simd8uint32 min_indices_new;
113
+ simd8uint32 max_distances_new;
114
+ simd8uint32 max_indices_new;
115
+
116
+ faiss::cmplt_min_max_fast(
117
+ distance_candidate,
118
+ indices_candidate,
119
+ min_distances_i[j][p],
120
+ min_indices_i[j][p],
121
+ min_distances_new,
122
+ min_indices_new,
123
+ max_distances_new,
124
+ max_indices_new);
125
+
126
+ distance_candidate = max_distances_new;
127
+ indices_candidate = max_indices_new;
128
+
129
+ min_distances_i[j][p] = min_distances_new;
130
+ min_indices_i[j][p] = min_indices_new;
131
+ }
132
+ }
133
+
134
+ current_indices += indices_delta;
135
+ }
136
+
137
+ // fix the indices
138
+ for (uint32_t j = 0; j < NBUCKETS_8; j++) {
139
+ const simd8uint32 offset(n_per_beam * beam_index + j * 8);
140
+ for (uint32_t p = 0; p < N; p++) {
141
+ min_indices_i[j][p] += offset;
142
+ }
143
+ }
144
+
145
+ // merge every bucket into the regular heap
146
+ for (uint32_t p = 0; p < N; p++) {
147
+ for (uint32_t j = 0; j < NBUCKETS_8; j++) {
148
+ uint32_t min_indices_scalar[8];
149
+ uint32_t min_distances_scalar[8];
150
+
151
+ min_indices_i[j][p].storeu(min_indices_scalar);
152
+ min_distances_i[j][p].storeu(min_distances_scalar);
153
+
154
+ // this exact way is needed to maintain the order as if the
155
+ // input elements were pushed to the heap sequentially
156
+ for (size_t j8 = 0; j8 < 8; j8++) {
157
+ const auto value = min_distances_scalar[j8];
158
+ const auto index = min_indices_scalar[j8];
159
+
160
+ if (C::cmp2(bh_val[0], value, bh_ids[0], index)) {
161
+ heap_replace_top<C>(
162
+ k, bh_val, bh_ids, value, index);
163
+ }
164
+ }
165
+ }
166
+ }
167
+
168
+ // process leftovers
169
+ for (uint32_t ip = nb; ip < n_per_beam; ip++) {
170
+ const auto index = ip + n_per_beam * beam_index;
171
+ const auto value =
172
+ hc.hamming(binary_vectors + (index)*code_size);
173
+
174
+ if (C::cmp(bh_val[0], value)) {
175
+ heap_replace_top<C>(k, bh_val, bh_ids, value, index);
176
+ }
177
+ }
178
+ }
179
+ }
180
+ };
181
+
182
+ // HeapWithBucketsForHamming16 uses simd16uint16 under the hood.
183
+ // Less registers needed in total, so higher values of NBUCKETS/N can be used,
184
+ // but somewhat slower.
185
+ // No more than 32K elements currently, but it can be reorganized a bit
186
+ // to be limited to 32K elements per beam.
187
+
188
+ template <typename C, uint32_t NBUCKETS, uint32_t N, typename HammingComputerT>
189
+ struct HeapWithBucketsForHamming16 {
190
+ // this case was not implemented yet.
191
+ };
192
+
193
+ template <uint32_t NBUCKETS, uint32_t N, typename HammingComputerT>
194
+ struct HeapWithBucketsForHamming16<
195
+ CMax<int, int64_t>,
196
+ NBUCKETS,
197
+ N,
198
+ HammingComputerT> {
199
+ static constexpr uint32_t NBUCKETS_16 = NBUCKETS / 16;
200
+ static_assert(
201
+ (NBUCKETS) > 0 && ((NBUCKETS % 16) == 0),
202
+ "Number of buckets needs to be 16, 32, 48...");
203
+
204
+ static void addn(
205
+ // number of elements
206
+ const uint32_t n,
207
+ // Hamming computer
208
+ const HammingComputerT& hc,
209
+ // n elements that can be used with hc
210
+ const uint8_t* const __restrict binaryVectors,
211
+ // number of best elements to keep
212
+ const uint32_t k,
213
+ // output distances
214
+ int* const __restrict bh_val,
215
+ // output indices, each being within [0, n) range
216
+ int64_t* const __restrict bh_ids) {
217
+ // forward a call to bs_addn with 1 beam
218
+ bs_addn(1, n, hc, binaryVectors, k, bh_val, bh_ids);
219
+ }
220
+
221
+ static void bs_addn(
222
+ // beam_size parameter of Beam Search algorithm
223
+ const uint32_t beam_size,
224
+ // number of elements per beam
225
+ const uint32_t n_per_beam,
226
+ // Hamming computer
227
+ const HammingComputerT& hc,
228
+ // n elements that can be used against hc
229
+ const uint8_t* const __restrict binary_vectors,
230
+ // number of best elements to keep
231
+ const uint32_t k,
232
+ // output distances
233
+ int* const __restrict bh_val,
234
+ // output indices, each being within [0, n_per_beam * beam_size)
235
+ // range
236
+ int64_t* const __restrict bh_ids) {
237
+ //
238
+ using C = CMax<int, int64_t>;
239
+
240
+ // Hamming code size
241
+ const size_t code_size = hc.get_code_size();
242
+
243
+ // main loop
244
+ for (uint32_t beam_index = 0; beam_index < beam_size; beam_index++) {
245
+ simd16uint16 min_distances_i[NBUCKETS_16][N];
246
+ simd16uint16 min_indices_i[NBUCKETS_16][N];
247
+
248
+ for (uint32_t j = 0; j < NBUCKETS_16; j++) {
249
+ for (uint32_t p = 0; p < N; p++) {
250
+ min_distances_i[j][p] =
251
+ simd16uint16(std::numeric_limits<int16_t>::max());
252
+ min_indices_i[j][p] = simd16uint16(
253
+ 0,
254
+ 1,
255
+ 2,
256
+ 3,
257
+ 4,
258
+ 5,
259
+ 6,
260
+ 7,
261
+ 8,
262
+ 9,
263
+ 10,
264
+ 11,
265
+ 12,
266
+ 13,
267
+ 14,
268
+ 15);
269
+ }
270
+ }
271
+
272
+ simd16uint16 current_indices(
273
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
274
+ const simd16uint16 indices_delta((uint16_t)NBUCKETS);
275
+
276
+ const uint32_t nb = (n_per_beam / NBUCKETS) * NBUCKETS;
277
+
278
+ // put the data into buckets
279
+ for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
280
+ for (uint32_t j = 0; j < NBUCKETS_16; j++) {
281
+ uint16_t hamming_distances[16];
282
+ for (size_t j16 = 0; j16 < 16; j16++) {
283
+ hamming_distances[j16] = hc.hamming(
284
+ binary_vectors +
285
+ (j16 + j * 16 + ip + n_per_beam * beam_index) *
286
+ code_size);
287
+ }
288
+
289
+ // loop. Compiler should get rid of unneeded ops
290
+ simd16uint16 distance_candidate;
291
+ distance_candidate.loadu(hamming_distances);
292
+ simd16uint16 indices_candidate = current_indices;
293
+
294
+ for (uint32_t p = 0; p < N; p++) {
295
+ simd16uint16 min_distances_new;
296
+ simd16uint16 min_indices_new;
297
+ simd16uint16 max_distances_new;
298
+ simd16uint16 max_indices_new;
299
+
300
+ faiss::cmplt_min_max_fast(
301
+ distance_candidate,
302
+ indices_candidate,
303
+ min_distances_i[j][p],
304
+ min_indices_i[j][p],
305
+ min_distances_new,
306
+ min_indices_new,
307
+ max_distances_new,
308
+ max_indices_new);
309
+
310
+ distance_candidate = max_distances_new;
311
+ indices_candidate = max_indices_new;
312
+
313
+ min_distances_i[j][p] = min_distances_new;
314
+ min_indices_i[j][p] = min_indices_new;
315
+ }
316
+ }
317
+
318
+ current_indices += indices_delta;
319
+ }
320
+
321
+ // fix the indices
322
+ for (uint32_t j = 0; j < NBUCKETS_16; j++) {
323
+ const simd16uint16 offset(
324
+ (uint16_t)(n_per_beam * beam_index + j * 16));
325
+ for (uint32_t p = 0; p < N; p++) {
326
+ min_indices_i[j][p] += offset;
327
+ }
328
+ }
329
+
330
+ // merge every bucket into the regular heap
331
+ for (uint32_t p = 0; p < N; p++) {
332
+ for (uint32_t j = 0; j < NBUCKETS_16; j++) {
333
+ uint16_t min_indices_scalar[16];
334
+ uint16_t min_distances_scalar[16];
335
+
336
+ min_indices_i[j][p].storeu(min_indices_scalar);
337
+ min_distances_i[j][p].storeu(min_distances_scalar);
338
+
339
+ // this exact way is needed to maintain the order as if the
340
+ // input elements were pushed to the heap sequentially
341
+ for (size_t j16 = 0; j16 < 16; j16++) {
342
+ const auto value = min_distances_scalar[j16];
343
+ const auto index = min_indices_scalar[j16];
344
+
345
+ if (C::cmp2(bh_val[0], value, bh_ids[0], index)) {
346
+ heap_replace_top<C>(
347
+ k, bh_val, bh_ids, value, index);
348
+ }
349
+ }
350
+ }
351
+ }
352
+
353
+ // process leftovers
354
+ for (uint32_t ip = nb; ip < n_per_beam; ip++) {
355
+ const auto index = ip + n_per_beam * beam_index;
356
+ const auto value =
357
+ hc.hamming(binary_vectors + (index)*code_size);
358
+
359
+ if (C::cmp(bh_val[0], value)) {
360
+ heap_replace_top<C>(k, bh_val, bh_ids, value, index);
361
+ }
362
+ }
363
+ }
364
+ }
365
+ };
366
+
367
+ } // namespace faiss
@@ -26,6 +26,8 @@
26
26
  #include <faiss/impl/IDSelector.h>
27
27
  #include <faiss/impl/ResultHandler.h>
28
28
 
29
+ #include <faiss/utils/distances_fused/distances_fused.h>
30
+
29
31
  #ifndef FINTEGER
30
32
  #define FINTEGER long
31
33
  #endif
@@ -229,7 +231,7 @@ void exhaustive_inner_product_blas(
229
231
  // distance correction is an operator that can be applied to transform
230
232
  // the distances
231
233
  template <class ResultHandler>
232
- void exhaustive_L2sqr_blas(
234
+ void exhaustive_L2sqr_blas_default_impl(
233
235
  const float* x,
234
236
  const float* y,
235
237
  size_t d,
@@ -311,10 +313,20 @@ void exhaustive_L2sqr_blas(
311
313
  }
312
314
  }
313
315
 
316
+ template <class ResultHandler>
317
+ void exhaustive_L2sqr_blas(
318
+ const float* x,
319
+ const float* y,
320
+ size_t d,
321
+ size_t nx,
322
+ size_t ny,
323
+ ResultHandler& res,
324
+ const float* y_norms = nullptr) {
325
+ exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
326
+ }
327
+
314
328
  #ifdef __AVX2__
315
- // an override for AVX2 if only a single closest point is needed.
316
- template <>
317
- void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
329
+ void exhaustive_L2sqr_blas_cmax_avx2(
318
330
  const float* x,
319
331
  const float* y,
320
332
  size_t d,
@@ -513,11 +525,53 @@ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
513
525
  res.add_result(i, current_min_distance, current_min_index);
514
526
  }
515
527
  }
528
+ // Does nothing for SingleBestResultHandler, but
529
+ // keeping the call for the consistency.
530
+ res.end_multiple();
516
531
  InterruptCallback::check();
517
532
  }
518
533
  }
519
534
  #endif
520
535
 
536
+ // an override if only a single closest point is needed
537
+ template <>
538
+ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
539
+ const float* x,
540
+ const float* y,
541
+ size_t d,
542
+ size_t nx,
543
+ size_t ny,
544
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
545
+ const float* y_norms) {
546
+ #if defined(__AVX2__)
547
+ // use a faster fused kernel if available
548
+ if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
549
+ // the kernel is available and it is complete, we're done.
550
+ return;
551
+ }
552
+
553
+ // run the specialized AVX2 implementation
554
+ exhaustive_L2sqr_blas_cmax_avx2(x, y, d, nx, ny, res, y_norms);
555
+
556
+ #elif defined(__aarch64__)
557
+ // use a faster fused kernel if available
558
+ if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
559
+ // the kernel is available and it is complete, we're done.
560
+ return;
561
+ }
562
+
563
+ // run the default implementation
564
+ exhaustive_L2sqr_blas_default_impl<
565
+ SingleBestResultHandler<CMax<float, int64_t>>>(
566
+ x, y, d, nx, ny, res, y_norms);
567
+ #else
568
+ // run the default implementation
569
+ exhaustive_L2sqr_blas_default_impl<
570
+ SingleBestResultHandler<CMax<float, int64_t>>>(
571
+ x, y, d, nx, ny, res, y_norms);
572
+ #endif
573
+ }
574
+
521
575
  template <class ResultHandler>
522
576
  void knn_L2sqr_select(
523
577
  const float* x,
@@ -770,7 +824,7 @@ void pairwise_indexed_L2sqr(
770
824
  const float* y,
771
825
  const int64_t* iy,
772
826
  float* dis) {
773
- #pragma omp parallel for
827
+ #pragma omp parallel for if (n > 1)
774
828
  for (int64_t j = 0; j < n; j++) {
775
829
  if (ix[j] >= 0 && iy[j] >= 0) {
776
830
  dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
@@ -786,7 +840,7 @@ void pairwise_indexed_inner_product(
786
840
  const float* y,
787
841
  const int64_t* iy,
788
842
  float* dis) {
789
- #pragma omp parallel for
843
+ #pragma omp parallel for if (n > 1)
790
844
  for (int64_t j = 0; j < n; j++) {
791
845
  if (ix[j] >= 0 && iy[j] >= 0) {
792
846
  dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
@@ -887,7 +941,7 @@ void pairwise_L2sqr(
887
941
  // store in beginning of distance matrix to avoid malloc
888
942
  float* b_norms = dis;
889
943
 
890
- #pragma omp parallel for
944
+ #pragma omp parallel for if (nb > 1)
891
945
  for (int64_t i = 0; i < nb; i++)
892
946
  b_norms[i] = fvec_norm_L2sqr(xb + i * ldb, d);
893
947
 
@@ -73,6 +73,17 @@ void fvec_L2sqr_ny(
73
73
  size_t d,
74
74
  size_t ny);
75
75
 
76
+ /* compute ny square L2 distance between x and a set of transposed contiguous
77
+ y vectors. squared lengths of y should be provided as well */
78
+ void fvec_L2sqr_ny_transposed(
79
+ float* dis,
80
+ const float* x,
81
+ const float* y,
82
+ const float* y_sqlen,
83
+ size_t d,
84
+ size_t d_offset,
85
+ size_t ny);
86
+
76
87
  /* compute ny square L2 distance between x and a set of contiguous y vectors
77
88
  and return the index of the nearest vector.
78
89
  return 0 if ny == 0. */