mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,319 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/steel/utils.h"
6
+
7
+ #include "mlx/backend/metal/kernels/steel/conv/params.h"
8
+
9
+ ///////////////////////////////////////////////////////////////////////////////
10
+ // Loading helper
11
+ ///////////////////////////////////////////////////////////////////////////////
12
+
13
+ namespace mlx {
14
+ namespace steel {
15
+
16
+ template <short n_channels_>
17
+ struct ChannelHelper {
18
+ STEEL_CONST short n_channels = n_channels_;
19
+ STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8;
20
+ STEEL_CONST short excess = vec_size - n_channels_;
21
+ };
22
+
23
+ template <>
24
+ struct ChannelHelper<1> {
25
+ STEEL_CONST short n_channels = 1;
26
+ STEEL_CONST short vec_size = 1;
27
+ STEEL_CONST short excess = 0;
28
+ };
29
+
30
+ template <>
31
+ struct ChannelHelper<2> {
32
+ STEEL_CONST short n_channels = 2;
33
+ STEEL_CONST short vec_size = 2;
34
+ STEEL_CONST short excess = 0;
35
+ };
36
+
37
+ template <>
38
+ struct ChannelHelper<3> {
39
+ STEEL_CONST short n_channels = 3;
40
+ STEEL_CONST short vec_size = 4;
41
+ STEEL_CONST short excess = 1;
42
+ };
43
+
44
+ template <>
45
+ struct ChannelHelper<4> {
46
+ STEEL_CONST short n_channels = 4;
47
+ STEEL_CONST short vec_size = 4;
48
+ STEEL_CONST short excess = 0;
49
+ };
50
+
51
+ template <
52
+ typename T,
53
+ short BM,
54
+ short BN,
55
+ short BK,
56
+ short tgp_size,
57
+ short n_channels,
58
+ short tgp_padding = 0>
59
+ struct Conv2DInputBlockLoaderSmallChannels {
60
+ // Destination dimensions
61
+ STEEL_CONST short BROWS = BM;
62
+ STEEL_CONST short BCOLS = BK;
63
+
64
+ // Read dimensions
65
+ STEEL_CONST short dst_ld = BCOLS + tgp_padding;
66
+ STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;
67
+
68
+ // Thread read shape
69
+ STEEL_CONST short TCOLS = BCOLS / vec_size;
70
+ STEEL_CONST short TROWS = tgp_size / TCOLS;
71
+
72
+ // Rows / strided reads within the block
73
+ STEEL_CONST short n_rows = BROWS / TROWS;
74
+
75
+ // Thread location indices
76
+ const short thread_idx;
77
+ const short bi;
78
+ const short bj;
79
+
80
+ // threadgroup and device memory
81
+ threadgroup T* dst;
82
+
83
+ const constant MLXConvParams<2>* params;
84
+ const constant ImplicitGemmConv2DParams* gemm_params;
85
+
86
+ int weight_hw;
87
+
88
+ const device T* src[n_rows];
89
+
90
+ int read_n[n_rows];
91
+ int read_ih[n_rows];
92
+ int read_iw[n_rows];
93
+
94
+ /* Constructor */
95
+ METAL_FUNC Conv2DInputBlockLoaderSmallChannels(
96
+ const device T* src_,
97
+ threadgroup T* dst_,
98
+ const int2 offsets,
99
+ const constant MLXConvParams<2>* params_,
100
+ const constant ImplicitGemmConv2DParams* gemm_params_,
101
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
102
+ uint simd_lane_id [[thread_index_in_simdgroup]])
103
+ : thread_idx(simd_group_id * 32 + simd_lane_id),
104
+ bi(thread_idx / TCOLS),
105
+ bj(vec_size * (thread_idx % TCOLS)),
106
+ dst(dst_ + bi * dst_ld + bj),
107
+ params(params_),
108
+ gemm_params(gemm_params_),
109
+ weight_hw(thread_idx % TCOLS) {
110
+ int out_n_pixels = params->oS[0] * params->oS[1];
111
+
112
+ STEEL_PRAGMA_UNROLL
113
+ for (short i = 0; i < n_rows; ++i) {
114
+ int offset_nhw = offsets.y + bi + i * TROWS;
115
+ int n = offset_nhw / out_n_pixels;
116
+ int hw = offset_nhw % out_n_pixels;
117
+ int oh = hw / params->oS[1];
118
+ int ow = hw % params->oS[1];
119
+
120
+ int ih = oh * params->str[0] - params->pad[0];
121
+ int iw = ow * params->str[1] - params->pad[1];
122
+
123
+ // Read from input if in bounds
124
+ src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
125
+ iw * params->in_strides[2];
126
+
127
+ read_n[i] = n;
128
+ read_ih[i] = ih;
129
+ read_iw[i] = iw;
130
+ }
131
+ }
132
+
133
+ /* Load from device memory into threadgroup memory - without bound checking */
134
+ METAL_FUNC void load_unsafe() const {
135
+ if (weight_hw >= params->wS[1] * params->wS[0]) {
136
+ STEEL_PRAGMA_UNROLL
137
+ for (short i = 0; i < BROWS; i += TROWS) {
138
+ STEEL_PRAGMA_UNROLL
139
+ for (short j = 0; j < vec_size; j++) {
140
+ dst[i * dst_ld + j] = T(0);
141
+ }
142
+ }
143
+ return;
144
+ }
145
+
146
+ int wh = (weight_hw / params->wS[1]);
147
+ int ww = (weight_hw % params->wS[1]);
148
+
149
+ int flip_h = params->flip ? params->wS[0] - wh - 1 : wh;
150
+ int flip_w = params->flip ? params->wS[1] - ww - 1 : ww;
151
+
152
+ int weight_h = flip_h * params->kdil[0];
153
+ int weight_w = flip_w * params->kdil[1];
154
+
155
+ STEEL_PRAGMA_UNROLL
156
+ for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
157
+ // Find bounds
158
+ int n = read_n[i];
159
+ int ih = read_ih[i] + weight_h;
160
+ int iw = read_iw[i] + weight_w;
161
+
162
+ // Read from input if in bounds
163
+ if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
164
+ (iw >= 0 && iw < params->iS[1])) {
165
+ const device T* curr_src = src[i] + weight_h * params->in_strides[1] +
166
+ weight_w * params->in_strides[2];
167
+
168
+ STEEL_PRAGMA_UNROLL
169
+ for (short j = 0; j < n_channels; ++j) {
170
+ dst[is * dst_ld + j] = curr_src[j];
171
+ }
172
+
173
+ STEEL_PRAGMA_UNROLL
174
+ for (short j = n_channels; j < vec_size; ++j) {
175
+ dst[is * dst_ld + j] = T(0);
176
+ }
177
+ }
178
+
179
+ // Zero pad otherwise
180
+ else {
181
+ STEEL_PRAGMA_UNROLL
182
+ for (short j = 0; j < vec_size; ++j) {
183
+ dst[is * dst_ld + j] = T(0);
184
+ }
185
+ }
186
+ }
187
+ }
188
+
189
+ /* Iteration helper */
190
+ METAL_FUNC void next() {
191
+ weight_hw += TCOLS;
192
+ }
193
+ };
194
+
195
+ template <
196
+ typename T,
197
+ short BM,
198
+ short BN,
199
+ short BK,
200
+ short tgp_size,
201
+ short n_channels,
202
+ short tgp_padding = 0>
203
+ struct Conv2DWeightBlockLoaderSmallChannels {
204
+ // Destination dimensions
205
+ STEEL_CONST short BROWS = BN;
206
+ STEEL_CONST short BCOLS = BK;
207
+
208
+ // Read dimensions
209
+ STEEL_CONST short dst_ld = BCOLS + tgp_padding;
210
+ STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;
211
+
212
+ // Thread read shape
213
+ STEEL_CONST short TCOLS = BCOLS / vec_size;
214
+ STEEL_CONST short TROWS = tgp_size / TCOLS;
215
+
216
+ // Rows / strided reads within the block
217
+ STEEL_CONST short n_rows = BROWS / TROWS;
218
+
219
+ // Leading dimension for src
220
+ const int src_ld;
221
+
222
+ // Thread location indices
223
+ const short thread_idx;
224
+ const short bi;
225
+ const short bj;
226
+
227
+ // threadgroup and device memory
228
+ threadgroup T* dst;
229
+ const device T* src;
230
+
231
+ const constant MLXConvParams<2>* params;
232
+
233
+ int weight_hw;
234
+
235
+ const int read_n;
236
+ const bool do_read;
237
+
238
+ /* Constructor */
239
+ METAL_FUNC Conv2DWeightBlockLoaderSmallChannels(
240
+ const device T* src_,
241
+ threadgroup T* dst_,
242
+ const int2 offsets,
243
+ const constant MLXConvParams<2>* params_,
244
+ const constant ImplicitGemmConv2DParams* gemm_params_,
245
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
246
+ uint simd_lane_id [[thread_index_in_simdgroup]])
247
+ : src_ld(params_->wt_strides[0]),
248
+ thread_idx(simd_group_id * 32 + simd_lane_id),
249
+ bi(thread_idx / TCOLS),
250
+ bj(vec_size * (thread_idx % TCOLS)),
251
+ dst(dst_ + bi * dst_ld + bj),
252
+ src(src_ + bi * src_ld),
253
+ params(params_),
254
+ weight_hw(thread_idx % TCOLS),
255
+ read_n(offsets.y + bi),
256
+ do_read(read_n + BN <= gemm_params_->N) {}
257
+
258
+ /* Load from device memory into threadgroup memory - without bound checking */
259
+ METAL_FUNC void load_unsafe() const {
260
+ if (bi >= BROWS || bj >= BCOLS)
261
+ return;
262
+
263
+ if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) {
264
+ STEEL_PRAGMA_UNROLL
265
+ for (short i = 0; i < BROWS; i += TROWS) {
266
+ STEEL_PRAGMA_UNROLL
267
+ for (short j = 0; j < vec_size; j++) {
268
+ dst[i * dst_ld + j] = T(0);
269
+ }
270
+ }
271
+
272
+ return;
273
+ }
274
+
275
+ const device T* curr_src = src + weight_hw * (params->C / params->groups);
276
+
277
+ if (BN != 8 || do_read) {
278
+ STEEL_PRAGMA_UNROLL
279
+ for (short i = 0; i < BROWS; i += TROWS) {
280
+ STEEL_PRAGMA_UNROLL
281
+ for (short j = 0; j < n_channels; j++) {
282
+ dst[i * dst_ld + j] = curr_src[i * src_ld + j];
283
+ }
284
+
285
+ STEEL_PRAGMA_UNROLL
286
+ for (short j = n_channels; j < vec_size; j++) {
287
+ dst[i * dst_ld + j] = T(0);
288
+ }
289
+ }
290
+ } else {
291
+ for (short i = 0; i < BROWS; i += TROWS) {
292
+ if (((read_n + i) < params->O)) {
293
+ STEEL_PRAGMA_UNROLL
294
+ for (short j = 0; j < n_channels; j++) {
295
+ dst[i * dst_ld + j] = curr_src[i * src_ld + j];
296
+ }
297
+
298
+ STEEL_PRAGMA_UNROLL
299
+ for (short j = n_channels; j < vec_size; j++) {
300
+ dst[i * dst_ld + j] = T(0);
301
+ }
302
+ } else {
303
+ STEEL_PRAGMA_UNROLL
304
+ for (short j = 0; j < vec_size; j++) {
305
+ dst[i * dst_ld + j] = T(0);
306
+ }
307
+ }
308
+ }
309
+ }
310
+ }
311
+
312
+ /* Iteration helper */
313
+ METAL_FUNC void next() {
314
+ weight_hw += TCOLS;
315
+ }
316
+ };
317
+
318
+ } // namespace steel
319
+ } // namespace mlx
@@ -0,0 +1,381 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/steel/defines.h"
6
+
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+ // Loading helper
9
+ ///////////////////////////////////////////////////////////////////////////////
10
+
11
+ namespace mlx {
12
+ namespace steel {
13
+
14
+ template <
15
+ typename T,
16
+ short BM,
17
+ short BN,
18
+ short BK,
19
+ short tgp_size,
20
+ short tgp_padding = 0>
21
+ struct Conv2DInputBlockLoaderGeneral {
22
+ // Destination dimensions
23
+ STEEL_CONST short BROWS = BM;
24
+ STEEL_CONST short BCOLS = BK;
25
+
26
+ // Read dimensions
27
+ STEEL_CONST short dst_ld = BCOLS + tgp_padding;
28
+ STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
29
+
30
+ // Thread read shape
31
+ STEEL_CONST short TCOLS = BCOLS / vec_size;
32
+ STEEL_CONST short TROWS = tgp_size / TCOLS;
33
+
34
+ // Rows / strided reads within the block
35
+ STEEL_CONST short n_rows = BROWS / TROWS;
36
+
37
+ // Thread location indices
38
+ const short thread_idx;
39
+ const short bi;
40
+ const short bj;
41
+
42
+ // threadgroup and device memory
43
+ threadgroup T* dst;
44
+
45
+ const constant MLXConvParams<2>* params;
46
+ const constant Conv2DGeneralJumpParams* jump_params;
47
+
48
+ const short base_wh;
49
+ const short base_ww;
50
+
51
+ short weight_h;
52
+ short weight_w;
53
+
54
+ const device T* src[n_rows];
55
+
56
+ int read_n[n_rows];
57
+ int read_ih[n_rows];
58
+ int read_iw[n_rows];
59
+
60
+ /* Constructor */
61
+ METAL_FUNC Conv2DInputBlockLoaderGeneral(
62
+ const device T* src_,
63
+ threadgroup T* dst_,
64
+ const int4 offsets,
65
+ const constant MLXConvParams<2>* params_,
66
+ const constant Conv2DGeneralJumpParams* jump_params_,
67
+ const short base_wh_,
68
+ const short base_ww_,
69
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
70
+ uint simd_lane_id [[thread_index_in_simdgroup]])
71
+ : thread_idx(simd_group_id * 32 + simd_lane_id),
72
+ bi(thread_idx / TCOLS),
73
+ bj(vec_size * (thread_idx % TCOLS)),
74
+ dst(dst_ + bi * dst_ld + bj),
75
+ params(params_),
76
+ jump_params(jump_params_),
77
+ base_wh(base_wh_),
78
+ base_ww(base_ww_),
79
+ weight_h(base_wh_),
80
+ weight_w(base_ww_) {
81
+ STEEL_PRAGMA_UNROLL
82
+ for (short i = 0; i < n_rows; ++i) {
83
+ int offset_nhw = offsets.y + bi + i * TROWS;
84
+ int n = offset_nhw / jump_params->adj_out_hw;
85
+ int hw = offset_nhw % jump_params->adj_out_hw;
86
+ int oh =
87
+ (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z;
88
+ int ow =
89
+ (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w;
90
+
91
+ int ih = oh * params->str[0] - params->pad[0];
92
+ int iw = ow * params->str[1] - params->pad[1];
93
+
94
+ read_n[i] = n;
95
+ read_ih[i] = ih;
96
+ read_iw[i] = iw;
97
+
98
+ // Read from input if in bounds
99
+ src[i] = src_ + n * params->in_strides[0] + bj;
100
+ }
101
+ }
102
+
103
+ /* Load from device memory into threadgroup memory - without bound checking */
104
+ METAL_FUNC void load_unsafe() const {
105
+ STEEL_PRAGMA_UNROLL
106
+ for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
107
+ // Find bounds
108
+ int n = read_n[i];
109
+
110
+ int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
111
+ int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
112
+
113
+ int ih_dil = read_ih[i] + h_flip * params->kdil[0];
114
+ int iw_dil = read_iw[i] + w_flip * params->kdil[1];
115
+
116
+ int ih = ih_dil / params->idil[0];
117
+ int iw = iw_dil / params->idil[1];
118
+
119
+ size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
120
+
121
+ // Read from input if in bounds
122
+ if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
123
+ (iw_dil >= 0 && iw < params->iS[1])) {
124
+ STEEL_PRAGMA_UNROLL
125
+ for (short j = 0; j < vec_size; ++j) {
126
+ dst[is * dst_ld + j] = (src[i])[offset + j];
127
+ }
128
+ }
129
+
130
+ // Zero pad otherwise
131
+ else {
132
+ STEEL_PRAGMA_UNROLL
133
+ for (short j = 0; j < vec_size; ++j) {
134
+ dst[is * dst_ld + j] = T(0);
135
+ }
136
+ }
137
+ }
138
+ }
139
+
140
+ METAL_FUNC void load_safe(const short remaining_k) const {
141
+ STEEL_PRAGMA_UNROLL
142
+ for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
143
+ // Find bounds
144
+ int n = read_n[i];
145
+
146
+ int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
147
+ int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
148
+
149
+ int ih_dil = read_ih[i] + h_flip * params->kdil[0];
150
+ int iw_dil = read_iw[i] + w_flip * params->kdil[1];
151
+
152
+ int ih = ih_dil / params->idil[0];
153
+ int iw = iw_dil / params->idil[1];
154
+
155
+ size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
156
+
157
+ // Read from input if in bounds
158
+ if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
159
+ (iw_dil >= 0 && iw < params->iS[1])) {
160
+ if (bj + vec_size <= remaining_k) {
161
+ STEEL_PRAGMA_UNROLL
162
+ for (short j = 0; j < vec_size; ++j) {
163
+ dst[is * dst_ld + j] = (src[i])[offset + j];
164
+ }
165
+ } else {
166
+ for (short j = 0; j < vec_size; ++j) {
167
+ if (bj + j < remaining_k) {
168
+ dst[is * dst_ld + j] = (src[i])[offset + j];
169
+ } else {
170
+ dst[is * dst_ld + j] = T(0);
171
+ }
172
+ }
173
+ }
174
+ }
175
+
176
+ // Zero pad otherwise
177
+ else {
178
+ STEEL_PRAGMA_UNROLL
179
+ for (short j = 0; j < vec_size; ++j) {
180
+ dst[is * dst_ld + j] = T(0);
181
+ }
182
+ }
183
+ }
184
+ }
185
+
186
+ /* Iteration helper */
187
+ METAL_FUNC void next() {
188
+ weight_w += jump_params->f_wgt_jump_w;
189
+ if (weight_w < params->wS[1]) {
190
+ return;
191
+ }
192
+
193
+ weight_w = base_ww;
194
+
195
+ weight_h += jump_params->f_wgt_jump_h;
196
+ if (weight_h < params->wS[0]) {
197
+ return;
198
+ }
199
+
200
+ weight_h = base_wh;
201
+
202
+ STEEL_PRAGMA_UNROLL
203
+ for (short i = 0; i < n_rows; i++) {
204
+ src[i] += BK;
205
+ }
206
+ }
207
+ };
208
+
209
+ template <
210
+ typename T,
211
+ short BM,
212
+ short BN,
213
+ short BK,
214
+ short tgp_size,
215
+ short tgp_padding = 0>
216
+ struct Conv2DWeightBlockLoaderGeneral {
217
+ // Destination dimensions
218
+ STEEL_CONST short BROWS = BN;
219
+ STEEL_CONST short BCOLS = BK;
220
+
221
+ // Read dimensions
222
+ STEEL_CONST short dst_ld = BCOLS + tgp_padding;
223
+ STEEL_CONST short vec_size =
224
+ (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
225
+
226
+ // Thread read shape
227
+ STEEL_CONST short TCOLS = BCOLS / vec_size;
228
+ STEEL_CONST short TROWS = tgp_size / TCOLS;
229
+
230
+ // Rows / strided reads within the block
231
+ STEEL_CONST short n_rows = BROWS / TROWS;
232
+
233
+ // Leading dimension for src
234
+ const int src_ld;
235
+
236
+ // Thread location indices
237
+ const short thread_idx;
238
+ const short bi;
239
+ const short bj;
240
+
241
+ // threadgroup and device memory
242
+ threadgroup T* dst;
243
+ const device T* src;
244
+
245
+ const constant MLXConvParams<2>* params;
246
+ const constant Conv2DGeneralJumpParams* jump_params;
247
+
248
+ const short base_wh;
249
+ const short base_ww;
250
+
251
+ short weight_h;
252
+ short weight_w;
253
+
254
+ const int start_row;
255
+
256
+ /* Constructor */
257
+ METAL_FUNC Conv2DWeightBlockLoaderGeneral(
258
+ const device T* src_,
259
+ threadgroup T* dst_,
260
+ const int2 offsets,
261
+ const constant MLXConvParams<2>* params_,
262
+ const constant Conv2DGeneralJumpParams* jump_params_,
263
+ const short base_wh_,
264
+ const short base_ww_,
265
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
266
+ uint simd_lane_id [[thread_index_in_simdgroup]])
267
+ : src_ld(params_->wt_strides[0]),
268
+ thread_idx(simd_group_id * 32 + simd_lane_id),
269
+ bi(thread_idx / TCOLS),
270
+ bj(vec_size * (thread_idx % TCOLS)),
271
+ dst(dst_ + bi * dst_ld + bj),
272
+ src(src_ + bi * src_ld + bj),
273
+ params(params_),
274
+ jump_params(jump_params_),
275
+ base_wh(base_wh_),
276
+ base_ww(base_ww_),
277
+ weight_h(base_wh_),
278
+ weight_w(base_ww_),
279
+ start_row(offsets.y + bi) {}
280
+
281
+ /* Load from device memory into threadgroup memory - without bound checking */
282
+ METAL_FUNC void load_unsafe() const {
283
+ const device T* curr_src = src + weight_h * params->wt_strides[1] +
284
+ weight_w * params->wt_strides[2];
285
+
286
+ if ((start_row + BN <= params->O)) {
287
+ STEEL_PRAGMA_UNROLL
288
+ for (short i = 0; i < BN; i += TROWS) {
289
+ STEEL_PRAGMA_UNROLL
290
+ for (short j = 0; j < vec_size; j++) {
291
+ dst[i * dst_ld + j] = curr_src[i * src_ld + j];
292
+ }
293
+ }
294
+ } else {
295
+ for (short i = 0; i < BN; i += TROWS) {
296
+ if ((start_row + i) < params->O) {
297
+ STEEL_PRAGMA_UNROLL
298
+ for (short j = 0; j < vec_size; j++) {
299
+ dst[i * dst_ld + j] = curr_src[i * src_ld + j];
300
+ }
301
+ } else {
302
+ STEEL_PRAGMA_UNROLL
303
+ for (short j = 0; j < vec_size; j++) {
304
+ dst[i * dst_ld + j] = T(0);
305
+ }
306
+ }
307
+ }
308
+ }
309
+ }
310
+
311
+ METAL_FUNC void load_safe(const short remaining_k) const {
312
+ const device T* curr_src = src + weight_h * params->wt_strides[1] +
313
+ weight_w * params->wt_strides[2];
314
+
315
+ if ((start_row + BN <= params->O)) {
316
+ STEEL_PRAGMA_UNROLL
317
+ for (short i = 0; i < BN; i += TROWS) {
318
+ if (bj + vec_size <= remaining_k) {
319
+ STEEL_PRAGMA_UNROLL
320
+ for (short j = 0; j < vec_size; j++) {
321
+ dst[i * dst_ld + j] = curr_src[i * src_ld + j];
322
+ }
323
+ } else {
324
+ for (short j = 0; j < vec_size; j++) {
325
+ if (bj + j < remaining_k) {
326
+ dst[i * dst_ld + j] = curr_src[i * src_ld + j];
327
+ } else {
328
+ dst[i * dst_ld + j] = T(0);
329
+ }
330
+ }
331
+ }
332
+ }
333
+ } else {
334
+ for (short i = 0; i < BN; i += TROWS) {
335
+ if ((start_row + i) < params->O) {
336
+ if (bj + vec_size <= remaining_k) {
337
+ STEEL_PRAGMA_UNROLL
338
+ for (short j = 0; j < vec_size; j++) {
339
+ dst[i * dst_ld + j] = curr_src[i * src_ld + j];
340
+ }
341
+ } else {
342
+ for (short j = 0; j < vec_size; j++) {
343
+ if (bj + j < remaining_k) {
344
+ dst[i * dst_ld + j] = curr_src[i * src_ld + j];
345
+ } else {
346
+ dst[i * dst_ld + j] = T(0);
347
+ }
348
+ }
349
+ }
350
+ } else {
351
+ STEEL_PRAGMA_UNROLL
352
+ for (short j = 0; j < vec_size; j++) {
353
+ dst[i * dst_ld + j] = T(0);
354
+ }
355
+ }
356
+ }
357
+ }
358
+ }
359
+
360
+ /* Iteration helper */
361
+ METAL_FUNC void next() {
362
+ weight_w += jump_params->f_wgt_jump_w;
363
+ if (weight_w < params->wS[1]) {
364
+ return;
365
+ }
366
+
367
+ weight_w = base_ww;
368
+
369
+ weight_h += jump_params->f_wgt_jump_h;
370
+ if (weight_h < params->wS[0]) {
371
+ return;
372
+ }
373
+
374
+ weight_h = base_wh;
375
+
376
+ src += BK;
377
+ }
378
+ };
379
+
380
+ } // namespace steel
381
+ } // namespace mlx