faiss 0.2.5 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/extconf.rb +1 -1
  5. data/ext/faiss/index.cpp +13 -0
  6. data/lib/faiss/version.rb +1 -1
  7. data/lib/faiss.rb +2 -2
  8. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  9. data/vendor/faiss/faiss/AutoTune.h +0 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  11. data/vendor/faiss/faiss/Clustering.h +0 -2
  12. data/vendor/faiss/faiss/IVFlib.h +0 -2
  13. data/vendor/faiss/faiss/Index.h +1 -2
  14. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  15. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  16. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  17. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  18. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  19. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  20. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  21. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  22. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  23. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  24. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  25. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  26. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  27. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  29. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  30. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  31. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  32. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  33. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  34. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  35. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  36. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  37. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  38. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  39. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  41. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  43. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  44. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  45. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  46. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  47. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  48. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  49. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  50. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  51. data/vendor/faiss/faiss/IndexShards.h +2 -3
  52. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  53. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  54. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  55. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  56. data/vendor/faiss/faiss/MetricType.h +14 -0
  57. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  58. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  59. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  60. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  61. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  62. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  69. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  70. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  71. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  72. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  73. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  74. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  75. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  76. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  77. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  78. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  81. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  82. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  83. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  84. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  85. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  86. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  87. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  91. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  92. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  93. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  94. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  95. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  96. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  97. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  98. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  99. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  100. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  101. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  102. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  104. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  105. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  106. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  107. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  109. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  110. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  111. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  112. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  113. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  114. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  115. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  116. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  117. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  118. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  119. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  120. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  121. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  122. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  123. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  124. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  125. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  128. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  129. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  130. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  131. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  132. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  133. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  134. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  135. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  137. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  138. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  139. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  140. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  141. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  142. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  143. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  144. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  145. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  146. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  147. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  148. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  149. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  150. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  151. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  152. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  153. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  155. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  156. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  157. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  158. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  159. data/vendor/faiss/faiss/utils/distances.h +11 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  164. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  165. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  166. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  167. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  168. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  169. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  170. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  171. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  172. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  173. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  174. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  175. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  176. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  179. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  180. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  181. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  182. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  183. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  184. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  185. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  186. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  187. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  188. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  189. data/vendor/faiss/faiss/utils/utils.h +2 -9
  190. metadata +30 -4
  191. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -0,0 +1,352 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/utils/distances_fused/simdlib_based.h>
11
+
12
+ #if defined(__AVX2__) || defined(__aarch64__)
13
+
14
+ #include <faiss/utils/simdlib.h>
15
+
16
+ #if defined(__AVX2__)
17
+ #include <immintrin.h>
18
+ #endif
19
+
20
+ namespace faiss {
21
+
22
+ namespace {
23
+
24
+ // It makes sense to like to overload certain cases because the further
25
+ // kernels are in need of registers. So, let's tell compiler
26
+ // not to waste registers on a bit faster code, if needed.
27
+ template <size_t DIM>
28
+ float l2_sqr(const float* const x) {
29
+ // compiler should be smart enough to handle that
30
+ float output = x[0] * x[0];
31
+ for (size_t i = 1; i < DIM; i++) {
32
+ output += x[i] * x[i];
33
+ }
34
+
35
+ return output;
36
+ }
37
+
38
+ template <size_t DIM>
39
+ float dot_product(
40
+ const float* const __restrict x,
41
+ const float* const __restrict y) {
42
+ // compiler should be smart enough to handle that
43
+ float output = x[0] * y[0];
44
+ for (size_t i = 1; i < DIM; i++) {
45
+ output += x[i] * y[i];
46
+ }
47
+
48
+ return output;
49
+ }
50
+
51
+ // The kernel for low dimensionality vectors.
52
+ // Finds the closest one from y for every given NX_POINTS_PER_LOOP points from x
53
+ //
54
+ // DIM is the dimensionality of the data
55
+ // NX_POINTS_PER_LOOP is the number of x points that get processed
56
+ // simultaneously.
57
+ // NY_POINTS_PER_LOOP is the number of y points that get processed
58
+ // simultaneously.
59
+ template <size_t DIM, size_t NX_POINTS_PER_LOOP, size_t NY_POINTS_PER_LOOP>
60
+ void kernel(
61
+ const float* const __restrict x,
62
+ const float* const __restrict y,
63
+ const float* const __restrict y_transposed,
64
+ const size_t ny,
65
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
66
+ const float* __restrict y_norms,
67
+ const size_t i) {
68
+ const size_t ny_p =
69
+ (ny / (8 * NY_POINTS_PER_LOOP)) * (8 * NY_POINTS_PER_LOOP);
70
+
71
+ // compute
72
+ const float* const __restrict xd_0 = x + i * DIM;
73
+
74
+ // prefetch the next point
75
+ #if defined(__AVX2__)
76
+ _mm_prefetch(xd_0 + DIM * sizeof(float), _MM_HINT_NTA);
77
+ #endif
78
+
79
+ // load a single point from x
80
+ // load -2 * value
81
+ simd8float32 x_i[NX_POINTS_PER_LOOP][DIM];
82
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
83
+ for (size_t dd = 0; dd < DIM; dd++) {
84
+ x_i[nx_k][dd] = simd8float32(-2 * *(xd_0 + nx_k * DIM + dd));
85
+ }
86
+ }
87
+
88
+ // compute x_norm
89
+ float x_norm_i[NX_POINTS_PER_LOOP];
90
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
91
+ x_norm_i[nx_k] = l2_sqr<DIM>(xd_0 + nx_k * DIM);
92
+ }
93
+
94
+ // distances and indices
95
+ simd8float32 min_distances_i[NX_POINTS_PER_LOOP];
96
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
97
+ min_distances_i[nx_k] =
98
+ simd8float32(res.dis_tab[i + nx_k] - x_norm_i[nx_k]);
99
+ }
100
+
101
+ simd8uint32 min_indices_i[NX_POINTS_PER_LOOP];
102
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
103
+ min_indices_i[nx_k] = simd8uint32((uint32_t)0);
104
+ }
105
+
106
+ //
107
+ simd8uint32 current_indices = simd8uint32(0, 1, 2, 3, 4, 5, 6, 7);
108
+ const simd8uint32 indices_delta = simd8uint32(8);
109
+
110
+ // main loop
111
+ size_t j = 0;
112
+ for (; j < ny_p; j += NY_POINTS_PER_LOOP * 8) {
113
+ // compute dot products for NX_POINTS from x and NY_POINTS from y
114
+ // technically, we're multiplying -2x and y
115
+ simd8float32 dp_i[NX_POINTS_PER_LOOP][NY_POINTS_PER_LOOP];
116
+
117
+ // DIM 0 that uses MUL
118
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
119
+ simd8float32 y_i =
120
+ simd8float32(y_transposed + j + ny_k * 8 + ny * 0);
121
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
122
+ dp_i[nx_k][ny_k] = x_i[nx_k][0] * y_i;
123
+ }
124
+ }
125
+
126
+ // other DIMs that use FMA
127
+ for (size_t dd = 1; dd < DIM; dd++) {
128
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
129
+ simd8float32 y_i =
130
+ simd8float32(y_transposed + j + ny_k * 8 + ny * dd);
131
+
132
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
133
+ dp_i[nx_k][ny_k] =
134
+ fmadd(x_i[nx_k][dd], y_i, dp_i[nx_k][ny_k]);
135
+ }
136
+ }
137
+ }
138
+
139
+ // compute y^2 + (-2x,y)
140
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
141
+ simd8float32 y_l2_sqr = simd8float32(y_norms + j + ny_k * 8);
142
+
143
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
144
+ dp_i[nx_k][ny_k] = dp_i[nx_k][ny_k] + y_l2_sqr;
145
+ }
146
+ }
147
+
148
+ // do the comparisons and alter the min indices
149
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
150
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
151
+ // cmpps
152
+ cmplt_and_blend_inplace(
153
+ dp_i[nx_k][ny_k],
154
+ current_indices,
155
+ min_distances_i[nx_k],
156
+ min_indices_i[nx_k]);
157
+ }
158
+
159
+ current_indices = current_indices + indices_delta;
160
+ }
161
+ }
162
+
163
+ // dump values and find the minimum distance / minimum index
164
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
165
+ float min_distances_scalar[8];
166
+ uint32_t min_indices_scalar[8];
167
+
168
+ min_distances_i[nx_k].storeu(min_distances_scalar);
169
+ min_indices_i[nx_k].storeu(min_indices_scalar);
170
+
171
+ float current_min_distance = res.dis_tab[i + nx_k];
172
+ uint32_t current_min_index = res.ids_tab[i + nx_k];
173
+
174
+ // This unusual comparison is needed to maintain the behavior
175
+ // of the original implementation: if two indices are
176
+ // represented with equal distance values, then
177
+ // the index with the min value is returned.
178
+ for (size_t jv = 0; jv < 8; jv++) {
179
+ // add missing x_norms[i]
180
+ float distance_candidate =
181
+ min_distances_scalar[jv] + x_norm_i[nx_k];
182
+
183
+ // negative values can occur for identical vectors
184
+ // due to roundoff errors.
185
+ if (distance_candidate < 0) {
186
+ distance_candidate = 0;
187
+ }
188
+
189
+ const int64_t index_candidate = min_indices_scalar[jv];
190
+
191
+ if (current_min_distance > distance_candidate) {
192
+ current_min_distance = distance_candidate;
193
+ current_min_index = index_candidate;
194
+ } else if (
195
+ current_min_distance == distance_candidate &&
196
+ current_min_index > index_candidate) {
197
+ current_min_index = index_candidate;
198
+ }
199
+ }
200
+
201
+ // process leftovers
202
+ for (size_t j0 = j; j0 < ny; j0++) {
203
+ const float dp =
204
+ dot_product<DIM>(x + (i + nx_k) * DIM, y + j0 * DIM);
205
+ float dis = x_norm_i[nx_k] + y_norms[j0] - 2 * dp;
206
+ // negative values can occur for identical vectors
207
+ // due to roundoff errors.
208
+ if (dis < 0) {
209
+ dis = 0;
210
+ }
211
+
212
+ if (current_min_distance > dis) {
213
+ current_min_distance = dis;
214
+ current_min_index = j0;
215
+ }
216
+ }
217
+
218
+ // done
219
+ res.add_result(i + nx_k, current_min_distance, current_min_index);
220
+ }
221
+ }
222
+
223
+ template <size_t DIM, size_t NX_POINTS_PER_LOOP, size_t NY_POINTS_PER_LOOP>
224
+ void exhaustive_L2sqr_fused_cmax(
225
+ const float* const __restrict x,
226
+ const float* const __restrict y,
227
+ size_t nx,
228
+ size_t ny,
229
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
230
+ const float* __restrict y_norms) {
231
+ // BLAS does not like empty matrices
232
+ if (nx == 0 || ny == 0) {
233
+ return;
234
+ }
235
+
236
+ // compute norms for y
237
+ std::unique_ptr<float[]> del2;
238
+ if (!y_norms) {
239
+ float* y_norms2 = new float[ny];
240
+ del2.reset(y_norms2);
241
+
242
+ for (size_t i = 0; i < ny; i++) {
243
+ y_norms2[i] = l2_sqr<DIM>(y + i * DIM);
244
+ }
245
+
246
+ y_norms = y_norms2;
247
+ }
248
+
249
+ // initialize res
250
+ res.begin_multiple(0, nx);
251
+
252
+ // transpose y
253
+ std::vector<float> y_transposed(DIM * ny);
254
+ for (size_t j = 0; j < DIM; j++) {
255
+ for (size_t i = 0; i < ny; i++) {
256
+ y_transposed[j * ny + i] = y[j + i * DIM];
257
+ }
258
+ }
259
+
260
+ const size_t nx_p = (nx / NX_POINTS_PER_LOOP) * NX_POINTS_PER_LOOP;
261
+ // the main loop.
262
+ #pragma omp parallel for schedule(dynamic)
263
+ for (size_t i = 0; i < nx_p; i += NX_POINTS_PER_LOOP) {
264
+ kernel<DIM, NX_POINTS_PER_LOOP, NY_POINTS_PER_LOOP>(
265
+ x, y, y_transposed.data(), ny, res, y_norms, i);
266
+ }
267
+
268
+ for (size_t i = nx_p; i < nx; i++) {
269
+ kernel<DIM, 1, NY_POINTS_PER_LOOP>(
270
+ x, y, y_transposed.data(), ny, res, y_norms, i);
271
+ }
272
+
273
+ // Does nothing for SingleBestResultHandler, but
274
+ // keeping the call for the consistency.
275
+ res.end_multiple();
276
+ InterruptCallback::check();
277
+ }
278
+
279
+ } // namespace
280
+
281
+ bool exhaustive_L2sqr_fused_cmax_simdlib(
282
+ const float* x,
283
+ const float* y,
284
+ size_t d,
285
+ size_t nx,
286
+ size_t ny,
287
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
288
+ const float* y_norms) {
289
+ // Process only cases with certain dimensionalities.
290
+ // An acceptable dimensionality value is limited by the number of
291
+ // available registers.
292
+
293
+ #define DISPATCH(DIM, NX_POINTS_PER_LOOP, NY_POINTS_PER_LOOP) \
294
+ case DIM: { \
295
+ exhaustive_L2sqr_fused_cmax< \
296
+ DIM, \
297
+ NX_POINTS_PER_LOOP, \
298
+ NY_POINTS_PER_LOOP>(x, y, nx, ny, res, y_norms); \
299
+ return true; \
300
+ }
301
+
302
+ // faiss/benchs/bench_quantizer.py was used for benchmarking
303
+ // and tuning 2nd and 3rd parameters values.
304
+ // Basically, the larger the values for 2nd and 3rd parameters are,
305
+ // the faster the execution is, but the more SIMD registers are needed.
306
+ // This can be compensated with L1 cache, this is why this
307
+ // code might operate with more registers than available
308
+ // because of concurrent ports operations for ALU and LOAD/STORE.
309
+
310
+ #if defined(__AVX2__)
311
+ // It was possible to tweak these parameters on x64 machine.
312
+ switch (d) {
313
+ DISPATCH(1, 6, 1)
314
+ DISPATCH(2, 6, 1)
315
+ DISPATCH(3, 6, 1)
316
+ DISPATCH(4, 8, 1)
317
+ DISPATCH(5, 8, 1)
318
+ DISPATCH(6, 8, 1)
319
+ DISPATCH(7, 8, 1)
320
+ DISPATCH(8, 8, 1)
321
+ DISPATCH(9, 8, 1)
322
+ DISPATCH(10, 8, 1)
323
+ DISPATCH(11, 8, 1)
324
+ DISPATCH(12, 8, 1)
325
+ DISPATCH(13, 6, 1)
326
+ DISPATCH(14, 6, 1)
327
+ DISPATCH(15, 6, 1)
328
+ DISPATCH(16, 6, 1)
329
+ }
330
+ #else
331
+ // Please feel free to alter 2nd and 3rd parameters if you have access
332
+ // to ARM-based machine so that you are able to benchmark this code.
333
+ // Or to enable other dimensions.
334
+ switch (d) {
335
+ DISPATCH(1, 4, 2)
336
+ DISPATCH(2, 2, 2)
337
+ DISPATCH(3, 2, 2)
338
+ DISPATCH(4, 2, 1)
339
+ DISPATCH(5, 1, 1)
340
+ DISPATCH(6, 1, 1)
341
+ DISPATCH(7, 1, 1)
342
+ DISPATCH(8, 1, 1)
343
+ }
344
+ #endif
345
+
346
+ return false;
347
+ #undef DISPATCH
348
+ }
349
+
350
+ } // namespace faiss
351
+
352
+ #endif
@@ -0,0 +1,32 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/impl/ResultHandler.h>
11
+ #include <faiss/impl/platform_macros.h>
12
+
13
+ #include <faiss/utils/Heap.h>
14
+
15
+ #if defined(__AVX2__) || defined(__aarch64__)
16
+
17
+ namespace faiss {
18
+
19
+ // Returns true if the fused kernel is available and the data was processed.
20
+ // Returns false if the fused kernel is not available.
21
+ bool exhaustive_L2sqr_fused_cmax_simdlib(
22
+ const float* x,
23
+ const float* y,
24
+ size_t d,
25
+ size_t nx,
26
+ size_t ny,
27
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
28
+ const float* y_norms);
29
+
30
+ } // namespace faiss
31
+
32
+ #endif