faiss 0.2.6 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -0,0 +1,352 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/utils/distances_fused/simdlib_based.h>
11
+
12
+ #if defined(__AVX2__) || defined(__aarch64__)
13
+
14
+ #include <faiss/utils/simdlib.h>
15
+
16
+ #if defined(__AVX2__)
17
+ #include <immintrin.h>
18
+ #endif
19
+
20
+ namespace faiss {
21
+
22
+ namespace {
23
+
24
+ // It makes sense to like to overload certain cases because the further
25
+ // kernels are in need of registers. So, let's tell compiler
26
+ // not to waste registers on a bit faster code, if needed.
27
+ template <size_t DIM>
28
+ float l2_sqr(const float* const x) {
29
+ // compiler should be smart enough to handle that
30
+ float output = x[0] * x[0];
31
+ for (size_t i = 1; i < DIM; i++) {
32
+ output += x[i] * x[i];
33
+ }
34
+
35
+ return output;
36
+ }
37
+
38
+ template <size_t DIM>
39
+ float dot_product(
40
+ const float* const __restrict x,
41
+ const float* const __restrict y) {
42
+ // compiler should be smart enough to handle that
43
+ float output = x[0] * y[0];
44
+ for (size_t i = 1; i < DIM; i++) {
45
+ output += x[i] * y[i];
46
+ }
47
+
48
+ return output;
49
+ }
50
+
51
+ // The kernel for low dimensionality vectors.
52
+ // Finds the closest one from y for every given NX_POINTS_PER_LOOP points from x
53
+ //
54
+ // DIM is the dimensionality of the data
55
+ // NX_POINTS_PER_LOOP is the number of x points that get processed
56
+ // simultaneously.
57
+ // NY_POINTS_PER_LOOP is the number of y points that get processed
58
+ // simultaneously.
59
+ template <size_t DIM, size_t NX_POINTS_PER_LOOP, size_t NY_POINTS_PER_LOOP>
60
+ void kernel(
61
+ const float* const __restrict x,
62
+ const float* const __restrict y,
63
+ const float* const __restrict y_transposed,
64
+ const size_t ny,
65
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
66
+ const float* __restrict y_norms,
67
+ const size_t i) {
68
+ const size_t ny_p =
69
+ (ny / (8 * NY_POINTS_PER_LOOP)) * (8 * NY_POINTS_PER_LOOP);
70
+
71
+ // compute
72
+ const float* const __restrict xd_0 = x + i * DIM;
73
+
74
+ // prefetch the next point
75
+ #if defined(__AVX2__)
76
+ _mm_prefetch(xd_0 + DIM * sizeof(float), _MM_HINT_NTA);
77
+ #endif
78
+
79
+ // load a single point from x
80
+ // load -2 * value
81
+ simd8float32 x_i[NX_POINTS_PER_LOOP][DIM];
82
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
83
+ for (size_t dd = 0; dd < DIM; dd++) {
84
+ x_i[nx_k][dd] = simd8float32(-2 * *(xd_0 + nx_k * DIM + dd));
85
+ }
86
+ }
87
+
88
+ // compute x_norm
89
+ float x_norm_i[NX_POINTS_PER_LOOP];
90
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
91
+ x_norm_i[nx_k] = l2_sqr<DIM>(xd_0 + nx_k * DIM);
92
+ }
93
+
94
+ // distances and indices
95
+ simd8float32 min_distances_i[NX_POINTS_PER_LOOP];
96
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
97
+ min_distances_i[nx_k] =
98
+ simd8float32(res.dis_tab[i + nx_k] - x_norm_i[nx_k]);
99
+ }
100
+
101
+ simd8uint32 min_indices_i[NX_POINTS_PER_LOOP];
102
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
103
+ min_indices_i[nx_k] = simd8uint32((uint32_t)0);
104
+ }
105
+
106
+ //
107
+ simd8uint32 current_indices = simd8uint32(0, 1, 2, 3, 4, 5, 6, 7);
108
+ const simd8uint32 indices_delta = simd8uint32(8);
109
+
110
+ // main loop
111
+ size_t j = 0;
112
+ for (; j < ny_p; j += NY_POINTS_PER_LOOP * 8) {
113
+ // compute dot products for NX_POINTS from x and NY_POINTS from y
114
+ // technically, we're multiplying -2x and y
115
+ simd8float32 dp_i[NX_POINTS_PER_LOOP][NY_POINTS_PER_LOOP];
116
+
117
+ // DIM 0 that uses MUL
118
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
119
+ simd8float32 y_i =
120
+ simd8float32(y_transposed + j + ny_k * 8 + ny * 0);
121
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
122
+ dp_i[nx_k][ny_k] = x_i[nx_k][0] * y_i;
123
+ }
124
+ }
125
+
126
+ // other DIMs that use FMA
127
+ for (size_t dd = 1; dd < DIM; dd++) {
128
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
129
+ simd8float32 y_i =
130
+ simd8float32(y_transposed + j + ny_k * 8 + ny * dd);
131
+
132
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
133
+ dp_i[nx_k][ny_k] =
134
+ fmadd(x_i[nx_k][dd], y_i, dp_i[nx_k][ny_k]);
135
+ }
136
+ }
137
+ }
138
+
139
+ // compute y^2 + (-2x,y)
140
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
141
+ simd8float32 y_l2_sqr = simd8float32(y_norms + j + ny_k * 8);
142
+
143
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
144
+ dp_i[nx_k][ny_k] = dp_i[nx_k][ny_k] + y_l2_sqr;
145
+ }
146
+ }
147
+
148
+ // do the comparisons and alter the min indices
149
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
150
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
151
+ // cmpps
152
+ cmplt_and_blend_inplace(
153
+ dp_i[nx_k][ny_k],
154
+ current_indices,
155
+ min_distances_i[nx_k],
156
+ min_indices_i[nx_k]);
157
+ }
158
+
159
+ current_indices = current_indices + indices_delta;
160
+ }
161
+ }
162
+
163
+ // dump values and find the minimum distance / minimum index
164
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
165
+ float min_distances_scalar[8];
166
+ uint32_t min_indices_scalar[8];
167
+
168
+ min_distances_i[nx_k].storeu(min_distances_scalar);
169
+ min_indices_i[nx_k].storeu(min_indices_scalar);
170
+
171
+ float current_min_distance = res.dis_tab[i + nx_k];
172
+ uint32_t current_min_index = res.ids_tab[i + nx_k];
173
+
174
+ // This unusual comparison is needed to maintain the behavior
175
+ // of the original implementation: if two indices are
176
+ // represented with equal distance values, then
177
+ // the index with the min value is returned.
178
+ for (size_t jv = 0; jv < 8; jv++) {
179
+ // add missing x_norms[i]
180
+ float distance_candidate =
181
+ min_distances_scalar[jv] + x_norm_i[nx_k];
182
+
183
+ // negative values can occur for identical vectors
184
+ // due to roundoff errors.
185
+ if (distance_candidate < 0) {
186
+ distance_candidate = 0;
187
+ }
188
+
189
+ const int64_t index_candidate = min_indices_scalar[jv];
190
+
191
+ if (current_min_distance > distance_candidate) {
192
+ current_min_distance = distance_candidate;
193
+ current_min_index = index_candidate;
194
+ } else if (
195
+ current_min_distance == distance_candidate &&
196
+ current_min_index > index_candidate) {
197
+ current_min_index = index_candidate;
198
+ }
199
+ }
200
+
201
+ // process leftovers
202
+ for (size_t j0 = j; j0 < ny; j0++) {
203
+ const float dp =
204
+ dot_product<DIM>(x + (i + nx_k) * DIM, y + j0 * DIM);
205
+ float dis = x_norm_i[nx_k] + y_norms[j0] - 2 * dp;
206
+ // negative values can occur for identical vectors
207
+ // due to roundoff errors.
208
+ if (dis < 0) {
209
+ dis = 0;
210
+ }
211
+
212
+ if (current_min_distance > dis) {
213
+ current_min_distance = dis;
214
+ current_min_index = j0;
215
+ }
216
+ }
217
+
218
+ // done
219
+ res.add_result(i + nx_k, current_min_distance, current_min_index);
220
+ }
221
+ }
222
+
223
+ template <size_t DIM, size_t NX_POINTS_PER_LOOP, size_t NY_POINTS_PER_LOOP>
224
+ void exhaustive_L2sqr_fused_cmax(
225
+ const float* const __restrict x,
226
+ const float* const __restrict y,
227
+ size_t nx,
228
+ size_t ny,
229
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
230
+ const float* __restrict y_norms) {
231
+ // BLAS does not like empty matrices
232
+ if (nx == 0 || ny == 0) {
233
+ return;
234
+ }
235
+
236
+ // compute norms for y
237
+ std::unique_ptr<float[]> del2;
238
+ if (!y_norms) {
239
+ float* y_norms2 = new float[ny];
240
+ del2.reset(y_norms2);
241
+
242
+ for (size_t i = 0; i < ny; i++) {
243
+ y_norms2[i] = l2_sqr<DIM>(y + i * DIM);
244
+ }
245
+
246
+ y_norms = y_norms2;
247
+ }
248
+
249
+ // initialize res
250
+ res.begin_multiple(0, nx);
251
+
252
+ // transpose y
253
+ std::vector<float> y_transposed(DIM * ny);
254
+ for (size_t j = 0; j < DIM; j++) {
255
+ for (size_t i = 0; i < ny; i++) {
256
+ y_transposed[j * ny + i] = y[j + i * DIM];
257
+ }
258
+ }
259
+
260
+ const size_t nx_p = (nx / NX_POINTS_PER_LOOP) * NX_POINTS_PER_LOOP;
261
+ // the main loop.
262
+ #pragma omp parallel for schedule(dynamic)
263
+ for (size_t i = 0; i < nx_p; i += NX_POINTS_PER_LOOP) {
264
+ kernel<DIM, NX_POINTS_PER_LOOP, NY_POINTS_PER_LOOP>(
265
+ x, y, y_transposed.data(), ny, res, y_norms, i);
266
+ }
267
+
268
+ for (size_t i = nx_p; i < nx; i++) {
269
+ kernel<DIM, 1, NY_POINTS_PER_LOOP>(
270
+ x, y, y_transposed.data(), ny, res, y_norms, i);
271
+ }
272
+
273
+ // Does nothing for SingleBestResultHandler, but
274
+ // keeping the call for the consistency.
275
+ res.end_multiple();
276
+ InterruptCallback::check();
277
+ }
278
+
279
+ } // namespace
280
+
281
+ bool exhaustive_L2sqr_fused_cmax_simdlib(
282
+ const float* x,
283
+ const float* y,
284
+ size_t d,
285
+ size_t nx,
286
+ size_t ny,
287
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
288
+ const float* y_norms) {
289
+ // Process only cases with certain dimensionalities.
290
+ // An acceptable dimensionality value is limited by the number of
291
+ // available registers.
292
+
293
+ #define DISPATCH(DIM, NX_POINTS_PER_LOOP, NY_POINTS_PER_LOOP) \
294
+ case DIM: { \
295
+ exhaustive_L2sqr_fused_cmax< \
296
+ DIM, \
297
+ NX_POINTS_PER_LOOP, \
298
+ NY_POINTS_PER_LOOP>(x, y, nx, ny, res, y_norms); \
299
+ return true; \
300
+ }
301
+
302
+ // faiss/benchs/bench_quantizer.py was used for benchmarking
303
+ // and tuning 2nd and 3rd parameters values.
304
+ // Basically, the larger the values for 2nd and 3rd parameters are,
305
+ // the faster the execution is, but the more SIMD registers are needed.
306
+ // This can be compensated with L1 cache, this is why this
307
+ // code might operate with more registers than available
308
+ // because of concurrent ports operations for ALU and LOAD/STORE.
309
+
310
+ #if defined(__AVX2__)
311
+ // It was possible to tweak these parameters on x64 machine.
312
+ switch (d) {
313
+ DISPATCH(1, 6, 1)
314
+ DISPATCH(2, 6, 1)
315
+ DISPATCH(3, 6, 1)
316
+ DISPATCH(4, 8, 1)
317
+ DISPATCH(5, 8, 1)
318
+ DISPATCH(6, 8, 1)
319
+ DISPATCH(7, 8, 1)
320
+ DISPATCH(8, 8, 1)
321
+ DISPATCH(9, 8, 1)
322
+ DISPATCH(10, 8, 1)
323
+ DISPATCH(11, 8, 1)
324
+ DISPATCH(12, 8, 1)
325
+ DISPATCH(13, 6, 1)
326
+ DISPATCH(14, 6, 1)
327
+ DISPATCH(15, 6, 1)
328
+ DISPATCH(16, 6, 1)
329
+ }
330
+ #else
331
+ // Please feel free to alter 2nd and 3rd parameters if you have access
332
+ // to ARM-based machine so that you are able to benchmark this code.
333
+ // Or to enable other dimensions.
334
+ switch (d) {
335
+ DISPATCH(1, 4, 2)
336
+ DISPATCH(2, 2, 2)
337
+ DISPATCH(3, 2, 2)
338
+ DISPATCH(4, 2, 1)
339
+ DISPATCH(5, 1, 1)
340
+ DISPATCH(6, 1, 1)
341
+ DISPATCH(7, 1, 1)
342
+ DISPATCH(8, 1, 1)
343
+ }
344
+ #endif
345
+
346
+ return false;
347
+ #undef DISPATCH
348
+ }
349
+
350
+ } // namespace faiss
351
+
352
+ #endif
@@ -0,0 +1,32 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/impl/ResultHandler.h>
11
+ #include <faiss/impl/platform_macros.h>
12
+
13
+ #include <faiss/utils/Heap.h>
14
+
15
+ #if defined(__AVX2__) || defined(__aarch64__)
16
+
17
+ namespace faiss {
18
+
19
+ // Returns true if the fused kernel is available and the data was processed.
20
+ // Returns false if the fused kernel is not available.
21
+ bool exhaustive_L2sqr_fused_cmax_simdlib(
22
+ const float* x,
23
+ const float* y,
24
+ size_t d,
25
+ size_t nx,
26
+ size_t ny,
27
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
28
+ const float* y_norms);
29
+
30
+ } // namespace faiss
31
+
32
+ #endif