mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. mlx/__main__.py +27 -0
  2. mlx/_reprlib_fix.py +16 -0
  3. mlx/extension.py +88 -0
  4. mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
  5. mlx/include/mlx/allocator.h +73 -0
  6. mlx/include/mlx/array.h +645 -0
  7. mlx/include/mlx/backend/common/binary.h +97 -0
  8. mlx/include/mlx/backend/common/broadcasting.h +11 -0
  9. mlx/include/mlx/backend/common/buffer_cache.h +157 -0
  10. mlx/include/mlx/backend/common/compiled.h +77 -0
  11. mlx/include/mlx/backend/common/copy.h +50 -0
  12. mlx/include/mlx/backend/common/hadamard.h +109 -0
  13. mlx/include/mlx/backend/common/matmul.h +67 -0
  14. mlx/include/mlx/backend/common/reduce.h +59 -0
  15. mlx/include/mlx/backend/common/slicing.h +20 -0
  16. mlx/include/mlx/backend/common/ternary.h +85 -0
  17. mlx/include/mlx/backend/common/unary.h +29 -0
  18. mlx/include/mlx/backend/common/utils.h +205 -0
  19. mlx/include/mlx/backend/cpu/arange.h +28 -0
  20. mlx/include/mlx/backend/cpu/available.h +9 -0
  21. mlx/include/mlx/backend/cpu/binary.h +517 -0
  22. mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
  23. mlx/include/mlx/backend/cpu/binary_two.h +166 -0
  24. mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
  25. mlx/include/mlx/backend/cpu/copy.h +36 -0
  26. mlx/include/mlx/backend/cpu/encoder.h +67 -0
  27. mlx/include/mlx/backend/cpu/eval.h +12 -0
  28. mlx/include/mlx/backend/cpu/gemm.h +26 -0
  29. mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  30. mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
  31. mlx/include/mlx/backend/cpu/lapack.h +80 -0
  32. mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  33. mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  34. mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
  35. mlx/include/mlx/backend/cpu/simd/math.h +193 -0
  36. mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  37. mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
  38. mlx/include/mlx/backend/cpu/simd/type.h +11 -0
  39. mlx/include/mlx/backend/cpu/slicing.h +21 -0
  40. mlx/include/mlx/backend/cpu/ternary.h +154 -0
  41. mlx/include/mlx/backend/cpu/threefry.h +21 -0
  42. mlx/include/mlx/backend/cpu/unary.h +281 -0
  43. mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
  44. mlx/include/mlx/backend/cuda/allocator.h +89 -0
  45. mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
  46. mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
  47. mlx/include/mlx/backend/cuda/cuda.h +10 -0
  48. mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
  49. mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
  50. mlx/include/mlx/backend/cuda/device/config.h +12 -0
  51. mlx/include/mlx/backend/cuda/device.h +189 -0
  52. mlx/include/mlx/backend/cuda/event.h +78 -0
  53. mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  54. mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
  55. mlx/include/mlx/backend/cuda/jit_module.h +119 -0
  56. mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
  57. mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  58. mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
  59. mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  60. mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
  61. mlx/include/mlx/backend/cuda/utils.h +46 -0
  62. mlx/include/mlx/backend/cuda/worker.h +55 -0
  63. mlx/include/mlx/backend/gpu/available.h +9 -0
  64. mlx/include/mlx/backend/gpu/copy.h +57 -0
  65. mlx/include/mlx/backend/gpu/eval.h +18 -0
  66. mlx/include/mlx/backend/gpu/slicing.h +36 -0
  67. mlx/include/mlx/backend/metal/allocator.h +79 -0
  68. mlx/include/mlx/backend/metal/binary.h +33 -0
  69. mlx/include/mlx/backend/metal/device.h +283 -0
  70. mlx/include/mlx/backend/metal/jit/includes.h +57 -0
  71. mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
  72. mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
  73. mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
  74. mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
  75. mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
  76. mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
  77. mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
  78. mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
  79. mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
  80. mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
  81. mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
  82. mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
  83. mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
  84. mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
  85. mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
  86. mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  87. mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
  88. mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
  89. mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
  90. mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
  91. mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
  92. mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  93. mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
  94. mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  95. mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  96. mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  97. mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  98. mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
  99. mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  100. mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  101. mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
  102. mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
  103. mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  104. mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  105. mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
  106. mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  107. mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  108. mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  109. mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  110. mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  111. mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  112. mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
  113. mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
  114. mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
  115. mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
  116. mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  117. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
  118. mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  119. mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  120. mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  121. mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  122. mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  123. mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  124. mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  125. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  126. mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  127. mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  128. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  129. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  130. mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  131. mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  132. mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
  133. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  134. mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
  135. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  136. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
  137. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  138. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
  139. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  140. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  141. mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  142. mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  143. mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  144. mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  145. mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
  146. mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  147. mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  148. mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  149. mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
  150. mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
  151. mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  152. mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
  153. mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
  154. mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
  155. mlx/include/mlx/backend/metal/matmul.h +144 -0
  156. mlx/include/mlx/backend/metal/metal.h +22 -0
  157. mlx/include/mlx/backend/metal/reduce.h +41 -0
  158. mlx/include/mlx/backend/metal/resident.h +32 -0
  159. mlx/include/mlx/backend/metal/scan.h +17 -0
  160. mlx/include/mlx/backend/metal/ternary.h +21 -0
  161. mlx/include/mlx/backend/metal/unary.h +21 -0
  162. mlx/include/mlx/backend/metal/utils.h +84 -0
  163. mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
  164. mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
  165. mlx/include/mlx/compile.h +44 -0
  166. mlx/include/mlx/compile_impl.h +69 -0
  167. mlx/include/mlx/device.h +31 -0
  168. mlx/include/mlx/distributed/distributed.h +60 -0
  169. mlx/include/mlx/distributed/distributed_impl.h +59 -0
  170. mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
  171. mlx/include/mlx/distributed/mpi/mpi.h +12 -0
  172. mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
  173. mlx/include/mlx/distributed/nccl/nccl.h +12 -0
  174. mlx/include/mlx/distributed/ops.h +56 -0
  175. mlx/include/mlx/distributed/primitives.h +156 -0
  176. mlx/include/mlx/distributed/reduction_ops.h +38 -0
  177. mlx/include/mlx/distributed/ring/ring.h +12 -0
  178. mlx/include/mlx/distributed/utils.h +67 -0
  179. mlx/include/mlx/dtype.h +115 -0
  180. mlx/include/mlx/dtype_utils.h +119 -0
  181. mlx/include/mlx/einsum.h +22 -0
  182. mlx/include/mlx/event.h +58 -0
  183. mlx/include/mlx/export.h +136 -0
  184. mlx/include/mlx/export_impl.h +98 -0
  185. mlx/include/mlx/fast.h +102 -0
  186. mlx/include/mlx/fast_primitives.h +427 -0
  187. mlx/include/mlx/fence.h +39 -0
  188. mlx/include/mlx/fft.h +167 -0
  189. mlx/include/mlx/graph_utils.h +66 -0
  190. mlx/include/mlx/io/gguf.h +20 -0
  191. mlx/include/mlx/io/load.h +175 -0
  192. mlx/include/mlx/io.h +61 -0
  193. mlx/include/mlx/linalg.h +111 -0
  194. mlx/include/mlx/memory.h +78 -0
  195. mlx/include/mlx/mlx.h +25 -0
  196. mlx/include/mlx/ops.h +1627 -0
  197. mlx/include/mlx/primitives.h +2524 -0
  198. mlx/include/mlx/random.h +282 -0
  199. mlx/include/mlx/scheduler.h +188 -0
  200. mlx/include/mlx/small_vector.h +540 -0
  201. mlx/include/mlx/stream.h +41 -0
  202. mlx/include/mlx/threadpool.h +133 -0
  203. mlx/include/mlx/transforms.h +229 -0
  204. mlx/include/mlx/transforms_impl.h +86 -0
  205. mlx/include/mlx/types/bf16.h +187 -0
  206. mlx/include/mlx/types/complex.h +113 -0
  207. mlx/include/mlx/types/fp16.h +234 -0
  208. mlx/include/mlx/types/half_types.h +58 -0
  209. mlx/include/mlx/types/limits.h +70 -0
  210. mlx/include/mlx/utils.h +175 -0
  211. mlx/include/mlx/version.h +20 -0
  212. mlx/lib/libmlx.so +0 -0
  213. mlx/py.typed +1 -0
  214. mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
  215. mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
  216. mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
  217. mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  218. mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  219. mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
  220. mlx/share/cmake/MLX/extension.cmake +50 -0
  221. mlx/utils.py +325 -0
  222. mlx_cpu-0.30.1.dist-info/METADATA +142 -0
  223. mlx_cpu-0.30.1.dist-info/RECORD +231 -0
  224. mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
  225. mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
  226. mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
  227. mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
  228. mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
  229. mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
  230. mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
  231. mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
@@ -0,0 +1,264 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/metal/kernels/steel/defines.h"
6
+
7
+ ///////////////////////////////////////////////////////////////////////////////
8
+ // Loading helper
9
+ ///////////////////////////////////////////////////////////////////////////////
10
+
11
+ namespace mlx {
12
+ namespace steel {
13
+
14
+ template <
15
+ typename T,
16
+ short BROWS,
17
+ short BCOLS,
18
+ short dst_ld,
19
+ short reduction_dim,
20
+ short tgp_size,
21
+ short alignment = 1,
22
+ short n_reads = (BCOLS * BROWS) / (tgp_size),
23
+ short TCOLS = BCOLS / n_reads,
24
+ short TROWS = tgp_size / TCOLS>
25
+ struct BlockLoader {
26
+ STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
27
+ STEEL_CONST short vec_size = n_reads;
28
+
29
+ // Leading dimension for src
30
+ const int src_ld;
31
+ const int tile_stride;
32
+
33
+ // Thread location indices
34
+ const short thread_idx;
35
+ const short bi;
36
+ const short bj;
37
+
38
+ // threadgroup and device memory
39
+ threadgroup T* dst;
40
+ const device T* src;
41
+
42
+ struct alignas(alignment * sizeof(T)) ReadVector {
43
+ uint8_t v[sizeof(T) * vec_size];
44
+ };
45
+
46
+ /* Constructor */
47
+ METAL_FUNC BlockLoader(
48
+ const device T* src_,
49
+ const int src_ld_,
50
+ threadgroup T* dst_,
51
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
52
+ ushort simd_lane_id [[thread_index_in_simdgroup]])
53
+ : src_ld(src_ld_),
54
+ tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
55
+ thread_idx(simd_group_id * 32 + simd_lane_id),
56
+ bi(thread_idx / TCOLS),
57
+ bj(vec_size * (thread_idx % TCOLS)),
58
+ dst(dst_ + bi * dst_ld + bj),
59
+ src(src_ + bi * src_ld + bj) {}
60
+
61
+ /* Apply operation to threadgroup without bound checking */
62
+ template <typename UnaryOp>
63
+ METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
64
+ STEEL_PRAGMA_UNROLL
65
+ for (short i = 0; i < BROWS; i += TROWS) {
66
+ STEEL_PRAGMA_UNROLL
67
+ for (short j = 0; j < vec_size; j++) {
68
+ dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
69
+ }
70
+ }
71
+ }
72
+
73
+ /* Load from device memory into threadgroup memory - without bound checking */
74
+ METAL_FUNC void load_unsafe() const {
75
+ STEEL_PRAGMA_UNROLL
76
+ for (short i = 0; i < BROWS; i += TROWS) {
77
+ *((threadgroup ReadVector*)(&dst[i * dst_ld])) =
78
+ *((const device ReadVector*)(&src[i * src_ld]));
79
+ }
80
+ }
81
+
82
+ /* Load from device memory into threadgroup memory - with bound checking */
83
+ METAL_FUNC void load_safe(short2 src_tile_dim) const {
84
+ src_tile_dim = src_tile_dim - short2(bj, bi);
85
+
86
+ // Skip loading if thread has no valid reads
87
+ if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
88
+ STEEL_PRAGMA_UNROLL
89
+ for (short i = 0; i < BROWS; i += TROWS) {
90
+ STEEL_PRAGMA_UNROLL
91
+ for (short j = 0; j < vec_size; j++) {
92
+ dst[i * dst_ld + j] = T(0);
93
+ }
94
+ }
95
+ return;
96
+ }
97
+
98
+ // Use fast thread memory for bound checks
99
+ bool tmp_idx[vec_size];
100
+ T tmp_val[vec_size];
101
+
102
+ STEEL_PRAGMA_UNROLL
103
+ for (short i = 0; i < BROWS; i += TROWS) {
104
+ // Make sure tmp_idx only contains valid indices
105
+ STEEL_PRAGMA_UNROLL
106
+ for (short j = 0; j < vec_size; j++) {
107
+ tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
108
+ }
109
+
110
+ // Read valid indices into tmp_val
111
+ STEEL_PRAGMA_UNROLL
112
+ for (short j = 0; j < vec_size; j++) {
113
+ tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
114
+ }
115
+
116
+ // Zero out unneeded values
117
+ STEEL_PRAGMA_UNROLL
118
+ for (short j = 0; j < vec_size; j++) {
119
+ tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
120
+ }
121
+
122
+ // Copy values to threadgroup memory
123
+ STEEL_PRAGMA_UNROLL
124
+ for (short j = 0; j < vec_size; j++) {
125
+ dst[i * dst_ld + j] = tmp_val[j];
126
+ }
127
+ }
128
+ }
129
+
130
+ /* Iteration helper */
131
+ METAL_FUNC void next() {
132
+ src += tile_stride;
133
+ }
134
+ };
135
+
136
+ template <int R, int C>
137
+ struct CShape {
138
+ STEEL_CONST int kRows = R;
139
+ STEEL_CONST int kCols = C;
140
+ };
141
+
142
+ template <
143
+ typename T,
144
+ short BROWS,
145
+ short BCOLS,
146
+ short kDstStrRow,
147
+ short kDstStrCol,
148
+ short reduction_dim,
149
+ short tgp_size,
150
+ short n_reads = (BCOLS * BROWS) / (tgp_size),
151
+ short TCOLS = BCOLS / n_reads,
152
+ short TROWS = tgp_size / TCOLS>
153
+ struct BlockLoaderT {
154
+ STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
155
+ STEEL_CONST short vec_size = n_reads;
156
+
157
+ // Leading dimension for src
158
+ const int src_ld;
159
+ const int tile_stride;
160
+
161
+ // Thread location indices
162
+ const short thread_idx;
163
+ const short bi;
164
+ const short bj;
165
+
166
+ // threadgroup and device memory
167
+ threadgroup T* dst;
168
+ const device T* src;
169
+
170
+ /* Constructor */
171
+ METAL_FUNC BlockLoaderT(
172
+ const device T* src_,
173
+ const int src_ld_,
174
+ threadgroup T* dst_,
175
+ ushort simd_group_id [[simdgroup_index_in_threadgroup]],
176
+ ushort simd_lane_id [[thread_index_in_simdgroup]])
177
+ : src_ld(src_ld_),
178
+ tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
179
+ thread_idx(simd_group_id * 32 + simd_lane_id),
180
+ bi(thread_idx / TCOLS),
181
+ bj(vec_size * (thread_idx % TCOLS)),
182
+ dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
183
+ src(src_ + bi * src_ld + bj) {}
184
+
185
+ /* Apply operation to threadgroup without bound checking */
186
+ template <typename UnaryOp>
187
+ METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
188
+ STEEL_PRAGMA_UNROLL
189
+ for (short i = 0; i < BROWS; i += TROWS) {
190
+ STEEL_PRAGMA_UNROLL
191
+ for (short j = 0; j < vec_size; j++) {
192
+ dst[i * kDstStrRow + j * kDstStrCol] =
193
+ op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
194
+ }
195
+ }
196
+ }
197
+
198
+ /* Load from device memory into threadgroup memory - without bound checking */
199
+ METAL_FUNC void load_unsafe() const {
200
+ STEEL_PRAGMA_UNROLL
201
+ for (short i = 0; i < BROWS; i += TROWS) {
202
+ STEEL_PRAGMA_UNROLL
203
+ for (short j = 0; j < vec_size; j++) {
204
+ dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
205
+ }
206
+ }
207
+ }
208
+
209
+ /* Load from device memory into threadgroup memory - with bound checking */
210
+ METAL_FUNC void load_safe(short2 src_tile_dim) const {
211
+ src_tile_dim = src_tile_dim - short2(bj, bi);
212
+
213
+ // Skip loading if thread has no valid reads
214
+ if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
215
+ STEEL_PRAGMA_UNROLL
216
+ for (short i = 0; i < BROWS; i += TROWS) {
217
+ STEEL_PRAGMA_UNROLL
218
+ for (short j = 0; j < vec_size; j++) {
219
+ dst[i * kDstStrRow + j * kDstStrCol] = T(0);
220
+ }
221
+ }
222
+ return;
223
+ }
224
+
225
+ // Use fast thread memory for bound checks
226
+ bool tmp_idx[vec_size];
227
+ T tmp_val[vec_size];
228
+
229
+ STEEL_PRAGMA_UNROLL
230
+ for (short i = 0; i < BROWS; i += TROWS) {
231
+ // Make sure tmp_idx only contains valid indices
232
+ STEEL_PRAGMA_UNROLL
233
+ for (short j = 0; j < vec_size; j++) {
234
+ tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
235
+ }
236
+
237
+ // Read valid indices into tmp_val
238
+ STEEL_PRAGMA_UNROLL
239
+ for (short j = 0; j < vec_size; j++) {
240
+ tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
241
+ }
242
+
243
+ // Zero out unneeded values
244
+ STEEL_PRAGMA_UNROLL
245
+ for (short j = 0; j < vec_size; j++) {
246
+ tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
247
+ }
248
+
249
+ // Copy values to threadgroup memory
250
+ STEEL_PRAGMA_UNROLL
251
+ for (short j = 0; j < vec_size; j++) {
252
+ dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
253
+ }
254
+ }
255
+ }
256
+
257
+ /* Iteration helper */
258
+ METAL_FUNC void next() {
259
+ src += tile_stride;
260
+ }
261
+ };
262
+
263
+ } // namespace steel
264
+ } // namespace mlx