mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,517 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+ #include <cassert>
5
+
6
+ #include "mlx/array.h"
7
+ #include "mlx/backend/common/binary.h"
8
+ #include "mlx/backend/common/utils.h"
9
+
10
+ #include "mlx/backend/cpu/encoder.h"
11
+ #include "mlx/backend/cpu/simd/simd.h"
12
+
13
+ namespace mlx::core {
14
+
15
+ template <typename Op>
16
+ struct VectorScalar {
17
+ template <typename T, typename U>
18
+ void operator()(const T* a, const T* b, U* dst, int size) {
19
+ T scalar = *b;
20
+ constexpr int N = simd::max_size<T>;
21
+ while (size >= N) {
22
+ simd::store(dst, Op{}(simd::load<T, N>(a), simd::Simd<T, N>(scalar)));
23
+ dst += N;
24
+ a += N;
25
+ size -= N;
26
+ }
27
+ while (size-- > 0) {
28
+ *dst = Op{}(*a, scalar);
29
+ dst++;
30
+ a++;
31
+ }
32
+ }
33
+ };
34
+
35
+ template <typename Op>
36
+ struct ScalarVector {
37
+ template <typename T, typename U>
38
+ void operator()(const T* a, const T* b, U* dst, int size) {
39
+ T scalar = *a;
40
+ constexpr int N = simd::max_size<T>;
41
+ while (size >= N) {
42
+ simd::store(dst, Op{}(simd::Simd<T, N>(scalar), simd::load<T, N>(b)));
43
+ dst += N;
44
+ b += N;
45
+ size -= N;
46
+ }
47
+ while (size-- > 0) {
48
+ *dst = Op{}(scalar, *b);
49
+ dst++;
50
+ b++;
51
+ }
52
+ }
53
+ };
54
+
55
+ template <typename Op>
56
+ struct VectorVector {
57
+ template <typename T, typename U>
58
+ void operator()(const T* a, const T* b, U* dst, int size) {
59
+ constexpr int N = simd::max_size<T>;
60
+ while (size >= N) {
61
+ simd::store(dst, Op{}(simd::load<T, N>(a), simd::load<T, N>(b)));
62
+ dst += N;
63
+ a += N;
64
+ b += N;
65
+ size -= N;
66
+ }
67
+ while (size-- > 0) {
68
+ *dst = Op{}(*a, *b);
69
+ dst++;
70
+ a++;
71
+ b++;
72
+ }
73
+ }
74
+ };
75
+
76
+ template <typename T, typename U, typename Op, int D, bool Strided>
77
+ void binary_op_dims(
78
+ const T* a,
79
+ const T* b,
80
+ U* out,
81
+ const Shape& shape,
82
+ const Strides& a_strides,
83
+ const Strides& b_strides,
84
+ const Strides& out_strides,
85
+ int axis) {
86
+ auto stride_a = a_strides[axis];
87
+ auto stride_b = b_strides[axis];
88
+ auto stride_out = out_strides[axis];
89
+ auto N = shape[axis];
90
+
91
+ for (int i = 0; i < N; i++) {
92
+ if constexpr (D > 1) {
93
+ binary_op_dims<T, U, Op, D - 1, Strided>(
94
+ a, b, out, shape, a_strides, b_strides, out_strides, axis + 1);
95
+ } else {
96
+ if constexpr (Strided) {
97
+ Op{}(a, b, out, stride_out);
98
+ } else {
99
+ *out = Op{}(*a, *b);
100
+ }
101
+ }
102
+ out += stride_out;
103
+ a += stride_a;
104
+ b += stride_b;
105
+ }
106
+ }
107
+
108
+ template <typename T, typename U, bool Strided, typename Op>
109
+ void binary_op_dispatch_dims(
110
+ const T* a,
111
+ const T* b,
112
+ U* out,
113
+ int dim,
114
+ int size,
115
+ const Shape& shape,
116
+ const Strides& a_strides,
117
+ const Strides& b_strides,
118
+ const Strides& out_strides) {
119
+ switch (dim) {
120
+ case 1:
121
+ binary_op_dims<T, U, Op, 1, Strided>(
122
+ a, b, out, shape, a_strides, b_strides, out_strides, 0);
123
+ return;
124
+ case 2:
125
+ binary_op_dims<T, U, Op, 2, Strided>(
126
+ a, b, out, shape, a_strides, b_strides, out_strides, 0);
127
+ return;
128
+ case 3:
129
+ binary_op_dims<T, U, Op, 3, Strided>(
130
+ a, b, out, shape, a_strides, b_strides, out_strides, 0);
131
+ return;
132
+ }
133
+
134
+ ContiguousIterator a_it(shape, a_strides, dim - 3);
135
+ ContiguousIterator b_it(shape, b_strides, dim - 3);
136
+ auto stride = out_strides[dim - 4];
137
+ for (int64_t elem = 0; elem < size; elem += stride) {
138
+ binary_op_dims<T, U, Op, 3, Strided>(
139
+ a + a_it.loc,
140
+ b + b_it.loc,
141
+ out + elem,
142
+ shape,
143
+ a_strides,
144
+ b_strides,
145
+ out_strides,
146
+ dim - 3);
147
+ a_it.step();
148
+ b_it.step();
149
+ }
150
+ }
151
+
152
+ template <typename T, typename U, typename Op>
153
+ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
154
+ // The full computation is scalar scalar so call the base op once
155
+ auto a_ptr = a.data<T>();
156
+ auto b_ptr = b.data<T>();
157
+
158
+ auto out_ptr = out.data<U>();
159
+ if (bopt == BinaryOpType::ScalarScalar) {
160
+ *out_ptr = Op{}(*a_ptr, *b_ptr);
161
+ return;
162
+ }
163
+
164
+ // The full computation is scalar vector so delegate to the op
165
+ if (bopt == BinaryOpType::ScalarVector) {
166
+ ScalarVector<Op>{}(a_ptr, b_ptr, out_ptr, b.data_size());
167
+ return;
168
+ }
169
+
170
+ // The full computation is vector scalar so delegate to the op
171
+ if (bopt == BinaryOpType::VectorScalar) {
172
+ VectorScalar<Op>{}(a_ptr, b_ptr, out_ptr, a.data_size());
173
+ return;
174
+ }
175
+
176
+ // The full computation is vector vector so delegate to the op
177
+ if (bopt == BinaryOpType::VectorVector) {
178
+ VectorVector<Op>{}(a_ptr, b_ptr, out_ptr, a.size());
179
+ return;
180
+ }
181
+
182
+ // General computation so let's try to optimize
183
+ auto [new_shape, new_strides] = collapse_contiguous_dims(
184
+ a.shape(), {a.strides(), b.strides(), out.strides()});
185
+ auto& a_strides = new_strides[0];
186
+ auto& b_strides = new_strides[1];
187
+ auto& strides = new_strides[2];
188
+
189
+ // Get the left-most dim such that the array is row contiguous after
190
+ auto leftmost_rc_dim = [&strides](const auto& arr_strides) {
191
+ int d = arr_strides.size() - 1;
192
+ for (; d >= 0 && arr_strides[d] == strides[d]; d--) {
193
+ }
194
+ return d + 1;
195
+ };
196
+ auto a_rc_dim = leftmost_rc_dim(a_strides);
197
+ auto b_rc_dim = leftmost_rc_dim(b_strides);
198
+
199
+ // Get the left-most dim such that the array is a broadcasted "scalar" after
200
+ auto leftmost_s_dim = [](const auto& arr_strides) {
201
+ int d = arr_strides.size() - 1;
202
+ for (; d >= 0 && arr_strides[d] == 0; d--) {
203
+ }
204
+ return d + 1;
205
+ };
206
+ auto a_s_dim = leftmost_s_dim(a_strides);
207
+ auto b_s_dim = leftmost_s_dim(b_strides);
208
+
209
+ auto ndim = new_shape.size();
210
+
211
+ // Case 1: LxM and FxM where L and F are broadcastable and M is row
212
+ // contiguous
213
+ int dim = ndim;
214
+ if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
215
+ bopt = BinaryOpType::VectorVector;
216
+ dim = d;
217
+ // Case 2: LxM and Fx1 where L and F are broadcastable and M is row
218
+ // contiguous
219
+ } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
220
+ bopt = BinaryOpType::VectorScalar;
221
+ dim = d;
222
+ // Case 3: Lx1 and FxM where L and F are broadcastable and M is row
223
+ // contiguous
224
+ } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
225
+ bopt = BinaryOpType::ScalarVector;
226
+ dim = d;
227
+ }
228
+
229
+ // Can be sure dim > 0 since otherwise we would have used one of the fully
230
+ // contiguous methods above. Except for the case that the flags do not
231
+ // correspond to the underlying contiguity.
232
+ if (dim == 0 || strides[dim - 1] < 16) {
233
+ bopt = BinaryOpType::General;
234
+ dim = ndim;
235
+ }
236
+
237
+ switch (bopt) {
238
+ case BinaryOpType::VectorVector:
239
+ binary_op_dispatch_dims<T, U, true, VectorVector<Op>>(
240
+ a_ptr,
241
+ b_ptr,
242
+ out_ptr,
243
+ dim,
244
+ a.size(),
245
+ new_shape,
246
+ a_strides,
247
+ b_strides,
248
+ strides);
249
+ break;
250
+ case BinaryOpType::VectorScalar:
251
+ binary_op_dispatch_dims<T, U, true, VectorScalar<Op>>(
252
+ a_ptr,
253
+ b_ptr,
254
+ out_ptr,
255
+ dim,
256
+ a.size(),
257
+ new_shape,
258
+ a_strides,
259
+ b_strides,
260
+ strides);
261
+ break;
262
+ case BinaryOpType::ScalarVector:
263
+ binary_op_dispatch_dims<T, U, true, ScalarVector<Op>>(
264
+ a_ptr,
265
+ b_ptr,
266
+ out_ptr,
267
+ dim,
268
+ a.size(),
269
+ new_shape,
270
+ a_strides,
271
+ b_strides,
272
+ strides);
273
+ break;
274
+ default:
275
+ binary_op_dispatch_dims<T, U, false, Op>(
276
+ a_ptr,
277
+ b_ptr,
278
+ out_ptr,
279
+ dim,
280
+ a.size(),
281
+ new_shape,
282
+ a_strides,
283
+ b_strides,
284
+ strides);
285
+ break;
286
+ }
287
+ }
288
+
289
+ template <typename T, typename Op>
290
+ void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) {
291
+ binary_op<T, T, Op>(a, b, out, bopt);
292
+ }
293
+
294
+ template <typename Op>
295
+ void binary_op_cpu(
296
+ const array& a,
297
+ const array& b,
298
+ array& out,
299
+ Op op,
300
+ Stream stream) {
301
+ auto bopt = get_binary_op_type(a, b);
302
+ set_binary_op_output_data(a, b, out, bopt);
303
+
304
+ auto& encoder = cpu::get_command_encoder(stream);
305
+ encoder.set_input_array(a);
306
+ encoder.set_input_array(b);
307
+ encoder.set_output_array(out);
308
+ encoder.dispatch([a = array::unsafe_weak_copy(a),
309
+ b = array::unsafe_weak_copy(b),
310
+ out = array::unsafe_weak_copy(out),
311
+ bopt]() mutable {
312
+ switch (out.dtype()) {
313
+ case bool_:
314
+ binary_op<bool, Op>(a, b, out, bopt);
315
+ break;
316
+ case uint8:
317
+ binary_op<uint8_t, Op>(a, b, out, bopt);
318
+ break;
319
+ case uint16:
320
+ binary_op<uint16_t, Op>(a, b, out, bopt);
321
+ break;
322
+ case uint32:
323
+ binary_op<uint32_t, Op>(a, b, out, bopt);
324
+ break;
325
+ case uint64:
326
+ binary_op<uint64_t, Op>(a, b, out, bopt);
327
+ break;
328
+ case int8:
329
+ binary_op<int8_t, Op>(a, b, out, bopt);
330
+ break;
331
+ case int16:
332
+ binary_op<int16_t, Op>(a, b, out, bopt);
333
+ break;
334
+ case int32:
335
+ binary_op<int32_t, Op>(a, b, out, bopt);
336
+ break;
337
+ case int64:
338
+ binary_op<int64_t, Op>(a, b, out, bopt);
339
+ break;
340
+ case float16:
341
+ binary_op<float16_t, Op>(a, b, out, bopt);
342
+ break;
343
+ case float32:
344
+ binary_op<float, Op>(a, b, out, bopt);
345
+ break;
346
+ case float64:
347
+ binary_op<double, Op>(a, b, out, bopt);
348
+ break;
349
+ case bfloat16:
350
+ binary_op<bfloat16_t, Op>(a, b, out, bopt);
351
+ break;
352
+ case complex64:
353
+ binary_op<complex64_t, Op>(a, b, out, bopt);
354
+ break;
355
+ }
356
+ });
357
+ }
358
+
359
+ template <typename Op>
360
+ void comparison_op_cpu(
361
+ const array& a,
362
+ const array& b,
363
+ array& out,
364
+ Op op,
365
+ Stream stream) {
366
+ auto bopt = get_binary_op_type(a, b);
367
+ set_binary_op_output_data(a, b, out, bopt);
368
+
369
+ auto& encoder = cpu::get_command_encoder(stream);
370
+ encoder.set_input_array(a);
371
+ encoder.set_input_array(b);
372
+ encoder.set_output_array(out);
373
+ encoder.dispatch([a = array::unsafe_weak_copy(a),
374
+ b = array::unsafe_weak_copy(b),
375
+ out = array::unsafe_weak_copy(out),
376
+ bopt]() mutable {
377
+ switch (a.dtype()) {
378
+ case bool_:
379
+ binary_op<bool, bool, Op>(a, b, out, bopt);
380
+ break;
381
+ case uint8:
382
+ binary_op<uint8_t, bool, Op>(a, b, out, bopt);
383
+ break;
384
+ case uint16:
385
+ binary_op<uint16_t, bool, Op>(a, b, out, bopt);
386
+ break;
387
+ case uint32:
388
+ binary_op<uint32_t, bool, Op>(a, b, out, bopt);
389
+ break;
390
+ case uint64:
391
+ binary_op<uint64_t, bool, Op>(a, b, out, bopt);
392
+ break;
393
+ case int8:
394
+ binary_op<int8_t, bool, Op>(a, b, out, bopt);
395
+ break;
396
+ case int16:
397
+ binary_op<int16_t, bool, Op>(a, b, out, bopt);
398
+ break;
399
+ case int32:
400
+ binary_op<int32_t, bool, Op>(a, b, out, bopt);
401
+ break;
402
+ case int64:
403
+ binary_op<int64_t, bool, Op>(a, b, out, bopt);
404
+ break;
405
+ case float16:
406
+ binary_op<float16_t, bool, Op>(a, b, out, bopt);
407
+ break;
408
+ case float32:
409
+ binary_op<float, bool, Op>(a, b, out, bopt);
410
+ break;
411
+ case float64:
412
+ binary_op<double, bool, Op>(a, b, out, bopt);
413
+ break;
414
+ case bfloat16:
415
+ binary_op<bfloat16_t, bool, Op>(a, b, out, bopt);
416
+ break;
417
+ case complex64:
418
+ binary_op<complex64_t, bool, Op>(a, b, out, bopt);
419
+ break;
420
+ }
421
+ });
422
+ }
423
+
424
+ template <typename Op>
425
+ void binary_float_op_cpu(
426
+ const array& a,
427
+ const array& b,
428
+ array& out,
429
+ Op op,
430
+ Stream stream) {
431
+ auto bopt = get_binary_op_type(a, b);
432
+ set_binary_op_output_data(a, b, out, bopt);
433
+
434
+ auto& encoder = cpu::get_command_encoder(stream);
435
+ encoder.set_input_array(a);
436
+ encoder.set_input_array(b);
437
+ encoder.set_output_array(out);
438
+ encoder.dispatch([a = array::unsafe_weak_copy(a),
439
+ b = array::unsafe_weak_copy(b),
440
+ out = array::unsafe_weak_copy(out),
441
+ bopt]() mutable {
442
+ switch (out.dtype()) {
443
+ case float16:
444
+ binary_op<float16_t, Op>(a, b, out, bopt);
445
+ break;
446
+ case float32:
447
+ binary_op<float, Op>(a, b, out, bopt);
448
+ break;
449
+ case float64:
450
+ binary_op<double, Op>(a, b, out, bopt);
451
+ break;
452
+ case bfloat16:
453
+ binary_op<bfloat16_t, Op>(a, b, out, bopt);
454
+ break;
455
+ case complex64:
456
+ binary_op<complex64_t, Op>(a, b, out, bopt);
457
+ break;
458
+ default:
459
+ throw std::runtime_error(
460
+ "[binary_float] Only supports floating point types.");
461
+ }
462
+ });
463
+ }
464
+
465
+ template <typename Op>
466
+ void binary_int_op_cpu(
467
+ const array& a,
468
+ const array& b,
469
+ array& out,
470
+ Op op,
471
+ Stream stream) {
472
+ auto bopt = get_binary_op_type(a, b);
473
+ set_binary_op_output_data(a, b, out, bopt);
474
+
475
+ auto& encoder = cpu::get_command_encoder(stream);
476
+ encoder.set_input_array(a);
477
+ encoder.set_input_array(b);
478
+ encoder.set_output_array(out);
479
+ encoder.dispatch([a = array::unsafe_weak_copy(a),
480
+ b = array::unsafe_weak_copy(b),
481
+ out = array::unsafe_weak_copy(out),
482
+ bopt]() mutable {
483
+ switch (out.dtype()) {
484
+ case bool_:
485
+ binary_op<bool, Op>(a, b, out, bopt);
486
+ case uint8:
487
+ binary_op<uint8_t, Op>(a, b, out, bopt);
488
+ break;
489
+ case uint16:
490
+ binary_op<uint16_t, Op>(a, b, out, bopt);
491
+ break;
492
+ case uint32:
493
+ binary_op<uint32_t, Op>(a, b, out, bopt);
494
+ break;
495
+ case uint64:
496
+ binary_op<uint64_t, Op>(a, b, out, bopt);
497
+ break;
498
+ case int8:
499
+ binary_op<int8_t, Op>(a, b, out, bopt);
500
+ break;
501
+ case int16:
502
+ binary_op<int16_t, Op>(a, b, out, bopt);
503
+ break;
504
+ case int32:
505
+ binary_op<int32_t, Op>(a, b, out, bopt);
506
+ break;
507
+ case int64:
508
+ binary_op<int64_t, Op>(a, b, out, bopt);
509
+ break;
510
+ default:
511
+ throw std::runtime_error("[binary_int] Type not supported");
512
+ break;
513
+ }
514
+ });
515
+ }
516
+
517
+ } // namespace mlx::core
@@ -0,0 +1,98 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/cpu/simd/simd.h"
6
+
7
+ namespace mlx::core::detail {
8
+
9
+ using namespace mlx::core::simd;
10
+
11
+ #define BINARY_SINGLE() \
12
+ template <typename T> \
13
+ T operator()(T x, T y) { \
14
+ return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
15
+ }
16
+
17
+ #define DEFAULT_BINARY_OP(Op, op) \
18
+ struct Op { \
19
+ template <int N, typename T> \
20
+ Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
21
+ return op(x, y); \
22
+ } \
23
+ BINARY_SINGLE() \
24
+ };
25
+
26
+ DEFAULT_BINARY_OP(Add, operator+)
27
+ DEFAULT_BINARY_OP(ArcTan2, atan2)
28
+ DEFAULT_BINARY_OP(Divide, operator/)
29
+ DEFAULT_BINARY_OP(Multiply, operator*)
30
+ DEFAULT_BINARY_OP(Subtract, operator-)
31
+ DEFAULT_BINARY_OP(LogicalAnd, operator&&)
32
+ DEFAULT_BINARY_OP(LogicalOr, operator||)
33
+ DEFAULT_BINARY_OP(BitwiseAnd, operator&)
34
+ DEFAULT_BINARY_OP(BitwiseOr, operator|)
35
+ DEFAULT_BINARY_OP(BitwiseXor, operator^)
36
+ DEFAULT_BINARY_OP(LeftShift, operator<<)
37
+ DEFAULT_BINARY_OP(RightShift, operator>>)
38
+ DEFAULT_BINARY_OP(Remainder, remainder)
39
+ DEFAULT_BINARY_OP(Maximum, maximum)
40
+ DEFAULT_BINARY_OP(Minimum, minimum)
41
+ DEFAULT_BINARY_OP(Power, pow)
42
+
43
+ #define DEFAULT_BOOL_OP(Op, op) \
44
+ struct Op { \
45
+ template <int N, typename T> \
46
+ Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) { \
47
+ return op(x, y); \
48
+ } \
49
+ template <typename T> \
50
+ bool operator()(T x, T y) { \
51
+ return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value; \
52
+ } \
53
+ };
54
+
55
+ DEFAULT_BOOL_OP(Equal, operator==)
56
+ DEFAULT_BOOL_OP(Greater, operator>)
57
+ DEFAULT_BOOL_OP(GreaterEqual, operator>=)
58
+ DEFAULT_BOOL_OP(Less, operator<)
59
+ DEFAULT_BOOL_OP(LessEqual, operator<=)
60
+ DEFAULT_BOOL_OP(NotEqual, operator!=)
61
+
62
+ struct NaNEqual {
63
+ template <int N, typename T>
64
+ Simd<bool, N> operator()(Simd<T, N> x, Simd<T, N> y) {
65
+ return x == y || (isnan(x) && isnan(y));
66
+ }
67
+ template <typename T>
68
+ bool operator()(T x, T y) {
69
+ return (*this)(Simd<T, 1>(x), Simd<T, 1>(y)).value;
70
+ }
71
+ };
72
+
73
+ struct LogAddExp {
74
+ template <int N, typename T>
75
+ Simd<T, N> operator()(Simd<T, N> x, Simd<T, N> y) {
76
+ auto maxval = maximum(x, y);
77
+ auto minval = minimum(x, y);
78
+ auto mask = minval == -inf || maxval == inf;
79
+ auto out = maxval + log1p(exp(minval - maxval));
80
+ return select(mask, Simd<T, N>(maxval), Simd<T, N>(out));
81
+ }
82
+ BINARY_SINGLE()
83
+ };
84
+
85
+ struct Select {
86
+ template <typename T>
87
+ T operator()(bool condition, T x, T y) {
88
+ return (*this)(Simd<bool, 1>(condition), Simd<T, 1>(x), Simd<T, 1>(y))
89
+ .value;
90
+ }
91
+
92
+ template <int N, typename T>
93
+ Simd<T, N> operator()(Simd<bool, N> condition, Simd<T, N> x, Simd<T, N> y) {
94
+ return select(condition, x, y);
95
+ }
96
+ };
97
+
98
+ } // namespace mlx::core::detail