faiss 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -0,0 +1,346 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/utils/distances_fused/avx512.h>
11
+
12
+ #ifdef __AVX512__
13
+
14
+ #include <immintrin.h>
15
+
16
+ namespace faiss {
17
+
18
+ namespace {
19
+
20
+ // It makes sense to like to overload certain cases because the further
21
+ // kernels are in need of AVX512 registers. So, let's tell compiler
22
+ // not to waste registers on a bit faster code, if needed.
23
+ template <size_t DIM>
24
+ float l2_sqr(const float* const x) {
25
+ // compiler should be smart enough to handle that
26
+ float output = x[0] * x[0];
27
+ for (size_t i = 1; i < DIM; i++) {
28
+ output += x[i] * x[i];
29
+ }
30
+
31
+ return output;
32
+ }
33
+
34
+ template <>
35
+ float l2_sqr<4>(const float* const x) {
36
+ __m128 v = _mm_loadu_ps(x);
37
+ __m128 v2 = _mm_mul_ps(v, v);
38
+ v2 = _mm_hadd_ps(v2, v2);
39
+ v2 = _mm_hadd_ps(v2, v2);
40
+
41
+ return _mm_cvtss_f32(v2);
42
+ }
43
+
44
+ template <size_t DIM>
45
+ float dot_product(
46
+ const float* const __restrict x,
47
+ const float* const __restrict y) {
48
+ // compiler should be smart enough to handle that
49
+ float output = x[0] * y[0];
50
+ for (size_t i = 1; i < DIM; i++) {
51
+ output += x[i] * y[i];
52
+ }
53
+
54
+ return output;
55
+ }
56
+
57
+ // The kernel for low dimensionality vectors.
58
+ // Finds the closest one from y for every given NX_POINTS_PER_LOOP points from x
59
+ //
60
+ // DIM is the dimensionality of the data
61
+ // NX_POINTS_PER_LOOP is the number of x points that get processed
62
+ // simultaneously.
63
+ // NY_POINTS_PER_LOOP is the number of y points that get processed
64
+ // simultaneously.
65
+ template <size_t DIM, size_t NX_POINTS_PER_LOOP, size_t NY_POINTS_PER_LOOP>
66
+ void kernel(
67
+ const float* const __restrict x,
68
+ const float* const __restrict y,
69
+ const float* const __restrict y_transposed,
70
+ size_t ny,
71
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
72
+ const float* __restrict y_norms,
73
+ size_t i) {
74
+ const size_t ny_p =
75
+ (ny / (16 * NY_POINTS_PER_LOOP)) * (16 * NY_POINTS_PER_LOOP);
76
+
77
+ // compute
78
+ const float* const __restrict xd_0 = x + i * DIM;
79
+
80
+ // prefetch the next point
81
+ _mm_prefetch(xd_0 + DIM * sizeof(float), _MM_HINT_NTA);
82
+
83
+ // load a single point from x
84
+ // load -2 * value
85
+ __m512 x_i[NX_POINTS_PER_LOOP][DIM];
86
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
87
+ for (size_t dd = 0; dd < DIM; dd++) {
88
+ x_i[nx_k][dd] = _mm512_set1_ps(-2 * *(xd_0 + nx_k * DIM + dd));
89
+ }
90
+ }
91
+
92
+ // compute x_norm
93
+ float x_norm_i[NX_POINTS_PER_LOOP];
94
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
95
+ x_norm_i[nx_k] = l2_sqr<DIM>(xd_0 + nx_k * DIM);
96
+ }
97
+
98
+ // distances and indices
99
+ __m512 min_distances_i[NX_POINTS_PER_LOOP];
100
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
101
+ min_distances_i[nx_k] =
102
+ _mm512_set1_ps(res.dis_tab[i + nx_k] - x_norm_i[nx_k]);
103
+ }
104
+
105
+ __m512i min_indices_i[NX_POINTS_PER_LOOP];
106
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
107
+ min_indices_i[nx_k] = _mm512_set1_epi32(0);
108
+ }
109
+
110
+ //
111
+ __m512i current_indices = _mm512_setr_epi32(
112
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15);
113
+ const __m512i indices_delta = _mm512_set1_epi32(16);
114
+
115
+ // main loop
116
+ size_t j = 0;
117
+ for (; j < ny_p; j += NY_POINTS_PER_LOOP * 16) {
118
+ // compute dot products for NX_POINTS from x and NY_POINTS from y
119
+ // technically, we're multiplying -2x and y
120
+ __m512 dp_i[NX_POINTS_PER_LOOP][NY_POINTS_PER_LOOP];
121
+
122
+ // DIM 0 that uses MUL
123
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
124
+ __m512 y_i = _mm512_loadu_ps(y_transposed + j + ny_k * 16 + ny * 0);
125
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
126
+ dp_i[nx_k][ny_k] = _mm512_mul_ps(x_i[nx_k][0], y_i);
127
+ }
128
+ }
129
+
130
+ // other DIMs that use FMA
131
+ for (size_t dd = 1; dd < DIM; dd++) {
132
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
133
+ __m512 y_i =
134
+ _mm512_loadu_ps(y_transposed + j + ny_k * 16 + ny * dd);
135
+
136
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
137
+ dp_i[nx_k][ny_k] = _mm512_fmadd_ps(
138
+ x_i[nx_k][dd], y_i, dp_i[nx_k][ny_k]);
139
+ }
140
+ }
141
+ }
142
+
143
+ // compute y^2 - 2 * (x,y)
144
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
145
+ __m512 y_l2_sqr = _mm512_loadu_ps(y_norms + j + ny_k * 16);
146
+
147
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
148
+ dp_i[nx_k][ny_k] = _mm512_add_ps(dp_i[nx_k][ny_k], y_l2_sqr);
149
+ }
150
+ }
151
+
152
+ // do the comparisons and alter the min indices
153
+ for (size_t ny_k = 0; ny_k < NY_POINTS_PER_LOOP; ny_k++) {
154
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
155
+ const __mmask16 comparison = _mm512_cmp_ps_mask(
156
+ dp_i[nx_k][ny_k], min_distances_i[nx_k], _CMP_LT_OS);
157
+ min_distances_i[nx_k] = _mm512_mask_blend_ps(
158
+ comparison, min_distances_i[nx_k], dp_i[nx_k][ny_k]);
159
+ min_indices_i[nx_k] = _mm512_castps_si512(_mm512_mask_blend_ps(
160
+ comparison,
161
+ _mm512_castsi512_ps(min_indices_i[nx_k]),
162
+ _mm512_castsi512_ps(current_indices)));
163
+ }
164
+
165
+ current_indices = _mm512_add_epi32(current_indices, indices_delta);
166
+ }
167
+ }
168
+
169
+ // dump values and find the minimum distance / minimum index
170
+ for (size_t nx_k = 0; nx_k < NX_POINTS_PER_LOOP; nx_k++) {
171
+ float min_distances_scalar[16];
172
+ uint32_t min_indices_scalar[16];
173
+ _mm512_storeu_ps(min_distances_scalar, min_distances_i[nx_k]);
174
+ _mm512_storeu_si512(
175
+ (__m512i*)(min_indices_scalar), min_indices_i[nx_k]);
176
+
177
+ float current_min_distance = res.dis_tab[i + nx_k];
178
+ uint32_t current_min_index = res.ids_tab[i + nx_k];
179
+
180
+ // This unusual comparison is needed to maintain the behavior
181
+ // of the original implementation: if two indices are
182
+ // represented with equal distance values, then
183
+ // the index with the min value is returned.
184
+ for (size_t jv = 0; jv < 16; jv++) {
185
+ // add missing x_norms[i]
186
+ float distance_candidate =
187
+ min_distances_scalar[jv] + x_norm_i[nx_k];
188
+
189
+ // negative values can occur for identical vectors
190
+ // due to roundoff errors.
191
+ if (distance_candidate < 0)
192
+ distance_candidate = 0;
193
+
194
+ const int64_t index_candidate = min_indices_scalar[jv];
195
+
196
+ if (current_min_distance > distance_candidate) {
197
+ current_min_distance = distance_candidate;
198
+ current_min_index = index_candidate;
199
+ } else if (
200
+ current_min_distance == distance_candidate &&
201
+ current_min_index > index_candidate) {
202
+ current_min_index = index_candidate;
203
+ }
204
+ }
205
+
206
+ // process leftovers
207
+ for (size_t j0 = j; j0 < ny; j0++) {
208
+ const float dp =
209
+ dot_product<DIM>(x + (i + nx_k) * DIM, y + j0 * DIM);
210
+ float dis = x_norm_i[nx_k] + y_norms[j0] - 2 * dp;
211
+ // negative values can occur for identical vectors
212
+ // due to roundoff errors.
213
+ if (dis < 0) {
214
+ dis = 0;
215
+ }
216
+
217
+ if (current_min_distance > dis) {
218
+ current_min_distance = dis;
219
+ current_min_index = j0;
220
+ }
221
+ }
222
+
223
+ // done
224
+ res.add_result(i + nx_k, current_min_distance, current_min_index);
225
+ }
226
+ }
227
+
228
+ template <size_t DIM, size_t NX_POINTS_PER_LOOP, size_t NY_POINTS_PER_LOOP>
229
+ void exhaustive_L2sqr_fused_cmax(
230
+ const float* const __restrict x,
231
+ const float* const __restrict y,
232
+ size_t nx,
233
+ size_t ny,
234
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
235
+ const float* __restrict y_norms) {
236
+ // BLAS does not like empty matrices
237
+ if (nx == 0 || ny == 0) {
238
+ return;
239
+ }
240
+
241
+ // compute norms for y
242
+ std::unique_ptr<float[]> del2;
243
+ if (!y_norms) {
244
+ float* y_norms2 = new float[ny];
245
+ del2.reset(y_norms2);
246
+
247
+ for (size_t i = 0; i < ny; i++) {
248
+ y_norms2[i] = l2_sqr<DIM>(y + i * DIM);
249
+ }
250
+
251
+ y_norms = y_norms2;
252
+ }
253
+
254
+ // initialize res
255
+ res.begin_multiple(0, nx);
256
+
257
+ // transpose y
258
+ std::vector<float> y_transposed(DIM * ny);
259
+ for (size_t j = 0; j < DIM; j++) {
260
+ for (size_t i = 0; i < ny; i++) {
261
+ y_transposed[j * ny + i] = y[j + i * DIM];
262
+ }
263
+ }
264
+
265
+ const size_t nx_p = (nx / NX_POINTS_PER_LOOP) * NX_POINTS_PER_LOOP;
266
+ // the main loop.
267
+ #pragma omp parallel for schedule(dynamic)
268
+ for (size_t i = 0; i < nx_p; i += NX_POINTS_PER_LOOP) {
269
+ kernel<DIM, NX_POINTS_PER_LOOP, NY_POINTS_PER_LOOP>(
270
+ x, y, y_transposed.data(), ny, res, y_norms, i);
271
+ }
272
+
273
+ for (size_t i = nx_p; i < nx; i++) {
274
+ kernel<DIM, 1, NY_POINTS_PER_LOOP>(
275
+ x, y, y_transposed.data(), ny, res, y_norms, i);
276
+ }
277
+
278
+ // Does nothing for SingleBestResultHandler, but
279
+ // keeping the call for the consistency.
280
+ res.end_multiple();
281
+ InterruptCallback::check();
282
+ }
283
+
284
+ } // namespace
285
+
286
+ bool exhaustive_L2sqr_fused_cmax_AVX512(
287
+ const float* x,
288
+ const float* y,
289
+ size_t d,
290
+ size_t nx,
291
+ size_t ny,
292
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
293
+ const float* y_norms) {
294
+ // process only cases with certain dimensionalities
295
+
296
+ #define DISPATCH(DIM, NX_POINTS_PER_LOOP, NY_POINTS_PER_LOOP) \
297
+ case DIM: { \
298
+ exhaustive_L2sqr_fused_cmax< \
299
+ DIM, \
300
+ NX_POINTS_PER_LOOP, \
301
+ NY_POINTS_PER_LOOP>(x, y, nx, ny, res, y_norms); \
302
+ return true; \
303
+ }
304
+
305
+ switch (d) {
306
+ DISPATCH(1, 8, 1)
307
+ DISPATCH(2, 8, 1)
308
+ DISPATCH(3, 8, 1)
309
+ DISPATCH(4, 8, 1)
310
+ DISPATCH(5, 8, 1)
311
+ DISPATCH(6, 8, 1)
312
+ DISPATCH(7, 8, 1)
313
+ DISPATCH(8, 8, 1)
314
+ DISPATCH(9, 8, 1)
315
+ DISPATCH(10, 8, 1)
316
+ DISPATCH(11, 8, 1)
317
+ DISPATCH(12, 8, 1)
318
+ DISPATCH(13, 8, 1)
319
+ DISPATCH(14, 8, 1)
320
+ DISPATCH(15, 8, 1)
321
+ DISPATCH(16, 8, 1)
322
+ DISPATCH(17, 8, 1)
323
+ DISPATCH(18, 8, 1)
324
+ DISPATCH(19, 8, 1)
325
+ DISPATCH(20, 8, 1)
326
+ DISPATCH(21, 8, 1)
327
+ DISPATCH(22, 8, 1)
328
+ DISPATCH(23, 8, 1)
329
+ DISPATCH(24, 8, 1)
330
+ DISPATCH(25, 8, 1)
331
+ DISPATCH(26, 8, 1)
332
+ DISPATCH(27, 8, 1)
333
+ DISPATCH(28, 8, 1)
334
+ DISPATCH(29, 8, 1)
335
+ DISPATCH(30, 8, 1)
336
+ DISPATCH(31, 8, 1)
337
+ DISPATCH(32, 8, 1)
338
+ }
339
+
340
+ return false;
341
+ #undef DISPATCH
342
+ }
343
+
344
+ } // namespace faiss
345
+
346
+ #endif
@@ -0,0 +1,36 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // AVX512 might be not used, but this version provides ~2x speedup
9
+ // over AVX2 kernel, say, for training PQx10 or PQx12, and speeds up
10
+ // additional cases with larger dimensionalities.
11
+
12
+ #pragma once
13
+
14
+ #include <faiss/impl/ResultHandler.h>
15
+ #include <faiss/impl/platform_macros.h>
16
+
17
+ #include <faiss/utils/Heap.h>
18
+
19
+ #ifdef __AVX512__
20
+
21
+ namespace faiss {
22
+
23
+ // Returns true if the fused kernel is available and the data was processed.
24
+ // Returns false if the fused kernel is not available.
25
+ bool exhaustive_L2sqr_fused_cmax_AVX512(
26
+ const float* x,
27
+ const float* y,
28
+ size_t d,
29
+ size_t nx,
30
+ size_t ny,
31
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
32
+ const float* y_norms);
33
+
34
+ } // namespace faiss
35
+
36
+ #endif
@@ -0,0 +1,42 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/utils/distances_fused/distances_fused.h>
9
+
10
+ #include <faiss/impl/platform_macros.h>
11
+
12
+ #include <faiss/utils/distances_fused/avx512.h>
13
+ #include <faiss/utils/distances_fused/simdlib_based.h>
14
+
15
+ namespace faiss {
16
+
17
+ bool exhaustive_L2sqr_fused_cmax(
18
+ const float* x,
19
+ const float* y,
20
+ size_t d,
21
+ size_t nx,
22
+ size_t ny,
23
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
24
+ const float* y_norms) {
25
+ if (nx == 0 || ny == 0) {
26
+ // nothing to do
27
+ return true;
28
+ }
29
+
30
+ #ifdef __AVX512__
31
+ // avx512 kernel
32
+ return exhaustive_L2sqr_fused_cmax_AVX512(x, y, d, nx, ny, res, y_norms);
33
+ #elif defined(__AVX2__) || defined(__aarch64__)
34
+ // avx2 or arm neon kernel
35
+ return exhaustive_L2sqr_fused_cmax_simdlib(x, y, d, nx, ny, res, y_norms);
36
+ #else
37
+ // not supported, please use a general-purpose kernel
38
+ return false;
39
+ #endif
40
+ }
41
+
42
+ } // namespace faiss
@@ -0,0 +1,40 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // This file contains a fused kernel that combines distance computation
9
+ // and the search for the CLOSEST point. Currently, this is done for small
10
+ // dimensionality vectors when it is beneficial to avoid storing temporary
11
+ // dot product information in RAM. This is particularly effective
12
+ // when training PQx10 or PQx12 with the default parameters.
13
+ //
14
+ // InterruptCallback::check() is not used, because it is assumed that the
15
+ // kernel takes a little time because of a tiny dimensionality.
16
+ //
17
+ // Later on, similar optimization can be implemented for large size vectors,
18
+ // but a different kernel is needed.
19
+ //
20
+
21
+ #pragma once
22
+
23
+ #include <faiss/impl/ResultHandler.h>
24
+
25
+ #include <faiss/utils/Heap.h>
26
+
27
+ namespace faiss {
28
+
29
+ // Returns true if the fused kernel is available and the data was processed.
30
+ // Returns false if the fused kernel is not available.
31
+ bool exhaustive_L2sqr_fused_cmax(
32
+ const float* x,
33
+ const float* y,
34
+ size_t d,
35
+ size_t nx,
36
+ size_t ny,
37
+ SingleBestResultHandler<CMax<float, int64_t>>& res,
38
+ const float* y_norms);
39
+
40
+ } // namespace faiss