mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,85 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+ #include "mlx/allocator.h"
5
+ #include "mlx/array.h"
6
+ #include "mlx/backend/common/utils.h"
7
+
8
+ namespace mlx::core {
9
+
10
+ // TODO: Add support for more combinations of input types.
11
+ enum class TernaryOpType {
12
+ ScalarScalarScalar,
13
+ VectorVectorVector,
14
+ VectorVectorScalar,
15
+ VectorScalarVector,
16
+ General,
17
+ };
18
+
19
+ inline TernaryOpType
20
+ get_ternary_op_type(const array& a, const array& b, const array& c) {
21
+ TernaryOpType topt;
22
+ if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
23
+ topt = TernaryOpType::ScalarScalarScalar;
24
+ } else if (
25
+ (a.flags().row_contiguous && b.flags().row_contiguous &&
26
+ c.flags().row_contiguous) ||
27
+ (a.flags().col_contiguous && b.flags().col_contiguous &&
28
+ c.flags().col_contiguous)) {
29
+ topt = TernaryOpType::VectorVectorVector;
30
+ } else if (
31
+ b.data_size() == 1 && a.flags().row_contiguous &&
32
+ c.flags().row_contiguous) {
33
+ topt = TernaryOpType::VectorScalarVector;
34
+ } else if (
35
+ c.data_size() == 1 && a.flags().row_contiguous &&
36
+ b.flags().row_contiguous) {
37
+ topt = TernaryOpType::VectorVectorScalar;
38
+ } else {
39
+ topt = TernaryOpType::General;
40
+ }
41
+ return topt;
42
+ }
43
+
44
+ inline void set_ternary_op_output_data(
45
+ const array& a,
46
+ const array& b,
47
+ const array& c,
48
+ array& out,
49
+ TernaryOpType topt,
50
+ std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
51
+ auto maybe_donate = [&out](const array& x) {
52
+ if (is_donatable(x, out)) {
53
+ out.copy_shared_buffer(x);
54
+ return true;
55
+ }
56
+ return false;
57
+ };
58
+
59
+ switch (topt) {
60
+ case TernaryOpType::ScalarScalarScalar:
61
+ out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
62
+ break;
63
+ case TernaryOpType::VectorVectorVector:
64
+ if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
65
+ out.set_data(
66
+ mallocfn(out.itemsize() * b.data_size()),
67
+ b.data_size(),
68
+ b.strides(),
69
+ b.flags());
70
+ }
71
+ break;
72
+ case TernaryOpType::VectorVectorScalar:
73
+ case TernaryOpType::VectorScalarVector:
74
+ case TernaryOpType::General:
75
+ // Try to donate an input which is row_contiguous
76
+ if (!((a.flags().row_contiguous && maybe_donate(a)) ||
77
+ (b.flags().row_contiguous && maybe_donate(b)) ||
78
+ (c.flags().row_contiguous && maybe_donate(c)))) {
79
+ out.set_data(mallocfn(out.nbytes()));
80
+ }
81
+ break;
82
+ }
83
+ }
84
+
85
+ } // namespace mlx::core
@@ -0,0 +1,29 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/allocator.h"
6
+ #include "mlx/backend/common/utils.h"
7
+
8
+ namespace mlx::core {
9
+
10
+ inline void set_unary_output_data(
11
+ const array& in,
12
+ array& out,
13
+ std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
14
+ if (in.flags().contiguous) {
15
+ if (is_donatable(in, out)) {
16
+ out.copy_shared_buffer(in);
17
+ } else {
18
+ out.set_data(
19
+ mallocfn(in.data_size() * out.itemsize()),
20
+ in.data_size(),
21
+ in.strides(),
22
+ in.flags());
23
+ }
24
+ } else {
25
+ out.set_data(mallocfn(out.nbytes()));
26
+ }
27
+ }
28
+
29
+ } // namespace mlx::core
@@ -0,0 +1,205 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <filesystem>
6
+ #include <tuple>
7
+ #include <vector>
8
+
9
+ #include "mlx/array.h"
10
+
11
+ namespace mlx::core {
12
+
13
+ // Return the directory that contains current shared library.
14
+ std::filesystem::path current_binary_dir();
15
+
16
+ inline int64_t
17
+ elem_to_loc(int elem, const Shape& shape, const Strides& strides) {
18
+ int64_t loc = 0;
19
+ for (int i = shape.size() - 1; i >= 0; --i) {
20
+ auto q_and_r = ldiv(elem, shape[i]);
21
+ loc += q_and_r.rem * strides[i];
22
+ elem = q_and_r.quot;
23
+ }
24
+ return loc;
25
+ }
26
+
27
+ inline int64_t elem_to_loc(int elem, const array& a) {
28
+ if (a.flags().row_contiguous) {
29
+ return elem;
30
+ }
31
+ return elem_to_loc(elem, a.shape(), a.strides());
32
+ }
33
+
34
+ inline Strides make_contiguous_strides(const Shape& shape) {
35
+ Strides strides(shape.size(), 1);
36
+ for (int i = shape.size() - 1; i > 0; i--) {
37
+ strides[i - 1] = strides[i] * shape[i];
38
+ }
39
+ return strides;
40
+ }
41
+
42
+ // Collapse dims that are contiguous to possibly route to a better kernel
43
+ // e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1})
44
+ // should return {{2, 4}, {{1, 2}}}.
45
+ //
46
+ // When multiple arrays are passed they should all have the same shape. The
47
+ // collapsed axes are also the same so one shape is returned.
48
+ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
49
+ const Shape& shape,
50
+ const std::vector<Strides>& strides,
51
+ int64_t size_cap = std::numeric_limits<int32_t>::max());
52
+
53
+ inline std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
54
+ const std::vector<array>& xs,
55
+ size_t size_cap = std::numeric_limits<int32_t>::max()) {
56
+ std::vector<Strides> strides;
57
+ for (auto& x : xs) {
58
+ strides.emplace_back(x.strides());
59
+ }
60
+ return collapse_contiguous_dims(xs[0].shape(), strides, size_cap);
61
+ }
62
+
63
+ template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
64
+ inline auto collapse_contiguous_dims(Arrays&&... xs) {
65
+ return collapse_contiguous_dims(
66
+ std::vector<array>{std::forward<Arrays>(xs)...});
67
+ }
68
+
69
+ // The single array version of the above.
70
+ std::pair<Shape, Strides> collapse_contiguous_dims(
71
+ const Shape& shape,
72
+ const Strides& strides,
73
+ int64_t size_cap = std::numeric_limits<int32_t>::max());
74
+ std::pair<Shape, Strides> collapse_contiguous_dims(
75
+ const array& a,
76
+ int64_t size_cap = std::numeric_limits<int32_t>::max());
77
+
78
+ // Compute the thread block dimensions which fit the given
79
+ // input dimensions.
80
+ // - The thread block dimensions will be powers of two
81
+ // - The thread block size will be less than 2^pow2
82
+ using Dims = std::tuple<uint32_t, uint32_t, uint32_t>;
83
+ Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10);
84
+
85
+ // Computes a 2D grid where each element is < UINT_MAX
86
+ // Assumes:
87
+ // - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2
88
+ // - shape and strides correspond to a contiguous (no holes) but
89
+ // possibly broadcasted array
90
+ Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides);
91
+
92
+ // Same as above but we do an implicit division with divisor.
93
+ // Basically, equivalent to factorizing
94
+ // Prod(s \forall s in shape if strides[s] > 0) / divisor.
95
+ Dims get_2d_grid_dims_common(
96
+ const Shape& shape,
97
+ const Strides& strides,
98
+ size_t divisor);
99
+
100
+ // Get both the block and a grid of blocks that covers dim0, dim1 and dim2.
101
+ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2);
102
+
103
+ struct ContiguousIterator {
104
+ inline void step() {
105
+ int dims = shape_.size();
106
+ if (dims == 0) {
107
+ return;
108
+ }
109
+ int i = dims - 1;
110
+ while (pos_[i] == (shape_[i] - 1) && i > 0) {
111
+ pos_[i] = 0;
112
+ loc -= (shape_[i] - 1) * strides_[i];
113
+ i--;
114
+ }
115
+ pos_[i]++;
116
+ loc += strides_[i];
117
+ }
118
+
119
+ void seek(int64_t n) {
120
+ loc = 0;
121
+ for (int i = shape_.size() - 1; i >= 0; --i) {
122
+ auto q_and_r = ldiv(n, shape_[i]);
123
+ loc += q_and_r.rem * strides_[i];
124
+ pos_[i] = q_and_r.rem;
125
+ n = q_and_r.quot;
126
+ }
127
+ }
128
+
129
+ void reset() {
130
+ loc = 0;
131
+ std::fill(pos_.begin(), pos_.end(), 0);
132
+ }
133
+
134
+ ContiguousIterator() {};
135
+
136
+ explicit ContiguousIterator(const array& a)
137
+ : shape_(a.shape()), strides_(a.strides()) {
138
+ if (!shape_.empty()) {
139
+ std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
140
+ pos_ = Shape(shape_.size(), 0);
141
+ }
142
+ }
143
+
144
+ explicit ContiguousIterator(
145
+ const Shape& shape,
146
+ const Strides& strides,
147
+ int dims)
148
+ : shape_(shape.begin(), shape.begin() + dims),
149
+ strides_(strides.begin(), strides.begin() + dims) {
150
+ if (!shape_.empty()) {
151
+ std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_);
152
+ pos_ = Shape(shape_.size(), 0);
153
+ }
154
+ }
155
+
156
+ int64_t loc{0};
157
+
158
+ private:
159
+ Shape shape_;
160
+ Strides strides_;
161
+ Shape pos_;
162
+ };
163
+
164
+ inline auto check_contiguity(const Shape& shape, const Strides& strides) {
165
+ size_t no_broadcast_data_size = 1;
166
+ int64_t f_stride = 1;
167
+ int64_t b_stride = 1;
168
+ bool is_row_contiguous = true;
169
+ bool is_col_contiguous = true;
170
+
171
+ for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) {
172
+ is_col_contiguous &= strides[i] == f_stride || shape[i] == 1;
173
+ is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1;
174
+ f_stride *= shape[i];
175
+ b_stride *= shape[ri];
176
+ if (strides[i] > 0) {
177
+ no_broadcast_data_size *= shape[i];
178
+ }
179
+ }
180
+
181
+ return std::make_tuple(
182
+ no_broadcast_data_size, is_row_contiguous, is_col_contiguous);
183
+ }
184
+
185
+ inline bool is_donatable(const array& in, const array& out) {
186
+ constexpr size_t donation_extra = 16384;
187
+
188
+ return in.is_donatable() && in.itemsize() == out.itemsize() &&
189
+ in.buffer_size() <= out.nbytes() + donation_extra;
190
+ }
191
+
192
+ std::pair<bool, Strides> prepare_reshape(const array& in, const array& out);
193
+
194
+ void shared_buffer_reshape(
195
+ const array& in,
196
+ const Strides& out_strides,
197
+ array& out);
198
+
199
+ template <typename T>
200
+ inline SmallVector<T> remove_index(SmallVector<T> vec, size_t index) {
201
+ vec.erase(std::next(vec.begin(), index));
202
+ return vec;
203
+ }
204
+
205
+ } // namespace mlx::core
@@ -0,0 +1,28 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+ #include "mlx/backend/cpu/encoder.h"
7
+
8
+ namespace mlx::core {
9
+
10
+ namespace {
11
+
12
+ template <typename T>
13
+ void arange(T start, T next, array& out, size_t size, Stream stream) {
14
+ auto ptr = out.data<T>();
15
+ auto step_size = next - start;
16
+ auto& encoder = cpu::get_command_encoder(stream);
17
+ encoder.set_output_array(out);
18
+ encoder.dispatch([ptr, start, step_size, size]() mutable {
19
+ for (int i = 0; i < size; ++i) {
20
+ ptr[i] = start;
21
+ start += step_size;
22
+ }
23
+ });
24
+ }
25
+
26
+ } // namespace
27
+
28
+ } // namespace mlx::core
@@ -0,0 +1,9 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ namespace mlx::core::cpu {
6
+
7
+ bool is_available();
8
+
9
+ } // namespace mlx::core::cpu