mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,90 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <metal_simdgroup>
4
+ #include <metal_stdlib>
5
+
6
+ template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
7
+ METAL_FUNC void gemm_loop_aligned(
8
+ threadgroup T* As,
9
+ threadgroup T* Bs,
10
+ thread mma_t& mma_op,
11
+ thread loader_a_t& loader_a,
12
+ thread loader_b_t& loader_b,
13
+ const int k_iterations) {
14
+ for (int k = 0; k < k_iterations; k++) {
15
+ threadgroup_barrier(mem_flags::mem_threadgroup);
16
+
17
+ // Load elements into threadgroup memory
18
+ loader_a.load_unsafe();
19
+ loader_b.load_unsafe();
20
+
21
+ threadgroup_barrier(mem_flags::mem_threadgroup);
22
+
23
+ // Multiply and accumulate threadgroup elements
24
+ mma_op.mma(As, Bs);
25
+
26
+ // Prepare for next iteration
27
+ loader_a.next();
28
+ loader_b.next();
29
+ }
30
+ }
31
+
32
+ template <
33
+ bool rows_aligned,
34
+ bool cols_aligned,
35
+ bool transpose,
36
+ typename T,
37
+ typename mma_t,
38
+ typename loader_a_t,
39
+ typename loader_b_t>
40
+ METAL_FUNC void gemm_loop_unaligned(
41
+ threadgroup T* As,
42
+ threadgroup T* Bs,
43
+ thread mma_t& mma_op,
44
+ thread loader_a_t& loader_a,
45
+ thread loader_b_t& loader_b,
46
+ const int k_iterations,
47
+ const short tgp_bm,
48
+ const short tgp_bn,
49
+ const short tgp_bk) {
50
+ for (int k = 0; k < k_iterations; k++) {
51
+ threadgroup_barrier(mem_flags::mem_threadgroup);
52
+
53
+ // Load elements into threadgroup memory
54
+ if (rows_aligned) {
55
+ loader_a.load_unsafe();
56
+ } else {
57
+ loader_a.load_safe(short2(tgp_bk, tgp_bm));
58
+ }
59
+ if (cols_aligned) {
60
+ loader_b.load_unsafe();
61
+ } else {
62
+ loader_b.load_safe(
63
+ transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk));
64
+ }
65
+
66
+ threadgroup_barrier(mem_flags::mem_threadgroup);
67
+
68
+ // Multiply and accumulate threadgroup elements
69
+ mma_op.mma(As, Bs);
70
+
71
+ // Prepare for next iteration
72
+ loader_a.next();
73
+ loader_b.next();
74
+ }
75
+ }
76
+
77
+ template <typename T, typename mma_t, typename loader_a_t, typename loader_b_t>
78
+ METAL_FUNC void gemm_loop_finalize(
79
+ threadgroup T* As,
80
+ threadgroup T* Bs,
81
+ thread mma_t& mma_op,
82
+ thread loader_a_t& loader_a,
83
+ thread loader_b_t& loader_b,
84
+ const short2 tile_a,
85
+ const short2 tile_b) {
86
+ loader_a.load_safe(tile_a);
87
+ loader_b.load_safe(tile_b);
88
+ threadgroup_barrier(mem_flags::mem_threadgroup);
89
+ mma_op.mma(As, Bs);
90
+ }
@@ -0,0 +1,5 @@
1
+ #pragma once
2
+ #include "mlx/backend/metal/kernels/reduction/reduce_all.h"
3
+ #include "mlx/backend/metal/kernels/reduction/reduce_col.h"
4
+ #include "mlx/backend/metal/kernels/reduction/reduce_init.h"
5
+ #include "mlx/backend/metal/kernels/reduction/reduce_row.h"
@@ -0,0 +1,6 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/atomic.h"
6
+ #include "mlx/backend/metal/kernels/reduction/ops.h"
@@ -0,0 +1,275 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_atomic>
6
+ #include <metal_simdgroup>
7
+
8
+ #define DEFINE_SIMD_REDUCE() \
9
+ template <typename T, metal::enable_if_t<sizeof(T) < 8, bool> = true> \
10
+ T simd_reduce(T val) { \
11
+ return simd_reduce_impl(val); \
12
+ } \
13
+ \
14
+ template <typename T, metal::enable_if_t<sizeof(T) == 8, bool> = true> \
15
+ T simd_reduce(T val) { \
16
+ for (short i = simd_size / 2; i > 0; i /= 2) { \
17
+ val = operator()(val, simd_shuffle_down(val, i)); \
18
+ } \
19
+ return val; \
20
+ }
21
+
22
+ static constant constexpr const uint8_t simd_size = 32;
23
+
24
+ union bool4_or_uint {
25
+ bool4 b;
26
+ unsigned int i;
27
+ };
28
+
29
+ struct None {
30
+ template <typename T>
31
+ void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
32
+ mlx_atomic_store_explicit(out, val, offset);
33
+ }
34
+ };
35
+
36
+ template <typename U = bool>
37
+ struct And {
38
+ DEFINE_SIMD_REDUCE()
39
+
40
+ bool simd_reduce_impl(bool val) {
41
+ return simd_all(val);
42
+ }
43
+
44
+ static constexpr constant bool init = true;
45
+
46
+ void atomic_update(
47
+ device mlx_atomic<unsigned int>* out,
48
+ bool val,
49
+ int elem_idx,
50
+ size_t offset = 0) {
51
+ if (!val) {
52
+ bool4_or_uint update;
53
+ update.b = {true, true, true, true};
54
+ update.b[elem_idx] = false;
55
+ mlx_atomic_fetch_and_explicit(out, update.i, offset);
56
+ }
57
+ }
58
+
59
+ void
60
+ atomic_update(device mlx_atomic<bool>* out, bool val, size_t offset = 0) {
61
+ if (!val) {
62
+ mlx_atomic_store_explicit(out, val, offset);
63
+ }
64
+ }
65
+
66
+ // Non atomic update
67
+ void update(device bool* out, bool val) {
68
+ *out &= val;
69
+ }
70
+
71
+ // Operator
72
+ bool operator()(bool a, bool b) {
73
+ return a && b;
74
+ }
75
+ };
76
+
77
+ template <typename U = bool>
78
+ struct Or {
79
+ DEFINE_SIMD_REDUCE()
80
+
81
+ bool simd_reduce_impl(bool val) {
82
+ return simd_any(val);
83
+ }
84
+
85
+ static constexpr constant bool init = false;
86
+
87
+ void atomic_update(
88
+ device mlx_atomic<unsigned int>* out,
89
+ bool val,
90
+ int elem_idx,
91
+ size_t offset = 0) {
92
+ if (val) {
93
+ bool4_or_uint update;
94
+ update.b = {false, false, false, false};
95
+ update.b[elem_idx] = true;
96
+ mlx_atomic_fetch_or_explicit(out, update.i, offset);
97
+ }
98
+ }
99
+
100
+ void
101
+ atomic_update(device mlx_atomic<bool>* out, bool val, size_t offset = 0) {
102
+ if (val) {
103
+ mlx_atomic_store_explicit(out, val, offset);
104
+ }
105
+ }
106
+
107
+ // Non atomic update
108
+ void update(device bool* out, bool val) {
109
+ *out |= val;
110
+ }
111
+
112
+ // Operator
113
+ bool operator()(bool a, bool b) {
114
+ return a || b;
115
+ }
116
+ };
117
+
118
+ template <typename U>
119
+ struct Sum {
120
+ DEFINE_SIMD_REDUCE()
121
+
122
+ template <typename T>
123
+ T simd_reduce_impl(T val) {
124
+ return simd_sum(val);
125
+ }
126
+
127
+ static constexpr constant U init = U(0);
128
+
129
+ template <typename T>
130
+ void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
131
+ mlx_atomic_fetch_add_explicit(out, val, offset);
132
+ }
133
+
134
+ // Operator
135
+ U operator()(U a, U b) {
136
+ return a + b;
137
+ }
138
+ };
139
+
140
+ template <typename U>
141
+ struct Prod {
142
+ DEFINE_SIMD_REDUCE()
143
+
144
+ template <typename T>
145
+ T simd_reduce_impl(T val) {
146
+ return simd_product(val);
147
+ }
148
+
149
+ static constexpr constant U init = U(1);
150
+
151
+ template <typename T>
152
+ void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
153
+ mlx_atomic_fetch_mul_explicit(out, val, offset);
154
+ }
155
+
156
+ // Operator
157
+ U operator()(U a, U b) {
158
+ return a * b;
159
+ }
160
+ };
161
+
162
+ template <typename U>
163
+ struct Min {
164
+ DEFINE_SIMD_REDUCE()
165
+
166
+ template <typename T>
167
+ metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
168
+ return simd_min(val);
169
+ }
170
+
171
+ template <typename T>
172
+ metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
173
+ if (simd_any(val != val)) {
174
+ return static_cast<T>(NAN);
175
+ }
176
+ return simd_min(val);
177
+ }
178
+
179
+ static constexpr constant U init = Limits<U>::max;
180
+
181
+ template <typename T>
182
+ void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
183
+ mlx_atomic_fetch_min_explicit(out, val, offset);
184
+ }
185
+
186
+ // Operator
187
+ template <typename T>
188
+ metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
189
+ return a < b ? a : b;
190
+ }
191
+
192
+ template <typename T>
193
+ metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
194
+ if (metal::isnan(a) || metal::isnan(b)) {
195
+ return static_cast<T>(NAN);
196
+ } else {
197
+ return a < b ? a : b;
198
+ }
199
+ }
200
+
201
+ template <>
202
+ complex64_t operator()(complex64_t a, complex64_t b) {
203
+ bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);
204
+ bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);
205
+
206
+ if (!real_is_nan && !imag_is_nan) {
207
+ return a < b ? a : b;
208
+ } else if (real_is_nan && !imag_is_nan) {
209
+ return complex64_t(
210
+ static_cast<float>(NAN), a.imag < b.imag ? a.imag : b.imag);
211
+ } else if (!real_is_nan && imag_is_nan) {
212
+ return complex64_t(
213
+ a.real < b.real ? a.real : b.real, static_cast<float>(NAN));
214
+ } else {
215
+ return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));
216
+ }
217
+ };
218
+ };
219
+ template <typename U>
220
+ struct Max {
221
+ DEFINE_SIMD_REDUCE()
222
+
223
+ template <typename T>
224
+ metal::enable_if_t<metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
225
+ return simd_max(val);
226
+ }
227
+
228
+ template <typename T>
229
+ metal::enable_if_t<!metal::is_integral_v<T>, T> simd_reduce_impl(T val) {
230
+ if (simd_any(val != val)) {
231
+ return static_cast<T>(NAN);
232
+ }
233
+ return simd_max(val);
234
+ }
235
+
236
+ static constexpr constant U init = Limits<U>::min;
237
+
238
+ template <typename T>
239
+ void atomic_update(device mlx_atomic<T>* out, T val, size_t offset = 0) {
240
+ mlx_atomic_fetch_max_explicit(out, val, offset);
241
+ }
242
+
243
+ // Operator
244
+ template <typename T>
245
+ metal::enable_if_t<metal::is_integral_v<T>, T> operator()(T a, T b) {
246
+ return a > b ? a : b;
247
+ }
248
+
249
+ template <typename T>
250
+ metal::enable_if_t<!metal::is_integral_v<T>, T> operator()(T a, T b) {
251
+ if (metal::isnan(a) || metal::isnan(b)) {
252
+ return static_cast<T>(NAN);
253
+ } else {
254
+ return a > b ? a : b;
255
+ }
256
+ }
257
+
258
+ template <>
259
+ complex64_t operator()(complex64_t a, complex64_t b) {
260
+ bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real);
261
+ bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag);
262
+
263
+ if (!real_is_nan && !imag_is_nan) {
264
+ return a > b ? a : b;
265
+ } else if (real_is_nan && !imag_is_nan) {
266
+ return complex64_t(
267
+ static_cast<float>(NAN), a.imag > b.imag ? a.imag : b.imag);
268
+ } else if (!real_is_nan && imag_is_nan) {
269
+ return complex64_t(
270
+ a.real > b.real ? a.real : b.real, static_cast<float>(NAN));
271
+ } else {
272
+ return complex64_t(static_cast<float>(NAN), static_cast<float>(NAN));
273
+ }
274
+ }
275
+ };
@@ -0,0 +1,66 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ template <
4
+ typename T,
5
+ typename U,
6
+ typename Op,
7
+ typename IdxT = int64_t,
8
+ int N_READS = REDUCE_N_READS>
9
+ [[kernel]] void all_reduce(
10
+ const device T* in [[buffer(0)]],
11
+ device U* out [[buffer(1)]],
12
+ const constant size_t& in_size [[buffer(2)]],
13
+ const constant size_t& row_size [[buffer(3)]],
14
+ uint3 gid [[threadgroup_position_in_grid]],
15
+ uint3 lid [[thread_position_in_threadgroup]],
16
+ uint3 lsize [[threads_per_threadgroup]],
17
+ uint simd_per_group [[simdgroups_per_threadgroup]],
18
+ uint simd_lane_id [[thread_index_in_simdgroup]],
19
+ uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
20
+ Op op;
21
+ threadgroup U shared_vals[simd_size];
22
+
23
+ U total = Op::init;
24
+ IdxT start_idx = gid.y * IdxT(row_size);
25
+ IdxT actual_row =
26
+ (start_idx + row_size <= in_size) ? row_size : in_size - start_idx;
27
+ IdxT blocks = actual_row / (lsize.x * N_READS);
28
+ int extra = actual_row - blocks * (lsize.x * N_READS);
29
+ extra -= lid.x * N_READS;
30
+ start_idx += lid.x * N_READS;
31
+ in += start_idx;
32
+
33
+ if (extra >= N_READS) {
34
+ blocks++;
35
+ extra = 0;
36
+ }
37
+
38
+ for (IdxT b = 0; b < blocks; b++) {
39
+ for (int i = 0; i < N_READS; i++) {
40
+ total = op(static_cast<U>(in[i]), total);
41
+ }
42
+ in += lsize.x * N_READS;
43
+ }
44
+ if (extra > 0) {
45
+ for (int i = 0; i < extra; i++) {
46
+ total = op(static_cast<U>(in[i]), total);
47
+ }
48
+ }
49
+
50
+ // Reduction within simd group
51
+ total = op.simd_reduce(total);
52
+ if (simd_per_group > 1) {
53
+ if (simd_lane_id == 0) {
54
+ shared_vals[simd_group_id] = total;
55
+ }
56
+
57
+ // Reduction within thread group
58
+ threadgroup_barrier(mem_flags::mem_threadgroup);
59
+ total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init;
60
+ total = op.simd_reduce(total);
61
+ }
62
+
63
+ if (lid.x == 0) {
64
+ out[gid.y] = total;
65
+ }
66
+ }