faiss 0.2.6 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -0,0 +1,346 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/utils/distances_fused/avx512.h>
11
+
12
+ #ifdef __AVX512__
13
+
14
+ #include <immintrin.h>
15
+
16
+ namespace faiss {
17
+
18
+ namespace {
19
+
20
+ // It makes sense to like to overload certain cases because the further
21
+ // kernels are in need of AVX512 registers. So, let's tell compiler
22
+ // not to waste registers on a bit faster code, if needed.
23
+ template <size_t DIM>
24
+ float l2_sqr(const float* const x) {
25
+ // compiler should be smart enough to handle that
26
+ float output = x[0] * x[0];
27
+ for (size_t i = 1; i < DIM; i++) {
28
+ output += x[i] * x[i];
29
+ }
30
+
31
+ return output;
32
+ }
33
+
34
+ template <>
35
+ float l2_sqr<4>(const float* const x) {
36
+ __m128 v = _mm_loadu_ps(x);
37
+ __m128 v2 = _mm_mul_ps(v, v);
38
+ v2 = _mm_hadd_ps(v2, v2);
39
+ v2 = _mm_hadd_ps(v2, v2);
40
+
41
+ return _mm_cvtss_f32(v2);
42
+ }
43
+
44
+ template <size_t DIM>
45
+ float dot_product(
46
+ const float* const __restrict x,
47
+ const float* const __restrict y) {
48
+ // compiler should be smart enough to handle that
49
+ float output = x[0] * y[0];
50
+ for (size_t i = 1; i < DIM; i++) {
51
+ output += x[i] * y[i];
52
+ }
53
+
54
+ return output;
55
+ }
56
+
57
+ // The kernel for low dimensionality vectors.
58
+ // Finds the closest one from y for every given NX_POINTS_PER_LOOP points from x
59
+ //
60
+ // DIM is the dimensionality of the data
61
+ // NX_POINTS_PER_LOOP is the number of x points that get processed
62
+ // simultaneously.
63
+ // NY_POINTS_PER_LOOP is the number of y points that get processed
64
+ // simultaneously.
65
+ template <size_t DIM, size_t NX_POINTS_PER_LOOP, size_t NY_POINTS_PER_LOOP>
66
+ void kernel(
67
+ const float* const __restrict x,
68
+ const float* const __restrict y,
69
+ const float* const __restrict y_transposed,
70
+ size_t ny,
71
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
72
+ const float* __restrict y_norms,
73
+ size_t i) {
74
+ const size_t ny_p =
75
+ (ny / (16 * NY_POINTS_PER_LOOP)) * (16 * NY_POINTS_PER_LOOP);
76
+
77
+ // compute
78
+ const float* const __restrict xd_0 = x + i * DIM;
79
+
80
+ // prefetch the next point
81
+ _mm_prefetch(xd_0 + DIM * sizeof(float), _MM_HINT_NTA);
82
+
83
+ // load a single point from x
84
+ // load -2 * value
85
+ __m512 x_i[NX_POINTS_PER_LOOP][DIM];
86
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
87
+ for (size_t dd = 0; dd < DIM; dd++) {
88
+ x_i[nx_k][dd] = _mm512_set1_ps(-2 * *(xd_0 + nx_k * DIM + dd));
89
+ }
90
+ }
91
+
92
+ // compute x_norm
93
+ float x_norm_i[NX_POINTS_PER_LOOP];
94
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
95
+ x_norm_i[nx_k] = l2_sqr<DIM>(xd_0 + nx_k * DIM);
96
+ }
97
+
98
+ // distances and indices
99
+ __m512 min_distances_i[NX_POINTS_PER_LOOP];
100
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
101
+ min_distances_i[nx_k] =
102
+ _mm512_set1_ps(res.dis_tab[i + nx_k] - x_norm_i[nx_k]);
103
+ }
104
+
105
+ __m512i min_indices_i[NX_POINTS_PER_LOOP];
106
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
107
+ min_indices_i[nx_k] = _mm512_set1_epi32(0);
108
+ }
109
+
110
+ //
111
+ __m512i current_indices = _mm512_setr_epi32(
112
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
113
+ const __m512i indices_delta = _mm512_set1_epi32(16);
114
+
115
+ // main loop
116
+ size_t j = 0;
117
+ for (; j < ny_p; j += NY_POINTS_PER_LOOP * 16) {
118
+ // compute dot products for NX_POINTS from x and NY_POINTS from y
119
+ // technically, we're multiplying -2x and y
120
+ __m512 dp_i[NX_POINTS_PER_LOOP][NY_POINTS_PER_LOOP];
121
+
122
+ // DIM 0 that uses MUL
123
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
124
+ __m512 y_i = _mm512_loadu_ps(y_transposed + j + ny_k * 16 + ny * 0);
125
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
126
+ dp_i[nx_k][ny_k] = _mm512_mul_ps(x_i[nx_k][0], y_i);
127
+ }
128
+ }
129
+
130
+ // other DIMs that use FMA
131
+ for (size_t dd = 1; dd < DIM; dd++) {
132
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
133
+ __m512 y_i =
134
+ _mm512_loadu_ps(y_transposed + j + ny_k * 16 + ny * dd);
135
+
136
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
137
+ dp_i[nx_k][ny_k] = _mm512_fmadd_ps(
138
+ x_i[nx_k][dd], y_i, dp_i[nx_k][ny_k]);
139
+ }
140
+ }
141
+ }
142
+
143
+ // compute y^2 - 2 * (x,y)
144
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
145
+ __m512 y_l2_sqr = _mm512_loadu_ps(y_norms + j + ny_k * 16);
146
+
147
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
148
+ dp_i[nx_k][ny_k] = _mm512_add_ps(dp_i[nx_k][ny_k], y_l2_sqr);
149
+ }
150
+ }
151
+
152
+ // do the comparisons and alter the min indices
153
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
154
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
155
+ const __mmask16 comparison = _mm512_cmp_ps_mask(
156
+ dp_i[nx_k][ny_k], min_distances_i[nx_k], _CMP_LT_OS);
157
+ min_distances_i[nx_k] = _mm512_mask_blend_ps(
158
+ comparison, min_distances_i[nx_k], dp_i[nx_k][ny_k]);
159
+ min_indices_i[nx_k] = _mm512_castps_si512(_mm512_mask_blend_ps(
160
+ comparison,
161
+ _mm512_castsi512_ps(min_indices_i[nx_k]),
162
+ _mm512_castsi512_ps(current_indices)));
163
+ }
164
+
165
+ current_indices = _mm512_add_epi32(current_indices, indices_delta);
166
+ }
167
+ }
168
+
169
+ // dump values and find the minimum distance / minimum index
170
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
171
+ float min_distances_scalar[16];
172
+ uint32_t min_indices_scalar[16];
173
+ _mm512_storeu_ps(min_distances_scalar, min_distances_i[nx_k]);
174
+ _mm512_storeu_si512(
175
+ (__m512i*)(min_indices_scalar), min_indices_i[nx_k]);
176
+
177
+ float current_min_distance = res.dis_tab[i + nx_k];
178
+ uint32_t current_min_index = res.ids_tab[i + nx_k];
179
+
180
+ // This unusual comparison is needed to maintain the behavior
181
+ // of the original implementation: if two indices are
182
+ // represented with equal distance values, then
183
+ // the index with the min value is returned.
184
+ for (size_t jv = 0; jv < 16; jv++) {
185
+ // add missing x_norms[i]
186
+ float distance_candidate =
187
+ min_distances_scalar[jv] + x_norm_i[nx_k];
188
+
189
+ // negative values can occur for identical vectors
190
+ // due to roundoff errors.
191
+ if (distance_candidate < 0)
192
+ distance_candidate = 0;
193
+
194
+ const int64_t index_candidate = min_indices_scalar[jv];
195
+
196
+ if (current_min_distance > distance_candidate) {
197
+ current_min_distance = distance_candidate;
198
+ current_min_index = index_candidate;
199
+ } else if (
200
+ current_min_distance == distance_candidate &&
201
+ current_min_index > index_candidate) {
202
+ current_min_index = index_candidate;
203
+ }
204
+ }
205
+
206
+ // process leftovers
207
+ for (size_t j0 = j; j0 < ny; j0++) {
208
+ const float dp =
209
+ dot_product<DIM>(x + (i + nx_k) * DIM, y + j0 * DIM);
210
+ float dis = x_norm_i[nx_k] + y_norms[j0] - 2 * dp;
211
+ // negative values can occur for identical vectors
212
+ // due to roundoff errors.
213
+ if (dis < 0) {
214
+ dis = 0;
215
+ }
216
+
217
+ if (current_min_distance > dis) {
218
+ current_min_distance = dis;
219
+ current_min_index = j0;
220
+ }
221
+ }
222
+
223
+ // done
224
+ res.add_result(i + nx_k, current_min_distance, current_min_index);
225
+ }
226
+ }
227
+
228
+ template <size_t DIM, size_t NX_POINTS_PER_LOOP, size_t NY_POINTS_PER_LOOP>
229
+ void exhaustive_L2sqr_fused_cmax(
230
+ const float* const __restrict x,
231
+ const float* const __restrict y,
232
+ size_t nx,
233
+ size_t ny,
234
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
235
+ const float* __restrict y_norms) {
236
+ // BLAS does not like empty matrices
237
+ if (nx == 0 || ny == 0) {
238
+ return;
239
+ }
240
+
241
+ // compute norms for y
242
+ std::unique_ptr<float[]> del2;
243
+ if (!y_norms) {
244
+ float* y_norms2 = new float[ny];
245
+ del2.reset(y_norms2);
246
+
247
+ for (size_t i = 0; i < ny; i++) {
248
+ y_norms2[i] = l2_sqr<DIM>(y + i * DIM);
249
+ }
250
+
251
+ y_norms = y_norms2;
252
+ }
253
+
254
+ // initialize res
255
+ res.begin_multiple(0, nx);
256
+
257
+ // transpose y
258
+ std::vector<float> y_transposed(DIM * ny);
259
+ for (size_t j = 0; j < DIM; j++) {
260
+ for (size_t i = 0; i < ny; i++) {
261
+ y_transposed[j * ny + i] = y[j + i * DIM];
262
+ }
263
+ }
264
+
265
+ const size_t nx_p = (nx / NX_POINTS_PER_LOOP) * NX_POINTS_PER_LOOP;
266
+ // the main loop.
267
+ #pragma omp parallel for schedule(dynamic)
268
+ for (size_t i = 0; i < nx_p; i += NX_POINTS_PER_LOOP) {
269
+ kernel<DIM, NX_POINTS_PER_LOOP, NY_POINTS_PER_LOOP>(
270
+ x, y, y_transposed.data(), ny, res, y_norms, i);
271
+ }
272
+
273
+ for (size_t i = nx_p; i < nx; i++) {
274
+ kernel<DIM, 1, NY_POINTS_PER_LOOP>(
275
+ x, y, y_transposed.data(), ny, res, y_norms, i);
276
+ }
277
+
278
+ // Does nothing for SingleBestResultHandler, but
279
+ // keeping the call for the consistency.
280
+ res.end_multiple();
281
+ InterruptCallback::check();
282
+ }
283
+
284
+ } // namespace
285
+
286
+ bool exhaustive_L2sqr_fused_cmax_AVX512(
287
+ const float* x,
288
+ const float* y,
289
+ size_t d,
290
+ size_t nx,
291
+ size_t ny,
292
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
293
+ const float* y_norms) {
294
+ // process only cases with certain dimensionalities
295
+
296
+ #define DISPATCH(DIM, NX_POINTS_PER_LOOP, NY_POINTS_PER_LOOP) \
297
+ case DIM: { \
298
+ exhaustive_L2sqr_fused_cmax< \
299
+ DIM, \
300
+ NX_POINTS_PER_LOOP, \
301
+ NY_POINTS_PER_LOOP>(x, y, nx, ny, res, y_norms); \
302
+ return true; \
303
+ }
304
+
305
+ switch (d) {
306
+ DISPATCH(1, 8, 1)
307
+ DISPATCH(2, 8, 1)
308
+ DISPATCH(3, 8, 1)
309
+ DISPATCH(4, 8, 1)
310
+ DISPATCH(5, 8, 1)
311
+ DISPATCH(6, 8, 1)
312
+ DISPATCH(7, 8, 1)
313
+ DISPATCH(8, 8, 1)
314
+ DISPATCH(9, 8, 1)
315
+ DISPATCH(10, 8, 1)
316
+ DISPATCH(11, 8, 1)
317
+ DISPATCH(12, 8, 1)
318
+ DISPATCH(13, 8, 1)
319
+ DISPATCH(14, 8, 1)
320
+ DISPATCH(15, 8, 1)
321
+ DISPATCH(16, 8, 1)
322
+ DISPATCH(17, 8, 1)
323
+ DISPATCH(18, 8, 1)
324
+ DISPATCH(19, 8, 1)
325
+ DISPATCH(20, 8, 1)
326
+ DISPATCH(21, 8, 1)
327
+ DISPATCH(22, 8, 1)
328
+ DISPATCH(23, 8, 1)
329
+ DISPATCH(24, 8, 1)
330
+ DISPATCH(25, 8, 1)
331
+ DISPATCH(26, 8, 1)
332
+ DISPATCH(27, 8, 1)
333
+ DISPATCH(28, 8, 1)
334
+ DISPATCH(29, 8, 1)
335
+ DISPATCH(30, 8, 1)
336
+ DISPATCH(31, 8, 1)
337
+ DISPATCH(32, 8, 1)
338
+ }
339
+
340
+ return false;
341
+ #undef DISPATCH
342
+ }
343
+
344
+ } // namespace faiss
345
+
346
+ #endif
@@ -0,0 +1,36 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // AVX512 might be not used, but this version provides ~2x speedup
9
+ // over AVX2 kernel, say, for training PQx10 or PQx12, and speeds up
10
+ // additional cases with larger dimensionalities.
11
+
12
+ #pragma once
13
+
14
+ #include <faiss/impl/ResultHandler.h>
15
+ #include <faiss/impl/platform_macros.h>
16
+
17
+ #include <faiss/utils/Heap.h>
18
+
19
+ #ifdef __AVX512__
20
+
21
+ namespace faiss {
22
+
23
+ // Returns true if the fused kernel is available and the data was processed.
24
+ // Returns false if the fused kernel is not available.
25
+ bool exhaustive_L2sqr_fused_cmax_AVX512(
26
+ const float* x,
27
+ const float* y,
28
+ size_t d,
29
+ size_t nx,
30
+ size_t ny,
31
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
32
+ const float* y_norms);
33
+
34
+ } // namespace faiss
35
+
36
+ #endif
@@ -0,0 +1,42 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/utils/distances_fused/distances_fused.h>
9
+
10
+ #include <faiss/impl/platform_macros.h>
11
+
12
+ #include <faiss/utils/distances_fused/avx512.h>
13
+ #include <faiss/utils/distances_fused/simdlib_based.h>
14
+
15
+ namespace faiss {
16
+
17
+ bool exhaustive_L2sqr_fused_cmax(
18
+ const float* x,
19
+ const float* y,
20
+ size_t d,
21
+ size_t nx,
22
+ size_t ny,
23
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
24
+ const float* y_norms) {
25
+ if (nx == 0 || ny == 0) {
26
+ // nothing to do
27
+ return true;
28
+ }
29
+
30
+ #ifdef __AVX512__
31
+ // avx512 kernel
32
+ return exhaustive_L2sqr_fused_cmax_AVX512(x, y, d, nx, ny, res, y_norms);
33
+ #elif defined(__AVX2__) || defined(__aarch64__)
34
+ // avx2 or arm neon kernel
35
+ return exhaustive_L2sqr_fused_cmax_simdlib(x, y, d, nx, ny, res, y_norms);
36
+ #else
37
+ // not supported, please use a general-purpose kernel
38
+ return false;
39
+ #endif
40
+ }
41
+
42
+ } // namespace faiss
@@ -0,0 +1,40 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // This file contains a fused kernel that combines distance computation
9
+ // and the search for the CLOSEST point. Currently, this is done for small
10
+ // dimensionality vectors when it is beneficial to avoid storing temporary
11
+ // dot product information in RAM. This is particularly effective
12
+ // when training PQx10 or PQx12 with the default parameters.
13
+ //
14
+ // InterruptCallback::check() is not used, because it is assumed that the
15
+ // kernel takes a little time because of a tiny dimensionality.
16
+ //
17
+ // Later on, similar optimization can be implemented for large size vectors,
18
+ // but a different kernel is needed.
19
+ //
20
+
21
+ #pragma once
22
+
23
+ #include <faiss/impl/ResultHandler.h>
24
+
25
+ #include <faiss/utils/Heap.h>
26
+
27
+ namespace faiss {
28
+
29
+ // Returns true if the fused kernel is available and the data was processed.
30
+ // Returns false if the fused kernel is not available.
31
+ bool exhaustive_L2sqr_fused_cmax(
32
+ const float* x,
33
+ const float* y,
34
+ size_t d,
35
+ size_t nx,
36
+ size_t ny,
37
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
38
+ const float* y_norms);
39
+
40
+ } // namespace faiss