mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,234 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <algorithm>
6
+ #include <cmath>
7
+ #include <cstdint>
8
+ #include <vector>
9
+
10
+ #define __MLX_HALF_NAN__ 0x7D00
11
+
12
+ namespace mlx::core {
13
+
14
+ namespace {
15
+ union float_bits_fp16 {
16
+ float f;
17
+ uint32_t u;
18
+ };
19
+ } // namespace
20
+
21
+ struct _MLX_Float16 {
22
+ uint16_t bits_;
23
+
24
+ // Default constructor
25
+ _MLX_Float16() = default;
26
+
27
+ // Default copy constructor
28
+ _MLX_Float16(_MLX_Float16 const&) = default;
29
+
30
+ // Appease std::vector<bool> for being special
31
+ _MLX_Float16& operator=(std::vector<bool>::reference x) {
32
+ bits_ = x;
33
+ return *this;
34
+ }
35
+
36
+ _MLX_Float16& operator=(const float& x) {
37
+ return (*this = _MLX_Float16(x));
38
+ }
39
+
40
+ // From float32
41
+ _MLX_Float16(const float& x) : bits_(0) {
42
+ // Conversion following
43
+ // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
44
+
45
+ // Union
46
+ float_bits_fp16 in;
47
+
48
+ // Take fp32 bits
49
+ in.f = x;
50
+
51
+ // Find and take sign bit
52
+ uint32_t x_sign_32 = in.u & uint32_t(0x80000000);
53
+ uint16_t x_sign_16 = (x_sign_32 >> 16);
54
+
55
+ if (std::isnan(x)) {
56
+ bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__);
57
+ } else {
58
+ // Union
59
+ float_bits_fp16 inf_scale, zero_scale, magic_bits;
60
+
61
+ // Find exponent bits and take the max supported by half
62
+ uint32_t x_expo_32 = in.u & uint32_t(0x7f800000);
63
+ uint32_t max_expo_32 = uint32_t(0x38800000);
64
+ x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32;
65
+ x_expo_32 += uint32_t(15) << 23;
66
+
67
+ // Handle scaling to inf as needed
68
+ inf_scale.u = uint32_t(0x77800000);
69
+ zero_scale.u = uint32_t(0x08800000);
70
+
71
+ // Combine with magic and let addition do rounding
72
+ magic_bits.u = x_expo_32;
73
+ magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f;
74
+
75
+ // Take the lower 5 bits of the exponent
76
+ uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00));
77
+
78
+ // Collect the lower 12 bits which have the mantissa
79
+ uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff);
80
+
81
+ // Combine sign, exp and mantissa
82
+ bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16));
83
+ }
84
+ }
85
+
86
+ // To float32
87
+ operator float() const {
88
+ // Conversion following
89
+ // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h
90
+
91
+ // Union
92
+ float_bits_fp16 out;
93
+
94
+ uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000);
95
+ uint32_t base = (bits_ << 16);
96
+ uint32_t two_base = base + base;
97
+
98
+ uint32_t denorm_max = 1u << 27;
99
+ if (two_base < denorm_max) {
100
+ out.u = uint32_t(126) << 23; // magic mask
101
+ out.u |= (two_base >> 17); // Bits from fp16
102
+ out.f -= 0.5f; // magic bias
103
+ } else {
104
+ out.u = uint32_t(0xE0) << 23; // exponent offset
105
+ out.u += (two_base >> 4); // Bits from fp16
106
+ float out_unscaled = out.f; // Store value
107
+ out.u = uint32_t(0x7800000); // exponent scale
108
+ out.f *= out_unscaled;
109
+ }
110
+
111
+ // Add sign
112
+ out.u |= x_sign_32;
113
+
114
+ return out.f;
115
+ }
116
+ };
117
+
118
+ #define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
119
+ inline otype __operator__(atype lhs, btype rhs) { \
120
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
121
+ }
122
+
123
+ #define half_binop_helper(__op__, __operator__, otype, itype, ctype) \
124
+ inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \
125
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
126
+ } \
127
+ inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \
128
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
129
+ }
130
+
131
+ // Operators
132
+ #define half_binop(__op__, __operator__) \
133
+ half_binop_base( \
134
+ __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \
135
+ half_binop_helper(__op__, __operator__, float, float, float); \
136
+ half_binop_helper(__op__, __operator__, double, double, double); \
137
+ half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \
138
+ half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \
139
+ half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \
140
+ half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \
141
+ half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float);
142
+
143
+ half_binop(+, operator+);
144
+ half_binop(-, operator-);
145
+ half_binop(*, operator*);
146
+ half_binop(/, operator/);
147
+
148
+ #undef half_binop
149
+
150
+ // Comparison ops
151
+ #define half_compop(__op__, __operator__) \
152
+ half_binop_base( \
153
+ __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \
154
+ half_binop_helper(__op__, __operator__, bool, float, float); \
155
+ half_binop_helper(__op__, __operator__, bool, double, double); \
156
+ half_binop_helper(__op__, __operator__, bool, int32_t, float); \
157
+ half_binop_helper(__op__, __operator__, bool, uint32_t, float); \
158
+ half_binop_helper(__op__, __operator__, bool, int64_t, float); \
159
+ half_binop_helper(__op__, __operator__, bool, uint64_t, float);
160
+
161
+ half_compop(>, operator>);
162
+ half_compop(<, operator<);
163
+ half_compop(>=, operator>=);
164
+ half_compop(<=, operator<=);
165
+ half_compop(==, operator==);
166
+ half_compop(!=, operator!=);
167
+
168
+ #undef half_compop
169
+
170
+ // Negative
171
+ inline _MLX_Float16 operator-(_MLX_Float16 lhs) {
172
+ return -static_cast<float>(lhs);
173
+ }
174
+
175
+ // Inplace ops
176
+ #define half_inplace_op(__op__, __operator__) \
177
+ inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \
178
+ lhs = lhs __op__ rhs; \
179
+ return lhs; \
180
+ } \
181
+ inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \
182
+ lhs = lhs __op__ rhs; \
183
+ return lhs; \
184
+ }
185
+
186
+ half_inplace_op(+, operator+=);
187
+ half_inplace_op(-, operator-=);
188
+ half_inplace_op(*, operator*=);
189
+ half_inplace_op(/, operator/=);
190
+
191
+ #undef half_inplace_op
192
+
193
+ // Bitwise ops
194
+
195
+ #define half_bitop(__op__, __operator__) \
196
+ inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \
197
+ _MLX_Float16 out; \
198
+ out.bits_ = lhs.bits_ __op__ rhs.bits_; \
199
+ return out; \
200
+ } \
201
+ inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \
202
+ _MLX_Float16 out; \
203
+ out.bits_ = lhs.bits_ __op__ rhs; \
204
+ return out; \
205
+ } \
206
+ inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \
207
+ _MLX_Float16 out; \
208
+ out.bits_ = lhs __op__ rhs.bits_; \
209
+ return out; \
210
+ }
211
+
212
+ half_bitop(|, operator|);
213
+ half_bitop(&, operator&);
214
+ half_bitop(^, operator^);
215
+
216
+ #undef half_bitop
217
+
218
+ #define half_inplace_bitop(__op__, __operator__) \
219
+ inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \
220
+ lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
221
+ return lhs; \
222
+ } \
223
+ inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \
224
+ lhs.bits_ = lhs.bits_ __op__ rhs; \
225
+ return lhs; \
226
+ }
227
+
228
+ half_inplace_bitop(|, operator|=);
229
+ half_inplace_bitop(&, operator&=);
230
+ half_inplace_bitop(^, operator^=);
231
+
232
+ #undef half_inplace_bitop
233
+
234
+ } // namespace mlx::core
@@ -0,0 +1,58 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
6
+
7
+ #include <arm_fp16.h>
8
+ namespace mlx::core {
9
+ using ::float16_t;
10
+ } // namespace mlx::core
11
+
12
+ #else
13
+
14
+ #define ADD_HALF_BINOPS
15
+ #include "mlx/types/fp16.h"
16
+ namespace mlx::core {
17
+ typedef struct _MLX_Float16 float16_t;
18
+ } // namespace mlx::core
19
+
20
+ #endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
21
+
22
+ #ifdef __ARM_FEATURE_BF16
23
+
24
+ #include <arm_bf16.h>
25
+ namespace mlx::core {
26
+ using ::bfloat16_t;
27
+ } // namespace mlx::core
28
+
29
+ #else
30
+
31
+ #define ADD_HALF_BINOPS
32
+ #include "mlx/types/bf16.h"
33
+ namespace mlx::core {
34
+ typedef struct _MLX_BFloat16 bfloat16_t;
35
+ } // namespace mlx::core
36
+
37
+ #endif // __ARM_FEATURE_BF16
38
+
39
+ #ifdef ADD_HALF_BINOPS
40
+ namespace mlx::core {
41
+
42
+ // clang-format off
43
+ #define fp16_bf16_binop_helper(__op__, __operator__) \
44
+ inline float __operator__(float16_t lhs, bfloat16_t rhs) { \
45
+ return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
46
+ } \
47
+ inline float __operator__(bfloat16_t lhs, float16_t rhs) { \
48
+ return static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
49
+ }
50
+
51
+ fp16_bf16_binop_helper(+, operator+)
52
+ fp16_bf16_binop_helper(-, operator-)
53
+ fp16_bf16_binop_helper(*, operator*)
54
+ fp16_bf16_binop_helper(/, operator/)
55
+ // clang-format on
56
+
57
+ } // namespace mlx::core
58
+ #endif
@@ -0,0 +1,70 @@
1
+ // Copyright © 2024 Apple Inc.
2
+ #pragma once
3
+
4
+ #include <limits>
5
+ #include "mlx/types/half_types.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ template <typename T>
10
+ struct numeric_limits;
11
+
12
+ template <>
13
+ struct numeric_limits<float> : public std::numeric_limits<float> {};
14
+
15
+ template <>
16
+ struct numeric_limits<double> : public std::numeric_limits<double> {};
17
+
18
+ template <>
19
+ struct numeric_limits<float16_t> {
20
+ private:
21
+ union half_or_bits {
22
+ uint16_t bits;
23
+ float16_t value;
24
+ };
25
+ constexpr static float16_t bits_to_half(uint16_t v) {
26
+ return half_or_bits{v}.value;
27
+ }
28
+
29
+ public:
30
+ constexpr static float16_t lowest() {
31
+ return bits_to_half(0xFBFF);
32
+ }
33
+ static constexpr float16_t max() {
34
+ return bits_to_half(0x7BFF);
35
+ }
36
+ static constexpr float16_t epsilon() {
37
+ return bits_to_half(0x1400);
38
+ }
39
+ static constexpr float16_t infinity() {
40
+ return bits_to_half(0x7C00);
41
+ }
42
+ };
43
+
44
+ template <>
45
+ struct numeric_limits<bfloat16_t> {
46
+ private:
47
+ union bfloat_or_bits {
48
+ uint16_t bits;
49
+ bfloat16_t value;
50
+ };
51
+ constexpr static bfloat16_t bits_to_bfloat(uint16_t v) {
52
+ return bfloat_or_bits{v}.value;
53
+ }
54
+
55
+ public:
56
+ constexpr static bfloat16_t lowest() {
57
+ return bits_to_bfloat(0xFF7F);
58
+ }
59
+ static constexpr bfloat16_t max() {
60
+ return bits_to_bfloat(0x7F7F);
61
+ }
62
+ static constexpr bfloat16_t epsilon() {
63
+ return bits_to_bfloat(0x3C00);
64
+ }
65
+ static constexpr bfloat16_t infinity() {
66
+ return bits_to_bfloat(0x7F80);
67
+ }
68
+ };
69
+
70
+ } // namespace mlx::core
@@ -0,0 +1,175 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <exception>
6
+ #include <variant>
7
+
8
+ #include "mlx/array.h"
9
+ #include "mlx/device.h"
10
+ #include "mlx/dtype.h"
11
+ #include "mlx/stream.h"
12
+
13
+ namespace mlx::core {
14
+
15
+ using StreamOrDevice = std::variant<std::monostate, Stream, Device>;
16
+ Stream to_stream(StreamOrDevice s);
17
+ Stream to_stream(StreamOrDevice s, Device default_);
18
+
19
+ struct StreamContext {
20
+ public:
21
+ StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) {
22
+ if (std::holds_alternative<std::monostate>(s)) {
23
+ throw std::runtime_error(
24
+ "[StreamContext] Invalid argument, please specify a stream or device.");
25
+ }
26
+ auto _s = to_stream(s);
27
+ set_default_device(_s.device);
28
+ set_default_stream(_s);
29
+ }
30
+
31
+ ~StreamContext() {
32
+ set_default_device(_stream.device);
33
+ set_default_stream(_stream);
34
+ }
35
+
36
+ private:
37
+ Stream _stream;
38
+ };
39
+
40
+ struct PrintFormatter {
41
+ inline void print(std::ostream& os, bool val);
42
+ inline void print(std::ostream& os, int16_t val);
43
+ inline void print(std::ostream& os, uint16_t val);
44
+ inline void print(std::ostream& os, int32_t val);
45
+ inline void print(std::ostream& os, uint32_t val);
46
+ inline void print(std::ostream& os, int64_t val);
47
+ inline void print(std::ostream& os, uint64_t val);
48
+ inline void print(std::ostream& os, float16_t val);
49
+ inline void print(std::ostream& os, bfloat16_t val);
50
+ inline void print(std::ostream& os, float val);
51
+ inline void print(std::ostream& os, double val);
52
+ inline void print(std::ostream& os, complex64_t val);
53
+
54
+ bool capitalize_bool{false};
55
+ };
56
+
57
+ PrintFormatter& get_global_formatter();
58
+
59
+ /** Print the exception and then abort. */
60
+ void abort_with_exception(const std::exception& error);
61
+
62
+ /** Holds information about floating-point types. */
63
+ struct finfo {
64
+ explicit finfo(Dtype dtype);
65
+ Dtype dtype;
66
+ double min;
67
+ double max;
68
+ double eps;
69
+ };
70
+
71
+ /** Holds information about integral types. */
72
+ struct iinfo {
73
+ explicit iinfo(Dtype dtype);
74
+ Dtype dtype;
75
+ int64_t min;
76
+ uint64_t max;
77
+ };
78
+
79
+ /** The type from promoting the arrays' types with one another. */
80
+ inline Dtype result_type(const array& a, const array& b) {
81
+ return promote_types(a.dtype(), b.dtype());
82
+ }
83
+ inline Dtype result_type(const array& a, const array& b, const array& c) {
84
+ return promote_types(result_type(a, b), c.dtype());
85
+ }
86
+ Dtype result_type(const std::vector<array>& arrays);
87
+
88
+ Shape broadcast_shapes(const Shape& s1, const Shape& s2);
89
+
90
+ /**
91
+ * Returns the axis normalized to be in the range [0, ndim).
92
+ */
93
+ int normalize_axis_index(
94
+ int axis,
95
+ int ndim,
96
+ const std::string& msg_prefix = "");
97
+
98
+ std::ostream& operator<<(std::ostream& os, const Device& d);
99
+ std::ostream& operator<<(std::ostream& os, const Stream& s);
100
+ std::ostream& operator<<(std::ostream& os, const Dtype& d);
101
+ std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k);
102
+ std::ostream& operator<<(std::ostream& os, array a);
103
+ inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) {
104
+ return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j";
105
+ }
106
+ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) {
107
+ return os << static_cast<float>(v);
108
+ }
109
+ inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) {
110
+ return os << static_cast<float>(v);
111
+ }
112
+
113
+ template <typename Vec, typename = std::enable_if_t<is_vector_v<Vec>>>
114
+ inline std::ostream& operator<<(std::ostream& os, const Vec& v) {
115
+ os << "(";
116
+ for (auto it = v.begin(); it != v.end(); ++it) {
117
+ os << *it;
118
+ if (it != std::prev(v.end())) {
119
+ os << ",";
120
+ }
121
+ }
122
+ os << ")";
123
+ return os;
124
+ }
125
+
126
+ inline bool is_power_of_2(int n) {
127
+ return ((n & (n - 1)) == 0) && n != 0;
128
+ }
129
+
130
+ inline int next_power_of_2(int n) {
131
+ if (is_power_of_2(n)) {
132
+ return n;
133
+ }
134
+ return pow(2, std::ceil(std::log2(n)));
135
+ }
136
+
137
+ namespace env {
138
+
139
+ int get_var(const char* name, int default_value);
140
+
141
+ inline int bfs_max_width() {
142
+ static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20);
143
+ return bfs_max_width_;
144
+ }
145
+
146
+ inline int max_ops_per_buffer(int default_value) {
147
+ static int max_ops_per_buffer_ =
148
+ get_var("MLX_MAX_OPS_PER_BUFFER", default_value);
149
+ return max_ops_per_buffer_;
150
+ }
151
+
152
+ inline int max_mb_per_buffer(int default_value) {
153
+ static int max_mb_per_buffer_ =
154
+ get_var("MLX_MAX_MB_PER_BUFFER", default_value);
155
+ return max_mb_per_buffer_;
156
+ }
157
+
158
+ inline bool metal_fast_synch() {
159
+ static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0);
160
+ return metal_fast_synch;
161
+ }
162
+
163
+ inline bool enable_tf32() {
164
+ static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1);
165
+ return enable_tf32_;
166
+ }
167
+
168
+ inline int nccl_timeout(int default_value) {
169
+ static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value);
170
+ return nccl_timeout;
171
+ }
172
+
173
+ } // namespace env
174
+
175
+ } // namespace mlx::core
@@ -0,0 +1,20 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #define MLX_VERSION_MAJOR 0
6
+ #define MLX_VERSION_MINOR 30
7
+ #define MLX_VERSION_PATCH 1
8
+ #define MLX_VERSION_NUMERIC \
9
+ (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)
10
+
11
+ namespace mlx::core {
12
+
13
+ /* A string representation of the MLX version in the format
14
+ * "major.minor.patch".
15
+ *
16
+ * For dev builds, the version will include the suffix ".devYYYYMMDD+hash"
17
+ */
18
+ const char* version();
19
+
20
+ } // namespace mlx::core
mlx/lib/libmlx.so ADDED
Binary file
mlx/py.typed ADDED
@@ -0,0 +1 @@
1
+
@@ -0,0 +1,54 @@
1
+ # FindNCCL.cmake This module finds the NVIDIA NCCL library and its include
2
+ # directories.
3
+
4
+ set(NCCL_ROOT_DIR
5
+ $ENV{NCCL_ROOT_DIR}
6
+ CACHE PATH "Folder contains NVIDIA NCCL")
7
+
8
+ find_path(
9
+ NCCL_INCLUDE_DIRS
10
+ NAMES nccl.h
11
+ HINTS ${NCCL_INCLUDE_DIR} ${NCCL_ROOT_DIR} ${NCCL_ROOT_DIR}/include
12
+ ${CUDA_TOOLKIT_ROOT_DIR}/include)
13
+
14
+ if($ENV{USE_STATIC_NCCL})
15
+ message(
16
+ STATUS "USE_STATIC_NCCL detected. Linking against static NCCL library")
17
+ set(NCCL_LIBNAME "libnccl_static.a")
18
+ else()
19
+ set(NCCL_LIBNAME "nccl")
20
+ endif()
21
+
22
+ find_library(
23
+ NCCL_LIBRARIES
24
+ NAMES ${NCCL_LIBNAME}
25
+ HINTS ${NCCL_LIB_DIR}
26
+ ${NCCL_ROOT_DIR}
27
+ ${NCCL_ROOT_DIR}/lib
28
+ ${NCCL_ROOT_DIR}/lib/x86_64-linux-gnu
29
+ ${NCCL_ROOT_DIR}/lib64
30
+ ${CUDA_TOOLKIT_ROOT_DIR}/lib
31
+ ${CUDA_TOOLKIT_ROOT_DIR}/lib64)
32
+
33
+ include(FindPackageHandleStandardArgs)
34
+ find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS
35
+ NCCL_LIBRARIES)
36
+
37
+ if(NCCL_FOUND)
38
+ set(NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
39
+ message(
40
+ STATUS "Determining NCCL version from the header file: ${NCCL_HEADER_FILE}")
41
+ file(
42
+ STRINGS ${NCCL_HEADER_FILE} NCCL_MAJOR_VERSION_DEFINED
43
+ REGEX "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+[0-9]+.*$"
44
+ LIMIT_COUNT 1)
45
+ if(NCCL_MAJOR_VERSION_DEFINED)
46
+ string(REGEX REPLACE "^[ \t]*#define[ \t]+NCCL_MAJOR[ \t]+" ""
47
+ NCCL_MAJOR_VERSION ${NCCL_MAJOR_VERSION_DEFINED})
48
+ message(STATUS "NCCL_MAJOR_VERSION: ${NCCL_MAJOR_VERSION}")
49
+ endif()
50
+ message(
51
+ STATUS
52
+ "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
53
+ mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
54
+ endif()
@@ -0,0 +1,3 @@
1
+ # This file does nothing but to suppress the cmake warning: "By not providing
2
+ # Findnvpl.cmake in CMAKE_MODULE_PATH...", which is caused by the
3
+ # find_package(nvpl) from cmake's builtin FindLAPACK.cmake module.
@@ -0,0 +1,66 @@
1
+ # Find MLX
2
+ #
3
+ # Defines the following variables:
4
+ #
5
+ # MLX_FOUND : True if MLX is found
6
+ # MLX_INCLUDE_DIRS : Include directory
7
+ # MLX_LIBRARIES : Libraries to link against
8
+ # MLX_CXX_FLAGS : Additional compiler flags
9
+ # MLX_BUILD_ACCELERATE : True if MLX was built with accelerate
10
+ # MLX_BUILD_METAL : True if MLX was built with metal
11
+
12
+
13
+ ####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() #######
14
+ ####### Any changes to this file will be overwritten by the next CMake run ####
15
+ ####### The input file was mlx.pc.in ########
16
+
17
+ get_filename_component(PACKAGE_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE)
18
+
19
+ macro(set_and_check _var _file)
20
+ set(${_var} "${_file}")
21
+ if(NOT EXISTS "${_file}")
22
+ message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !")
23
+ endif()
24
+ endmacro()
25
+
26
+ ####################################################################################
27
+
28
+ include(${PACKAGE_PREFIX_DIR}/share/cmake/MLX/MLXTargets.cmake)
29
+ include(${PACKAGE_PREFIX_DIR}/share/cmake/MLX/extension.cmake)
30
+
31
+ set_and_check(MLX_LIBRARY_DIRS ${PACKAGE_PREFIX_DIR}/lib)
32
+ set_and_check(MLX_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
33
+ set(MLX_LIBRARIES mlx)
34
+
35
+ find_library(MLX_LIBRARY mlx PATHS ${MLX_LIBRARY_DIRS})
36
+
37
+ if (OFF)
38
+ set(MLX_BUILD_ACCELERATE OFF)
39
+ set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -DACCELERATE_NEW_LAPACK)
40
+ endif()
41
+
42
+ if (OFF)
43
+ set(MLX_BUILD_METAL OFF)
44
+ set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
45
+ set(MLX_INCLUDE_DIRS
46
+ "${MLX_INCLUDE_DIRS};"
47
+ ${PACKAGE_PREFIX_DIR}/include/metal_cpp
48
+ )
49
+ if( GREATER_EQUAL 310)
50
+ set(MLX_INCLUDE_DIRS
51
+ "${MLX_INCLUDE_DIRS};"
52
+ ${PACKAGE_PREFIX_DIR}/include/mlx/backend/metal/kernels/metal_3_1)
53
+ else()
54
+ set(MLX_INCLUDE_DIRS
55
+ "${MLX_INCLUDE_DIRS};"
56
+ ${PACKAGE_PREFIX_DIR}/include/mlx/backend/metal/kernels/metal_3_0)
57
+ endif()
58
+ endif()
59
+
60
+ set_target_properties(mlx PROPERTIES
61
+ CXX_STANDARD 17
62
+ INTERFACE_COMPILE_OPTIONS "${MLX_CXX_FLAGS}"
63
+ )
64
+
65
+ include(FindPackageHandleStandardArgs)
66
+ find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)