mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,229 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <optional>
6
+
7
+ #include "mlx/array.h"
8
+
9
+ namespace mlx::core {
10
+
11
+ void async_eval(std::vector<array> outputs);
12
+
13
+ template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
14
+ void async_eval(Arrays&&... outputs) {
15
+ async_eval(std::vector<array>{std::forward<Arrays>(outputs)...});
16
+ }
17
+
18
+ void eval(std::vector<array> outputs);
19
+
20
+ template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
21
+ void eval(Arrays&&... outputs) {
22
+ eval(std::vector<array>{std::forward<Arrays>(outputs)...});
23
+ }
24
+
25
+ /**
26
+ * Computes the output and vector-Jacobian product (VJP) of a function.
27
+ *
28
+ * Computes the vector-Jacobian product of the vector of cotangents with the
29
+ * Jacobian of the function evaluated at the primals. Returns a pair of
30
+ * vectors of output arrays and VJP arrays.
31
+ **/
32
+ std::pair<std::vector<array>, std::vector<array>> vjp(
33
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
34
+ const std::vector<array>& primals,
35
+ const std::vector<array>& cotangents);
36
+
37
+ /**
38
+ * Computes the output and vector-Jacobian product (VJP) of a unary function.
39
+ */
40
+ std::pair<array, array> vjp(
41
+ const std::function<array(const array&)>& fun,
42
+ const array& primal,
43
+ const array& cotangent);
44
+
45
+ /**
46
+ * Computes the output and Jacobian-vector product (JVP) of a function.
47
+ *
48
+ * Computes the Jacobian-vector product of the Jacobian of the function
49
+ * evaluated at the primals with the vector of tangents. Returns a pair of
50
+ * vectors of output arrays and JVP arrays.
51
+ **/
52
+ std::pair<std::vector<array>, std::vector<array>> jvp(
53
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
54
+ const std::vector<array>& primals,
55
+ const std::vector<array>& tangents);
56
+
57
+ /**
58
+ * Computes the output and Jacobian-vector product (JVP) of a unary function.
59
+ */
60
+ std::pair<array, array> jvp(
61
+ const std::function<array(const array&)>& fun,
62
+ const array& primal,
63
+ const array& tangent);
64
+
65
+ // Return type of general value_and_grad: a function which takes an input
66
+ // vector of arrays and returns a pair of vectors of arrays one for the
67
+ // values and one for the gradients wrt the first value.
68
+ using ValueAndGradFn =
69
+ std::function<std::pair<std::vector<array>, std::vector<array>>(
70
+ const std::vector<array>&)>;
71
+ using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
72
+ const std::vector<array>&)>;
73
+
74
+ /**
75
+ * Returns a function which computes the value and gradient of the input
76
+ * function with respect to a vector of input arrays.
77
+ **/
78
+ ValueAndGradFn value_and_grad(
79
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
80
+ const std::vector<int>& argnums);
81
+
82
+ /**
83
+ * Returns a function which computes the value and gradient of the input
84
+ * function with respect to a single input array.
85
+ **/
86
+ ValueAndGradFn inline value_and_grad(
87
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
88
+ int argnum = 0) {
89
+ return value_and_grad(fun, std::vector<int>{argnum});
90
+ }
91
+
92
+ /**
93
+ * Returns a function which computes the value and gradient of the unary
94
+ * input function.
95
+ **/
96
+ std::function<std::pair<array, array>(const array&)> inline value_and_grad(
97
+ const std::function<array(const array&)>& fun) {
98
+ return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
99
+ }
100
+
101
+ SimpleValueAndGradFn inline value_and_grad(
102
+ const std::function<array(const std::vector<array>&)>& fun,
103
+ const std::vector<int>& argnums) {
104
+ return [fun, argnums](auto inputs) {
105
+ auto result = value_and_grad(
106
+ [fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
107
+ argnums)(inputs);
108
+
109
+ return std::make_pair(result.first[0], result.second);
110
+ };
111
+ }
112
+
113
+ SimpleValueAndGradFn inline value_and_grad(
114
+ const std::function<array(const std::vector<array>&)>& fun,
115
+ int argnum = 0) {
116
+ return value_and_grad(fun, std::vector<int>{argnum});
117
+ }
118
+
119
+ /**
120
+ * Returns a function which computes the gradient of the input function with
121
+ * respect to a vector of input arrays.
122
+ *
123
+ * The function being differentiated takes a vector of arrays and returns an
124
+ * array. The vector of `argnums` specifies which the arguments to compute
125
+ * the gradient with respect to. At least one argument must be specified.
126
+ **/
127
+ std::function<std::vector<array>(const std::vector<array>&)> inline grad(
128
+ const std::function<array(const std::vector<array>&)>& fun,
129
+ const std::vector<int>& argnums) {
130
+ auto fn = value_and_grad(fun, argnums);
131
+ return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
132
+ }
133
+
134
+ /**
135
+ * Returns a function which computes the gradient of the input function with
136
+ * respect to a single input array.
137
+ *
138
+ * The function being differentiated takes a vector of arrays and returns an
139
+ * array. The optional `argnum` index specifies which the argument to compute
140
+ * the gradient with respect to and defaults to 0.
141
+ **/
142
+ std::function<std::vector<array>(const std::vector<array>&)> inline grad(
143
+ const std::function<array(const std::vector<array>&)>& fun,
144
+ int argnum = 0) {
145
+ return grad(fun, std::vector<int>{argnum});
146
+ }
147
+
148
+ /**
149
+ * Returns a function which computes the gradient of the unary input function.
150
+ **/
151
+ std::function<array(const array&)> inline grad(
152
+ const std::function<array(const array&)>& fun) {
153
+ auto fn = value_and_grad(fun);
154
+ return [fn](const array& input) { return fn(input).second; };
155
+ }
156
+
157
+ /**
158
+ * Automatically vectorize a unary function over the requested axes.
159
+ */
160
+ std::function<array(const array&)> vmap(
161
+ const std::function<array(const array&)>& fun,
162
+ int in_axis = 0,
163
+ int out_axis = 0);
164
+
165
+ /**
166
+ * Automatically vectorize a binary function over the requested axes.
167
+ */
168
+ std::function<array(const array&, const array&)> vmap(
169
+ const std::function<array(const array&, const array&)>& fun,
170
+ int in_axis_a = 0,
171
+ int in_axis_b = 0,
172
+ int out_axis = 0);
173
+
174
+ /**
175
+ * Automatically vectorize a function over the requested axes.
176
+ *
177
+ * The input function to `vmap` takes as an argument a vector of arrays and
178
+ * returns a vector of arrays. Optionally specify the axes to vectorize over
179
+ * with `in_axes` and `out_axes`, otherwise a default of 0 is used.
180
+ * Returns a vectorized function with the same signature as the input
181
+ * function.
182
+ */
183
+ std::function<std::vector<array>(const std::vector<array>&)> vmap(
184
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
185
+ const std::vector<int>& in_axes = {},
186
+ const std::vector<int>& out_axes = {});
187
+
188
+ /**
189
+ * Redefine the transformations of `fun` according to the provided functions.
190
+ *
191
+ * Namely when calling the vjp of `fun` then `fun_vjp` will be called,
192
+ * `fun_jvp` for the jvp and `fun_vmap` for vmap.
193
+ *
194
+ * If any transformation is not provided, then a default one is created by
195
+ * calling `vjp`, `jvp` and `vmap` on the function directly.
196
+ */
197
+ std::function<std::vector<array>(const std::vector<array>&)> custom_function(
198
+ std::function<std::vector<array>(const std::vector<array>&)> fun,
199
+ std::optional<std::function<std::vector<array>(
200
+ const std::vector<array>&,
201
+ const std::vector<array>&,
202
+ const std::vector<array>&)>> fun_vjp = std::nullopt,
203
+ std::optional<std::function<std::vector<array>(
204
+ const std::vector<array>&,
205
+ const std::vector<array>&,
206
+ const std::vector<int>&)>> fun_jvp = std::nullopt,
207
+ std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
208
+ const std::vector<array>&,
209
+ const std::vector<int>&)>> fun_vmap = std::nullopt);
210
+
211
+ /**
212
+ * Return a function that behaves exactly like `fun` but if the vjp of the
213
+ * results is computed `fun_vjp` will be used instead of `vjp(fun, ...)` .
214
+ */
215
+ std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
216
+ std::function<std::vector<array>(const std::vector<array>&)> fun,
217
+ std::function<std::vector<array>(
218
+ const std::vector<array>&,
219
+ const std::vector<array>&,
220
+ const std::vector<array>&)> fun_vjp);
221
+
222
+ /**
223
+ * Checkpoint the gradient of a function. Namely, discard all intermediate
224
+ * state and recalculate it when we need to compute the gradient.
225
+ */
226
+ std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
227
+ std::function<std::vector<array>(const std::vector<array>&)> fun);
228
+
229
+ } // namespace mlx::core
@@ -0,0 +1,86 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ namespace mlx::core::detail {
6
+
7
+ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
8
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
9
+ const std::vector<array>& inputs,
10
+ const std::vector<int>& in_axes);
11
+
12
+ std::vector<array> vmap_replace(
13
+ const std::vector<array>& inputs,
14
+ const std::vector<array>& s_inputs,
15
+ const std::vector<array>& s_outputs,
16
+ const std::vector<int>& in_axes,
17
+ const std::vector<int>& out_axes);
18
+
19
+ // Create an InTracing object during tracing operations to signify to the rest
20
+ // of the codebase that we are during tracing so evals should not throw away
21
+ // the graph.
22
+ struct InTracing {
23
+ explicit InTracing(bool dynamic = false, bool grad = false) {
24
+ grad_counter += grad;
25
+ trace_stack().push_back({dynamic, grad});
26
+ }
27
+ ~InTracing() {
28
+ grad_counter -= trace_stack().back().second;
29
+ trace_stack().pop_back();
30
+ }
31
+
32
+ static bool in_tracing() {
33
+ return !trace_stack().empty();
34
+ }
35
+ static bool in_dynamic_tracing() {
36
+ // compile is always and only the outer-most transform
37
+ return in_tracing() && trace_stack().front().first;
38
+ }
39
+
40
+ static bool in_grad_tracing() {
41
+ return grad_counter > 0;
42
+ }
43
+
44
+ private:
45
+ static int grad_counter;
46
+ static std::vector<std::pair<char, char>>& trace_stack();
47
+ };
48
+
49
+ struct RetainGraph {
50
+ RetainGraph() {
51
+ tracing_counter++;
52
+ }
53
+ ~RetainGraph() {
54
+ tracing_counter--;
55
+ }
56
+
57
+ static bool retain_graph() {
58
+ return tracing_counter > 0;
59
+ }
60
+
61
+ private:
62
+ static int tracing_counter;
63
+ };
64
+
65
+ /** Return true if we are currently performing a function transformation in
66
+ * order to keep the graph when evaluating tracer arrays. */
67
+ inline bool in_tracing() {
68
+ return detail::InTracing::in_tracing();
69
+ }
70
+
71
+ /** Return true if we are in a dynamic (shapeless) trace used for compiling or
72
+ * exporting graphs with dynamic shapes. */
73
+ inline bool in_dynamic_tracing() {
74
+ return detail::InTracing::in_dynamic_tracing();
75
+ }
76
+
77
+ /** Return true if we are in a gradient trace (vjp, jvp, etc). */
78
+ inline bool in_grad_tracing() {
79
+ return detail::InTracing::in_grad_tracing();
80
+ }
81
+
82
+ inline bool retain_graph() {
83
+ return detail::RetainGraph::retain_graph();
84
+ }
85
+
86
+ } // namespace mlx::core::detail
@@ -0,0 +1,187 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <algorithm>
6
+ #include <cmath>
7
+ #include <cstdint>
8
+ #include <vector>
9
+
10
+ #define __MLX_BFLOAT_NAN__ 0x7FC0
11
+
12
+ namespace mlx::core {
13
+
14
+ namespace {
15
+ union float_bits_bf16 {
16
+ float f;
17
+ uint32_t u;
18
+ };
19
+ } // namespace
20
+
21
+ struct _MLX_BFloat16 {
22
+ uint16_t bits_;
23
+
24
+ // Default constructor
25
+ _MLX_BFloat16() = default;
26
+
27
+ // Default copy constructor
28
+ _MLX_BFloat16(_MLX_BFloat16 const&) = default;
29
+
30
+ // Appease std::vector<bool> for being special
31
+ _MLX_BFloat16& operator=(std::vector<bool>::reference x) {
32
+ bits_ = x;
33
+ return *this;
34
+ }
35
+
36
+ _MLX_BFloat16& operator=(const float& x) {
37
+ return (*this = _MLX_BFloat16(x));
38
+ }
39
+
40
+ // From float32
41
+ _MLX_BFloat16(const float& x) {
42
+ if (std::isnan(x)) {
43
+ bits_ = __MLX_BFLOAT_NAN__;
44
+ } else {
45
+ // Union
46
+ float_bits_bf16 in;
47
+
48
+ // Take bits
49
+ in.f = x;
50
+
51
+ // Round to nearest even
52
+ in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF);
53
+
54
+ // Take upper 16 bits
55
+ bits_ = in.u >> 16;
56
+ }
57
+ }
58
+
59
+ // To float32
60
+ operator float() const {
61
+ // Union
62
+ float_bits_bf16 out;
63
+
64
+ // Upper 16 bits are the data and lower 16 bits are 0s
65
+ out.u = ((uint32_t)bits_) << 16;
66
+
67
+ return out.f;
68
+ }
69
+ };
70
+
71
+ #define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
72
+ inline otype __operator__(atype lhs, btype rhs) { \
73
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
74
+ }
75
+
76
+ #define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
77
+ inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
78
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
79
+ } \
80
+ inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
81
+ return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
82
+ }
83
+
84
+ // Operators
85
+ #define bfloat_binop(_op_, _operator_) \
86
+ bfloat_binop_base( \
87
+ _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
88
+ bfloat_binop_helper(_op_, _operator_, float, float, float); \
89
+ bfloat_binop_helper(_op_, _operator_, double, double, double); \
90
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \
91
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
92
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
93
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
94
+ bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
95
+
96
+ bfloat_binop(+, operator+);
97
+ bfloat_binop(-, operator-);
98
+ bfloat_binop(*, operator*);
99
+ bfloat_binop(/, operator/);
100
+
101
+ #undef bfloat_binop
102
+
103
+ // Comparison ops
104
+ #define bfloat_compop(__op__, __operator__) \
105
+ bfloat_binop_base( \
106
+ __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
107
+ bfloat_binop_helper(__op__, __operator__, bool, float, float); \
108
+ bfloat_binop_helper(__op__, __operator__, bool, double, double); \
109
+ bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
110
+ bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
111
+ bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
112
+ bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
113
+
114
+ bfloat_compop(>, operator>);
115
+ bfloat_compop(<, operator<);
116
+ bfloat_compop(>=, operator>=);
117
+ bfloat_compop(<=, operator<=);
118
+ bfloat_compop(==, operator==);
119
+ bfloat_compop(!=, operator!=);
120
+
121
+ #undef bfloat_compop
122
+
123
+ // Negative
124
+ inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) {
125
+ return -static_cast<float>(lhs);
126
+ }
127
+
128
+ // Inplace ops
129
+ #define bfloat_inplace_op(__op__, __operator__) \
130
+ inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \
131
+ lhs = lhs __op__ rhs; \
132
+ return lhs; \
133
+ } \
134
+ inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \
135
+ lhs = lhs __op__ rhs; \
136
+ return lhs; \
137
+ }
138
+
139
+ bfloat_inplace_op(+, operator+=);
140
+ bfloat_inplace_op(-, operator-=);
141
+ bfloat_inplace_op(*, operator*=);
142
+ bfloat_inplace_op(/, operator/=);
143
+
144
+ #undef bfloat_inplace_op
145
+
146
+ // Bitwise ops
147
+
148
+ #define bfloat_bitop(__op__, __operator__) \
149
+ inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \
150
+ _MLX_BFloat16 out; \
151
+ out.bits_ = lhs.bits_ __op__ rhs.bits_; \
152
+ return out; \
153
+ } \
154
+ inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \
155
+ _MLX_BFloat16 out; \
156
+ out.bits_ = lhs.bits_ __op__ rhs; \
157
+ return out; \
158
+ } \
159
+ inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \
160
+ _MLX_BFloat16 out; \
161
+ out.bits_ = lhs __op__ rhs.bits_; \
162
+ return out; \
163
+ }
164
+
165
+ bfloat_bitop(|, operator|);
166
+ bfloat_bitop(&, operator&);
167
+ bfloat_bitop(^, operator^);
168
+
169
+ #undef bfloat_bitop
170
+
171
+ #define bfloat_inplace_bitop(__op__, __operator__) \
172
+ inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
173
+ lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \
174
+ return lhs; \
175
+ } \
176
+ inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \
177
+ lhs.bits_ = lhs.bits_ __op__ rhs; \
178
+ return lhs; \
179
+ }
180
+
181
+ bfloat_inplace_bitop(|, operator|=);
182
+ bfloat_inplace_bitop(&, operator&=);
183
+ bfloat_inplace_bitop(^, operator^=);
184
+
185
+ #undef bfloat_inplace_bitop
186
+
187
+ } // namespace mlx::core
@@ -0,0 +1,113 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+ #include <complex>
5
+ #include "mlx/types/half_types.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ struct complex64_t;
10
+ struct complex128_t;
11
+
12
+ template <typename T>
13
+ inline constexpr bool can_convert_to_complex128 =
14
+ !std::is_same_v<T, complex128_t> && std::is_convertible_v<T, double>;
15
+
16
+ struct complex128_t : public std::complex<double> {
17
+ complex128_t() : std::complex<double>() {};
18
+ complex128_t(double v, double u) : std::complex<double>(v, u) {};
19
+ complex128_t(std::complex<double> v) : std::complex<double>(v) {};
20
+
21
+ template <
22
+ typename T,
23
+ typename = typename std::enable_if<can_convert_to_complex128<T>>::type>
24
+ complex128_t(T x) : std::complex<double>(x){};
25
+
26
+ operator float() const {
27
+ return real();
28
+ };
29
+ };
30
+
31
+ template <typename T>
32
+ inline constexpr bool can_convert_to_complex64 =
33
+ !std::is_same_v<T, complex64_t> && std::is_convertible_v<T, float>;
34
+
35
+ struct complex64_t : public std::complex<float> {
36
+ complex64_t() : std::complex<float>() {};
37
+ complex64_t(float v, float u) : std::complex<float>(v, u) {};
38
+ complex64_t(std::complex<float> v) : std::complex<float>(v) {};
39
+
40
+ template <
41
+ typename T,
42
+ typename = typename std::enable_if<can_convert_to_complex64<T>>::type>
43
+ complex64_t(T x) : std::complex<float>(x){};
44
+
45
+ operator float() const {
46
+ return real();
47
+ };
48
+ };
49
+
50
+ inline bool operator>=(const complex64_t& a, const complex64_t& b) {
51
+ return (a.real() > b.real()) ||
52
+ (a.real() == b.real() && a.imag() >= b.imag());
53
+ }
54
+
55
+ inline bool operator>(const complex64_t& a, const complex64_t& b) {
56
+ return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag());
57
+ }
58
+
59
+ inline complex64_t operator%(complex64_t a, complex64_t b) {
60
+ auto real = a.real() - (b.real() * static_cast<int64_t>(a.real() / b.real()));
61
+ auto imag = a.imag() - (b.imag() * static_cast<int64_t>(a.imag() / b.imag()));
62
+ if (real != 0 && ((real < 0) != (b.real() < 0)))
63
+ real += b.real();
64
+ if (imag != 0 && ((imag < 0) != (b.imag() < 0)))
65
+ imag += b.imag();
66
+ return {real, imag};
67
+ }
68
+
69
+ inline bool operator<=(const complex64_t& a, const complex64_t& b) {
70
+ return operator>=(b, a);
71
+ }
72
+
73
+ inline bool operator<(const complex64_t& a, const complex64_t& b) {
74
+ return operator>(b, a);
75
+ }
76
+
77
+ inline complex64_t operator-(const complex64_t& v) {
78
+ return -static_cast<std::complex<float>>(v);
79
+ }
80
+
81
+ // clang-format off
82
+ #define complex_binop_helper(_op_, _operator_, itype) \
83
+ inline complex64_t _operator_(itype x, const complex64_t& y) { \
84
+ return static_cast<complex64_t>(x) _op_ y; \
85
+ } \
86
+ inline complex64_t _operator_(const complex64_t& x, itype y) { \
87
+ return x _op_ static_cast<complex64_t>(y); \
88
+ }
89
+
90
+ #define complex_binop(_op_, _operator_) \
91
+ inline complex64_t _operator_(const std::complex<float>& x, const complex64_t& y) { \
92
+ return x _op_ static_cast<std::complex<float>>(y); \
93
+ } \
94
+ inline complex64_t _operator_(const complex64_t& x, const std::complex<float>& y) { \
95
+ return static_cast<std::complex<float>>(x) _op_ y; \
96
+ } \
97
+ inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \
98
+ return static_cast<std::complex<float>>(x) \
99
+ _op_ static_cast<std::complex<float>>(y); \
100
+ } \
101
+ complex_binop_helper(_op_, _operator_, bool) \
102
+ complex_binop_helper(_op_, _operator_, uint32_t) \
103
+ complex_binop_helper(_op_, _operator_, uint64_t) \
104
+ complex_binop_helper(_op_, _operator_, int32_t) \
105
+ complex_binop_helper(_op_, _operator_, int64_t) \
106
+ complex_binop_helper(_op_, _operator_, float16_t) \
107
+ complex_binop_helper(_op_, _operator_, bfloat16_t) \
108
+ complex_binop_helper(_op_, _operator_, float)
109
+ // clang-format on
110
+
111
+ complex_binop(+, operator+)
112
+
113
+ } // namespace mlx::core