mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,76 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ constexpr std::string_view gather_kernels = R"(
4
+ [[kernel]] void gather{0}_{3}_{6}_{7}(
5
+ const device {1}* src [[buffer(0)]],
6
+ device {1}* out [[buffer(1)]],
7
+ const constant int* src_shape [[buffer(2)]],
8
+ const constant int64_t* src_strides [[buffer(3)]],
9
+ const constant size_t& src_ndim [[buffer(4)]],
10
+ const constant int* slice_sizes [[buffer(5)]],
11
+ const constant int* axes [[buffer(6)]],
12
+ const constant int* idx_shapes [[buffer(7)]],
13
+ const constant int64_t* idx_strides [[buffer(8)]],
14
+ const constant bool* idx_contigs [[buffer(9)]],
15
+ const constant int& idx_ndim [[buffer(10)]],
16
+ {4}
17
+ uint3 index [[thread_position_in_grid]],
18
+ uint3 grid_dim [[threads_per_grid]]) {{
19
+ Indices<{2}, {3}> idxs{{
20
+ {{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
21
+
22
+ return gather_impl<{1}, {2}, {3}, {6}, {7}>(
23
+ src,
24
+ out,
25
+ src_shape,
26
+ src_strides,
27
+ src_ndim,
28
+ slice_sizes,
29
+ axes,
30
+ idxs,
31
+ index,
32
+ grid_dim);
33
+ }}
34
+ )";
35
+
36
+ constexpr std::string_view scatter_kernels = R"(
37
+ [[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}(
38
+ const device {1}* updates [[buffer(1)]],
39
+ device mlx_atomic<{1}>* out [[buffer(2)]],
40
+ const constant int* upd_shape [[buffer(3)]],
41
+ const constant int64_t* upd_strides [[buffer(4)]],
42
+ const constant size_t& upd_ndim [[buffer(5)]],
43
+ const constant size_t& upd_size [[buffer(6)]],
44
+ const constant int* out_shape [[buffer(7)]],
45
+ const constant int64_t* out_strides [[buffer(8)]],
46
+ const constant size_t& out_ndim [[buffer(9)]],
47
+ const constant int* axes [[buffer(10)]],
48
+ const constant int* idx_shapes [[buffer(11)]],
49
+ const constant int64_t* idx_strides [[buffer(12)]],
50
+ const constant bool* idx_contigs [[buffer(13)]],
51
+ const constant int& idx_ndim [[buffer(14)]],
52
+ const constant size_t& idx_size [[buffer(15)]],
53
+ {5}
54
+ uint2 gid [[thread_position_in_grid]]) {{
55
+ Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
56
+
57
+ return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>(
58
+ updates,
59
+ out,
60
+ upd_shape,
61
+ upd_strides,
62
+ upd_ndim,
63
+ upd_size,
64
+ out_shape,
65
+ out_strides,
66
+ out_ndim,
67
+ axes,
68
+ idx_size,
69
+ idxs,
70
+ gid);
71
+ }}
72
+ )";
73
+
74
+ constexpr std::string_view masked_assign_kernel = R"(
75
+ template [[host_name("{0}")]] [[kernel]] decltype(masked_assign_impl<{1}, {2}>) masked_assign_impl<{1}, {2}>;
76
+ )";
@@ -0,0 +1,9 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+ template <typename T>
3
+ [[kernel]] void arange(
4
+ constant const T& start,
5
+ constant const T& step,
6
+ device T* out,
7
+ uint index [[thread_position_in_grid]]) {
8
+ out[index] = start + index * step;
9
+ }
@@ -0,0 +1,345 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_atomic>
6
+ #include <metal_stdlib>
7
+
8
+ using namespace metal;
9
+
10
+ ///////////////////////////////////////////////////////////////////////////////
11
+ // Atomic utils
12
+ ///////////////////////////////////////////////////////////////////////////////
13
+
14
+ #pragma METAL internals : enable
15
+ template <typename T>
16
+ constexpr constant bool is_metal_atomic = _disjunction<
17
+ is_same<T, int>,
18
+ is_same<T, uint>,
19
+ is_same<T, ulong>,
20
+ is_same<T, float>>::value;
21
+
22
+ #pragma METAL internals : disable
23
+
24
+ template <typename T, typename = void>
25
+ struct mlx_atomic {
26
+ atomic<uint> val;
27
+ };
28
+
29
+ template <typename T>
30
+ struct mlx_atomic<T, enable_if_t<is_metal_atomic<T>>> {
31
+ atomic<T> val;
32
+ };
33
+
34
+ ///////////////////////////////////////////////////////////////////////////////
35
+ // Native metal atomics
36
+ ///////////////////////////////////////////////////////////////////////////////
37
+
38
+ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
39
+ METAL_FUNC T
40
+ mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
41
+ return atomic_load_explicit(&(object[offset].val), memory_order_relaxed);
42
+ }
43
+
44
+ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
45
+ METAL_FUNC void
46
+ mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
47
+ atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed);
48
+ }
49
+
50
+ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
51
+ METAL_FUNC void mlx_atomic_fetch_and_explicit(
52
+ device mlx_atomic<T>* object,
53
+ T val,
54
+ size_t offset) {
55
+ atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed);
56
+ }
57
+
58
+ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
59
+ METAL_FUNC void mlx_atomic_fetch_or_explicit(
60
+ device mlx_atomic<T>* object,
61
+ T val,
62
+ size_t offset) {
63
+ atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed);
64
+ }
65
+
66
+ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
67
+ METAL_FUNC void mlx_atomic_fetch_min_explicit(
68
+ device mlx_atomic<T>* object,
69
+ T val,
70
+ size_t offset) {
71
+ atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed);
72
+ }
73
+
74
+ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
75
+ METAL_FUNC void mlx_atomic_fetch_max_explicit(
76
+ device mlx_atomic<T>* object,
77
+ T val,
78
+ size_t offset) {
79
+ atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed);
80
+ }
81
+
82
+ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
83
+ METAL_FUNC void mlx_atomic_fetch_add_explicit(
84
+ device mlx_atomic<T>* object,
85
+ T val,
86
+ size_t offset) {
87
+ atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed);
88
+ }
89
+
90
+ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
91
+ METAL_FUNC void mlx_atomic_fetch_mul_explicit(
92
+ device mlx_atomic<T>* object,
93
+ T val,
94
+ size_t offset) {
95
+ T expected = mlx_atomic_load_explicit(object, offset);
96
+ while (!mlx_atomic_compare_exchange_weak_explicit(
97
+ object, &expected, val * expected, offset)) {
98
+ }
99
+ }
100
+
101
+ template <typename T, enable_if_t<is_metal_atomic<T>, bool> = true>
102
+ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
103
+ device mlx_atomic<T>* object,
104
+ thread T* expected,
105
+ T val,
106
+ size_t offset) {
107
+ return atomic_compare_exchange_weak_explicit(
108
+ &(object[offset].val),
109
+ expected,
110
+ val,
111
+ memory_order_relaxed,
112
+ memory_order_relaxed);
113
+ }
114
+
115
+ // Specialization for float since it does not atomic_fetch_min_explicit
116
+ template <>
117
+ METAL_FUNC void mlx_atomic_fetch_min_explicit<float>(
118
+ device mlx_atomic<float>* object,
119
+ float val,
120
+ size_t offset) {
121
+ float expected = mlx_atomic_load_explicit(object, offset);
122
+ while (val < expected) {
123
+ if (mlx_atomic_compare_exchange_weak_explicit(
124
+ object, &expected, val, offset)) {
125
+ return;
126
+ }
127
+ }
128
+ }
129
+
130
+ // Specialization for float since it does not atomic_fetch_max_explicit
131
+ template <>
132
+ METAL_FUNC void mlx_atomic_fetch_max_explicit<float>(
133
+ device mlx_atomic<float>* object,
134
+ float val,
135
+ size_t offset) {
136
+ float expected = mlx_atomic_load_explicit(object, offset);
137
+ while (val > expected) {
138
+ if (mlx_atomic_compare_exchange_weak_explicit(
139
+ object, &expected, val, offset)) {
140
+ return;
141
+ }
142
+ }
143
+ }
144
+
145
+ ///////////////////////////////////////////////////////////////////////////////
146
+ // Custom atomics
147
+ ///////////////////////////////////////////////////////////////////////////////
148
+
149
+ namespace {
150
+
151
+ template <typename T>
152
+ constexpr constant uint packing_size = sizeof(uint) / sizeof(T);
153
+
154
+ template <typename T>
155
+ union uint_or_packed {
156
+ T val[packing_size<T>];
157
+ uint bits;
158
+ };
159
+
160
+ template <typename T, typename Op>
161
+ struct mlx_atomic_update_helper {
162
+ uint operator()(uint_or_packed<T> init, T update, size_t elem_offset) {
163
+ Op op;
164
+ init.val[elem_offset] = op(update, init.val[elem_offset]);
165
+ return init.bits;
166
+ }
167
+ };
168
+
169
+ template <typename T, typename Op>
170
+ METAL_FUNC void mlx_atomic_update_and_store(
171
+ device mlx_atomic<T>* object,
172
+ T update,
173
+ size_t offset) {
174
+ size_t pack_offset = offset / packing_size<T>;
175
+ size_t elem_offset = offset % packing_size<T>;
176
+
177
+ mlx_atomic_update_helper<T, Op> helper;
178
+ uint_or_packed<T> expected;
179
+ expected.bits =
180
+ atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
181
+
182
+ while (Op::condition(update, expected.val[elem_offset]) &&
183
+ !mlx_atomic_compare_exchange_weak_explicit(
184
+ object,
185
+ &(expected.bits),
186
+ helper(expected, update, elem_offset),
187
+ pack_offset)) {
188
+ }
189
+ }
190
+
191
+ template <typename T>
192
+ struct __None {
193
+ static bool condition(T a, T b) {
194
+ #pragma unused(a)
195
+ #pragma unused(b)
196
+ return true;
197
+ }
198
+
199
+ T operator()(T a, T b) {
200
+ #pragma unused(b)
201
+ return a;
202
+ }
203
+ };
204
+
205
+ template <typename T>
206
+ struct __Add {
207
+ static bool condition(T a, T b) {
208
+ #pragma unused(a)
209
+ #pragma unused(b)
210
+ return true;
211
+ }
212
+
213
+ T operator()(T a, T b) {
214
+ return a + b;
215
+ }
216
+ };
217
+
218
+ template <typename T>
219
+ struct __Mul {
220
+ static bool condition(T a, T b) {
221
+ #pragma unused(a)
222
+ return b != 0;
223
+ }
224
+
225
+ T operator()(T a, T b) {
226
+ return a * b;
227
+ }
228
+ };
229
+
230
+ template <typename T>
231
+ struct __Max {
232
+ static bool condition(T a, T b) {
233
+ return a > b;
234
+ }
235
+
236
+ T operator()(T a, T b) {
237
+ return max(a, b);
238
+ }
239
+ };
240
+
241
+ template <typename T>
242
+ struct __Min {
243
+ static bool condition(T a, T b) {
244
+ return a < b;
245
+ }
246
+
247
+ T operator()(T a, T b) {
248
+ return min(a, b);
249
+ }
250
+ };
251
+
252
+ } // namespace
253
+
254
+ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
255
+ METAL_FUNC T
256
+ mlx_atomic_load_explicit(device mlx_atomic<T>* object, size_t offset) {
257
+ size_t pack_offset = offset / sizeof(T);
258
+ size_t elem_offset = offset % sizeof(T);
259
+ uint_or_packed<T> packed_val;
260
+ packed_val.bits =
261
+ atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed);
262
+ return packed_val.val[elem_offset];
263
+ }
264
+
265
+ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
266
+ METAL_FUNC void
267
+ mlx_atomic_store_explicit(device mlx_atomic<T>* object, T val, size_t offset) {
268
+ mlx_atomic_update_and_store<T, __None<T>>(object, val, offset);
269
+ }
270
+
271
+ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
272
+ METAL_FUNC void mlx_atomic_fetch_and_explicit(
273
+ device mlx_atomic<T>* object,
274
+ T val,
275
+ size_t offset) {
276
+ size_t pack_offset = offset / packing_size<T>;
277
+ size_t elem_offset = offset % packing_size<T>;
278
+ uint_or_packed<T> identity;
279
+ identity.bits = __UINT32_MAX__;
280
+ identity.val[elem_offset] = val;
281
+
282
+ atomic_fetch_and_explicit(
283
+ &(object[pack_offset].val), identity.bits, memory_order_relaxed);
284
+ }
285
+
286
+ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
287
+ METAL_FUNC void mlx_atomic_fetch_or_explicit(
288
+ device mlx_atomic<T>* object,
289
+ T val,
290
+ size_t offset) {
291
+ size_t pack_offset = offset / packing_size<T>;
292
+ size_t elem_offset = offset % packing_size<T>;
293
+ uint_or_packed<T> identity;
294
+ identity.bits = 0;
295
+ identity.val[elem_offset] = val;
296
+
297
+ atomic_fetch_or_explicit(
298
+ &(object[pack_offset].val), identity.bits, memory_order_relaxed);
299
+ }
300
+
301
+ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
302
+ METAL_FUNC void mlx_atomic_fetch_min_explicit(
303
+ device mlx_atomic<T>* object,
304
+ T val,
305
+ size_t offset) {
306
+ mlx_atomic_update_and_store<T, __Min<T>>(object, val, offset);
307
+ }
308
+
309
+ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
310
+ METAL_FUNC void mlx_atomic_fetch_max_explicit(
311
+ device mlx_atomic<T>* object,
312
+ T val,
313
+ size_t offset) {
314
+ mlx_atomic_update_and_store<T, __Max<T>>(object, val, offset);
315
+ }
316
+
317
+ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
318
+ METAL_FUNC void mlx_atomic_fetch_add_explicit(
319
+ device mlx_atomic<T>* object,
320
+ T val,
321
+ size_t offset) {
322
+ mlx_atomic_update_and_store<T, __Add<T>>(object, val, offset);
323
+ }
324
+
325
+ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
326
+ METAL_FUNC void mlx_atomic_fetch_mul_explicit(
327
+ device mlx_atomic<T>* object,
328
+ T val,
329
+ size_t offset) {
330
+ mlx_atomic_update_and_store<T, __Mul<T>>(object, val, offset);
331
+ }
332
+
333
+ template <typename T, enable_if_t<!is_metal_atomic<T>, bool> = true>
334
+ METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit(
335
+ device mlx_atomic<T>* object,
336
+ thread uint* expected,
337
+ uint val,
338
+ size_t offset) {
339
+ return atomic_compare_exchange_weak_explicit(
340
+ &(object[offset].val),
341
+ expected,
342
+ val,
343
+ memory_order_relaxed,
344
+ memory_order_relaxed);
345
+ }
@@ -0,0 +1,16 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <metal_stdlib>
6
+
7
+ using namespace metal;
8
+
9
+ typedef bfloat bfloat16_t;
10
+ inline uint16_t bfloat16_to_uint16(const bfloat16_t x) {
11
+ return as_type<uint16_t>(x);
12
+ }
13
+
14
+ inline bfloat16_t uint16_to_bfloat16(const uint16_t x) {
15
+ return as_type<bfloat16_t>(x);
16
+ }