mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,114 @@
1
+ // Copyright © 2025 Apple Inc.
2
+ #pragma once
3
+
4
+ #include "mlx/array.h"
5
+ #include "mlx/backend/cuda/cublas_utils.h"
6
+ #include "mlx/backend/cuda/device.h"
7
+
8
+ #include <cublasLt.h>
9
+
10
+ namespace mlx::core {
11
+
12
+ class CublasGemm : public CublasMatmulBase {
13
+ public:
14
+ CublasGemm(
15
+ cu::Device& device,
16
+ Dtype dtype,
17
+ bool a_transposed,
18
+ uint64_t a_rows,
19
+ uint64_t a_cols,
20
+ int64_t lda,
21
+ bool b_transposed,
22
+ uint64_t b_rows,
23
+ uint64_t b_cols,
24
+ int64_t ldb,
25
+ int32_t batch_count,
26
+ int64_t a_batch_stride,
27
+ int64_t b_batch_stride);
28
+
29
+ CublasGemm(
30
+ cu::Device& device,
31
+ Dtype dtype,
32
+ bool a_transposed,
33
+ uint64_t a_rows,
34
+ uint64_t a_cols,
35
+ int64_t lda,
36
+ bool b_transposed,
37
+ uint64_t b_rows,
38
+ uint64_t b_cols,
39
+ int64_t ldb,
40
+ int64_t ldc,
41
+ int32_t batch_count,
42
+ int64_t a_batch_stride,
43
+ int64_t b_batch_stride,
44
+ int64_t c_batch_stride);
45
+
46
+ // The output's descriptor is inferred from inputs by default, use this method
47
+ // for unusual output.
48
+ void set_out(
49
+ Dtype dtype,
50
+ bool transposed,
51
+ uint64_t rows,
52
+ uint64_t cols,
53
+ int64_t ld,
54
+ int32_t batch_count,
55
+ int64_t batch_stride);
56
+
57
+ void run(
58
+ cu::CommandEncoder& encoder,
59
+ array& out,
60
+ const array& a,
61
+ const array& b,
62
+ const Shape& batch_shape,
63
+ const Strides& a_batch_strides,
64
+ const Strides& b_batch_strides,
65
+ float alpha = 1.0f);
66
+
67
+ void run(
68
+ cu::CommandEncoder& encoder,
69
+ array& out,
70
+ const array& a,
71
+ const array& b,
72
+ const array& c,
73
+ const Shape& batch_shape,
74
+ const Strides& a_batch_strides,
75
+ const Strides& b_batch_strides,
76
+ const Strides& c_batch_strides,
77
+ float alpha,
78
+ float beta);
79
+
80
+ private:
81
+ void run_batched(
82
+ cu::CommandEncoder& encoder,
83
+ array& out,
84
+ const array& a,
85
+ const array& b,
86
+ const Shape& batch_shape,
87
+ const Strides& a_batch_strides,
88
+ const Strides& b_batch_strides,
89
+ float alpha);
90
+
91
+ void run_batched(
92
+ cu::CommandEncoder& encoder,
93
+ array& out,
94
+ const array& a,
95
+ const array& b,
96
+ const array& c,
97
+ const Shape& batch_shape,
98
+ const Strides& a_batch_strides,
99
+ const Strides& b_batch_strides,
100
+ const Strides& c_batch_strides,
101
+ float alpha,
102
+ float beta);
103
+
104
+ void execute(
105
+ cu::CommandEncoder& encoder,
106
+ void* out,
107
+ const void* a,
108
+ const void* b,
109
+ const void* c,
110
+ float alpha = 1,
111
+ float beta = 0);
112
+ };
113
+
114
+ } // namespace mlx::core
@@ -0,0 +1,24 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/cuda/device.h"
6
+
7
+ namespace mlx::core::cu {
8
+
9
+ bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed);
10
+
11
+ void gemv(
12
+ const array& a,
13
+ const array& b,
14
+ array& out,
15
+ int M,
16
+ int N,
17
+ int K,
18
+ uint32_t batch_count,
19
+ const mlx::core::Shape& batch_shape,
20
+ const mlx::core::Strides& a_batch_strides,
21
+ const mlx::core::Strides& b_batch_strides,
22
+ CommandEncoder& encoder);
23
+
24
+ } // namespace mlx::core::cu
@@ -0,0 +1,119 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+ #include "mlx/backend/common/utils.h"
7
+ #include "mlx/backend/cuda/device.h"
8
+ #include "mlx/backend/cuda/device/config.h"
9
+
10
+ #include <deque>
11
+ #include <unordered_map>
12
+ #include <utility>
13
+ #include <variant>
14
+
15
+ #include <cuda.h>
16
+ #include <fmt/format.h>
17
+
18
+ namespace mlx::core::cu {
19
+
20
+ class Device;
21
+
22
+ using KernelBuilderResult = std::tuple<
23
+ /* precompiled */ bool,
24
+ /* source code */ std::string,
25
+ /* kernel names */ std::vector<std::string>>;
26
+ using KernelBuilder = std::function<KernelBuilderResult()>;
27
+
28
+ struct KernelArgs {
29
+ void** args() {
30
+ return args_.data();
31
+ }
32
+
33
+ void append(const array& a) {
34
+ append(reinterpret_cast<CUdeviceptr>(gpu_ptr<void>(a)));
35
+ }
36
+
37
+ template <typename T>
38
+ void append(T val) {
39
+ storage_.emplace_back(val);
40
+ append_ptr(&storage_.back());
41
+ }
42
+
43
+ template <typename T>
44
+ void append(SmallVector<T> vec) {
45
+ storage_.emplace_back(std::move(vec));
46
+ append_ptr(std::get<SmallVector<T>>(storage_.back()).data());
47
+ }
48
+
49
+ template <typename T>
50
+ void append(const std::vector<T>& vec) {
51
+ append(SmallVector<T>(vec.begin(), vec.end()));
52
+ }
53
+
54
+ // Make sure the arg is copied to an array with size of NDIM.
55
+ template <size_t NDIM = MAX_NDIM, typename T>
56
+ void append_ndim(SmallVector<T> vec) {
57
+ if (vec.size() > NDIM) {
58
+ throw std::runtime_error(
59
+ fmt::format("ndim can not be larger than {}.", NDIM));
60
+ }
61
+ vec.resize(NDIM);
62
+ append(std::move(vec));
63
+ }
64
+
65
+ void append_ptr(const void* v) {
66
+ args_.push_back(const_cast<void*>(v));
67
+ }
68
+
69
+ private:
70
+ std::vector<void*> args_;
71
+
72
+ // The cuGraphAddKernelNode API requires passing pointers to arguments so
73
+ // store temporary values until the node is created.
74
+ using Arg = std::variant<
75
+ std::monostate,
76
+ CUdeviceptr,
77
+ bool,
78
+ int32_t,
79
+ uint32_t,
80
+ int64_t,
81
+ float,
82
+ SmallVector<const void*>,
83
+ SmallVector<int32_t>,
84
+ SmallVector<int64_t>>;
85
+ std::deque<Arg> storage_;
86
+ };
87
+
88
+ class JitModule {
89
+ public:
90
+ JitModule(
91
+ Device& device,
92
+ const std::string& module_name,
93
+ const KernelBuilder& builder,
94
+ bool cache);
95
+ ~JitModule();
96
+
97
+ JitModule(const JitModule&) = delete;
98
+ JitModule& operator=(const JitModule&) = delete;
99
+ CUfunction get_kernel(
100
+ const std::string& kernel_name,
101
+ std::function<void(CUfunction)> configure_kernel = nullptr);
102
+ std::pair<CUfunction, uint> get_kernel_and_dims(
103
+ const std::string& kernel_name,
104
+ std::function<void(CUfunction)> configure_kernel = nullptr);
105
+
106
+ private:
107
+ CUmodule module_{nullptr};
108
+ std::unordered_map<std::string, std::tuple<CUfunction, bool, uint>> kernels_;
109
+ };
110
+
111
+ std::unordered_map<std::string, JitModule>& get_jit_module_cache();
112
+
113
+ JitModule& get_jit_module(
114
+ const mlx::core::Device& device,
115
+ const std::string& name,
116
+ const KernelBuilder& builder,
117
+ bool use_disk_cache = true);
118
+
119
+ } // namespace mlx::core::cu
@@ -0,0 +1,189 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/utils.h"
6
+
7
+ #include <cstring>
8
+ #include <list>
9
+ #include <unordered_map>
10
+ #include <utility>
11
+
12
+ #include <fmt/format.h>
13
+
14
+ namespace mlx::core {
15
+
16
+ template <
17
+ typename K,
18
+ typename V,
19
+ template <typename...> typename M = std::unordered_map>
20
+ class LRUCache {
21
+ public:
22
+ using value_type = std::pair<K, V>;
23
+ using list_type = std::list<value_type>;
24
+ using iterator = typename list_type::iterator;
25
+ using const_iterator = typename list_type::const_iterator;
26
+ using map_type = M<K, iterator>;
27
+
28
+ explicit LRUCache(size_t capacity) : capacity_(capacity) {
29
+ if (capacity == 0) {
30
+ throw std::runtime_error("LRUCache requires capacity > 0.");
31
+ }
32
+ }
33
+
34
+ // Initialize with capacity read from |env_name|.
35
+ LRUCache(const char* env_name, int default_capacity)
36
+ : LRUCache(env::get_var(env_name, default_capacity)) {
37
+ if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) {
38
+ env_name_ = env_name;
39
+ }
40
+ }
41
+
42
+ size_t size() const {
43
+ return map_.size();
44
+ }
45
+ size_t capacity() const {
46
+ return capacity_;
47
+ }
48
+ bool empty() const {
49
+ return vlist_.empty();
50
+ }
51
+
52
+ void resize(size_t new_capacity) {
53
+ capacity_ = new_capacity;
54
+ trim();
55
+ }
56
+
57
+ iterator begin() {
58
+ return vlist_.begin();
59
+ }
60
+ const_iterator begin() const {
61
+ return vlist_.begin();
62
+ }
63
+ iterator end() {
64
+ return vlist_.end();
65
+ }
66
+ const_iterator end() const {
67
+ return vlist_.end();
68
+ }
69
+
70
+ void clear() {
71
+ map_.clear();
72
+ vlist_.clear();
73
+ }
74
+
75
+ iterator find(const K& key) {
76
+ auto it = map_.find(key);
77
+ if (it == map_.end())
78
+ return end();
79
+ vlist_.splice(vlist_.begin(), vlist_, it->second);
80
+ return it->second;
81
+ }
82
+
83
+ template <typename U>
84
+ std::pair<iterator, bool> emplace(const K& key, U&& value) {
85
+ auto it = map_.find(key);
86
+ if (it != map_.end()) {
87
+ vlist_.splice(vlist_.begin(), vlist_, it->second);
88
+ return {it->second, false};
89
+ }
90
+
91
+ if (env_name_ && ++cache_misses_ > 2 * capacity_) {
92
+ throw std::runtime_error(fmt::format(
93
+ "Cache thrashing is happening, please set the environment variable "
94
+ "{} to a larger value than {} to fix degraded performance.",
95
+ env_name_,
96
+ capacity_));
97
+ }
98
+
99
+ vlist_.emplace_front(key, std::forward<U>(value));
100
+ map_[key] = vlist_.begin();
101
+
102
+ trim();
103
+
104
+ return {vlist_.begin(), true};
105
+ }
106
+
107
+ iterator erase(iterator pos) {
108
+ map_.erase(pos->first);
109
+ return vlist_.erase(pos);
110
+ }
111
+
112
+ V& operator[](const K& key) {
113
+ auto it = find(key);
114
+ if (it == end()) {
115
+ it = emplace(key, V{}).first;
116
+ }
117
+ return it->second;
118
+ }
119
+
120
+ private:
121
+ void trim() {
122
+ while (map_.size() > capacity_) {
123
+ auto last = std::prev(vlist_.end());
124
+ map_.erase(last->first);
125
+ vlist_.pop_back();
126
+ }
127
+ }
128
+
129
+ const char* env_name_{nullptr};
130
+ size_t cache_misses_{0};
131
+
132
+ list_type vlist_;
133
+ map_type map_;
134
+ size_t capacity_;
135
+ };
136
+
137
+ // Turn a POD struct into a container key by doing bytes compare.
138
+ //
139
+ // Usage:
140
+ // BytesKey<MyKey> key;
141
+ // key.pod = { ... };
142
+ template <typename T>
143
+ struct BytesKey {
144
+ T pod;
145
+ static_assert(std::is_standard_layout_v<T>, "T is not POD");
146
+
147
+ BytesKey() {
148
+ // Make sure the paddings between members are filled with 0.
149
+ memset(&pod, 0, sizeof(T));
150
+ }
151
+
152
+ BytesKey(const BytesKey& other) {
153
+ memcpy(&pod, &other.pod, sizeof(T));
154
+ }
155
+
156
+ BytesKey(BytesKey&& other) {
157
+ memcpy(&pod, &other.pod, sizeof(T));
158
+ }
159
+
160
+ bool operator==(const BytesKey& other) const {
161
+ auto* ptr1 = reinterpret_cast<const uint8_t*>(&pod);
162
+ auto* ptr2 = reinterpret_cast<const uint8_t*>(&other.pod);
163
+ return memcmp(ptr1, ptr2, sizeof(T)) == 0;
164
+ }
165
+ };
166
+
167
+ // Compute hash according to the bytes value of T.
168
+ template <typename T>
169
+ struct BytesHash {
170
+ static_assert(std::is_standard_layout_v<T>, "T is not POD");
171
+
172
+ size_t operator()(const T& pod) const {
173
+ auto* ptr = reinterpret_cast<const uint8_t*>(&pod);
174
+ uint32_t value = 0x811C9DC5;
175
+ for (int i = 0; i < sizeof(T); ++i) {
176
+ value ^= ptr[i];
177
+ value *= 0x01000193;
178
+ }
179
+ return value;
180
+ }
181
+ };
182
+
183
+ template <typename K, typename V>
184
+ using BytesKeyHashMap = std::unordered_map<K, V, BytesHash<K>>;
185
+
186
+ template <typename K, typename V>
187
+ using LRUBytesKeyCache = LRUCache<BytesKey<K>, V, BytesKeyHashMap>;
188
+
189
+ } // namespace mlx::core
@@ -0,0 +1,88 @@
1
+ // Copyright © 2025 Apple Inc.
2
+ #pragma once
3
+
4
+ #include "mlx/array.h"
5
+ #include "mlx/backend/cuda/cublas_utils.h"
6
+ #include "mlx/backend/cuda/device.h"
7
+
8
+ #include <cublasLt.h>
9
+
10
+ namespace mlx::core {
11
+
12
+ class CublasQQMM : public CublasMatmulBase {
13
+ public:
14
+ CublasQQMM(
15
+ cu::Device& device,
16
+ bool a_transposed,
17
+ uint64_t a_rows,
18
+ uint64_t a_cols,
19
+ int64_t lda,
20
+ bool b_transposed,
21
+ uint64_t b_rows,
22
+ uint64_t b_cols,
23
+ int64_t ldb,
24
+ int32_t batch_count,
25
+ int64_t a_batch_stride,
26
+ int64_t b_batch_stride,
27
+ Dtype out_dtype,
28
+ std::string quantization_mode);
29
+
30
+ CublasQQMM(
31
+ cu::Device& device,
32
+ bool a_transposed,
33
+ uint64_t a_rows,
34
+ uint64_t a_cols,
35
+ int64_t lda,
36
+ bool b_transposed,
37
+ uint64_t b_rows,
38
+ uint64_t b_cols,
39
+ int64_t ldb,
40
+ int64_t ldc,
41
+ int32_t batch_count,
42
+ int64_t a_batch_stride,
43
+ int64_t b_batch_stride,
44
+ int64_t c_batch_stride,
45
+ Dtype out_dtype,
46
+ std::string quantization_mode);
47
+
48
+ void run(
49
+ cu::CommandEncoder& encoder,
50
+ array& out,
51
+ const array& a,
52
+ const array& b,
53
+ const array& a_scale,
54
+ const array& b_scale,
55
+ float alpha = 1.0f);
56
+
57
+ private:
58
+ void run_batched(
59
+ cu::CommandEncoder& encoder,
60
+ array& out,
61
+ const array& a,
62
+ const array& b,
63
+ const array& a_scale,
64
+ const array& b_scale,
65
+ const Shape& batch_shape,
66
+ const Strides& a_batch_strides,
67
+ const Strides& b_batch_strides,
68
+ float alpha);
69
+
70
+ void execute(
71
+ cu::CommandEncoder& encoder,
72
+ void* out,
73
+ const void* a,
74
+ const void* b,
75
+ const void* a_scale,
76
+ const void* b_scale,
77
+ const void* c,
78
+ float alpha = 1,
79
+ float beta = 0);
80
+
81
+ std::string quantization_mode_;
82
+ cublasLtMatmulMatrixScale_t a_scale_mode_;
83
+ cublasLtMatmulMatrixScale_t b_scale_mode_;
84
+ cublasLtMatmulMatrixScale_t c_scale_mode_;
85
+ cublasLtMatmulMatrixScale_t out_scale_mode_;
86
+ };
87
+
88
+ } // namespace mlx::core
@@ -0,0 +1,83 @@
1
+ #pragma once
2
+
3
+ struct __nv_fp8_e8m0 {
4
+ __device__ __nv_fp8_e8m0(float x) {
5
+ if (!std::isfinite(x)) {
6
+ __x = 0xFF;
7
+ return;
8
+ }
9
+ if (x < 0.0f) {
10
+ __x = 0x00;
11
+ return;
12
+ }
13
+ float le = std::log2f(x);
14
+ int n = static_cast<int>(std::nearbyintf(le));
15
+
16
+ n = n < -127 ? -127 : n;
17
+ n = n > 127 ? 127 : n;
18
+ __x = static_cast<uint8_t>(n + 127);
19
+ }
20
+
21
+ __device__ operator float() {
22
+ if (__x == 0xFF) {
23
+ return std::numeric_limits<float>::quiet_NaN();
24
+ }
25
+ return std::ldexp(1.0f, static_cast<int>(__x) - 127);
26
+ }
27
+
28
+ uint8_t __x{0};
29
+ };
30
+
31
+ struct __nv_fp4_e2m1 {
32
+ __device__ __nv_fp4_e2m1(float x) {
33
+ if (std::isnan(x)) {
34
+ __x = 0x7;
35
+ return;
36
+ }
37
+
38
+ const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0;
39
+ x = std::abs(x);
40
+
41
+ if (x > 5.0f) {
42
+ __x = 0x7;
43
+ } else if (x >= 3.5f) {
44
+ __x = 0x6;
45
+ } else if (x > 2.5f) {
46
+ __x = 0x5;
47
+ } else if (x >= 1.75f) {
48
+ __x = 0x4;
49
+ } else if (x > 1.25f) {
50
+ __x = 0x3;
51
+ } else if (x >= 0.75f) {
52
+ __x = 0x2;
53
+ } else if (x > 0.25f) {
54
+ __x = 0x1;
55
+ } else {
56
+ __x = 0x0;
57
+ }
58
+ __x |= sign_bit;
59
+ }
60
+
61
+ __device__ operator float() {
62
+ static const float LUT[16] = {
63
+ 0.0f,
64
+ 0.5f,
65
+ 1.0f,
66
+ 1.5f,
67
+ 2.0f,
68
+ 3.0f,
69
+ 4.0f,
70
+ 6.0f,
71
+ -0.0f,
72
+ -0.5f,
73
+ -1.0f,
74
+ -1.5f,
75
+ -2.0f,
76
+ -3.0f,
77
+ -4.0f,
78
+ -6.0f};
79
+
80
+ return LUT[__x];
81
+ }
82
+ uint8_t __x{0};
83
+ };
@@ -0,0 +1,30 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+ #include "mlx/backend/cuda/device.h"
7
+
8
+ namespace mlx::core {
9
+
10
+ // Compute padded dimensions for tiled layout
11
+ // Tiles are 128 rows × 4 columns, must allocate full tiles
12
+ inline std::pair<int, int> get_padded_scale_dims(int num_rows, int num_cols) {
13
+ constexpr int rows_per_tile = 128;
14
+ constexpr int cols_per_tile = 4;
15
+
16
+ int padded_rows =
17
+ ((num_rows + rows_per_tile - 1) / rows_per_tile) * rows_per_tile;
18
+ int padded_cols =
19
+ ((num_cols + cols_per_tile - 1) / cols_per_tile) * cols_per_tile;
20
+
21
+ return {padded_rows, padded_cols};
22
+ }
23
+
24
+ void repack_scales(
25
+ const array& scales,
26
+ array& scales_tiled,
27
+ cu::CommandEncoder& enc,
28
+ const Stream& s);
29
+
30
+ } // namespace mlx::core
@@ -0,0 +1,45 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/device.h"
4
+
5
+ namespace mlx::core {
6
+
7
+ void affine_quantize(
8
+ const array& w,
9
+ array& wq,
10
+ array& scales,
11
+ array& biases,
12
+ int group_size_,
13
+ int bits_,
14
+ cu::CommandEncoder& enc,
15
+ const Stream& s);
16
+
17
+ void affine_dequantize(
18
+ const array& wq,
19
+ const array& scales,
20
+ const array& biases,
21
+ array& w,
22
+ int group_size_,
23
+ int bits_,
24
+ cu::CommandEncoder& enc,
25
+ const Stream& s);
26
+
27
+ void fp_quantize(
28
+ const array& w,
29
+ array& wq,
30
+ array& scales,
31
+ int group_size,
32
+ int bits,
33
+ cu::CommandEncoder& enc,
34
+ const Stream& s);
35
+
36
+ void fp_dequantize(
37
+ const array& wq,
38
+ const array& scales,
39
+ array& w,
40
+ int group_size,
41
+ int bits,
42
+ cu::CommandEncoder& enc,
43
+ const Stream& s);
44
+
45
+ } // namespace mlx::core