mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,329 @@
1
+ #pragma once
2
+
3
+ #include <arm_neon.h>
4
+ #include <simd/math.h>
5
+ #include <simd/vector.h>
6
+
7
+ #include <stdint.h>
8
+ #include <cmath>
9
+ #include <complex>
10
+
11
+ #include "mlx/backend/cpu/simd/base_simd.h"
12
+
13
+ // There seems to be a bug in simd/base_simd.h
14
+ // __XROS_2_0 is not defined, the expression evaluates
15
+ // to true instead of false setting the SIMD library
16
+ // higher than it should be even on macOS < 15
17
+ #if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \
18
+ __IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \
19
+ __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
20
+ __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \
21
+ __TV_OS_VERSION_MIN_REQUIRED >= 180000
22
+ #define MLX_SIMD_LIBRARY_VERSION 6
23
+ #else
24
+ #define MLX_SIMD_LIBRARY_VERSION 5
25
+ #endif
26
+
27
+ namespace mlx::core::simd {
28
+
29
+ // Apple simd namespace
30
+ namespace asd = ::simd;
31
+
32
+ // This indirection is needed to remap certain types to ones that accelerate
33
+ // SIMD can handle
34
+ template <typename T, int N>
35
+ struct ScalarT {
36
+ using v = T;
37
+ };
38
+ template <int N>
39
+ struct ScalarT<bool, N> {
40
+ using v = char;
41
+ };
42
+ template <int N>
43
+ struct ScalarT<int8_t, N> {
44
+ using v = char;
45
+ };
46
+ template <int N>
47
+ struct ScalarT<uint64_t, N> {
48
+ using v = unsigned long;
49
+ };
50
+ template <int N>
51
+ struct ScalarT<int64_t, N> {
52
+ using v = long;
53
+ };
54
+
55
+ template <typename T, int N>
56
+ struct Simd {
57
+ static constexpr int size = N;
58
+ using scalar_t = typename ScalarT<T, N>::v;
59
+
60
+ Simd<T, N>() {}
61
+
62
+ template <typename U>
63
+ Simd<T, N>(Simd<U, N> other) : value(asd::convert<scalar_t>(other.value)) {}
64
+
65
+ template <typename U>
66
+ Simd<T, N>(U v) : value(v){};
67
+
68
+ Simd<T, N>(Simd<T, N / 2> x, Simd<T, N / 2> y) {
69
+ value = asd::make<typename asd::Vector<scalar_t, N>::packed_t>(
70
+ x.value, y.value);
71
+ };
72
+
73
+ T operator[](int idx) const {
74
+ return reinterpret_cast<const T*>(&value)[idx];
75
+ }
76
+
77
+ T& operator[](int idx) {
78
+ return reinterpret_cast<T*>(&value)[idx];
79
+ }
80
+
81
+ typename asd::Vector<scalar_t, N>::packed_t value;
82
+ };
83
+
84
+ // Values chosen based on benchmarks on M3 Max
85
+ // TODO: consider choosing these more optimally
86
+ template <>
87
+ inline constexpr int max_size<int8_t> = 16;
88
+ template <>
89
+ inline constexpr int max_size<int16_t> = 16;
90
+ template <>
91
+ inline constexpr int max_size<int> = 8;
92
+ template <>
93
+ inline constexpr int max_size<int64_t> = 4;
94
+ template <>
95
+ inline constexpr int max_size<uint8_t> = 16;
96
+ template <>
97
+ inline constexpr int max_size<uint16_t> = 16;
98
+ template <>
99
+ inline constexpr int max_size<uint32_t> = 8;
100
+ template <>
101
+ inline constexpr int max_size<uint64_t> = 4;
102
+ template <>
103
+ inline constexpr int max_size<float> = 8;
104
+ template <>
105
+ inline constexpr int max_size<double> = 4;
106
+
107
+ #define SIMD_DEFAULT_UNARY(name, op) \
108
+ template <typename T, int N> \
109
+ Simd<T, N> name(Simd<T, N> v) { \
110
+ return op(v.value); \
111
+ }
112
+
113
+ SIMD_DEFAULT_UNARY(abs, asd::abs)
114
+ SIMD_DEFAULT_UNARY(floor, asd::floor)
115
+ SIMD_DEFAULT_UNARY(acos, asd::acos)
116
+ SIMD_DEFAULT_UNARY(acosh, asd::acosh)
117
+ SIMD_DEFAULT_UNARY(asin, asd::asin)
118
+ SIMD_DEFAULT_UNARY(asinh, asd::asinh)
119
+ SIMD_DEFAULT_UNARY(atan, asd::atan)
120
+ SIMD_DEFAULT_UNARY(atanh, asd::atanh)
121
+ SIMD_DEFAULT_UNARY(ceil, asd::ceil)
122
+ SIMD_DEFAULT_UNARY(cosh, asd::cosh)
123
+ SIMD_DEFAULT_UNARY(expm1, asd::expm1)
124
+ SIMD_DEFAULT_UNARY(log, asd::log)
125
+ SIMD_DEFAULT_UNARY(log2, asd::log2)
126
+ SIMD_DEFAULT_UNARY(log10, asd::log10)
127
+ SIMD_DEFAULT_UNARY(log1p, asd::log1p)
128
+ SIMD_DEFAULT_UNARY(rint, asd::rint)
129
+ SIMD_DEFAULT_UNARY(sinh, asd::sinh)
130
+ SIMD_DEFAULT_UNARY(sqrt, asd::sqrt)
131
+ SIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt)
132
+ SIMD_DEFAULT_UNARY(recip, asd::recip)
133
+ SIMD_DEFAULT_UNARY(tan, asd::tan)
134
+ SIMD_DEFAULT_UNARY(tanh, asd::tanh)
135
+
136
+ template <typename T, int N>
137
+ Simd<T, N> operator-(Simd<T, N> v) {
138
+ return -v.value;
139
+ }
140
+
141
+ template <typename T, int N>
142
+ Simd<T, N> operator~(Simd<T, N> v) {
143
+ return ~v.value;
144
+ }
145
+
146
+ template <typename T, int N>
147
+ Simd<bool, N> isnan(Simd<T, N> v) {
148
+ return asd::convert<char>(v.value != v.value);
149
+ }
150
+
151
+ // No simd_boolN in accelerate, use int8_t instead
152
+ template <typename T, int N>
153
+ Simd<bool, N> operator!(Simd<T, N> v) {
154
+ return asd::convert<char>(!v.value);
155
+ }
156
+
157
+ #define SIMD_DEFAULT_BINARY(OP) \
158
+ template <typename T, typename U, int N> \
159
+ Simd<T, N> operator OP(Simd<T, N> x, U y) { \
160
+ return asd::convert<typename Simd<T, N>::scalar_t>(x.value OP y); \
161
+ } \
162
+ template <typename T1, typename T2, int N> \
163
+ Simd<T2, N> operator OP(T1 x, Simd<T2, N> y) { \
164
+ return asd::convert<typename Simd<T2, N>::scalar_t>(x OP y.value); \
165
+ } \
166
+ template <typename T1, typename T2, int N> \
167
+ Simd<T1, N> operator OP(Simd<T1, N> x, Simd<T2, N> y) { \
168
+ return asd::convert<typename Simd<T1, N>::scalar_t>(x.value OP y.value); \
169
+ }
170
+
171
+ SIMD_DEFAULT_BINARY(+)
172
+ SIMD_DEFAULT_BINARY(-)
173
+ SIMD_DEFAULT_BINARY(/)
174
+ SIMD_DEFAULT_BINARY(*)
175
+ SIMD_DEFAULT_BINARY(<<)
176
+ SIMD_DEFAULT_BINARY(>>)
177
+ SIMD_DEFAULT_BINARY(|)
178
+ SIMD_DEFAULT_BINARY(^)
179
+ SIMD_DEFAULT_BINARY(&)
180
+ SIMD_DEFAULT_BINARY(&&)
181
+ SIMD_DEFAULT_BINARY(||)
182
+
183
+ #define SIMD_DEFAULT_COMPARISONS(OP) \
184
+ template <int N, typename T, typename U> \
185
+ Simd<bool, N> operator OP(Simd<T, N> a, U b) { \
186
+ return asd::convert<char>(a.value OP b); \
187
+ } \
188
+ template <int N, typename T, typename U> \
189
+ Simd<bool, N> operator OP(T a, Simd<U, N> b) { \
190
+ return asd::convert<char>(a OP b.value); \
191
+ } \
192
+ template <int N, typename T1, typename T2> \
193
+ Simd<bool, N> operator OP(Simd<T1, N> a, Simd<T2, N> b) { \
194
+ return asd::convert<char>(a.value OP b.value); \
195
+ }
196
+
197
+ SIMD_DEFAULT_COMPARISONS(>)
198
+ SIMD_DEFAULT_COMPARISONS(<)
199
+ SIMD_DEFAULT_COMPARISONS(>=)
200
+ SIMD_DEFAULT_COMPARISONS(<=)
201
+ SIMD_DEFAULT_COMPARISONS(==)
202
+ SIMD_DEFAULT_COMPARISONS(!=)
203
+
204
+ template <typename T, int N>
205
+ Simd<T, N> clz(Simd<T, N> x) {
206
+ auto a = *(uint32x4_t*)(&x);
207
+ auto b = *((uint32x4_t*)(&x) + 1);
208
+ a = vclzq_u32(a);
209
+ b = vclzq_u32(b);
210
+ return asd::make_uint8(a, b);
211
+ }
212
+
213
+ template <typename T, int N>
214
+ Simd<T, N> atan2(Simd<T, N> a, Simd<T, N> b) {
215
+ return asd::atan2(a.value, b.value);
216
+ }
217
+
218
+ template <typename T, int N>
219
+ Simd<T, N> maximum(Simd<T, N> a, Simd<T, N> b) {
220
+ auto out = Simd<T, N>(asd::max(a.value, b.value));
221
+ if constexpr (!std::is_integral_v<T>) {
222
+ out = select(isnan(b), b, select(isnan(a), a, out));
223
+ }
224
+ return out;
225
+ }
226
+
227
+ template <typename T, int N>
228
+ Simd<T, N> minimum(Simd<T, N> a, Simd<T, N> b) {
229
+ auto out = Simd<T, N>(asd::min(a.value, b.value));
230
+ if constexpr (!std::is_integral_v<T>) {
231
+ out = select(isnan(b), b, select(isnan(a), a, out));
232
+ }
233
+ return out;
234
+ }
235
+
236
+ template <typename T, int N>
237
+ Simd<T, N> remainder(Simd<T, N> a, Simd<T, N> b) {
238
+ Simd<T, N> r;
239
+ if constexpr (!std::is_integral_v<T>) {
240
+ r = asd::remainder(a.value, b.value);
241
+ } else {
242
+ r = a - b * (a / b);
243
+ }
244
+ if constexpr (std::is_signed_v<T>) {
245
+ auto mask = r != 0 && (r < 0 != b < 0);
246
+ r = select(mask, r + b, r);
247
+ }
248
+ return r;
249
+ }
250
+
251
+ template <typename MaskT, typename T1, typename T2, int N>
252
+ Simd<T1, N> select(Simd<MaskT, N> mask, Simd<T1, N> x, Simd<T2, N> y) {
253
+ static_assert(std::is_same_v<MaskT, bool>);
254
+ if constexpr (sizeof(T1) == 1) {
255
+ return asd::bitselect(y.value, x.value, asd::convert<char>(mask.value));
256
+ } else if constexpr (sizeof(T1) == 2) {
257
+ return asd::bitselect(y.value, x.value, asd::convert<short>(mask.value));
258
+ } else if constexpr (sizeof(T1) == 4) {
259
+ return asd::bitselect(y.value, x.value, asd::convert<int>(mask.value));
260
+ } else {
261
+ return asd::bitselect(y.value, x.value, asd::convert<long>(mask.value));
262
+ }
263
+ }
264
+
265
+ template <typename T, int N>
266
+ Simd<T, N> pow(Simd<T, N> base, Simd<T, N> exp) {
267
+ if constexpr (!std::is_integral_v<T>) {
268
+ return asd::pow(base.value, exp.value);
269
+ } else {
270
+ Simd<T, N> res = 1;
271
+ // Raising an integer to a negative power is undefined
272
+ if (any(exp < 0)) {
273
+ return 0;
274
+ }
275
+ while (any(exp > 0)) {
276
+ res = select((exp & 1) != 0, res * base, res);
277
+ base = select(exp > 0, base * base, base);
278
+ exp = exp >> 1;
279
+ }
280
+ return res;
281
+ }
282
+ }
283
+
284
+ template <typename T, int N>
285
+ Simd<T, N> clamp(Simd<T, N> v, Simd<T, N> min, Simd<T, N> max) {
286
+ return asd::clamp(v.value, min.value, max.value);
287
+ }
288
+
289
+ template <typename T, typename U, int N>
290
+ Simd<T, N> fma(Simd<T, N> x, Simd<T, N> y, U z) {
291
+ return asd::muladd(x.value, y.value, Simd<T, N>(z).value);
292
+ }
293
+
294
+ // Reductions
295
+
296
+ template <typename T, int N>
297
+ bool all(Simd<T, N> x) {
298
+ return asd::all(x.value);
299
+ }
300
+ template <typename T, int N>
301
+ bool any(Simd<T, N> x) {
302
+ return asd::any(x.value);
303
+ }
304
+ template <typename T, int N>
305
+ T sum(Simd<T, N> x) {
306
+ return asd::reduce_add(x.value);
307
+ }
308
+ template <typename T, int N>
309
+ T max(Simd<T, N> x) {
310
+ return asd::reduce_max(x.value);
311
+ }
312
+ template <typename T, int N>
313
+ T min(Simd<T, N> x) {
314
+ return asd::reduce_min(x.value);
315
+ }
316
+
317
+ template <typename T, int N>
318
+ T prod(Simd<T, N> x) {
319
+ auto ptr = (T*)&x;
320
+ auto lhs = load<T, N / 2>(ptr);
321
+ auto rhs = load<T, N / 2>(ptr + N / 2);
322
+ return prod(lhs * rhs);
323
+ }
324
+
325
+ } // namespace mlx::core::simd
326
+
327
+ #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
328
+ #include "mlx/backend/cpu/simd/accelerate_fp16_simd.h"
329
+ #endif
@@ -0,0 +1,295 @@
1
+ #pragma once
2
+
3
+ #include <stdint.h>
4
+ #include <algorithm>
5
+ #include <cmath>
6
+ #include <complex>
7
+ #include <functional>
8
+
9
+ namespace mlx::core::simd {
10
+ template <typename T, int N>
11
+ struct Simd;
12
+
13
+ template <typename T>
14
+ static constexpr int max_size = 1;
15
+
16
+ template <typename T>
17
+ struct Simd<T, 1> {
18
+ static constexpr int size = 1;
19
+ T value;
20
+ Simd() {}
21
+ template <typename U>
22
+ Simd(Simd<U, 1> v) : value(v.value) {}
23
+ template <typename U>
24
+ Simd(U v) : value(v) {}
25
+ };
26
+
27
+ template <typename T, int N>
28
+ Simd<T, N> load(const T* x) {
29
+ return *(Simd<T, N>*)x;
30
+ }
31
+
32
+ template <typename T, int N>
33
+ void store(T* dst, Simd<T, N> x) {
34
+ // Maintain invariant that bool is either 0 or 1 as
35
+ // simd comparison ops set all bits in the result to 1
36
+ if constexpr (std::is_same_v<T, bool> && N > 1) {
37
+ x = x & 1;
38
+ }
39
+ *(Simd<T, N>*)dst = x;
40
+ }
41
+
42
+ template <typename, typename = void>
43
+ constexpr bool is_complex = false;
44
+
45
+ template <typename T>
46
+ constexpr bool is_complex<T, std::void_t<decltype(std::declval<T>().real())>> =
47
+ true;
48
+
49
+ template <typename T>
50
+ Simd<T, 1> rint(Simd<T, 1> in) {
51
+ if constexpr (is_complex<T>) {
52
+ return Simd<T, 1>{
53
+ T{std::rint(in.value.real()), std::rint(in.value.imag())}};
54
+ } else {
55
+ return Simd<T, 1>{std::rint(in.value)};
56
+ }
57
+ }
58
+
59
+ template <typename T>
60
+ Simd<T, 1> rsqrt(Simd<T, 1> in) {
61
+ return T(1.0) / sqrt(in);
62
+ }
63
+
64
+ template <typename T>
65
+ Simd<T, 1> recip(Simd<T, 1> in) {
66
+ return T(1.0) / in;
67
+ }
68
+
69
+ #define DEFAULT_UNARY(name, op) \
70
+ template <typename T> \
71
+ Simd<T, 1> name(Simd<T, 1> in) { \
72
+ return op(in.value); \
73
+ }
74
+
75
+ DEFAULT_UNARY(operator-, std::negate{})
76
+ DEFAULT_UNARY(operator!, std::logical_not{})
77
+ DEFAULT_UNARY(abs, std::abs)
78
+ DEFAULT_UNARY(acos, std::acos)
79
+ DEFAULT_UNARY(acosh, std::acosh)
80
+ DEFAULT_UNARY(asin, std::asin)
81
+ DEFAULT_UNARY(asinh, std::asinh)
82
+ DEFAULT_UNARY(atan, std::atan)
83
+ DEFAULT_UNARY(atanh, std::atanh)
84
+ DEFAULT_UNARY(ceil, std::ceil)
85
+ DEFAULT_UNARY(conj, std::conj)
86
+ DEFAULT_UNARY(cosh, std::cosh)
87
+ DEFAULT_UNARY(expm1, std::expm1)
88
+ DEFAULT_UNARY(floor, std::floor)
89
+ DEFAULT_UNARY(log, std::log)
90
+ DEFAULT_UNARY(log10, std::log10)
91
+ DEFAULT_UNARY(sinh, std::sinh)
92
+ DEFAULT_UNARY(sqrt, std::sqrt)
93
+ DEFAULT_UNARY(tan, std::tan)
94
+ DEFAULT_UNARY(tanh, std::tanh)
95
+
96
+ template <typename T>
97
+ Simd<T, 1> log1p(Simd<T, 1> in) {
98
+ if constexpr (is_complex<T>) {
99
+ auto x = in.value.real();
100
+ auto y = in.value.imag();
101
+ auto zabs = std::abs(in.value);
102
+ auto theta = std::atan2(y, x + 1);
103
+ if (zabs < 0.5) {
104
+ auto r = x * (2 + x) + y * y;
105
+ if (r == 0) { // handle underflow
106
+ return Simd<T, 1>{T{x, theta}};
107
+ }
108
+ return Simd<T, 1>{T{((typeof(x))(0.5)) * std::log1p(r), theta}};
109
+ } else {
110
+ auto z0 = std::hypot(x + 1, y);
111
+ return Simd<T, 1>{T{std::log(z0), theta}};
112
+ }
113
+ } else {
114
+ return Simd<T, 1>{std::log1p(in.value)};
115
+ }
116
+ }
117
+
118
+ template <typename T>
119
+ Simd<T, 1> log2(Simd<T, 1> in) {
120
+ if constexpr (is_complex<T>) {
121
+ auto out = std::log(in.value);
122
+ auto scale = decltype(out.real())(M_LN2);
123
+ return Simd<T, 1>{T{out.real() / scale, out.imag() / scale}};
124
+ } else {
125
+ return Simd<T, 1>{std::log2(in.value)};
126
+ }
127
+ }
128
+
129
+ template <typename T>
130
+ Simd<T, 1> operator~(Simd<T, 1> in) {
131
+ return ~in.value;
132
+ }
133
+
134
+ template <typename T>
135
+ auto real(Simd<T, 1> in) -> Simd<decltype(std::real(in.value)), 1> {
136
+ return std::real(in.value);
137
+ }
138
+ template <typename T>
139
+ auto imag(Simd<T, 1> in) -> Simd<decltype(std::imag(in.value)), 1> {
140
+ return std::imag(in.value);
141
+ }
142
+ template <typename T>
143
+ Simd<bool, 1> isnan(Simd<T, 1> in) {
144
+ return std::isnan(in.value);
145
+ }
146
+
147
+ #define DEFAULT_BINARY(OP) \
148
+ template <typename T1, typename T2> \
149
+ auto operator OP(Simd<T1, 1> a, Simd<T2, 1> b) \
150
+ ->Simd<decltype(a.value OP b.value), 1> { \
151
+ return a.value OP b.value; \
152
+ } \
153
+ template <typename T1, typename T2> \
154
+ auto operator OP(T1 a, Simd<T2, 1> b)->Simd<decltype(a OP b.value), 1> { \
155
+ return a OP b.value; \
156
+ } \
157
+ template <typename T1, typename T2> \
158
+ auto operator OP(Simd<T1, 1> a, T2 b)->Simd<decltype(a.value OP b), 1> { \
159
+ return a.value OP b; \
160
+ }
161
+
162
+ DEFAULT_BINARY(+)
163
+ DEFAULT_BINARY(-)
164
+ DEFAULT_BINARY(*)
165
+ DEFAULT_BINARY(/)
166
+ DEFAULT_BINARY(<<)
167
+ DEFAULT_BINARY(>>)
168
+ DEFAULT_BINARY(|)
169
+ DEFAULT_BINARY(^)
170
+ DEFAULT_BINARY(&)
171
+ DEFAULT_BINARY(&&)
172
+ DEFAULT_BINARY(||)
173
+
174
+ template <typename T>
175
+ Simd<T, 1> clz(Simd<T, 1> x_) {
176
+ return __builtin_clz(x_.value);
177
+ }
178
+
179
+ template <typename T>
180
+ Simd<T, 1> remainder(Simd<T, 1> a_, Simd<T, 1> b_) {
181
+ T a = a_.value;
182
+ T b = b_.value;
183
+ T r;
184
+ if constexpr (std::is_integral_v<T>) {
185
+ r = a % b;
186
+ } else {
187
+ r = std::remainder(a, b);
188
+ }
189
+ if constexpr (std::is_signed_v<T>) {
190
+ if (r != 0 && (r < 0 != b < 0)) {
191
+ r += b;
192
+ }
193
+ }
194
+ return r;
195
+ }
196
+
197
+ template <typename T>
198
+ Simd<T, 1> maximum(Simd<T, 1> a_, Simd<T, 1> b_) {
199
+ T a = a_.value;
200
+ T b = b_.value;
201
+ if constexpr (!std::is_integral_v<T>) {
202
+ if (std::isnan(a)) {
203
+ return a;
204
+ }
205
+ }
206
+ return (a > b) ? a : b;
207
+ }
208
+
209
+ template <typename T>
210
+ Simd<T, 1> minimum(Simd<T, 1> a_, Simd<T, 1> b_) {
211
+ T a = a_.value;
212
+ T b = b_.value;
213
+ if constexpr (!std::is_integral_v<T>) {
214
+ if (std::isnan(a)) {
215
+ return a;
216
+ }
217
+ }
218
+ return (a < b) ? a : b;
219
+ }
220
+
221
+ template <typename T>
222
+ Simd<T, 1> pow(Simd<T, 1> a, Simd<T, 1> b) {
223
+ T base = a.value;
224
+ T exp = b.value;
225
+ if constexpr (!std::is_integral_v<T>) {
226
+ return std::pow(base, exp);
227
+ } else {
228
+ T res = 1;
229
+ while (exp) {
230
+ if (exp & 1) {
231
+ res *= base;
232
+ }
233
+ exp >>= 1;
234
+ base *= base;
235
+ }
236
+ return res;
237
+ }
238
+ }
239
+
240
+ template <typename T>
241
+ Simd<T, 1> atan2(Simd<T, 1> a, Simd<T, 1> b) {
242
+ return std::atan2(a.value, b.value);
243
+ }
244
+
245
+ #define DEFAULT_COMPARISONS(OP) \
246
+ template <typename T1, typename T2> \
247
+ Simd<bool, 1> operator OP(Simd<T1, 1> a, Simd<T2, 1> b) { \
248
+ return a.value OP b.value; \
249
+ } \
250
+ template <typename T1, typename T2> \
251
+ Simd<bool, 1> operator OP(T1 a, Simd<T2, 1> b) { \
252
+ return a OP b.value; \
253
+ } \
254
+ template <typename T1, typename T2> \
255
+ Simd<bool, 1> operator OP(Simd<T1, 1> a, T2 b) { \
256
+ return a.value OP b; \
257
+ }
258
+
259
+ DEFAULT_COMPARISONS(>)
260
+ DEFAULT_COMPARISONS(<)
261
+ DEFAULT_COMPARISONS(>=)
262
+ DEFAULT_COMPARISONS(<=)
263
+ DEFAULT_COMPARISONS(==)
264
+ DEFAULT_COMPARISONS(!=)
265
+
266
+ template <typename MaskT, typename T>
267
+ Simd<T, 1> select(Simd<MaskT, 1> mask, Simd<T, 1> x, Simd<T, 1> y) {
268
+ return mask.value ? x.value : y.value;
269
+ }
270
+
271
+ template <typename T>
272
+ Simd<T, 1> clamp(Simd<T, 1> v, Simd<T, 1> min, Simd<T, 1> max) {
273
+ return std::clamp(v.value, min.value, max.value);
274
+ }
275
+
276
+ template <typename T, typename U>
277
+ Simd<T, 1> fma(Simd<T, 1> x, Simd<T, 1> y, U z) {
278
+ return std::fma(x.value, y.value, Simd<T, 1>(z).value);
279
+ }
280
+
281
+ // Reductions
282
+ #define DEFAULT_REDUCTION(name, type) \
283
+ template <typename T> \
284
+ type name(Simd<T, 1> x) { \
285
+ return x.value; \
286
+ }
287
+
288
+ DEFAULT_REDUCTION(max, T)
289
+ DEFAULT_REDUCTION(min, T)
290
+ DEFAULT_REDUCTION(sum, T)
291
+ DEFAULT_REDUCTION(prod, T)
292
+ DEFAULT_REDUCTION(any, bool)
293
+ DEFAULT_REDUCTION(all, bool)
294
+
295
+ } // namespace mlx::core::simd