mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,451 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/steel/utils.h"
6
+
7
+ #include "mlx/backend/metal/kernels/steel/conv/params.h"
8
+
9
+ ///////////////////////////////////////////////////////////////////////////////
10
+ // Loading helper
11
+ ///////////////////////////////////////////////////////////////////////////////
12
+
13
+ namespace mlx {
14
+ namespace steel {
15
+
16
+ template <
17
+ typename T,
18
+ short BM,
19
+ short BN,
20
+ short BK,
21
+ short tgp_size,
22
+ short tgp_padding = 0>
23
+ struct Conv2DInputBlockLoaderLargeFilter {
24
+ // Destination dimensions
25
+ STEEL_CONST short BROWS = BM;
26
+ STEEL_CONST short BCOLS = BK;
27
+
28
+ // Read dimensions
29
+ STEEL_CONST short dst_ld = BCOLS + tgp_padding;
30
+ STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
31
+
32
+ // Thread read shape
33
+ STEEL_CONST short TCOLS = BCOLS / vec_size;
34
+ STEEL_CONST short TROWS = tgp_size / TCOLS;
35
+
36
+ // Rows / strided reads within the block
37
+ STEEL_CONST short n_rows = BROWS / TROWS;
38
+
39
+ // Thread location indices
40
+ const short thread_idx;
41
+ const short bi;
42
+ const short bj;
43
+
44
+ // threadgroup and device memory
45
+ threadgroup T* dst;
46
+
47
+ const constant MLXConvParams<2>* params;
48
+ const constant ImplicitGemmConv2DParams* gemm_params;
49
+
50
+ short weight_h;
51
+ short weight_w;
52
+
53
+ const device T* src[n_rows];
54
+
55
+ int read_n[n_rows];
56
+ int read_ih[n_rows];
57
+ int read_iw[n_rows];
58
+
59
+ /* Constructor */
60
+ METAL_FUNC Conv2DInputBlockLoaderLargeFilter(
61
+ const device T* src_,
62
+ threadgroup T* dst_,
63
+ const int2 offsets,
64
+ const constant MLXConvParams<2>* params_,
65
+ const constant ImplicitGemmConv2DParams* gemm_params_,
66
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
67
+ uint simd_lane_id [[thread_index_in_simdgroup]])
68
+ : thread_idx(simd_group_id * 32 + simd_lane_id),
69
+ bi(thread_idx / TCOLS),
70
+ bj(vec_size * (thread_idx % TCOLS)),
71
+ dst(dst_ + bi * dst_ld + bj),
72
+ params(params_),
73
+ gemm_params(gemm_params_),
74
+ weight_h(0),
75
+ weight_w(0) {
76
+ int out_n_pixels = params->oS[0] * params->oS[1];
77
+
78
+ STEEL_PRAGMA_UNROLL
79
+ for (short i = 0; i < n_rows; ++i) {
80
+ int offset_nhw = offsets.y + bi + i * TROWS;
81
+ int n = offset_nhw / out_n_pixels;
82
+ int hw = offset_nhw % out_n_pixels;
83
+ int oh = hw / params->oS[1];
84
+ int ow = hw % params->oS[1];
85
+
86
+ int ih = oh * params->str[0] - params->pad[0];
87
+ int iw = ow * params->str[1] - params->pad[1];
88
+
89
+ read_n[i] = n;
90
+ read_ih[i] = ih;
91
+ read_iw[i] = iw;
92
+
93
+ // Adjust for flip
94
+ if (params->flip) {
95
+ ih += (params->wS[0] - 1) * params->kdil[0];
96
+ iw += (params->wS[1] - 1) * params->kdil[1];
97
+ }
98
+
99
+ // Read from input if in bounds
100
+ src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
101
+ iw * params->in_strides[2] + bj;
102
+ }
103
+ }
104
+
105
+ /* Load from device memory into threadgroup memory - without bound checking */
106
+ METAL_FUNC void load_unsafe() const {
107
+ STEEL_PRAGMA_UNROLL
108
+ for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
109
+ // Find bounds
110
+ int n = read_n[i];
111
+ int ih = read_ih[i] + weight_h * params->kdil[0];
112
+ int iw = read_iw[i] + weight_w * params->kdil[1];
113
+
114
+ // Read from input if in bounds
115
+ if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
116
+ (iw >= 0 && iw < params->iS[1])) {
117
+ STEEL_PRAGMA_UNROLL
118
+ for (short j = 0; j < vec_size; ++j) {
119
+ dst[is * dst_ld + j] = src[i][j];
120
+ }
121
+ }
122
+
123
+ // Zero pad otherwise
124
+ else {
125
+ STEEL_PRAGMA_UNROLL
126
+ for (short j = 0; j < vec_size; ++j) {
127
+ dst[is * dst_ld + j] = T(0);
128
+ }
129
+ }
130
+ }
131
+ }
132
+
133
+ /* Iteration helper */
134
+ METAL_FUNC void next() {
135
+ if (++weight_w < params->wS[1]) {
136
+ STEEL_PRAGMA_UNROLL
137
+ for (short i = 0; i < n_rows; i++) {
138
+ src[i] += gemm_params->inp_jump_w;
139
+ }
140
+
141
+ return;
142
+ }
143
+
144
+ weight_w = 0;
145
+
146
+ if (++weight_h < params->wS[0]) {
147
+ STEEL_PRAGMA_UNROLL
148
+ for (short i = 0; i < n_rows; i++) {
149
+ src[i] += gemm_params->inp_jump_h;
150
+ }
151
+
152
+ return;
153
+ }
154
+
155
+ weight_h = 0;
156
+
157
+ STEEL_PRAGMA_UNROLL
158
+ for (short i = 0; i < n_rows; i++) {
159
+ src[i] += gemm_params->inp_jump_c;
160
+ }
161
+ }
162
+ };
163
+
164
+ template <
165
+ typename T,
166
+ short BM,
167
+ short BN,
168
+ short BK,
169
+ short tgp_size,
170
+ short tgp_padding = 0>
171
+ struct Conv2DInputBlockLoaderSmallFilter {
172
+ // Destination dimensions
173
+ STEEL_CONST short BROWS = BM;
174
+ STEEL_CONST short BCOLS = BK;
175
+
176
+ // Read dimensions
177
+ STEEL_CONST short dst_ld = BCOLS + tgp_padding;
178
+ STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
179
+
180
+ // Thread read shape
181
+ STEEL_CONST short TCOLS = BCOLS / vec_size;
182
+ STEEL_CONST short TROWS = tgp_size / TCOLS;
183
+
184
+ // Rows / strided reads within the block
185
+ STEEL_CONST short n_rows = BROWS / TROWS;
186
+
187
+ using mask_t = short;
188
+
189
+ // Thread location indices
190
+ const short thread_idx;
191
+ const short bi;
192
+ const short bj;
193
+
194
+ // threadgroup and device memory
195
+ threadgroup T* dst;
196
+
197
+ const constant MLXConvParams<2>* params;
198
+ const constant ImplicitGemmConv2DParams* gemm_params;
199
+
200
+ short weight_h;
201
+ short weight_w;
202
+
203
+ const device T* src[n_rows];
204
+
205
+ mask_t mask_h[n_rows];
206
+ mask_t mask_w[n_rows];
207
+
208
+ /* Constructor */
209
+ METAL_FUNC Conv2DInputBlockLoaderSmallFilter(
210
+ const device T* src_,
211
+ threadgroup T* dst_,
212
+ const int2 offsets,
213
+ const constant MLXConvParams<2>* params_,
214
+ const constant ImplicitGemmConv2DParams* gemm_params_,
215
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
216
+ uint simd_lane_id [[thread_index_in_simdgroup]])
217
+ : thread_idx(simd_group_id * 32 + simd_lane_id),
218
+ bi(thread_idx / TCOLS),
219
+ bj(vec_size * (thread_idx % TCOLS)),
220
+ dst(dst_ + bi * dst_ld + bj),
221
+ params(params_),
222
+ gemm_params(gemm_params_),
223
+ weight_h(0),
224
+ weight_w(0) {
225
+ int out_n_pixels = params->oS[0] * params->oS[1];
226
+
227
+ int read_n[n_rows];
228
+ int read_ih[n_rows];
229
+ int read_iw[n_rows];
230
+
231
+ STEEL_PRAGMA_UNROLL
232
+ for (short i = 0; i < n_rows; ++i) {
233
+ int offset_nhw = offsets.y + bi + i * TROWS;
234
+ int n = offset_nhw / out_n_pixels;
235
+ int hw = offset_nhw % out_n_pixels;
236
+ int oh = hw / params->oS[1];
237
+ int ow = hw % params->oS[1];
238
+
239
+ int ih = oh * params->str[0] - params->pad[0];
240
+ int iw = ow * params->str[1] - params->pad[1];
241
+
242
+ read_n[i] = n;
243
+ read_ih[i] = ih;
244
+ read_iw[i] = iw;
245
+
246
+ // Adjust for flip
247
+ if (params->flip) {
248
+ ih += (params->wS[0] - 1) * params->kdil[0];
249
+ iw += (params->wS[1] - 1) * params->kdil[1];
250
+ }
251
+
252
+ // Read from input if in bounds
253
+ src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
254
+ iw * params->in_strides[2] + bj;
255
+ }
256
+
257
+ STEEL_PRAGMA_UNROLL
258
+ for (short i = 0; i < n_rows; ++i) {
259
+ mask_h[i] = 0;
260
+ mask_w[i] = 0;
261
+ }
262
+
263
+ for (short kh = 0; kh < params->wS[0]; kh++) {
264
+ short flip_h = params->flip ? params->wS[0] - kh - 1 : kh;
265
+ STEEL_PRAGMA_UNROLL
266
+ for (short i = 0; i < n_rows; ++i) {
267
+ int n = read_n[i];
268
+ int ih = read_ih[i] + flip_h * params->kdil[0];
269
+
270
+ bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0];
271
+
272
+ mask_h[i] |= (in_bounds << kh);
273
+ }
274
+ }
275
+
276
+ for (short kw = 0; kw < params->wS[1]; kw++) {
277
+ short flip_w = params->flip ? params->wS[1] - kw - 1 : kw;
278
+ STEEL_PRAGMA_UNROLL
279
+ for (short i = 0; i < n_rows; ++i) {
280
+ int iw = read_iw[i] + flip_w * params->kdil[1];
281
+
282
+ bool in_bounds = iw >= 0 && iw < params->iS[1];
283
+
284
+ mask_w[i] |= (in_bounds << kw);
285
+ }
286
+ }
287
+ }
288
+
289
+ /* Load from device memory into threadgroup memory - without bound checking */
290
+ METAL_FUNC void load_unsafe() const {
291
+ mask_t h_mask = mask_t(1) << weight_h;
292
+ mask_t w_mask = mask_t(1) << weight_w;
293
+
294
+ STEEL_PRAGMA_UNROLL
295
+ for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
296
+ // Read from input if in bounds
297
+ if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) {
298
+ STEEL_PRAGMA_UNROLL
299
+ for (short j = 0; j < vec_size; ++j) {
300
+ dst[is * dst_ld + j] = src[i][j];
301
+ }
302
+ }
303
+
304
+ // Zero pad otherwise
305
+ else {
306
+ STEEL_PRAGMA_UNROLL
307
+ for (short j = 0; j < vec_size; ++j) {
308
+ dst[is * dst_ld + j] = T(0);
309
+ }
310
+ }
311
+ }
312
+ }
313
+
314
+ /* Iteration helper */
315
+ METAL_FUNC void next() {
316
+ if (++weight_w < params->wS[1]) {
317
+ STEEL_PRAGMA_UNROLL
318
+ for (short i = 0; i < n_rows; i++) {
319
+ src[i] += gemm_params->inp_jump_w;
320
+ }
321
+
322
+ return;
323
+ }
324
+
325
+ weight_w = 0;
326
+
327
+ if (++weight_h < params->wS[0]) {
328
+ STEEL_PRAGMA_UNROLL
329
+ for (short i = 0; i < n_rows; i++) {
330
+ src[i] += gemm_params->inp_jump_h;
331
+ }
332
+
333
+ return;
334
+ }
335
+
336
+ weight_h = 0;
337
+
338
+ STEEL_PRAGMA_UNROLL
339
+ for (short i = 0; i < n_rows; i++) {
340
+ src[i] += gemm_params->inp_jump_c;
341
+ }
342
+ }
343
+ };
344
+
345
+ template <
346
+ typename T,
347
+ short BM,
348
+ short BN,
349
+ short BK,
350
+ short tgp_size,
351
+ short tgp_padding = 0>
352
+ struct Conv2DWeightBlockLoader {
353
+ // Destination dimensions
354
+ STEEL_CONST short BROWS = BN;
355
+ STEEL_CONST short BCOLS = BK;
356
+
357
+ // Read dimensions
358
+ STEEL_CONST short dst_ld = BCOLS + tgp_padding;
359
+ STEEL_CONST short vec_size =
360
+ (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
361
+
362
+ // Thread read shape
363
+ STEEL_CONST short TCOLS = BCOLS / vec_size;
364
+ STEEL_CONST short TROWS = tgp_size / TCOLS;
365
+
366
+ // Rows / strided reads within the block
367
+ STEEL_CONST short n_rows = BROWS / TROWS;
368
+
369
+ // Leading dimension for src
370
+ const int src_ld;
371
+
372
+ // Thread location indices
373
+ const short thread_idx;
374
+ const short bi;
375
+ const short bj;
376
+
377
+ // threadgroup and device memory
378
+ threadgroup T* dst;
379
+ const device T* src;
380
+
381
+ const constant MLXConvParams<2>* params;
382
+
383
+ int weight_hw;
384
+ int weight_step;
385
+
386
+ const int read_n;
387
+ const bool do_read;
388
+
389
+ /* Constructor */
390
+ METAL_FUNC Conv2DWeightBlockLoader(
391
+ const device T* src_,
392
+ threadgroup T* dst_,
393
+ const int2 offsets,
394
+ const constant MLXConvParams<2>* params_,
395
+ const constant ImplicitGemmConv2DParams* gemm_params_,
396
+ uint simd_group_id [[simdgroup_index_in_threadgroup]],
397
+ uint simd_lane_id [[thread_index_in_simdgroup]])
398
+ : src_ld(params_->wt_strides[0]),
399
+ thread_idx(simd_group_id * 32 + simd_lane_id),
400
+ bi(thread_idx / TCOLS),
401
+ bj(vec_size * (thread_idx % TCOLS)),
402
+ dst(dst_ + bi * dst_ld + bj),
403
+ src(src_ + bi * src_ld + bj),
404
+ params(params_),
405
+ weight_hw(0),
406
+ weight_step(params->C / params->groups),
407
+ read_n(offsets.y + bi),
408
+ do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
409
+
410
+ /* Load from device memory into threadgroup memory - without bound checking */
411
+ METAL_FUNC void load_unsafe() const {
412
+ if (BN != 8 || do_read) {
413
+ STEEL_PRAGMA_UNROLL
414
+ for (short i = 0; i < BN; i += TROWS) {
415
+ STEEL_PRAGMA_UNROLL
416
+ for (short j = 0; j < vec_size; j++) {
417
+ dst[i * dst_ld + j] = src[i * src_ld + j];
418
+ }
419
+ }
420
+ } else {
421
+ for (short i = 0; i < BN; i += TROWS) {
422
+ if ((read_n + i) < params->O) {
423
+ STEEL_PRAGMA_UNROLL
424
+ for (short j = 0; j < vec_size; j++) {
425
+ dst[i * dst_ld + j] = src[i * src_ld + j];
426
+ }
427
+ } else {
428
+ STEEL_PRAGMA_UNROLL
429
+ for (short j = 0; j < vec_size; j++) {
430
+ dst[i * dst_ld + j] = T(0);
431
+ }
432
+ }
433
+ }
434
+ }
435
+ }
436
+
437
+ /* Iteration helper */
438
+ METAL_FUNC void next() {
439
+ if (++weight_hw < (params->wS[1] * params->wS[0])) {
440
+ src += weight_step;
441
+ return;
442
+ }
443
+
444
+ weight_hw = 0;
445
+
446
+ src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step;
447
+ }
448
+ };
449
+
450
+ } // namespace steel
451
+ } // namespace mlx