mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,156 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/distributed/distributed.h"
6
+ #include "mlx/distributed/distributed_impl.h"
7
+ #include "mlx/primitives.h"
8
+
9
+ namespace mlx::core::distributed {
10
+
11
+ class DistPrimitive : public Primitive {
12
+ public:
13
+ DistPrimitive(Stream stream, Group group)
14
+ : Primitive(stream), group_(group) {}
15
+
16
+ const Group& group() const {
17
+ return group_;
18
+ }
19
+
20
+ private:
21
+ Group group_;
22
+ };
23
+
24
+ class AllReduce : public DistPrimitive {
25
+ public:
26
+ enum ReduceType { And, Or, Sum, Prod, Min, Max };
27
+
28
+ AllReduce(Stream stream, Group group, ReduceType reduce_type)
29
+ : DistPrimitive(stream, group), reduce_type_(reduce_type) {}
30
+
31
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
32
+ override;
33
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
34
+ override;
35
+ std::pair<std::vector<array>, std::vector<int>> vmap(
36
+ const std::vector<array>& inputs,
37
+ const std::vector<int>& axes) override;
38
+ std::vector<array> jvp(
39
+ const std::vector<array>& primals,
40
+ const std::vector<array>& tangents,
41
+ const std::vector<int>& argnums) override;
42
+ std::vector<array> vjp(
43
+ const std::vector<array>& primals,
44
+ const std::vector<array>& cotangents,
45
+ const std::vector<int>& argnums,
46
+ const std::vector<array>& outputs) override;
47
+
48
+ const char* name() const override {
49
+ switch (reduce_type_) {
50
+ case And:
51
+ return "And AllReduce";
52
+ case Or:
53
+ return "Or AllReduce";
54
+ case Sum:
55
+ return "Sum AllReduce";
56
+ case Prod:
57
+ return "Prod AllReduce";
58
+ case Min:
59
+ return "Min AllReduce";
60
+ case Max:
61
+ return "Max AllReduce";
62
+ }
63
+ return "<unknwon AllReduce>";
64
+ }
65
+
66
+ private:
67
+ ReduceType reduce_type_;
68
+ };
69
+
70
+ class AllGather : public DistPrimitive {
71
+ public:
72
+ AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {}
73
+
74
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
75
+ override;
76
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
77
+ override;
78
+
79
+ std::pair<std::vector<array>, std::vector<int>> vmap(
80
+ const std::vector<array>& inputs,
81
+ const std::vector<int>& axes) override;
82
+ std::vector<array> jvp(
83
+ const std::vector<array>& primals,
84
+ const std::vector<array>& tangents,
85
+ const std::vector<int>& argnums) override;
86
+ std::vector<array> vjp(
87
+ const std::vector<array>& primals,
88
+ const std::vector<array>& cotangents,
89
+ const std::vector<int>& argnums,
90
+ const std::vector<array>& outputs) override;
91
+
92
+ DEFINE_NAME(AllGather);
93
+ };
94
+
95
+ class Send : public DistPrimitive {
96
+ public:
97
+ Send(Stream stream, Group group, int dst)
98
+ : DistPrimitive(stream, group), dst_(dst) {}
99
+
100
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
101
+ override;
102
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
103
+ override;
104
+ std::pair<std::vector<array>, std::vector<int>> vmap(
105
+ const std::vector<array>& inputs,
106
+ const std::vector<int>& axes) override;
107
+
108
+ DEFINE_NAME(Send);
109
+
110
+ private:
111
+ int dst_;
112
+ };
113
+
114
+ class Recv : public DistPrimitive {
115
+ public:
116
+ Recv(Stream stream, Group group, int src)
117
+ : DistPrimitive(stream, group), src_(src) {}
118
+
119
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
120
+ override;
121
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
122
+ override;
123
+
124
+ DEFINE_NAME(Recv);
125
+
126
+ private:
127
+ int src_;
128
+ };
129
+
130
+ class ReduceScatter : public DistPrimitive {
131
+ public:
132
+ enum ReduceType { Sum, Min, Max };
133
+ ReduceScatter(Stream stream, Group group, ReduceType reduce_type)
134
+ : DistPrimitive(stream, group), reduce_type_(reduce_type) {}
135
+
136
+ void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
137
+ override;
138
+ void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
139
+ override;
140
+
141
+ const char* name() const override {
142
+ switch (reduce_type_) {
143
+ case Sum:
144
+ return "Sum ReduceScatter";
145
+ case Min:
146
+ return "Min ReduceScatter";
147
+ case Max:
148
+ return "Max ReduceScatter";
149
+ }
150
+ return "<unknwon ReduceScatter>";
151
+ }
152
+
153
+ private:
154
+ ReduceType reduce_type_;
155
+ };
156
+ } // namespace mlx::core::distributed
@@ -0,0 +1,38 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ namespace mlx::core::distributed::detail {
4
+
5
+ template <typename T>
6
+ struct SumOp {
7
+ void operator()(const T* input, T* output, size_t N) const {
8
+ while (N-- > 0) {
9
+ *output += *input;
10
+ input++;
11
+ output++;
12
+ }
13
+ }
14
+ };
15
+
16
+ template <typename T>
17
+ struct MaxOp {
18
+ void operator()(const T* input, T* output, size_t N) const {
19
+ while (N-- > 0) {
20
+ *output = std::max(*output, *input);
21
+ input++;
22
+ output++;
23
+ }
24
+ }
25
+ };
26
+
27
+ template <typename T>
28
+ struct MinOp {
29
+ void operator()(const T* input, T* output, size_t N) const {
30
+ while (N-- > 0) {
31
+ *output = std::min(*output, *input);
32
+ input++;
33
+ output++;
34
+ }
35
+ }
36
+ };
37
+
38
+ } // namespace mlx::core::distributed::detail
@@ -0,0 +1,12 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include "mlx/distributed/distributed.h"
4
+
5
+ namespace mlx::core::distributed::ring {
6
+
7
+ using GroupImpl = mlx::core::distributed::detail::GroupImpl;
8
+
9
+ bool is_available();
10
+ std::shared_ptr<GroupImpl> init(bool strict = false);
11
+
12
+ } // namespace mlx::core::distributed::ring
@@ -0,0 +1,67 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <sys/socket.h>
6
+ #include <functional>
7
+ #include <string>
8
+
9
+ namespace mlx::core::distributed::detail {
10
+
11
+ struct address_t {
12
+ sockaddr_storage addr;
13
+ socklen_t len;
14
+
15
+ const sockaddr* get() const {
16
+ return (struct sockaddr*)&addr;
17
+ }
18
+ };
19
+
20
+ /**
21
+ * Parse a sockaddr from an ip and port provided as strings.
22
+ */
23
+ address_t parse_address(const std::string& ip, const std::string& port);
24
+
25
+ /**
26
+ * Parse a sockaddr provided as an <ip>:<port> string.
27
+ */
28
+ address_t parse_address(const std::string& ip_port);
29
+
30
+ /**
31
+ * Small wrapper over a TCP socket to simplify initiating connections.
32
+ */
33
+ class TCPSocket {
34
+ public:
35
+ TCPSocket(const char* tag);
36
+ TCPSocket(const TCPSocket&) = delete;
37
+ TCPSocket& operator=(const TCPSocket&) = delete;
38
+ TCPSocket(TCPSocket&& s);
39
+ TCPSocket& operator=(TCPSocket&&);
40
+ ~TCPSocket();
41
+
42
+ void listen(const char* tag, const address_t& addr);
43
+ TCPSocket accept(const char* tag);
44
+
45
+ void send(const char* tag, const void* data, size_t len);
46
+ void recv(const char* tag, void* data, size_t len);
47
+
48
+ int detach();
49
+
50
+ operator int() const {
51
+ return sock_;
52
+ }
53
+
54
+ static TCPSocket connect(
55
+ const char* tag,
56
+ const address_t& addr,
57
+ int num_retries = 1,
58
+ int wait = 0,
59
+ std::function<void(int, int)> cb = nullptr);
60
+
61
+ private:
62
+ TCPSocket(int sock);
63
+
64
+ int sock_;
65
+ };
66
+
67
+ } // namespace mlx::core::distributed::detail
@@ -0,0 +1,115 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <complex>
6
+ #include <cstdint>
7
+
8
+ #include "mlx/types/complex.h"
9
+ #include "mlx/types/half_types.h"
10
+
11
+ namespace mlx::core {
12
+
13
+ struct Dtype {
14
+ enum class Val {
15
+ bool_,
16
+ uint8,
17
+ uint16,
18
+ uint32,
19
+ uint64,
20
+ int8,
21
+ int16,
22
+ int32,
23
+ int64,
24
+ float16,
25
+ float32,
26
+ float64,
27
+ bfloat16,
28
+ complex64,
29
+ };
30
+
31
+ enum class Kind {
32
+ b, /* bool */
33
+ u, /* unsigned int */
34
+ i, /* signed int */
35
+ f, /* float */
36
+ c, /* complex */
37
+ V, /* void - used for brain float */
38
+ };
39
+
40
+ enum class Category {
41
+ complexfloating,
42
+ floating,
43
+ inexact,
44
+ signedinteger,
45
+ unsignedinteger,
46
+ integer,
47
+ number,
48
+ generic
49
+ };
50
+
51
+ constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {}
52
+
53
+ constexpr operator Val() const {
54
+ return val_;
55
+ }
56
+ constexpr Val val() const {
57
+ return val_;
58
+ }
59
+ constexpr uint8_t size() const {
60
+ return size_;
61
+ }
62
+
63
+ private:
64
+ Val val_;
65
+ uint8_t size_;
66
+ };
67
+
68
+ inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)};
69
+
70
+ inline constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)};
71
+ inline constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)};
72
+ inline constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)};
73
+ inline constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)};
74
+
75
+ inline constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)};
76
+ inline constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)};
77
+ inline constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)};
78
+ inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)};
79
+
80
+ inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)};
81
+ inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)};
82
+ inline constexpr Dtype float64{Dtype::Val::float64, sizeof(double)};
83
+ inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)};
84
+ inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)};
85
+
86
+ inline constexpr Dtype::Category complexfloating =
87
+ Dtype::Category::complexfloating;
88
+ inline constexpr Dtype::Category floating = Dtype::Category::floating;
89
+ inline constexpr Dtype::Category inexact = Dtype::Category::inexact;
90
+ inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger;
91
+ inline constexpr Dtype::Category unsignedinteger =
92
+ Dtype::Category::unsignedinteger;
93
+ inline constexpr Dtype::Category integer = Dtype::Category::integer;
94
+ inline constexpr Dtype::Category number = Dtype::Category::number;
95
+ inline constexpr Dtype::Category generic = Dtype::Category::generic;
96
+
97
+ bool issubdtype(const Dtype& a, const Dtype& b);
98
+ bool issubdtype(const Dtype::Category& a, const Dtype& b);
99
+ bool issubdtype(const Dtype& a, const Dtype::Category& b);
100
+ bool issubdtype(const Dtype::Category& a, const Dtype::Category& b);
101
+
102
+ Dtype promote_types(const Dtype& t1, const Dtype& t2);
103
+
104
+ inline uint8_t size_of(const Dtype& t) {
105
+ return t.size();
106
+ }
107
+
108
+ Dtype::Kind kindof(const Dtype& t);
109
+
110
+ template <typename T>
111
+ struct TypeToDtype {
112
+ operator Dtype();
113
+ };
114
+
115
+ } // namespace mlx::core
@@ -0,0 +1,119 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <sstream>
6
+
7
+ #include "mlx/dtype.h"
8
+ #include "mlx/utils.h"
9
+
10
+ namespace mlx::core {
11
+
12
+ // Return string representation of dtype.
13
+ const char* dtype_to_string(Dtype arg);
14
+
15
+ #define MLX_INTERNAL_DTYPE_SWITCH_CASE(DTYPE, TYPE) \
16
+ case DTYPE: \
17
+ f(type_identity<TYPE>{}); \
18
+ break
19
+
20
+ #define MLX_INTERNAL_DTYPE_SWITCH_INTS() \
21
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(int8, int8_t); \
22
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(int16, int16_t); \
23
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(int32, int32_t); \
24
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(int64, int64_t); \
25
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(uint8, uint8_t); \
26
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(uint16, uint16_t); \
27
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(uint32, uint32_t); \
28
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t)
29
+
30
+ #define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \
31
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \
32
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \
33
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \
34
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(float64, double)
35
+
36
+ // This already exists in C++20 but in C++20 we can also just use templated
37
+ // lambdas which will make this so much nicer.
38
+ template <typename T>
39
+ struct type_identity {
40
+ using type = T;
41
+ };
42
+
43
+ #define MLX_GET_TYPE(x) typename decltype(x)::type
44
+ #define MLX_GET_VALUE(x) decltype(x)::value
45
+
46
+ template <typename F>
47
+ void dispatch_all_types(Dtype dt, F&& f) {
48
+ switch (dt) {
49
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool);
50
+ MLX_INTERNAL_DTYPE_SWITCH_INTS();
51
+ MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
52
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t);
53
+ }
54
+ }
55
+
56
+ template <typename F>
57
+ void dispatch_int_types(Dtype dt, std::string_view tag, F&& f) {
58
+ switch (dt) {
59
+ MLX_INTERNAL_DTYPE_SWITCH_INTS();
60
+ default:
61
+ std::ostringstream msg;
62
+ msg << tag << " Only integer types supported but " << dt
63
+ << " was provided";
64
+ throw std::invalid_argument(msg.str());
65
+ }
66
+ }
67
+
68
+ template <typename F>
69
+ void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) {
70
+ switch (dt) {
71
+ MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
72
+ default:
73
+ std::ostringstream msg;
74
+ msg << tag << " Only float types supported but " << dt << " was provided";
75
+ throw std::invalid_argument(msg.str());
76
+ }
77
+ }
78
+
79
+ template <typename F>
80
+ void dispatch_inexact_types(Dtype dt, std::string_view tag, F&& f) {
81
+ switch (dt) {
82
+ MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
83
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t);
84
+ default:
85
+ std::ostringstream msg;
86
+ msg << tag << " Only inexact (float/complex) types supported but " << dt
87
+ << " was provided";
88
+ throw std::invalid_argument(msg.str());
89
+ }
90
+ }
91
+
92
+ template <typename F>
93
+ void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) {
94
+ switch (dt) {
95
+ MLX_INTERNAL_DTYPE_SWITCH_INTS();
96
+ MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
97
+ default:
98
+ std::ostringstream msg;
99
+ msg << tag << " Only integer and float types supported but " << dt
100
+ << " was provided";
101
+ throw std::invalid_argument(msg.str());
102
+ }
103
+ }
104
+
105
+ template <typename F>
106
+ void dispatch_real_types(Dtype dt, std::string_view tag, F&& f) {
107
+ switch (dt) {
108
+ MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool);
109
+ MLX_INTERNAL_DTYPE_SWITCH_INTS();
110
+ MLX_INTERNAL_DTYPE_SWITCH_FLOATS();
111
+ default:
112
+ std::ostringstream msg;
113
+ msg << tag << " Only real numbers supported but " << dt
114
+ << " was provided";
115
+ throw std::invalid_argument(msg.str());
116
+ }
117
+ }
118
+
119
+ } // namespace mlx::core
@@ -0,0 +1,22 @@
1
+ // Copyright © 2024 Apple Inc.
2
+ #pragma once
3
+
4
+ #include <string>
5
+ #include <tuple>
6
+ #include <vector>
7
+
8
+ #include "mlx/array.h"
9
+ #include "mlx/utils.h"
10
+
11
+ namespace mlx::core {
12
+
13
+ std::pair<std::vector<std::vector<int>>, std::string> einsum_path(
14
+ const std::string& subscripts,
15
+ const std::vector<array>& operands);
16
+
17
+ array einsum(
18
+ const std::string& subscripts,
19
+ const std::vector<array>& operands,
20
+ StreamOrDevice s = {});
21
+
22
+ } // namespace mlx::core
@@ -0,0 +1,58 @@
1
+ // Copyright © 2024 Apple Inc.
2
+ #pragma once
3
+
4
+ #include <cstdint>
5
+ #include <memory>
6
+ #include <stdexcept>
7
+
8
+ #include "mlx/stream.h"
9
+
10
+ namespace mlx::core {
11
+
12
+ class Event {
13
+ public:
14
+ Event() {};
15
+ explicit Event(Stream stream);
16
+
17
+ // Wait for the event to be signaled at its current value
18
+ void wait();
19
+
20
+ // Wait in the given stream for the event to be signaled at its current value
21
+ void wait(Stream stream);
22
+
23
+ // Signal the event at its current value in the given stream
24
+ void signal(Stream stream);
25
+
26
+ // Check if the event has been signaled at its current value
27
+ bool is_signaled() const;
28
+
29
+ // Check if the event is valid
30
+ bool valid() const {
31
+ return event_ != nullptr;
32
+ }
33
+
34
+ uint64_t value() const {
35
+ return value_;
36
+ }
37
+
38
+ void set_value(uint64_t v) {
39
+ value_ = v;
40
+ }
41
+
42
+ const Stream& stream() const {
43
+ if (!valid()) {
44
+ throw std::runtime_error(
45
+ "[Event::stream] Cannot access stream on invalid event.");
46
+ }
47
+ return stream_;
48
+ }
49
+
50
+ private:
51
+ // Default constructed stream should never be used
52
+ // since the event is not yet valid
53
+ Stream stream_{0, Device::cpu};
54
+ std::shared_ptr<void> event_{nullptr};
55
+ uint64_t value_{0};
56
+ };
57
+
58
+ } // namespace mlx::core