mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/steel/defines.h"
|
|
6
|
+
|
|
7
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
8
|
+
// Loading helper
|
|
9
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
10
|
+
|
|
11
|
+
namespace mlx {
|
|
12
|
+
namespace steel {
|
|
13
|
+
|
|
14
|
+
template <
|
|
15
|
+
typename T,
|
|
16
|
+
short BROWS,
|
|
17
|
+
short BCOLS,
|
|
18
|
+
short dst_ld,
|
|
19
|
+
short reduction_dim,
|
|
20
|
+
short tgp_size,
|
|
21
|
+
short alignment = 1,
|
|
22
|
+
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
|
23
|
+
short TCOLS = BCOLS / n_reads,
|
|
24
|
+
short TROWS = tgp_size / TCOLS>
|
|
25
|
+
struct BlockLoader {
|
|
26
|
+
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
|
27
|
+
STEEL_CONST short vec_size = n_reads;
|
|
28
|
+
|
|
29
|
+
// Leading dimension for src
|
|
30
|
+
const int src_ld;
|
|
31
|
+
const int tile_stride;
|
|
32
|
+
|
|
33
|
+
// Thread location indices
|
|
34
|
+
const short thread_idx;
|
|
35
|
+
const short bi;
|
|
36
|
+
const short bj;
|
|
37
|
+
|
|
38
|
+
// threadgroup and device memory
|
|
39
|
+
threadgroup T* dst;
|
|
40
|
+
const device T* src;
|
|
41
|
+
|
|
42
|
+
struct alignas(alignment * sizeof(T)) ReadVector {
|
|
43
|
+
uint8_t v[sizeof(T) * vec_size];
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
/* Constructor */
|
|
47
|
+
METAL_FUNC BlockLoader(
|
|
48
|
+
const device T* src_,
|
|
49
|
+
const int src_ld_,
|
|
50
|
+
threadgroup T* dst_,
|
|
51
|
+
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
52
|
+
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
|
53
|
+
: src_ld(src_ld_),
|
|
54
|
+
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
|
55
|
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
56
|
+
bi(thread_idx / TCOLS),
|
|
57
|
+
bj(vec_size * (thread_idx % TCOLS)),
|
|
58
|
+
dst(dst_ + bi * dst_ld + bj),
|
|
59
|
+
src(src_ + bi * src_ld + bj) {}
|
|
60
|
+
|
|
61
|
+
/* Apply operation to threadgroup without bound checking */
|
|
62
|
+
template <typename UnaryOp>
|
|
63
|
+
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
|
|
64
|
+
STEEL_PRAGMA_UNROLL
|
|
65
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
66
|
+
STEEL_PRAGMA_UNROLL
|
|
67
|
+
for (short j = 0; j < vec_size; j++) {
|
|
68
|
+
dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]);
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
/* Load from device memory into threadgroup memory - without bound checking */
|
|
74
|
+
METAL_FUNC void load_unsafe() const {
|
|
75
|
+
STEEL_PRAGMA_UNROLL
|
|
76
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
77
|
+
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
|
78
|
+
*((const device ReadVector*)(&src[i * src_ld]));
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
/* Load from device memory into threadgroup memory - with bound checking */
|
|
83
|
+
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
|
84
|
+
src_tile_dim = src_tile_dim - short2(bj, bi);
|
|
85
|
+
|
|
86
|
+
// Skip loading if thread has no valid reads
|
|
87
|
+
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
|
88
|
+
STEEL_PRAGMA_UNROLL
|
|
89
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
90
|
+
STEEL_PRAGMA_UNROLL
|
|
91
|
+
for (short j = 0; j < vec_size; j++) {
|
|
92
|
+
dst[i * dst_ld + j] = T(0);
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
return;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
// Use fast thread memory for bound checks
|
|
99
|
+
bool tmp_idx[vec_size];
|
|
100
|
+
T tmp_val[vec_size];
|
|
101
|
+
|
|
102
|
+
STEEL_PRAGMA_UNROLL
|
|
103
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
104
|
+
// Make sure tmp_idx only contains valid indices
|
|
105
|
+
STEEL_PRAGMA_UNROLL
|
|
106
|
+
for (short j = 0; j < vec_size; j++) {
|
|
107
|
+
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
// Read valid indices into tmp_val
|
|
111
|
+
STEEL_PRAGMA_UNROLL
|
|
112
|
+
for (short j = 0; j < vec_size; j++) {
|
|
113
|
+
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
// Zero out unneeded values
|
|
117
|
+
STEEL_PRAGMA_UNROLL
|
|
118
|
+
for (short j = 0; j < vec_size; j++) {
|
|
119
|
+
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
// Copy values to threadgroup memory
|
|
123
|
+
STEEL_PRAGMA_UNROLL
|
|
124
|
+
for (short j = 0; j < vec_size; j++) {
|
|
125
|
+
dst[i * dst_ld + j] = tmp_val[j];
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
/* Iteration helper */
|
|
131
|
+
METAL_FUNC void next() {
|
|
132
|
+
src += tile_stride;
|
|
133
|
+
}
|
|
134
|
+
};
|
|
135
|
+
|
|
136
|
+
template <int R, int C>
|
|
137
|
+
struct CShape {
|
|
138
|
+
STEEL_CONST int kRows = R;
|
|
139
|
+
STEEL_CONST int kCols = C;
|
|
140
|
+
};
|
|
141
|
+
|
|
142
|
+
template <
|
|
143
|
+
typename T,
|
|
144
|
+
short BROWS,
|
|
145
|
+
short BCOLS,
|
|
146
|
+
short kDstStrRow,
|
|
147
|
+
short kDstStrCol,
|
|
148
|
+
short reduction_dim,
|
|
149
|
+
short tgp_size,
|
|
150
|
+
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
|
151
|
+
short TCOLS = BCOLS / n_reads,
|
|
152
|
+
short TROWS = tgp_size / TCOLS>
|
|
153
|
+
struct BlockLoaderT {
|
|
154
|
+
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
|
155
|
+
STEEL_CONST short vec_size = n_reads;
|
|
156
|
+
|
|
157
|
+
// Leading dimension for src
|
|
158
|
+
const int src_ld;
|
|
159
|
+
const int tile_stride;
|
|
160
|
+
|
|
161
|
+
// Thread location indices
|
|
162
|
+
const short thread_idx;
|
|
163
|
+
const short bi;
|
|
164
|
+
const short bj;
|
|
165
|
+
|
|
166
|
+
// threadgroup and device memory
|
|
167
|
+
threadgroup T* dst;
|
|
168
|
+
const device T* src;
|
|
169
|
+
|
|
170
|
+
/* Constructor */
|
|
171
|
+
METAL_FUNC BlockLoaderT(
|
|
172
|
+
const device T* src_,
|
|
173
|
+
const int src_ld_,
|
|
174
|
+
threadgroup T* dst_,
|
|
175
|
+
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
176
|
+
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
|
177
|
+
: src_ld(src_ld_),
|
|
178
|
+
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
|
179
|
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
180
|
+
bi(thread_idx / TCOLS),
|
|
181
|
+
bj(vec_size * (thread_idx % TCOLS)),
|
|
182
|
+
dst(dst_ + bi * kDstStrRow + bj * kDstStrCol),
|
|
183
|
+
src(src_ + bi * src_ld + bj) {}
|
|
184
|
+
|
|
185
|
+
/* Apply operation to threadgroup without bound checking */
|
|
186
|
+
template <typename UnaryOp>
|
|
187
|
+
METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const {
|
|
188
|
+
STEEL_PRAGMA_UNROLL
|
|
189
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
190
|
+
STEEL_PRAGMA_UNROLL
|
|
191
|
+
for (short j = 0; j < vec_size; j++) {
|
|
192
|
+
dst[i * kDstStrRow + j * kDstStrCol] =
|
|
193
|
+
op.apply(dst[i * kDstStrRow + j * kDstStrCol]);
|
|
194
|
+
}
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
/* Load from device memory into threadgroup memory - without bound checking */
|
|
199
|
+
METAL_FUNC void load_unsafe() const {
|
|
200
|
+
STEEL_PRAGMA_UNROLL
|
|
201
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
202
|
+
STEEL_PRAGMA_UNROLL
|
|
203
|
+
for (short j = 0; j < vec_size; j++) {
|
|
204
|
+
dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j];
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
/* Load from device memory into threadgroup memory - with bound checking */
|
|
210
|
+
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
|
211
|
+
src_tile_dim = src_tile_dim - short2(bj, bi);
|
|
212
|
+
|
|
213
|
+
// Skip loading if thread has no valid reads
|
|
214
|
+
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
|
215
|
+
STEEL_PRAGMA_UNROLL
|
|
216
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
217
|
+
STEEL_PRAGMA_UNROLL
|
|
218
|
+
for (short j = 0; j < vec_size; j++) {
|
|
219
|
+
dst[i * kDstStrRow + j * kDstStrCol] = T(0);
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
return;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
// Use fast thread memory for bound checks
|
|
226
|
+
bool tmp_idx[vec_size];
|
|
227
|
+
T tmp_val[vec_size];
|
|
228
|
+
|
|
229
|
+
STEEL_PRAGMA_UNROLL
|
|
230
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
231
|
+
// Make sure tmp_idx only contains valid indices
|
|
232
|
+
STEEL_PRAGMA_UNROLL
|
|
233
|
+
for (short j = 0; j < vec_size; j++) {
|
|
234
|
+
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
// Read valid indices into tmp_val
|
|
238
|
+
STEEL_PRAGMA_UNROLL
|
|
239
|
+
for (short j = 0; j < vec_size; j++) {
|
|
240
|
+
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
// Zero out unneeded values
|
|
244
|
+
STEEL_PRAGMA_UNROLL
|
|
245
|
+
for (short j = 0; j < vec_size; j++) {
|
|
246
|
+
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
// Copy values to threadgroup memory
|
|
250
|
+
STEEL_PRAGMA_UNROLL
|
|
251
|
+
for (short j = 0; j < vec_size; j++) {
|
|
252
|
+
dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j];
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
/* Iteration helper */
|
|
258
|
+
METAL_FUNC void next() {
|
|
259
|
+
src += tile_stride;
|
|
260
|
+
}
|
|
261
|
+
};
|
|
262
|
+
|
|
263
|
+
} // namespace steel
|
|
264
|
+
} // namespace mlx
|