mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,166 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/common/utils.h"
6
+ #include "mlx/backend/cpu/binary.h"
7
+
8
+ namespace mlx::core {
9
+
10
+ namespace {
11
+
12
+ template <typename T, typename U, typename Op, int D>
13
+ void binary_op_dims(
14
+ const T* a,
15
+ const T* b,
16
+ U* out_a,
17
+ U* out_b,
18
+ Op op,
19
+ const Shape& shape,
20
+ const Strides& a_strides,
21
+ const Strides& b_strides,
22
+ const Strides& out_strides,
23
+ int axis) {
24
+ auto stride_a = a_strides[axis];
25
+ auto stride_b = b_strides[axis];
26
+ auto stride_out = out_strides[axis];
27
+ auto N = shape[axis];
28
+
29
+ for (int i = 0; i < N; i++) {
30
+ if constexpr (D > 1) {
31
+ binary_op_dims<T, U, Op, D - 1>(
32
+ a,
33
+ b,
34
+ out_a,
35
+ out_b,
36
+ op,
37
+ shape,
38
+ a_strides,
39
+ b_strides,
40
+ out_strides,
41
+ axis + 1);
42
+ } else {
43
+ std::tie(*out_a, *out_b) = op(*a, *b);
44
+ }
45
+ a += stride_a;
46
+ b += stride_b;
47
+ out_a += stride_out;
48
+ out_b += stride_out;
49
+ }
50
+ }
51
+
52
+ template <typename T, typename U, typename Op>
53
+ void binary_op_dispatch_dims(
54
+ const array& a,
55
+ const array& b,
56
+ array& out_a,
57
+ array& out_b,
58
+ Op op) {
59
+ auto [shape, strides] = collapse_contiguous_dims(
60
+ a.shape(), {a.strides(), b.strides(), out_a.strides()});
61
+ const T* a_ptr = a.data<T>();
62
+ const T* b_ptr = b.data<T>();
63
+ U* out_a_ptr = out_a.data<U>();
64
+ U* out_b_ptr = out_b.data<U>();
65
+
66
+ const auto& a_strides = strides[0];
67
+ const auto& b_strides = strides[1];
68
+ const auto& out_strides = strides[2];
69
+ int ndim = shape.size();
70
+ switch (ndim) {
71
+ case 1:
72
+ binary_op_dims<T, U, Op, 1>(
73
+ a_ptr,
74
+ b_ptr,
75
+ out_a_ptr,
76
+ out_b_ptr,
77
+ op,
78
+ shape,
79
+ a_strides,
80
+ b_strides,
81
+ out_strides,
82
+ 0);
83
+ return;
84
+ case 2:
85
+ binary_op_dims<T, U, Op, 2>(
86
+ a_ptr,
87
+ b_ptr,
88
+ out_a_ptr,
89
+ out_b_ptr,
90
+ op,
91
+ shape,
92
+ a_strides,
93
+ b_strides,
94
+ out_strides,
95
+ 0);
96
+ return;
97
+ }
98
+
99
+ ContiguousIterator a_it(shape, a_strides, ndim - 2);
100
+ ContiguousIterator b_it(shape, b_strides, ndim - 2);
101
+ auto stride = out_strides[ndim - 3];
102
+ for (size_t elem = 0; elem < a.size(); elem += stride) {
103
+ binary_op_dims<T, U, Op, 2>(
104
+ a_ptr + a_it.loc,
105
+ b_ptr + b_it.loc,
106
+ out_a_ptr + elem,
107
+ out_b_ptr + elem,
108
+ op,
109
+ shape,
110
+ a_strides,
111
+ b_strides,
112
+ out_strides,
113
+ ndim - 2);
114
+ a_it.step();
115
+ b_it.step();
116
+ }
117
+ }
118
+
119
+ template <typename T, typename U = T, typename Op>
120
+ void binary_op(
121
+ const array& a,
122
+ const array& b,
123
+ array& out_a,
124
+ array& out_b,
125
+ Op op,
126
+ BinaryOpType bopt) {
127
+ // The full computation is scalar scalar so call the base op once
128
+ if (bopt == BinaryOpType::General) {
129
+ binary_op_dispatch_dims<T, U, Op>(a, b, out_a, out_b, op);
130
+ return;
131
+ }
132
+
133
+ auto a_ptr = a.data<T>();
134
+ auto b_ptr = b.data<T>();
135
+ auto out_a_ptr = out_a.data<U>();
136
+ auto out_b_ptr = out_b.data<U>();
137
+ if (bopt == BinaryOpType::ScalarScalar) {
138
+ std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
139
+ } else if (bopt == BinaryOpType::ScalarVector) {
140
+ for (size_t i = 0; i < b.data_size(); ++i) {
141
+ std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
142
+ out_a_ptr++;
143
+ out_b_ptr++;
144
+ b_ptr++;
145
+ }
146
+ } else if (bopt == BinaryOpType::VectorScalar) {
147
+ for (size_t i = 0; i < a.data_size(); ++i) {
148
+ std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
149
+ out_a_ptr++;
150
+ out_b_ptr++;
151
+ a_ptr++;
152
+ }
153
+ } else { // VectorVector
154
+ for (size_t i = 0; i < a.size(); ++i) {
155
+ std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr);
156
+ out_a_ptr++;
157
+ out_b_ptr++;
158
+ a_ptr++;
159
+ b_ptr++;
160
+ }
161
+ }
162
+ }
163
+
164
+ } // namespace
165
+
166
+ } // namespace mlx::core
@@ -0,0 +1,12 @@
1
+ // Copyright © 2023-24 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ // clang-format off
6
+ #include "mlx/types/half_types.h"
7
+ #include "mlx/types/complex.h"
8
+ #include "mlx/backend/cpu/unary_ops.h"
9
+ #include "mlx/backend/cpu/binary_ops.h"
10
+ // clang-format on
11
+
12
+ const char* get_kernel_preamble();
@@ -0,0 +1,36 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <optional>
6
+
7
+ #include "mlx/array.h"
8
+ #include "mlx/backend/common/copy.h"
9
+ #include "mlx/backend/common/utils.h"
10
+
11
+ namespace mlx::core {
12
+
13
+ void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream);
14
+ void copy_cpu_inplace(
15
+ const array& src,
16
+ array& dst,
17
+ CopyType ctype,
18
+ Stream stream);
19
+
20
+ void copy_cpu_inplace(
21
+ const array& src,
22
+ array& dst,
23
+ const Shape& data_shape,
24
+ const Strides& i_strides,
25
+ const Strides& o_strides,
26
+ int64_t i_offset,
27
+ int64_t o_offset,
28
+ CopyType ctype,
29
+ Stream stream,
30
+ const std::optional<array>& dynamic_i_offset = std::nullopt,
31
+ const std::optional<array>& dynamic_o_offset = std::nullopt);
32
+
33
+ // Return a contiguous array with same shape that copies the data of |arr|.
34
+ array contiguous_copy_cpu(const array& arr, Stream stream);
35
+
36
+ } // namespace mlx::core
@@ -0,0 +1,67 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <unordered_map>
6
+
7
+ #include "mlx/array.h"
8
+ #include "mlx/scheduler.h"
9
+
10
+ namespace mlx::core::cpu {
11
+
12
+ // Number of dispatches per scheduler task
13
+ constexpr int DISPATCHES_PER_TASK = 10;
14
+
15
+ struct CommandEncoder {
16
+ CommandEncoder(Stream stream) : stream_(stream) {}
17
+
18
+ CommandEncoder(const CommandEncoder&) = delete;
19
+ CommandEncoder& operator=(const CommandEncoder&) = delete;
20
+ CommandEncoder(CommandEncoder&&) = delete;
21
+ CommandEncoder& operator=(CommandEncoder&&) = delete;
22
+
23
+ void set_input_array(const array& a) {}
24
+ void set_output_array(array& a) {}
25
+
26
+ // Hold onto a temporary until any already scheduled tasks which use it as
27
+ // an input are complete.
28
+ void add_temporary(array arr) {
29
+ temporaries_.push_back(std::move(arr));
30
+ }
31
+
32
+ void add_temporaries(std::vector<array> arrays) {
33
+ temporaries_.insert(
34
+ temporaries_.end(),
35
+ std::make_move_iterator(arrays.begin()),
36
+ std::make_move_iterator(arrays.end()));
37
+ }
38
+
39
+ std::vector<array>& temporaries() {
40
+ return temporaries_;
41
+ }
42
+
43
+ template <class F, class... Args>
44
+ void dispatch(F&& f, Args&&... args) {
45
+ num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK;
46
+ auto task = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
47
+ if (num_ops_ == 0) {
48
+ scheduler::notify_new_task(stream_);
49
+ auto task_wrap = [s = stream_, task = std::move(task)]() mutable {
50
+ task();
51
+ scheduler::notify_task_completion(s);
52
+ };
53
+ scheduler::enqueue(stream_, std::move(task_wrap));
54
+ } else {
55
+ scheduler::enqueue(stream_, std::move(task));
56
+ }
57
+ }
58
+
59
+ private:
60
+ Stream stream_;
61
+ std::vector<array> temporaries_;
62
+ int num_ops_{0};
63
+ };
64
+
65
+ CommandEncoder& get_command_encoder(Stream stream);
66
+
67
+ } // namespace mlx::core::cpu
@@ -0,0 +1,12 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+ #include "mlx/stream.h"
7
+
8
+ namespace mlx::core::cpu {
9
+
10
+ void eval(array& arr);
11
+
12
+ } // namespace mlx::core::cpu
@@ -0,0 +1,26 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+ #include "mlx/array.h"
5
+
6
+ namespace mlx::core {
7
+
8
+ template <typename T>
9
+ void matmul(
10
+ const T* a,
11
+ const T* b,
12
+ T* out,
13
+ bool a_transposed,
14
+ bool b_transposed,
15
+ size_t lda,
16
+ size_t ldb,
17
+ size_t ldc,
18
+ float alpha,
19
+ float beta,
20
+ size_t batch_size,
21
+ const Shape& a_shape,
22
+ const Strides& a_strides,
23
+ const Shape& b_shape,
24
+ const Strides& b_strides);
25
+
26
+ } // namespace mlx::core
@@ -0,0 +1,139 @@
1
+ // Copyright © 2025 Apple Inc.
2
+ #pragma once
3
+
4
+ #include "mlx/backend/cpu/simd/simd.h"
5
+
6
+ namespace mlx::core {
7
+
8
+ inline int ceildiv(int a, int b) {
9
+ return (a + b - 1) / b;
10
+ }
11
+
12
+ template <int block_size, typename T, typename AccT>
13
+ void load_block(
14
+ const T* in,
15
+ AccT* out,
16
+ int M,
17
+ int N,
18
+ int i,
19
+ int j,
20
+ bool transpose) {
21
+ if (transpose) {
22
+ for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
23
+ for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
24
+ out[jj * block_size + ii] =
25
+ in[(i * block_size + ii) * N + j * block_size + jj];
26
+ }
27
+ }
28
+ } else {
29
+ for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
30
+ for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
31
+ out[ii * block_size + jj] =
32
+ in[(i * block_size + ii) * N + j * block_size + jj];
33
+ }
34
+ }
35
+ }
36
+ }
37
+
38
+ template <typename T, typename AccT>
39
+ void simd_gemm(
40
+ const T* a,
41
+ const T* b,
42
+ T* c,
43
+ bool a_trans,
44
+ bool b_trans,
45
+ int M,
46
+ int N,
47
+ int K,
48
+ float alpha,
49
+ float beta) {
50
+ constexpr int block_size = 16;
51
+ constexpr int simd_size = simd::max_size<AccT>;
52
+ static_assert(
53
+ (block_size % simd_size) == 0,
54
+ "Block size must be divisible by SIMD size");
55
+
56
+ int last_k_block_size = K - block_size * (K / block_size);
57
+ int last_k_simd_block = (last_k_block_size / simd_size) * simd_size;
58
+ for (int i = 0; i < ceildiv(M, block_size); i++) {
59
+ for (int j = 0; j < ceildiv(N, block_size); j++) {
60
+ AccT c_block[block_size * block_size] = {0.0};
61
+ AccT a_block[block_size * block_size];
62
+ AccT b_block[block_size * block_size];
63
+
64
+ int k = 0;
65
+ for (; k < K / block_size; k++) {
66
+ // Load a and b blocks
67
+ if (a_trans) {
68
+ load_block<block_size>(a, a_block, K, M, k, i, true);
69
+ } else {
70
+ load_block<block_size>(a, a_block, M, K, i, k, false);
71
+ }
72
+ if (b_trans) {
73
+ load_block<block_size>(b, b_block, N, K, j, k, false);
74
+ } else {
75
+ load_block<block_size>(b, b_block, K, N, k, j, true);
76
+ }
77
+
78
+ // Multiply and accumulate
79
+ for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
80
+ for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
81
+ for (int kk = 0; kk < block_size; kk += simd_size) {
82
+ auto av =
83
+ simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
84
+ auto bv =
85
+ simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
86
+ c_block[ii * block_size + jj] += simd::sum(av * bv);
87
+ }
88
+ }
89
+ }
90
+ }
91
+ if (last_k_block_size) {
92
+ // Load a and b blocks
93
+ if (a_trans) {
94
+ load_block<block_size>(a, a_block, K, M, k, i, true);
95
+ } else {
96
+ load_block<block_size>(a, a_block, M, K, i, k, false);
97
+ }
98
+ if (b_trans) {
99
+ load_block<block_size>(b, b_block, N, K, j, k, false);
100
+ } else {
101
+ load_block<block_size>(b, b_block, K, N, k, j, true);
102
+ }
103
+
104
+ // Multiply and accumulate
105
+ for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
106
+ for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
107
+ int kk = 0;
108
+ for (; kk < last_k_simd_block; kk += simd_size) {
109
+ auto av =
110
+ simd::load<AccT, simd_size>(a_block + ii * block_size + kk);
111
+ auto bv =
112
+ simd::load<AccT, simd_size>(b_block + jj * block_size + kk);
113
+ c_block[ii * block_size + jj] += simd::sum(av * bv);
114
+ }
115
+ for (; kk < last_k_block_size; ++kk) {
116
+ c_block[ii * block_size + jj] +=
117
+ a_block[ii * block_size + kk] * b_block[jj * block_size + kk];
118
+ }
119
+ }
120
+ }
121
+ }
122
+
123
+ // Store
124
+ for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) {
125
+ for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) {
126
+ auto c_idx = (i * block_size + ii) * N + j * block_size + jj;
127
+ if (beta != 0) {
128
+ c[c_idx] = static_cast<T>(
129
+ alpha * c_block[ii * block_size + jj] + beta * c[c_idx]);
130
+ } else {
131
+ c[c_idx] = static_cast<T>(alpha * c_block[ii * block_size + jj]);
132
+ }
133
+ }
134
+ }
135
+ }
136
+ }
137
+ }
138
+
139
+ } // namespace mlx::core
@@ -0,0 +1,20 @@
1
+ // Copyright © 2024 Apple Inc.
2
+ #pragma once
3
+
4
+ #include <filesystem>
5
+
6
+ namespace mlx::core {
7
+
8
+ class JitCompiler {
9
+ public:
10
+ // Build a shell command that compiles a source code file to a shared library.
11
+ static std::string build_command(
12
+ const std::filesystem::path& dir,
13
+ const std::string& source_file_name,
14
+ const std::string& shared_lib_name);
15
+
16
+ // Run a command and get its output.
17
+ static std::string exec(const std::string& cmd);
18
+ };
19
+
20
+ } // namespace mlx::core
@@ -0,0 +1,80 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <complex>
6
+ #define LAPACK_COMPLEX_CUSTOM
7
+ #define lapack_complex_float std::complex<float>
8
+ #define lapack_complex_double std::complex<double>
9
+ #define lapack_complex_float_real(z) ((z).real())
10
+ #define lapack_complex_float_imag(z) ((z).imag())
11
+ #define lapack_complex_double_real(z) ((z).real())
12
+ #define lapack_complex_double_imag(z) ((z).imag())
13
+
14
+ #ifdef MLX_USE_ACCELERATE
15
+ #include <Accelerate/Accelerate.h>
16
+ #else
17
+ #include <cblas.h>
18
+ #include <lapack.h>
19
+ #endif
20
+
21
+ #if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME)
22
+
23
+ // This is to work around a change in the function signatures of lapack >= 3.9.1
24
+ // where functions taking char* also include a strlen argument, see a similar
25
+ // change in OpenCV:
26
+ // https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57
27
+ #define MLX_LAPACK_FUNC(f) LAPACK_##f
28
+
29
+ #else
30
+
31
+ #define MLX_LAPACK_FUNC(f) f##_
32
+
33
+ #endif
34
+
35
+ #define INSTANTIATE_LAPACK_REAL(FUNC) \
36
+ template <typename T, typename... Args> \
37
+ void FUNC(Args... args) { \
38
+ if constexpr (std::is_same_v<T, float>) { \
39
+ MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
40
+ } else if constexpr (std::is_same_v<T, double>) { \
41
+ MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
42
+ } \
43
+ }
44
+
45
+ INSTANTIATE_LAPACK_REAL(geqrf)
46
+ INSTANTIATE_LAPACK_REAL(orgqr)
47
+ INSTANTIATE_LAPACK_REAL(syevd)
48
+ INSTANTIATE_LAPACK_REAL(potrf)
49
+ INSTANTIATE_LAPACK_REAL(getrf)
50
+ INSTANTIATE_LAPACK_REAL(getri)
51
+ INSTANTIATE_LAPACK_REAL(trtri)
52
+
53
+ #define INSTANTIATE_LAPACK_COMPLEX(FUNC) \
54
+ template <typename T, typename... Args> \
55
+ void FUNC(Args... args) { \
56
+ if constexpr (std::is_same_v<T, std::complex<float>>) { \
57
+ MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
58
+ } else if constexpr (std::is_same_v<T, std::complex<double>>) { \
59
+ MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
60
+ } \
61
+ }
62
+
63
+ INSTANTIATE_LAPACK_COMPLEX(heevd)
64
+
65
+ #define INSTANTIATE_LAPACK_ALL(FUNC) \
66
+ template <typename T, typename... Args> \
67
+ void FUNC(Args... args) { \
68
+ if constexpr (std::is_same_v<T, float>) { \
69
+ MLX_LAPACK_FUNC(s##FUNC)(std::forward<Args>(args)...); \
70
+ } else if constexpr (std::is_same_v<T, double>) { \
71
+ MLX_LAPACK_FUNC(d##FUNC)(std::forward<Args>(args)...); \
72
+ } else if constexpr (std::is_same_v<T, std::complex<float>>) { \
73
+ MLX_LAPACK_FUNC(c##FUNC)(std::forward<Args>(args)...); \
74
+ } else if constexpr (std::is_same_v<T, std::complex<double>>) { \
75
+ MLX_LAPACK_FUNC(z##FUNC)(std::forward<Args>(args)...); \
76
+ } \
77
+ }
78
+
79
+ INSTANTIATE_LAPACK_ALL(geev)
80
+ INSTANTIATE_LAPACK_ALL(gesdd)
@@ -0,0 +1,56 @@
1
+ #pragma once
2
+
3
+ #include "mlx/backend/cpu/simd/base_simd.h"
4
+
5
+ #if MLX_SIMD_LIBRARY_VERSION < 6
6
+ #include "mlx/backend/cpu/simd/neon_fp16_simd.h"
7
+ #endif
8
+
9
+ namespace mlx::core::simd {
10
+
11
+ #if MLX_SIMD_LIBRARY_VERSION >= 6
12
+ constexpr int N = 8;
13
+ template <int N>
14
+ struct ScalarT<float16_t, N> {
15
+ using v = _Float16;
16
+ };
17
+ #endif
18
+
19
+ template <>
20
+ inline constexpr int max_size<float16_t> = N;
21
+
22
+ #define SIMD_FP16_DEFAULT_UNARY(op) \
23
+ template <> \
24
+ inline Simd<float16_t, N> op(Simd<float16_t, N> v) { \
25
+ Simd<float, N> in = v; \
26
+ return op(in); \
27
+ }
28
+
29
+ SIMD_FP16_DEFAULT_UNARY(acos)
30
+ SIMD_FP16_DEFAULT_UNARY(acosh)
31
+ SIMD_FP16_DEFAULT_UNARY(asin)
32
+ SIMD_FP16_DEFAULT_UNARY(asinh)
33
+ SIMD_FP16_DEFAULT_UNARY(atan)
34
+ SIMD_FP16_DEFAULT_UNARY(atanh)
35
+ SIMD_FP16_DEFAULT_UNARY(cosh)
36
+ SIMD_FP16_DEFAULT_UNARY(expm1)
37
+ SIMD_FP16_DEFAULT_UNARY(log)
38
+ SIMD_FP16_DEFAULT_UNARY(log2)
39
+ SIMD_FP16_DEFAULT_UNARY(log10)
40
+ SIMD_FP16_DEFAULT_UNARY(log1p)
41
+ SIMD_FP16_DEFAULT_UNARY(sinh)
42
+ SIMD_FP16_DEFAULT_UNARY(tan)
43
+ SIMD_FP16_DEFAULT_UNARY(tanh)
44
+
45
+ #define SIMD_FP16_DEFAULT_BINARY(op) \
46
+ template <> \
47
+ inline Simd<float16_t, N> op(Simd<float16_t, N> x, Simd<float16_t, N> y) { \
48
+ Simd<float, N> a = x; \
49
+ Simd<float, N> b = y; \
50
+ return op(a, b); \
51
+ }
52
+ SIMD_FP16_DEFAULT_BINARY(atan2)
53
+ SIMD_FP16_DEFAULT_BINARY(remainder)
54
+ SIMD_FP16_DEFAULT_BINARY(pow)
55
+
56
+ } // namespace mlx::core::simd