mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,136 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <optional>
6
+ #include <set>
7
+ #include <unordered_map>
8
+ #include <variant>
9
+ #include "mlx/array.h"
10
+
11
+ namespace mlx::core {
12
+
13
+ using Args = std::vector<array>;
14
+ using Kwargs = std::unordered_map<std::string, array>;
15
+
16
+ // Possible types for a Primitive's state
17
+ using StateT = std::variant<
18
+ bool,
19
+ int,
20
+ size_t,
21
+ float,
22
+ double,
23
+ Dtype,
24
+ Shape,
25
+ Strides,
26
+ std::vector<int>,
27
+ std::vector<size_t>,
28
+ std::vector<std::tuple<bool, bool, bool>>,
29
+ std::vector<std::variant<bool, int, float>>,
30
+ std::optional<float>,
31
+ std::string>;
32
+
33
+ using ExportCallbackInput = std::unordered_map<
34
+ std::string,
35
+ std::variant<
36
+ std::vector<std::tuple<std::string, Shape, Dtype>>,
37
+ std::vector<std::pair<std::string, array>>,
38
+ std::vector<std::pair<std::string, std::string>>,
39
+ std::vector<StateT>,
40
+ std::string>>;
41
+ using ExportCallback = std::function<void(const ExportCallbackInput&)>;
42
+
43
+ struct FunctionExporter;
44
+
45
+ /**
46
+ * Make an exporter to save multiple traces of a given function to
47
+ * the same file.
48
+ */
49
+ FunctionExporter exporter(
50
+ const std::string& file,
51
+ const std::function<std::vector<array>(const Args&)>& fun,
52
+ bool shapeless = false);
53
+
54
+ FunctionExporter exporter(
55
+ const std::string& file,
56
+ const std::function<std::vector<array>(const Kwargs&)>& fun,
57
+ bool shapeless = false);
58
+
59
+ FunctionExporter exporter(
60
+ const std::string& path,
61
+ const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
62
+ bool shapeless = false);
63
+
64
+ /**
65
+ * Export a function to a file.
66
+ */
67
+ void export_function(
68
+ const std::string& file,
69
+ const std::function<std::vector<array>(const Args&)>& fun,
70
+ const Args& args,
71
+ bool shapeless = false);
72
+
73
+ void export_function(
74
+ const std::string& file,
75
+ const std::function<std::vector<array>(const Kwargs&)>& fun,
76
+ const Kwargs& kwargs,
77
+ bool shapeless = false);
78
+
79
+ void export_function(
80
+ const std::string& file,
81
+ const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
82
+ const Args& args,
83
+ const Kwargs& kwargs,
84
+ bool shapeless = false);
85
+
86
+ struct ImportedFunction;
87
+
88
+ /**
89
+ * Import a function from a file.
90
+ */
91
+ ImportedFunction import_function(const std::string& file);
92
+
93
+ /**
94
+ * Make an exporter to export multiple traces of a given function with the same
95
+ * callback.
96
+ */
97
+ FunctionExporter exporter(
98
+ const ExportCallback& callback,
99
+ const std::function<std::vector<array>(const Args&)>& fun,
100
+ bool shapeless = false);
101
+
102
+ FunctionExporter exporter(
103
+ const ExportCallback& callback,
104
+ const std::function<std::vector<array>(const Kwargs&)>& fun,
105
+ bool shapeless = false);
106
+
107
+ FunctionExporter exporter(
108
+ const ExportCallback& callback,
109
+ const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
110
+ bool shapeless = false);
111
+
112
+ /**
113
+ * Export a function with a callback.
114
+ */
115
+ void export_function(
116
+ const ExportCallback& callback,
117
+ const std::function<std::vector<array>(const Args&)>& fun,
118
+ const Args& args,
119
+ bool shapeless = false);
120
+
121
+ void export_function(
122
+ const ExportCallback& callback,
123
+ const std::function<std::vector<array>(const Kwargs&)>& fun,
124
+ const Kwargs& kwargs,
125
+ bool shapeless = false);
126
+
127
+ void export_function(
128
+ const ExportCallback& callback,
129
+ const std::function<std::vector<array>(const Args&, const Kwargs&)>& fun,
130
+ const Args& args,
131
+ const Kwargs& kwargs,
132
+ bool shapeless = false);
133
+
134
+ } // namespace mlx::core
135
+
136
+ #include "mlx/export_impl.h"
@@ -0,0 +1,98 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include "mlx/io/load.h"
4
+
5
+ #pragma once
6
+
7
+ namespace mlx::core {
8
+
9
+ struct FunctionTable;
10
+
11
+ struct FunctionExporter {
12
+ void operator()(const std::initializer_list<array>& args) {
13
+ this->operator()(Args(args));
14
+ }
15
+ void operator()(const Args& args);
16
+ void operator()(const Kwargs& kwargs);
17
+ void operator()(const Args& args, const Kwargs& kwargs);
18
+
19
+ void close();
20
+
21
+ FunctionExporter(const FunctionExporter&) = delete;
22
+ FunctionExporter& operator=(const FunctionExporter&) = delete;
23
+ FunctionExporter(FunctionExporter&& other) = default;
24
+
25
+ private:
26
+ friend FunctionExporter exporter(
27
+ const std::string&,
28
+ const std::function<std::vector<array>(const Args&)>&,
29
+ bool shapeless);
30
+
31
+ friend FunctionExporter exporter(
32
+ const std::string&,
33
+ const std::function<std::vector<array>(const Kwargs&)>&,
34
+ bool shapeless);
35
+
36
+ friend FunctionExporter exporter(
37
+ const std::string&,
38
+ const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
39
+ bool shapeless);
40
+
41
+ friend FunctionExporter exporter(
42
+ const ExportCallback&,
43
+ const std::function<std::vector<array>(const Args&)>&,
44
+ bool shapeless);
45
+
46
+ friend FunctionExporter exporter(
47
+ const ExportCallback&,
48
+ const std::function<std::vector<array>(const Kwargs&)>&,
49
+ bool shapeless);
50
+
51
+ friend FunctionExporter exporter(
52
+ const ExportCallback&,
53
+ const std::function<std::vector<array>(const Args&, const Kwargs&)>&,
54
+ bool shapeless);
55
+
56
+ FunctionExporter(
57
+ const std::string& file,
58
+ std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
59
+ bool shapeless);
60
+
61
+ FunctionExporter(
62
+ const ExportCallback& callback,
63
+ std::function<std::vector<array>(const Args&, const Kwargs&)> fun,
64
+ bool shapeless);
65
+
66
+ io::FileWriter os;
67
+ ExportCallback callback;
68
+ std::function<std::vector<array>(const Args&, const Kwargs& kwargs)> fun;
69
+ void export_function(const Args& args, const Kwargs& kwargs);
70
+ void export_with_callback(
71
+ const std::vector<array>& inputs,
72
+ const std::vector<array>& outputs,
73
+ const std::vector<array>& tape,
74
+ const std::vector<std::string>& kwarg_keys);
75
+ std::unordered_map<std::uintptr_t, array> constants;
76
+ int count{0};
77
+ bool closed{false};
78
+ std::shared_ptr<FunctionTable> ftable;
79
+ };
80
+
81
+ struct ImportedFunction {
82
+ std::vector<array> operator()(
83
+ const std::initializer_list<array>& args) const {
84
+ return this->operator()(Args(args));
85
+ }
86
+ std::vector<array> operator()(const Args& args) const;
87
+ std::vector<array> operator()(const Kwargs& kwargs) const;
88
+ std::vector<array> operator()(const Args& args, const Kwargs& kwargs) const;
89
+
90
+ private:
91
+ ImportedFunction(const std::string& file);
92
+ friend ImportedFunction import_function(const std::string&);
93
+ ImportedFunction();
94
+
95
+ std::shared_ptr<FunctionTable> ftable;
96
+ };
97
+
98
+ } // namespace mlx::core
mlx/include/mlx/fast.h ADDED
@@ -0,0 +1,102 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <optional>
6
+ #include <variant>
7
+
8
+ #include "mlx/utils.h"
9
+
10
+ namespace mlx::core::fast {
11
+
12
+ array rms_norm(
13
+ const array& x,
14
+ const std::optional<array>& weight,
15
+ float eps,
16
+ StreamOrDevice s = {});
17
+
18
+ array layer_norm(
19
+ const array& x,
20
+ const std::optional<array>& weight,
21
+ const std::optional<array>& bias,
22
+ float eps,
23
+ StreamOrDevice s = {});
24
+
25
+ array rope(
26
+ const array& x,
27
+ int dims,
28
+ bool traditional,
29
+ std::optional<float> base,
30
+ float scale,
31
+ int offset,
32
+ const std::optional<array>& freqs = std::nullopt,
33
+ StreamOrDevice s = {});
34
+
35
+ array rope(
36
+ const array& x,
37
+ int dims,
38
+ bool traditional,
39
+ std::optional<float> base,
40
+ float scale,
41
+ const array& offset,
42
+ const std::optional<array>& freqs = std::nullopt,
43
+ StreamOrDevice s = {});
44
+
45
+ /** Computes: O = softmax(Q @ K.T) @ V **/
46
+ array scaled_dot_product_attention(
47
+ const array& queries,
48
+ const array& keys,
49
+ const array& values,
50
+ const float scale,
51
+ const std::string& mask_mode = "",
52
+ std::optional<array> mask_arr = {},
53
+ const std::optional<array>& sinks = {},
54
+ StreamOrDevice s = {});
55
+
56
+ using TemplateArg = std::variant<int, bool, Dtype>;
57
+ using ScalarArg = std::variant<bool, int, float>;
58
+
59
+ using CustomKernelFunction = std::function<std::vector<array>(
60
+ const std::vector<array>&,
61
+ const std::vector<Shape>&,
62
+ const std::vector<Dtype>&,
63
+ std::tuple<int, int, int>,
64
+ std::tuple<int, int, int>,
65
+ std::vector<std::pair<std::string, TemplateArg>>,
66
+ std::optional<float>,
67
+ bool,
68
+ StreamOrDevice)>;
69
+
70
+ CustomKernelFunction metal_kernel(
71
+ const std::string& name,
72
+ const std::vector<std::string>& input_names,
73
+ const std::vector<std::string>& output_names,
74
+ const std::string& source,
75
+ const std::string& header = "",
76
+ bool ensure_row_contiguous = true,
77
+ bool atomic_outputs = false);
78
+
79
+ CustomKernelFunction cuda_kernel(
80
+ const std::string& name,
81
+ const std::vector<std::string>& input_names,
82
+ const std::vector<std::string>& output_names,
83
+ const std::string& source,
84
+ const std::string& header = "",
85
+ bool ensure_row_contiguous = true,
86
+ int shared_memory = 0);
87
+
88
+ std::vector<array> precompiled_cuda_kernel(
89
+ const std::string& name,
90
+ const std::string& compiled_source,
91
+ const std::vector<array>& inputs,
92
+ const std::vector<Shape>& output_shapes,
93
+ const std::vector<Dtype>& output_dtypes,
94
+ const std::vector<ScalarArg>& scalars,
95
+ std::tuple<int, int, int> grid,
96
+ std::tuple<int, int, int> threadgroup,
97
+ int shared_memory = 0,
98
+ std::optional<float> init_value = std::nullopt,
99
+ bool ensure_row_contiguous = false,
100
+ StreamOrDevice s = {});
101
+
102
+ } // namespace mlx::core::fast