mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
using namespace mlx::steel;
|
|
4
|
+
|
|
5
|
+
constant bool segments_contiguous [[function_constant(199)]];
|
|
6
|
+
constant bool align_M [[function_constant(200)]];
|
|
7
|
+
constant bool align_N [[function_constant(201)]];
|
|
8
|
+
|
|
9
|
+
template <
|
|
10
|
+
typename T,
|
|
11
|
+
int BM,
|
|
12
|
+
int BN,
|
|
13
|
+
int BK,
|
|
14
|
+
int WM,
|
|
15
|
+
int WN,
|
|
16
|
+
bool transpose_a,
|
|
17
|
+
bool transpose_b,
|
|
18
|
+
typename AccumType = float>
|
|
19
|
+
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm(
|
|
20
|
+
const device T* A [[buffer(0)]],
|
|
21
|
+
const device T* B [[buffer(1)]],
|
|
22
|
+
const device uint32_t* segments [[buffer(2)]],
|
|
23
|
+
device T* C [[buffer(3)]],
|
|
24
|
+
const constant GEMMParams* params [[buffer(4)]],
|
|
25
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
26
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
27
|
+
uint3 tid [[threadgroup_position_in_grid]]) {
|
|
28
|
+
using gemm_kernel = GEMMKernel<
|
|
29
|
+
T,
|
|
30
|
+
T,
|
|
31
|
+
BM,
|
|
32
|
+
BN,
|
|
33
|
+
BK,
|
|
34
|
+
WM,
|
|
35
|
+
WN,
|
|
36
|
+
transpose_a,
|
|
37
|
+
transpose_b,
|
|
38
|
+
true,
|
|
39
|
+
true,
|
|
40
|
+
AccumType>;
|
|
41
|
+
|
|
42
|
+
using loader_a_t = typename gemm_kernel::loader_a_t;
|
|
43
|
+
using loader_b_t = typename gemm_kernel::loader_b_t;
|
|
44
|
+
using mma_t = typename gemm_kernel::mma_t;
|
|
45
|
+
|
|
46
|
+
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
|
47
|
+
params->tiles_m <= static_cast<int>(tid.y)) {
|
|
48
|
+
return;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
// Prepare threadgroup memory
|
|
52
|
+
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
53
|
+
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
54
|
+
|
|
55
|
+
// Find the block in A, B, C
|
|
56
|
+
const int c_row = tid.y * BM;
|
|
57
|
+
const int c_col = tid.x * BN;
|
|
58
|
+
const size_t c_row_long = size_t(c_row);
|
|
59
|
+
const size_t c_col_long = size_t(c_col);
|
|
60
|
+
|
|
61
|
+
// Prepare threadgroup bounds
|
|
62
|
+
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
|
63
|
+
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
|
64
|
+
|
|
65
|
+
// Move the pointers to the output tile
|
|
66
|
+
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
|
67
|
+
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
|
68
|
+
C += c_row_long * params->ldd + c_col_long;
|
|
69
|
+
|
|
70
|
+
// Move the pointers to the start of the segment
|
|
71
|
+
uint32_t k_start, k_end;
|
|
72
|
+
if (segments_contiguous) {
|
|
73
|
+
k_start = segments[2 * tid.z];
|
|
74
|
+
k_end = segments[2 * tid.z + 1];
|
|
75
|
+
} else {
|
|
76
|
+
// We accept either contiguous (above) or weird strides where the beginning
|
|
77
|
+
// of the next one is the previous one. Basically the last two strides are
|
|
78
|
+
// both 1!
|
|
79
|
+
k_start = segments[tid.z];
|
|
80
|
+
k_end = segments[tid.z + 1];
|
|
81
|
+
}
|
|
82
|
+
A += transpose_a ? k_start * params->lda : k_start;
|
|
83
|
+
B += transpose_b ? k_start : k_start * params->ldb;
|
|
84
|
+
C += tid.z * params->batch_stride_d;
|
|
85
|
+
|
|
86
|
+
// Prepare threadgroup mma operation
|
|
87
|
+
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
88
|
+
|
|
89
|
+
// Prepare threadgroup loading operations
|
|
90
|
+
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
91
|
+
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
92
|
+
|
|
93
|
+
// Matrix level alignment so only check K
|
|
94
|
+
if (align_M && align_N) {
|
|
95
|
+
uint32_t k = k_start + BK;
|
|
96
|
+
for (; k <= k_end; k += BK) {
|
|
97
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
98
|
+
|
|
99
|
+
// Load elements into threadgroup
|
|
100
|
+
loader_a.load_unsafe();
|
|
101
|
+
loader_b.load_unsafe();
|
|
102
|
+
|
|
103
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
104
|
+
|
|
105
|
+
// Multiply and accumulate threadgroup elements
|
|
106
|
+
mma_op.mma(As, Bs);
|
|
107
|
+
|
|
108
|
+
// Prepare for next iteration
|
|
109
|
+
loader_a.next();
|
|
110
|
+
loader_b.next();
|
|
111
|
+
}
|
|
112
|
+
short k_remain = BK - short(k - k_end);
|
|
113
|
+
const short2 tile_dims_A =
|
|
114
|
+
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
|
115
|
+
const short2 tile_dims_B =
|
|
116
|
+
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
117
|
+
if (k_remain > 0) {
|
|
118
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
119
|
+
loader_a.load_safe(tile_dims_A);
|
|
120
|
+
loader_b.load_safe(tile_dims_B);
|
|
121
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
122
|
+
mma_op.mma(As, Bs);
|
|
123
|
+
}
|
|
124
|
+
mma_op.store_result(C, params->ldd);
|
|
125
|
+
} else {
|
|
126
|
+
// Tile aligned do the same as above
|
|
127
|
+
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
|
128
|
+
uint32_t k = k_start + BK;
|
|
129
|
+
for (; k <= k_end; k += BK) {
|
|
130
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
131
|
+
|
|
132
|
+
// Load elements into threadgroup
|
|
133
|
+
loader_a.load_unsafe();
|
|
134
|
+
loader_b.load_unsafe();
|
|
135
|
+
|
|
136
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
137
|
+
|
|
138
|
+
// Multiply and accumulate threadgroup elements
|
|
139
|
+
mma_op.mma(As, Bs);
|
|
140
|
+
|
|
141
|
+
// Prepare for next iteration
|
|
142
|
+
loader_a.next();
|
|
143
|
+
loader_b.next();
|
|
144
|
+
}
|
|
145
|
+
short k_remain = BK - short(k - k_end);
|
|
146
|
+
const short2 tile_dims_A =
|
|
147
|
+
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
|
148
|
+
const short2 tile_dims_B =
|
|
149
|
+
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
150
|
+
if (k_remain > 0) {
|
|
151
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
152
|
+
loader_a.load_safe(tile_dims_A);
|
|
153
|
+
loader_b.load_safe(tile_dims_B);
|
|
154
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
155
|
+
mma_op.mma(As, Bs);
|
|
156
|
+
}
|
|
157
|
+
mma_op.store_result(C, params->ldd);
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
// Tile partially aligned check rows
|
|
161
|
+
else if (align_N || tgp_bn == BN) {
|
|
162
|
+
uint32_t k = k_start + BK;
|
|
163
|
+
for (; k <= k_end; k += BK) {
|
|
164
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
165
|
+
|
|
166
|
+
// Load elements into threadgroup
|
|
167
|
+
loader_a.load_safe(
|
|
168
|
+
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
|
|
169
|
+
loader_b.load_unsafe();
|
|
170
|
+
|
|
171
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
172
|
+
|
|
173
|
+
// Multiply and accumulate threadgroup elements
|
|
174
|
+
mma_op.mma(As, Bs);
|
|
175
|
+
|
|
176
|
+
// Prepare for next iteration
|
|
177
|
+
loader_a.next();
|
|
178
|
+
loader_b.next();
|
|
179
|
+
}
|
|
180
|
+
short k_remain = BK - short(k - k_end);
|
|
181
|
+
const short2 tile_dims_A =
|
|
182
|
+
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
|
183
|
+
const short2 tile_dims_B =
|
|
184
|
+
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
185
|
+
if (k_remain > 0) {
|
|
186
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
187
|
+
loader_a.load_safe(tile_dims_A);
|
|
188
|
+
loader_b.load_safe(tile_dims_B);
|
|
189
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
190
|
+
mma_op.mma(As, Bs);
|
|
191
|
+
}
|
|
192
|
+
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
// Tile partially aligned check cols
|
|
196
|
+
else if (align_M || tgp_bm == BM) {
|
|
197
|
+
uint32_t k = k_start + BK;
|
|
198
|
+
for (; k <= k_end; k += BK) {
|
|
199
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
200
|
+
|
|
201
|
+
// Load elements into threadgroup
|
|
202
|
+
loader_a.load_unsafe();
|
|
203
|
+
loader_b.load_safe(
|
|
204
|
+
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
|
|
205
|
+
|
|
206
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
207
|
+
|
|
208
|
+
// Multiply and accumulate threadgroup elements
|
|
209
|
+
mma_op.mma(As, Bs);
|
|
210
|
+
|
|
211
|
+
// Prepare for next iteration
|
|
212
|
+
loader_a.next();
|
|
213
|
+
loader_b.next();
|
|
214
|
+
}
|
|
215
|
+
short k_remain = BK - short(k - k_end);
|
|
216
|
+
const short2 tile_dims_A =
|
|
217
|
+
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
|
218
|
+
const short2 tile_dims_B =
|
|
219
|
+
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
220
|
+
if (k_remain > 0) {
|
|
221
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
222
|
+
loader_a.load_safe(tile_dims_A);
|
|
223
|
+
loader_b.load_safe(tile_dims_B);
|
|
224
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
225
|
+
mma_op.mma(As, Bs);
|
|
226
|
+
}
|
|
227
|
+
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
// Nothing aligned so check both rows and cols
|
|
231
|
+
else {
|
|
232
|
+
uint32_t k = k_start + BK;
|
|
233
|
+
for (; k <= k_end; k += BK) {
|
|
234
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
235
|
+
|
|
236
|
+
// Load elements into threadgroup
|
|
237
|
+
loader_a.load_safe(
|
|
238
|
+
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm));
|
|
239
|
+
loader_b.load_safe(
|
|
240
|
+
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
|
|
241
|
+
|
|
242
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
243
|
+
|
|
244
|
+
// Multiply and accumulate threadgroup elements
|
|
245
|
+
mma_op.mma(As, Bs);
|
|
246
|
+
|
|
247
|
+
// Prepare for next iteration
|
|
248
|
+
loader_a.next();
|
|
249
|
+
loader_b.next();
|
|
250
|
+
}
|
|
251
|
+
short k_remain = BK - short(k - k_end);
|
|
252
|
+
const short2 tile_dims_A =
|
|
253
|
+
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
|
254
|
+
const short2 tile_dims_B =
|
|
255
|
+
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
256
|
+
if (k_remain > 0) {
|
|
257
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
258
|
+
loader_a.load_safe(tile_dims_A);
|
|
259
|
+
loader_b.load_safe(tile_dims_B);
|
|
260
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
261
|
+
mma_op.mma(As, Bs);
|
|
262
|
+
}
|
|
263
|
+
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
|
264
|
+
}
|
|
265
|
+
}
|
|
266
|
+
}
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
using namespace mlx::steel;
|
|
4
|
+
|
|
5
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
6
|
+
// GEMM kernels
|
|
7
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
8
|
+
|
|
9
|
+
template <
|
|
10
|
+
typename T,
|
|
11
|
+
typename U,
|
|
12
|
+
int BM,
|
|
13
|
+
int BN,
|
|
14
|
+
int BK,
|
|
15
|
+
int WM,
|
|
16
|
+
int WN,
|
|
17
|
+
bool transpose_a,
|
|
18
|
+
bool transpose_b,
|
|
19
|
+
bool MN_aligned,
|
|
20
|
+
bool K_aligned>
|
|
21
|
+
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk(
|
|
22
|
+
const device T* A [[buffer(0)]],
|
|
23
|
+
const device T* B [[buffer(1)]],
|
|
24
|
+
device U* C [[buffer(2)]],
|
|
25
|
+
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
|
26
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
27
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
28
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
29
|
+
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
30
|
+
(void)lid;
|
|
31
|
+
|
|
32
|
+
using gemm_kernel = GEMMKernel<
|
|
33
|
+
T,
|
|
34
|
+
U,
|
|
35
|
+
BM,
|
|
36
|
+
BN,
|
|
37
|
+
BK,
|
|
38
|
+
WM,
|
|
39
|
+
WN,
|
|
40
|
+
transpose_a,
|
|
41
|
+
transpose_b,
|
|
42
|
+
MN_aligned,
|
|
43
|
+
K_aligned>;
|
|
44
|
+
using loader_a_t = typename gemm_kernel::loader_a_t;
|
|
45
|
+
using loader_b_t = typename gemm_kernel::loader_b_t;
|
|
46
|
+
using mma_t = typename gemm_kernel::mma_t;
|
|
47
|
+
|
|
48
|
+
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
49
|
+
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
50
|
+
|
|
51
|
+
const int tid_x = tid.x;
|
|
52
|
+
const int tid_y = tid.y;
|
|
53
|
+
const int tid_z = tid.z;
|
|
54
|
+
|
|
55
|
+
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
56
|
+
return;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
// Find block in A, B, C
|
|
60
|
+
const int c_row = tid_y * BM;
|
|
61
|
+
const int c_col = tid_x * BN;
|
|
62
|
+
const int k_start = params->split_k_partition_size * tid_z;
|
|
63
|
+
|
|
64
|
+
const size_t c_row_long = size_t(c_row);
|
|
65
|
+
const size_t c_col_long = size_t(c_col);
|
|
66
|
+
const size_t k_start_long = size_t(k_start);
|
|
67
|
+
|
|
68
|
+
A += transpose_a ? (c_row_long + k_start_long * params->lda)
|
|
69
|
+
: (k_start_long + c_row_long * params->lda);
|
|
70
|
+
B += transpose_b ? (k_start_long + c_col_long * params->ldb)
|
|
71
|
+
: (c_col_long + k_start_long * params->ldb);
|
|
72
|
+
C += (size_t(params->split_k_partition_stride) * tid_z) +
|
|
73
|
+
(c_row_long * params->ldc + c_col_long);
|
|
74
|
+
|
|
75
|
+
// Prepare threadgroup loading operations
|
|
76
|
+
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
77
|
+
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
78
|
+
|
|
79
|
+
// Prepare threadgroup mma operation
|
|
80
|
+
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
81
|
+
|
|
82
|
+
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
83
|
+
|
|
84
|
+
short tgp_bm = min(BM, params->M - c_row);
|
|
85
|
+
short tgp_bn = min(BN, params->N - c_col);
|
|
86
|
+
short leftover_bk = params->K % BK;
|
|
87
|
+
|
|
88
|
+
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
|
89
|
+
gemm_kernel::gemm_loop(
|
|
90
|
+
As,
|
|
91
|
+
Bs,
|
|
92
|
+
gemm_k_iterations,
|
|
93
|
+
loader_a,
|
|
94
|
+
loader_b,
|
|
95
|
+
mma_op,
|
|
96
|
+
tgp_bm,
|
|
97
|
+
tgp_bn,
|
|
98
|
+
leftover_bk,
|
|
99
|
+
LoopAlignment<true, true, true>{});
|
|
100
|
+
} else if (tgp_bn == BN) {
|
|
101
|
+
gemm_kernel::gemm_loop(
|
|
102
|
+
As,
|
|
103
|
+
Bs,
|
|
104
|
+
gemm_k_iterations,
|
|
105
|
+
loader_a,
|
|
106
|
+
loader_b,
|
|
107
|
+
mma_op,
|
|
108
|
+
tgp_bm,
|
|
109
|
+
tgp_bn,
|
|
110
|
+
leftover_bk,
|
|
111
|
+
LoopAlignment<false, true, true>{});
|
|
112
|
+
} else if (tgp_bm == BM) {
|
|
113
|
+
gemm_kernel::gemm_loop(
|
|
114
|
+
As,
|
|
115
|
+
Bs,
|
|
116
|
+
gemm_k_iterations,
|
|
117
|
+
loader_a,
|
|
118
|
+
loader_b,
|
|
119
|
+
mma_op,
|
|
120
|
+
tgp_bm,
|
|
121
|
+
tgp_bn,
|
|
122
|
+
leftover_bk,
|
|
123
|
+
LoopAlignment<true, false, true>{});
|
|
124
|
+
} else {
|
|
125
|
+
gemm_kernel::gemm_loop(
|
|
126
|
+
As,
|
|
127
|
+
Bs,
|
|
128
|
+
gemm_k_iterations,
|
|
129
|
+
loader_a,
|
|
130
|
+
loader_b,
|
|
131
|
+
mma_op,
|
|
132
|
+
tgp_bm,
|
|
133
|
+
tgp_bn,
|
|
134
|
+
leftover_bk,
|
|
135
|
+
LoopAlignment<false, false, true>{});
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
139
|
+
|
|
140
|
+
if ((tid_z + 1) == (params->split_k_partitions)) {
|
|
141
|
+
int gemm_k_iter_remaining =
|
|
142
|
+
(params->K - (k_start + params->split_k_partition_size)) / BK;
|
|
143
|
+
if (!K_aligned || gemm_k_iter_remaining > 0)
|
|
144
|
+
gemm_kernel::gemm_loop(
|
|
145
|
+
As,
|
|
146
|
+
Bs,
|
|
147
|
+
gemm_k_iter_remaining,
|
|
148
|
+
loader_a,
|
|
149
|
+
loader_b,
|
|
150
|
+
mma_op,
|
|
151
|
+
tgp_bm,
|
|
152
|
+
tgp_bn,
|
|
153
|
+
leftover_bk,
|
|
154
|
+
LoopAlignment<false, false, K_aligned>{});
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
|
158
|
+
mma_op.store_result(C, params->ldc);
|
|
159
|
+
} else {
|
|
160
|
+
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
165
|
+
// Split k accumulation kernel
|
|
166
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
167
|
+
|
|
168
|
+
template <
|
|
169
|
+
typename AccT,
|
|
170
|
+
typename OutT,
|
|
171
|
+
typename Epilogue = TransformNone<OutT, AccT>>
|
|
172
|
+
[[kernel]] void gemm_splitk_accum(
|
|
173
|
+
const device AccT* C_split [[buffer(0)]],
|
|
174
|
+
device OutT* D [[buffer(1)]],
|
|
175
|
+
const constant int& k_partitions [[buffer(2)]],
|
|
176
|
+
const constant int& partition_stride [[buffer(3)]],
|
|
177
|
+
const constant int& ldd [[buffer(4)]],
|
|
178
|
+
uint2 gid [[thread_position_in_grid]]) {
|
|
179
|
+
// Ajust D and C
|
|
180
|
+
D += gid.x + gid.y * size_t(ldd);
|
|
181
|
+
C_split += gid.x + gid.y * size_t(ldd);
|
|
182
|
+
|
|
183
|
+
size_t offset = 0;
|
|
184
|
+
AccT out = 0;
|
|
185
|
+
|
|
186
|
+
for (int i = 0; i < k_partitions; i++) {
|
|
187
|
+
out += C_split[offset];
|
|
188
|
+
offset += partition_stride;
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
// Write output
|
|
192
|
+
D[0] = Epilogue::apply(out);
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
template <
|
|
196
|
+
typename AccT,
|
|
197
|
+
typename OutT,
|
|
198
|
+
typename Epilogue = TransformAxpby<OutT, AccT>>
|
|
199
|
+
[[kernel]] void gemm_splitk_accum_axpby(
|
|
200
|
+
const device AccT* C_split [[buffer(0)]],
|
|
201
|
+
device OutT* D [[buffer(1)]],
|
|
202
|
+
const constant int& k_partitions [[buffer(2)]],
|
|
203
|
+
const constant int& partition_stride [[buffer(3)]],
|
|
204
|
+
const constant int& ldd [[buffer(4)]],
|
|
205
|
+
const device OutT* C [[buffer(5)]],
|
|
206
|
+
const constant int& ldc [[buffer(6)]],
|
|
207
|
+
const constant int& fdc [[buffer(7)]],
|
|
208
|
+
const constant float& alpha [[buffer(8)]],
|
|
209
|
+
const constant float& beta [[buffer(9)]],
|
|
210
|
+
uint2 gid [[thread_position_in_grid]]) {
|
|
211
|
+
// Ajust D and C
|
|
212
|
+
C += gid.x * size_t(fdc) + gid.y * size_t(ldc);
|
|
213
|
+
D += gid.x + gid.y * size_t(ldd);
|
|
214
|
+
C_split += gid.x + gid.y * size_t(ldd);
|
|
215
|
+
|
|
216
|
+
size_t offset = 0;
|
|
217
|
+
AccT out = 0;
|
|
218
|
+
|
|
219
|
+
for (int i = 0; i < k_partitions; i++) {
|
|
220
|
+
out += C_split[offset];
|
|
221
|
+
offset += partition_stride;
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
// Write output
|
|
225
|
+
Epilogue op(alpha, beta);
|
|
226
|
+
D[0] = op.apply(out, *C);
|
|
227
|
+
}
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/steel/defines.h"
|
|
6
|
+
|
|
7
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
8
|
+
// Loading helper
|
|
9
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
10
|
+
|
|
11
|
+
namespace mlx {
|
|
12
|
+
namespace steel {
|
|
13
|
+
|
|
14
|
+
template <
|
|
15
|
+
typename T,
|
|
16
|
+
short BROWS,
|
|
17
|
+
short BCOLS,
|
|
18
|
+
short dst_ld,
|
|
19
|
+
short reduction_dim,
|
|
20
|
+
short tgp_size,
|
|
21
|
+
short alignment = 1,
|
|
22
|
+
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
|
23
|
+
short TCOLS = BCOLS / n_reads,
|
|
24
|
+
short TROWS = tgp_size / TCOLS>
|
|
25
|
+
struct BlockLoader {
|
|
26
|
+
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
|
27
|
+
STEEL_CONST short vec_size = n_reads;
|
|
28
|
+
|
|
29
|
+
// Leading dimension for src
|
|
30
|
+
const int src_ld;
|
|
31
|
+
const int tile_stride;
|
|
32
|
+
|
|
33
|
+
// Thread location indices
|
|
34
|
+
const short thread_idx;
|
|
35
|
+
const short bi;
|
|
36
|
+
const short bj;
|
|
37
|
+
|
|
38
|
+
// threadgroup and device memory
|
|
39
|
+
threadgroup T* dst;
|
|
40
|
+
const device T* src;
|
|
41
|
+
|
|
42
|
+
struct alignas(alignment * sizeof(T)) ReadVector {
|
|
43
|
+
uint8_t v[sizeof(T) * vec_size];
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
/* Constructor */
|
|
47
|
+
METAL_FUNC BlockLoader(
|
|
48
|
+
const device T* src_,
|
|
49
|
+
const int src_ld_,
|
|
50
|
+
threadgroup T* dst_,
|
|
51
|
+
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
52
|
+
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
|
53
|
+
: src_ld(src_ld_),
|
|
54
|
+
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
|
55
|
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
56
|
+
bi(thread_idx / TCOLS),
|
|
57
|
+
bj(vec_size * (thread_idx % TCOLS)),
|
|
58
|
+
dst(dst_ + bi * dst_ld + bj),
|
|
59
|
+
src(src_ + bi * src_ld + bj) {}
|
|
60
|
+
|
|
61
|
+
/* Apply operation to threadgroup without bound checking */
|
|
62
|
+
template <typename UnaryOp>
|
|
63
|
+
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
|
|
64
|
+
STEEL_PRAGMA_UNROLL
|
|
65
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
66
|
+
STEEL_PRAGMA_UNROLL
|
|
67
|
+
for (short j = 0; j < vec_size; j++) {
|
|
68
|
+
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
/* Load from device memory into threadgroup memory - without bound checking */
|
|
74
|
+
METAL_FUNC void load_unsafe() const {
|
|
75
|
+
STEEL_PRAGMA_UNROLL
|
|
76
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
77
|
+
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
|
78
|
+
*((const device ReadVector*)(&src[i * src_ld]));
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
/* Load from device memory into threadgroup memory - with bound checking */
|
|
83
|
+
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
|
84
|
+
src_tile_dim = src_tile_dim - short2(bj, bi);
|
|
85
|
+
|
|
86
|
+
// Skip loading if thread has no valid reads
|
|
87
|
+
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
|
88
|
+
STEEL_PRAGMA_UNROLL
|
|
89
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
90
|
+
STEEL_PRAGMA_UNROLL
|
|
91
|
+
for (short j = 0; j < vec_size; j++) {
|
|
92
|
+
dst[i * dst_ld + j] = T(0);
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
return;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
// Use fast thread memory for bound checks
|
|
99
|
+
bool tmp_idx[vec_size];
|
|
100
|
+
T tmp_val[vec_size];
|
|
101
|
+
|
|
102
|
+
STEEL_PRAGMA_UNROLL
|
|
103
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
104
|
+
// Make sure tmp_idx only contains valid indices
|
|
105
|
+
STEEL_PRAGMA_UNROLL
|
|
106
|
+
for (short j = 0; j < vec_size; j++) {
|
|
107
|
+
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
// Read valid indices into tmp_val
|
|
111
|
+
STEEL_PRAGMA_UNROLL
|
|
112
|
+
for (short j = 0; j < vec_size; j++) {
|
|
113
|
+
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
// Zero out unneeded values
|
|
117
|
+
STEEL_PRAGMA_UNROLL
|
|
118
|
+
for (short j = 0; j < vec_size; j++) {
|
|
119
|
+
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
// Copy values to threadgroup memory
|
|
123
|
+
STEEL_PRAGMA_UNROLL
|
|
124
|
+
for (short j = 0; j < vec_size; j++) {
|
|
125
|
+
dst[i * dst_ld + j] = tmp_val[j];
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
/* Iteration helper */
|
|
131
|
+
METAL_FUNC void next() {
|
|
132
|
+
src += tile_stride;
|
|
133
|
+
}
|
|
134
|
+
};
|
|
135
|
+
|
|
136
|
+
} // namespace steel
|
|
137
|
+
} // namespace mlx
|