mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,17 @@
1
+ #pragma once
2
+
3
+ #include "mlx/array.h"
4
+ #include "mlx/primitives.h"
5
+
6
+ namespace mlx::core {
7
+
8
+ void scan_gpu_inplace(
9
+ array in,
10
+ array& out,
11
+ Scan::ReduceType reduce_type,
12
+ int axis,
13
+ bool reverse,
14
+ bool inclusive,
15
+ const Stream& s);
16
+
17
+ } // namespace mlx::core
@@ -0,0 +1,21 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ void ternary_op_gpu(
10
+ const std::vector<array>& inputs,
11
+ array& out,
12
+ const char* op,
13
+ const Stream& s);
14
+
15
+ void ternary_op_gpu_inplace(
16
+ const std::vector<array>& inputs,
17
+ array& out,
18
+ const char* op,
19
+ const Stream& s);
20
+
21
+ } // namespace mlx::core
@@ -0,0 +1,21 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ void unary_op_gpu(
10
+ const std::vector<array>& inputs,
11
+ array& out,
12
+ const char* op,
13
+ const Stream& s);
14
+
15
+ void unary_op_gpu_inplace(
16
+ const std::vector<array>& inputs,
17
+ array& out,
18
+ const char* op,
19
+ const Stream& s);
20
+
21
+ } // namespace mlx::core
@@ -0,0 +1,84 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <type_traits>
6
+
7
+ #include "mlx/array.h"
8
+ #include "mlx/backend/metal/device.h"
9
+ #include "mlx/primitives.h"
10
+
11
+ namespace mlx::core {
12
+
13
+ std::string type_to_name(const Dtype& t);
14
+ std::string type_to_name(const array& a);
15
+
16
+ // Compute the grid and block dimensions, check backend/common/utils.h for docs.
17
+ MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);
18
+ MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides);
19
+ MTL::Size
20
+ get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor);
21
+
22
+ inline NS::String* make_string(std::ostringstream& os) {
23
+ std::string string = os.str();
24
+ return NS::String::string(string.c_str(), NS::UTF8StringEncoding);
25
+ }
26
+
27
+ inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) {
28
+ #ifdef MLX_METAL_DEBUG
29
+ std::ostringstream label;
30
+ label << "Stream " << index;
31
+ queue->setLabel(make_string(label));
32
+ #endif
33
+ }
34
+
35
+ inline void debug_set_primitive_buffer_label(
36
+ MTL::CommandBuffer* command_buffer,
37
+ Primitive& primitive) {
38
+ #ifdef MLX_METAL_DEBUG
39
+ std::ostringstream label;
40
+ if (auto cbuf_label = command_buffer->label(); cbuf_label) {
41
+ label << cbuf_label->utf8String();
42
+ }
43
+ label << primitive.name();
44
+ command_buffer->setLabel(make_string(label));
45
+ #endif
46
+ }
47
+
48
+ template <typename T>
49
+ constexpr bool is_numeric_except_char = std::is_arithmetic_v<T> &&
50
+ !std::is_same_v<T, char> && !std::is_same_v<T, signed char> &&
51
+ !std::is_same_v<T, unsigned char> && !std::is_same_v<T, wchar_t>;
52
+
53
+ template <typename T>
54
+ void concatenate(std::string& acc, T first) {
55
+ if constexpr (is_numeric_except_char<T>) {
56
+ acc += std::to_string(first);
57
+ } else {
58
+ acc += first;
59
+ }
60
+ }
61
+
62
+ template <typename T, typename... Args>
63
+ void concatenate(std::string& acc, T first, Args... args) {
64
+ if constexpr (is_numeric_except_char<T>) {
65
+ acc += std::to_string(first);
66
+ } else {
67
+ acc += first;
68
+ }
69
+ concatenate(acc, args...);
70
+ }
71
+
72
+ inline int get_work_per_thread(Dtype dtype) {
73
+ return std::max(1, 8 / dtype.size());
74
+ }
75
+ inline int get_work_per_thread(Dtype dtype, size_t size) {
76
+ constexpr size_t wpt_threshold = 1 << 16;
77
+ return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size());
78
+ }
79
+
80
+ inline size_t ceildiv(size_t n, size_t m) {
81
+ return (n + m - 1) / m;
82
+ }
83
+
84
+ } // namespace mlx::core
@@ -0,0 +1,16 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <sys/sysctl.h>
6
+
7
+ namespace {
8
+
9
+ size_t get_memory_size() {
10
+ size_t memsize = 0;
11
+ size_t length = sizeof(memsize);
12
+ sysctlbyname("hw.memsize", &memsize, &length, NULL, 0);
13
+ return memsize;
14
+ }
15
+
16
+ } // namespace
@@ -0,0 +1,22 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <sys/sysinfo.h>
6
+
7
+ namespace {
8
+
9
+ size_t get_memory_size() {
10
+ struct sysinfo info;
11
+
12
+ if (sysinfo(&info) != 0) {
13
+ return 0;
14
+ }
15
+
16
+ size_t total_ram = info.totalram;
17
+ total_ram *= info.mem_unit;
18
+
19
+ return total_ram;
20
+ }
21
+
22
+ } // namespace
@@ -0,0 +1,44 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
10
+
11
+ /** Compile takes a function and returns a compiled function. */
12
+ std::function<std::vector<array>(const std::vector<array>&)> compile(
13
+ std::function<std::vector<array>(const std::vector<array>&)> fun,
14
+ bool shapeless = false);
15
+
16
+ std::function<std::vector<array>(const std::vector<array>&)> compile(
17
+ std::vector<array> (*fun)(const std::vector<array>&),
18
+ bool shapeless = false);
19
+
20
+ // Convert capture-less lambdas to function pointers.
21
+ template <
22
+ typename F,
23
+ typename = std::enable_if_t<
24
+ std::is_convertible_v<F, decltype(+std::declval<F>())>>>
25
+ std::function<std::vector<array>(const std::vector<array>&)> compile(
26
+ F&& f,
27
+ bool shapeless = false) {
28
+ return compile(+f, shapeless);
29
+ }
30
+
31
+ /** Globally disable compilation.
32
+ * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
33
+ * be used to disable compilation.
34
+ */
35
+ void disable_compile();
36
+
37
+ /** Globally enable compilation.
38
+ * This will override the environment variable ``MLX_DISABLE_COMPILE``.
39
+ */
40
+ void enable_compile();
41
+
42
+ /** Set the compiler mode to the given value. */
43
+ void set_compile_mode(CompileMode mode);
44
+ } // namespace mlx::core
@@ -0,0 +1,69 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <unordered_map>
6
+
7
+ #include "mlx/array.h"
8
+
9
+ namespace mlx::core::detail {
10
+
11
+ using ArraysAndExtra = std::pair<std::vector<array>, std::shared_ptr<void>>;
12
+ using ArrayFnWithExtra =
13
+ std::function<ArraysAndExtra(const std::vector<array>&)>;
14
+
15
+ // This is not part of the general C++ API as calling with a bad id is a bad
16
+ // idea.
17
+ std::function<std::vector<array>(const std::vector<array>&)> compile(
18
+ std::function<std::vector<array>(const std::vector<array>&)> fun,
19
+ std::uintptr_t fun_id,
20
+ bool shapeless = false,
21
+ std::vector<uint64_t> constants = {});
22
+
23
+ ArrayFnWithExtra compile(
24
+ ArrayFnWithExtra fun,
25
+ std::uintptr_t fun_id,
26
+ bool shapeless,
27
+ std::vector<uint64_t> constants);
28
+
29
+ // Erase cached compile functions
30
+ void compile_erase(std::uintptr_t fun_id);
31
+
32
+ // Clear the compiler cache causing a recompilation of all compiled functions
33
+ // when called again.
34
+ void compile_clear_cache();
35
+
36
+ bool compile_available_for_device(const Device& device);
37
+
38
+ std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
39
+ compile_trace(
40
+ const ArrayFnWithExtra& fun,
41
+ const std::vector<array>& inputs,
42
+ bool shapeless);
43
+
44
+ using ParentsMap =
45
+ std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
46
+
47
+ // Traverses the graph to build a tape and a map of array ids to their parents
48
+ std::pair<std::vector<array>, ParentsMap> compile_dfs(
49
+ const std::vector<array>& inputs,
50
+ std::vector<array>& outputs,
51
+ const std::vector<array>& original_inputs);
52
+
53
+ // Simplify the tape.
54
+ void compile_simplify(
55
+ std::vector<array>& tape,
56
+ ParentsMap& parents_map,
57
+ std::vector<array>& outputs,
58
+ int passes);
59
+
60
+ std::vector<array> compile_replace(
61
+ const std::vector<array>& tape,
62
+ const std::vector<array>& trace_inputs,
63
+ const std::vector<array>& trace_outputs,
64
+ const std::vector<array>& inputs,
65
+ bool shapeless);
66
+
67
+ void compile_validate_shapeless(const std::vector<array>& tape);
68
+
69
+ } // namespace mlx::core::detail
@@ -0,0 +1,31 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ namespace mlx::core {
6
+
7
+ struct Device {
8
+ enum class DeviceType {
9
+ cpu,
10
+ gpu,
11
+ };
12
+
13
+ static constexpr DeviceType cpu = DeviceType::cpu;
14
+ static constexpr DeviceType gpu = DeviceType::gpu;
15
+
16
+ Device(DeviceType type, int index = 0) : type(type), index(index) {}
17
+
18
+ DeviceType type;
19
+ int index;
20
+ };
21
+
22
+ const Device& default_device();
23
+
24
+ void set_default_device(const Device& d);
25
+
26
+ bool operator==(const Device& lhs, const Device& rhs);
27
+ bool operator!=(const Device& lhs, const Device& rhs);
28
+
29
+ bool is_available(const Device& d);
30
+
31
+ } // namespace mlx::core
@@ -0,0 +1,60 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <memory>
6
+
7
+ #include "mlx/array.h"
8
+ #include "mlx/utils.h"
9
+
10
+ namespace mlx::core::distributed {
11
+
12
+ // Forward declaration of the base group implementation.
13
+ namespace detail {
14
+ class GroupImpl;
15
+ };
16
+
17
+ /* Check if a communication backend is available */
18
+ bool is_available();
19
+ bool is_available(const std::string& bk);
20
+
21
+ /**
22
+ * A distributed::Group represents a group of independent mlx processes that
23
+ * can communicate. We must also be able to create sub-groups from a group in
24
+ * order to define more granular communication.
25
+ */
26
+ struct Group {
27
+ Group(std::shared_ptr<detail::GroupImpl> group) : group_(std::move(group)) {}
28
+
29
+ int rank() const;
30
+ int size() const;
31
+
32
+ /**
33
+ * Split the group according to the provided color. Namely processes that use
34
+ * the same color will go to the same group.
35
+ *
36
+ * The key defines the rank of the processes in the new group. The smaller
37
+ * the key the smaller the rank. If the provided key is negative, then the
38
+ * rank in the current group is used.
39
+ */
40
+ Group split(int color, int key = -1) const;
41
+
42
+ const std::shared_ptr<detail::GroupImpl>& raw_group() const {
43
+ return group_;
44
+ }
45
+
46
+ private:
47
+ std::shared_ptr<detail::GroupImpl> group_{nullptr};
48
+ };
49
+
50
+ /**
51
+ * Initialize the distributed backend and return the group containing all
52
+ * discoverable processes.
53
+ *
54
+ * If strict is true then throw an error if we couldn't initialize the
55
+ * distributed subsystem. Otherwise simply return a singleton group which will
56
+ * render communication operations as no-op.
57
+ */
58
+ Group init(bool strict = false, const std::string& bk = "any");
59
+
60
+ } // namespace mlx::core::distributed
@@ -0,0 +1,59 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/distributed/distributed.h"
6
+
7
+ namespace mlx::core::distributed::detail {
8
+
9
+ /**
10
+ * Abstract base class of a distributed group implementation.
11
+ */
12
+ class GroupImpl {
13
+ public:
14
+ virtual ~GroupImpl() {}
15
+
16
+ // Choose the stream this communication group can operate on
17
+ virtual Stream communication_stream(StreamOrDevice s = {}) = 0;
18
+
19
+ // Group operations
20
+ virtual int rank() = 0;
21
+ virtual int size() = 0;
22
+ virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
23
+
24
+ // Actual communication operations
25
+ virtual void all_sum(const array& input, array& output, Stream stream) = 0;
26
+ virtual void all_gather(const array& input, array& output, Stream stream) = 0;
27
+ virtual void send(const array& input, int dst, Stream stream) = 0;
28
+ virtual void recv(array& out, int src, Stream stream) = 0;
29
+ virtual void all_max(const array& input, array& output, Stream stream) = 0;
30
+ virtual void all_min(const array& input, array& output, Stream stream) = 0;
31
+ virtual void
32
+ sum_scatter(const array& input, array& output, Stream stream) = 0;
33
+ };
34
+
35
+ /* Define the MLX stream that the communication should happen in. */
36
+ Stream communication_stream(Group group, StreamOrDevice s = {});
37
+
38
+ /* Perform an all reduce sum operation */
39
+ void all_sum(Group group, const array& input, array& output, Stream stream);
40
+
41
+ /* Perform an all gather operation */
42
+ void all_gather(Group group, const array& input, array& output, Stream stream);
43
+
44
+ /** Send an array to the dst rank */
45
+ void send(Group group, const array& input, int dst, Stream stream);
46
+
47
+ /** Recv an array from the src rank */
48
+ void recv(Group group, array& out, int src, Stream stream);
49
+
50
+ /** Max reduction */
51
+ void all_max(Group group, const array& input, array& output, Stream stream);
52
+
53
+ /** Min reduction */
54
+ void all_min(Group group, const array& input, array& output, Stream stream);
55
+
56
+ /** Reduce scatter with average operation */
57
+ void sum_scatter(Group group, const array& input, array& output, Stream stream);
58
+
59
+ } // namespace mlx::core::distributed::detail
@@ -0,0 +1,12 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/distributed/distributed.h"
4
+
5
+ namespace mlx::core::distributed::jaccl {
6
+
7
+ using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8
+
9
+ bool is_available();
10
+ std::shared_ptr<GroupImpl> init(bool strict = false);
11
+
12
+ } // namespace mlx::core::distributed::jaccl
@@ -0,0 +1,12 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include "mlx/distributed/distributed.h"
4
+
5
+ namespace mlx::core::distributed::mpi {
6
+
7
+ using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8
+
9
+ bool is_available();
10
+ std::shared_ptr<GroupImpl> init(bool strict = false);
11
+
12
+ } // namespace mlx::core::distributed::mpi
@@ -0,0 +1,28 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ // Constants
4
+
5
+ #define MPI_SUCCESS 0
6
+ #define MPI_ANY_SOURCE -1
7
+ #define MPI_ANY_TAG -1
8
+ #define MPI_IN_PLACE ((void*)1)
9
+ #define MPI_MAX_LIBRARY_VERSION_STRING 256
10
+
11
+ // Define all the types that we use so that we don't include <mpi.h> which
12
+ // causes linker errors on some platforms.
13
+ //
14
+ // NOTE: We define everything for openmpi.
15
+
16
+ typedef void* MPI_Comm;
17
+ typedef void* MPI_Datatype;
18
+ typedef void* MPI_Op;
19
+
20
+ typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*);
21
+
22
+ typedef struct ompi_status_public_t {
23
+ int MPI_SOURCE;
24
+ int MPI_TAG;
25
+ int MPI_ERROR;
26
+ int _cancelled;
27
+ size_t _ucount;
28
+ } MPI_Status;
@@ -0,0 +1,12 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include "mlx/distributed/distributed.h"
4
+
5
+ namespace mlx::core::distributed::nccl {
6
+
7
+ using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8
+
9
+ bool is_available();
10
+ std::shared_ptr<GroupImpl> init(bool strict = false);
11
+
12
+ } // namespace mlx::core::distributed::nccl
@@ -0,0 +1,56 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <optional>
6
+
7
+ #include "mlx/distributed/distributed.h"
8
+ #include "mlx/utils.h"
9
+
10
+ namespace mlx::core::distributed {
11
+
12
+ array all_sum(
13
+ const array& x,
14
+ std::optional<Group> group = std::nullopt,
15
+ StreamOrDevice s = {});
16
+
17
+ array all_gather(
18
+ const array& x,
19
+ std::optional<Group> group = std::nullopt,
20
+ StreamOrDevice S = {});
21
+
22
+ array send(
23
+ const array& x,
24
+ int dst,
25
+ std::optional<Group> group = std::nullopt,
26
+ StreamOrDevice s = {});
27
+
28
+ array recv(
29
+ Shape shape,
30
+ Dtype dtype,
31
+ int src,
32
+ std::optional<Group> group = std::nullopt,
33
+ StreamOrDevice s = {});
34
+
35
+ array recv_like(
36
+ const array& x,
37
+ int src,
38
+ std::optional<Group> group = std::nullopt,
39
+ StreamOrDevice s = {});
40
+
41
+ array all_max(
42
+ const array& x,
43
+ std::optional<Group> group = std::nullopt,
44
+ StreamOrDevice s = {});
45
+
46
+ array all_min(
47
+ const array& x,
48
+ std::optional<Group> group = std::nullopt,
49
+ StreamOrDevice s = {});
50
+
51
+ array sum_scatter(
52
+ const array& x,
53
+ std::optional<Group> group = std::nullopt,
54
+ StreamOrDevice s = {});
55
+
56
+ } // namespace mlx::core::distributed