mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
6
|
+
|
|
7
|
+
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
|
8
|
+
|
|
9
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
10
|
+
// Loading helper
|
|
11
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
12
|
+
|
|
13
|
+
namespace mlx {
|
|
14
|
+
namespace steel {
|
|
15
|
+
|
|
16
|
+
template <short n_channels_>
|
|
17
|
+
struct ChannelHelper {
|
|
18
|
+
STEEL_CONST short n_channels = n_channels_;
|
|
19
|
+
STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8;
|
|
20
|
+
STEEL_CONST short excess = vec_size - n_channels_;
|
|
21
|
+
};
|
|
22
|
+
|
|
23
|
+
template <>
|
|
24
|
+
struct ChannelHelper<1> {
|
|
25
|
+
STEEL_CONST short n_channels = 1;
|
|
26
|
+
STEEL_CONST short vec_size = 1;
|
|
27
|
+
STEEL_CONST short excess = 0;
|
|
28
|
+
};
|
|
29
|
+
|
|
30
|
+
template <>
|
|
31
|
+
struct ChannelHelper<2> {
|
|
32
|
+
STEEL_CONST short n_channels = 2;
|
|
33
|
+
STEEL_CONST short vec_size = 2;
|
|
34
|
+
STEEL_CONST short excess = 0;
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
template <>
|
|
38
|
+
struct ChannelHelper<3> {
|
|
39
|
+
STEEL_CONST short n_channels = 3;
|
|
40
|
+
STEEL_CONST short vec_size = 4;
|
|
41
|
+
STEEL_CONST short excess = 1;
|
|
42
|
+
};
|
|
43
|
+
|
|
44
|
+
template <>
|
|
45
|
+
struct ChannelHelper<4> {
|
|
46
|
+
STEEL_CONST short n_channels = 4;
|
|
47
|
+
STEEL_CONST short vec_size = 4;
|
|
48
|
+
STEEL_CONST short excess = 0;
|
|
49
|
+
};
|
|
50
|
+
|
|
51
|
+
template <
|
|
52
|
+
typename T,
|
|
53
|
+
short BM,
|
|
54
|
+
short BN,
|
|
55
|
+
short BK,
|
|
56
|
+
short tgp_size,
|
|
57
|
+
short n_channels,
|
|
58
|
+
short tgp_padding = 0>
|
|
59
|
+
struct Conv2DInputBlockLoaderSmallChannels {
|
|
60
|
+
// Destination dimensions
|
|
61
|
+
STEEL_CONST short BROWS = BM;
|
|
62
|
+
STEEL_CONST short BCOLS = BK;
|
|
63
|
+
|
|
64
|
+
// Read dimensions
|
|
65
|
+
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
|
66
|
+
STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;
|
|
67
|
+
|
|
68
|
+
// Thread read shape
|
|
69
|
+
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
|
70
|
+
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
|
71
|
+
|
|
72
|
+
// Rows / strided reads within the block
|
|
73
|
+
STEEL_CONST short n_rows = BROWS / TROWS;
|
|
74
|
+
|
|
75
|
+
// Thread location indices
|
|
76
|
+
const short thread_idx;
|
|
77
|
+
const short bi;
|
|
78
|
+
const short bj;
|
|
79
|
+
|
|
80
|
+
// threadgroup and device memory
|
|
81
|
+
threadgroup T* dst;
|
|
82
|
+
|
|
83
|
+
const constant MLXConvParams<2>* params;
|
|
84
|
+
const constant ImplicitGemmConv2DParams* gemm_params;
|
|
85
|
+
|
|
86
|
+
int weight_hw;
|
|
87
|
+
|
|
88
|
+
const device T* src[n_rows];
|
|
89
|
+
|
|
90
|
+
int read_n[n_rows];
|
|
91
|
+
int read_ih[n_rows];
|
|
92
|
+
int read_iw[n_rows];
|
|
93
|
+
|
|
94
|
+
/* Constructor */
|
|
95
|
+
METAL_FUNC Conv2DInputBlockLoaderSmallChannels(
|
|
96
|
+
const device T* src_,
|
|
97
|
+
threadgroup T* dst_,
|
|
98
|
+
const int2 offsets,
|
|
99
|
+
const constant MLXConvParams<2>* params_,
|
|
100
|
+
const constant ImplicitGemmConv2DParams* gemm_params_,
|
|
101
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
102
|
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
|
103
|
+
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
104
|
+
bi(thread_idx / TCOLS),
|
|
105
|
+
bj(vec_size * (thread_idx % TCOLS)),
|
|
106
|
+
dst(dst_ + bi * dst_ld + bj),
|
|
107
|
+
params(params_),
|
|
108
|
+
gemm_params(gemm_params_),
|
|
109
|
+
weight_hw(thread_idx % TCOLS) {
|
|
110
|
+
int out_n_pixels = params->oS[0] * params->oS[1];
|
|
111
|
+
|
|
112
|
+
STEEL_PRAGMA_UNROLL
|
|
113
|
+
for (short i = 0; i < n_rows; ++i) {
|
|
114
|
+
int offset_nhw = offsets.y + bi + i * TROWS;
|
|
115
|
+
int n = offset_nhw / out_n_pixels;
|
|
116
|
+
int hw = offset_nhw % out_n_pixels;
|
|
117
|
+
int oh = hw / params->oS[1];
|
|
118
|
+
int ow = hw % params->oS[1];
|
|
119
|
+
|
|
120
|
+
int ih = oh * params->str[0] - params->pad[0];
|
|
121
|
+
int iw = ow * params->str[1] - params->pad[1];
|
|
122
|
+
|
|
123
|
+
// Read from input if in bounds
|
|
124
|
+
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
|
|
125
|
+
iw * params->in_strides[2];
|
|
126
|
+
|
|
127
|
+
read_n[i] = n;
|
|
128
|
+
read_ih[i] = ih;
|
|
129
|
+
read_iw[i] = iw;
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
/* Load from device memory into threadgroup memory - without bound checking */
|
|
134
|
+
METAL_FUNC void load_unsafe() const {
|
|
135
|
+
if (weight_hw >= params->wS[1] * params->wS[0]) {
|
|
136
|
+
STEEL_PRAGMA_UNROLL
|
|
137
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
138
|
+
STEEL_PRAGMA_UNROLL
|
|
139
|
+
for (short j = 0; j < vec_size; j++) {
|
|
140
|
+
dst[i * dst_ld + j] = T(0);
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
return;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
int wh = (weight_hw / params->wS[1]);
|
|
147
|
+
int ww = (weight_hw % params->wS[1]);
|
|
148
|
+
|
|
149
|
+
int flip_h = params->flip ? params->wS[0] - wh - 1 : wh;
|
|
150
|
+
int flip_w = params->flip ? params->wS[1] - ww - 1 : ww;
|
|
151
|
+
|
|
152
|
+
int weight_h = flip_h * params->kdil[0];
|
|
153
|
+
int weight_w = flip_w * params->kdil[1];
|
|
154
|
+
|
|
155
|
+
STEEL_PRAGMA_UNROLL
|
|
156
|
+
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
|
157
|
+
// Find bounds
|
|
158
|
+
int n = read_n[i];
|
|
159
|
+
int ih = read_ih[i] + weight_h;
|
|
160
|
+
int iw = read_iw[i] + weight_w;
|
|
161
|
+
|
|
162
|
+
// Read from input if in bounds
|
|
163
|
+
if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
|
|
164
|
+
(iw >= 0 && iw < params->iS[1])) {
|
|
165
|
+
const device T* curr_src = src[i] + weight_h * params->in_strides[1] +
|
|
166
|
+
weight_w * params->in_strides[2];
|
|
167
|
+
|
|
168
|
+
STEEL_PRAGMA_UNROLL
|
|
169
|
+
for (short j = 0; j < n_channels; ++j) {
|
|
170
|
+
dst[is * dst_ld + j] = curr_src[j];
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
STEEL_PRAGMA_UNROLL
|
|
174
|
+
for (short j = n_channels; j < vec_size; ++j) {
|
|
175
|
+
dst[is * dst_ld + j] = T(0);
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
// Zero pad otherwise
|
|
180
|
+
else {
|
|
181
|
+
STEEL_PRAGMA_UNROLL
|
|
182
|
+
for (short j = 0; j < vec_size; ++j) {
|
|
183
|
+
dst[is * dst_ld + j] = T(0);
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
/* Iteration helper */
|
|
190
|
+
METAL_FUNC void next() {
|
|
191
|
+
weight_hw += TCOLS;
|
|
192
|
+
}
|
|
193
|
+
};
|
|
194
|
+
|
|
195
|
+
template <
|
|
196
|
+
typename T,
|
|
197
|
+
short BM,
|
|
198
|
+
short BN,
|
|
199
|
+
short BK,
|
|
200
|
+
short tgp_size,
|
|
201
|
+
short n_channels,
|
|
202
|
+
short tgp_padding = 0>
|
|
203
|
+
struct Conv2DWeightBlockLoaderSmallChannels {
|
|
204
|
+
// Destination dimensions
|
|
205
|
+
STEEL_CONST short BROWS = BN;
|
|
206
|
+
STEEL_CONST short BCOLS = BK;
|
|
207
|
+
|
|
208
|
+
// Read dimensions
|
|
209
|
+
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
|
210
|
+
STEEL_CONST short vec_size = ChannelHelper<n_channels>::vec_size;
|
|
211
|
+
|
|
212
|
+
// Thread read shape
|
|
213
|
+
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
|
214
|
+
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
|
215
|
+
|
|
216
|
+
// Rows / strided reads within the block
|
|
217
|
+
STEEL_CONST short n_rows = BROWS / TROWS;
|
|
218
|
+
|
|
219
|
+
// Leading dimension for src
|
|
220
|
+
const int src_ld;
|
|
221
|
+
|
|
222
|
+
// Thread location indices
|
|
223
|
+
const short thread_idx;
|
|
224
|
+
const short bi;
|
|
225
|
+
const short bj;
|
|
226
|
+
|
|
227
|
+
// threadgroup and device memory
|
|
228
|
+
threadgroup T* dst;
|
|
229
|
+
const device T* src;
|
|
230
|
+
|
|
231
|
+
const constant MLXConvParams<2>* params;
|
|
232
|
+
|
|
233
|
+
int weight_hw;
|
|
234
|
+
|
|
235
|
+
const int read_n;
|
|
236
|
+
const bool do_read;
|
|
237
|
+
|
|
238
|
+
/* Constructor */
|
|
239
|
+
METAL_FUNC Conv2DWeightBlockLoaderSmallChannels(
|
|
240
|
+
const device T* src_,
|
|
241
|
+
threadgroup T* dst_,
|
|
242
|
+
const int2 offsets,
|
|
243
|
+
const constant MLXConvParams<2>* params_,
|
|
244
|
+
const constant ImplicitGemmConv2DParams* gemm_params_,
|
|
245
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
246
|
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
|
247
|
+
: src_ld(params_->wt_strides[0]),
|
|
248
|
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
249
|
+
bi(thread_idx / TCOLS),
|
|
250
|
+
bj(vec_size * (thread_idx % TCOLS)),
|
|
251
|
+
dst(dst_ + bi * dst_ld + bj),
|
|
252
|
+
src(src_ + bi * src_ld),
|
|
253
|
+
params(params_),
|
|
254
|
+
weight_hw(thread_idx % TCOLS),
|
|
255
|
+
read_n(offsets.y + bi),
|
|
256
|
+
do_read(read_n + BN <= gemm_params_->N) {}
|
|
257
|
+
|
|
258
|
+
/* Load from device memory into threadgroup memory - without bound checking */
|
|
259
|
+
METAL_FUNC void load_unsafe() const {
|
|
260
|
+
if (bi >= BROWS || bj >= BCOLS)
|
|
261
|
+
return;
|
|
262
|
+
|
|
263
|
+
if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) {
|
|
264
|
+
STEEL_PRAGMA_UNROLL
|
|
265
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
266
|
+
STEEL_PRAGMA_UNROLL
|
|
267
|
+
for (short j = 0; j < vec_size; j++) {
|
|
268
|
+
dst[i * dst_ld + j] = T(0);
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
return;
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
const device T* curr_src = src + weight_hw * (params->C / params->groups);
|
|
276
|
+
|
|
277
|
+
if (BN != 8 || do_read) {
|
|
278
|
+
STEEL_PRAGMA_UNROLL
|
|
279
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
280
|
+
STEEL_PRAGMA_UNROLL
|
|
281
|
+
for (short j = 0; j < n_channels; j++) {
|
|
282
|
+
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
STEEL_PRAGMA_UNROLL
|
|
286
|
+
for (short j = n_channels; j < vec_size; j++) {
|
|
287
|
+
dst[i * dst_ld + j] = T(0);
|
|
288
|
+
}
|
|
289
|
+
}
|
|
290
|
+
} else {
|
|
291
|
+
for (short i = 0; i < BROWS; i += TROWS) {
|
|
292
|
+
if (((read_n + i) < params->O)) {
|
|
293
|
+
STEEL_PRAGMA_UNROLL
|
|
294
|
+
for (short j = 0; j < n_channels; j++) {
|
|
295
|
+
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
STEEL_PRAGMA_UNROLL
|
|
299
|
+
for (short j = n_channels; j < vec_size; j++) {
|
|
300
|
+
dst[i * dst_ld + j] = T(0);
|
|
301
|
+
}
|
|
302
|
+
} else {
|
|
303
|
+
STEEL_PRAGMA_UNROLL
|
|
304
|
+
for (short j = 0; j < vec_size; j++) {
|
|
305
|
+
dst[i * dst_ld + j] = T(0);
|
|
306
|
+
}
|
|
307
|
+
}
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
/* Iteration helper */
|
|
313
|
+
METAL_FUNC void next() {
|
|
314
|
+
weight_hw += TCOLS;
|
|
315
|
+
}
|
|
316
|
+
};
|
|
317
|
+
|
|
318
|
+
} // namespace steel
|
|
319
|
+
} // namespace mlx
|
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/steel/defines.h"
|
|
6
|
+
|
|
7
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
8
|
+
// Loading helper
|
|
9
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
10
|
+
|
|
11
|
+
namespace mlx {
|
|
12
|
+
namespace steel {
|
|
13
|
+
|
|
14
|
+
template <
|
|
15
|
+
typename T,
|
|
16
|
+
short BM,
|
|
17
|
+
short BN,
|
|
18
|
+
short BK,
|
|
19
|
+
short tgp_size,
|
|
20
|
+
short tgp_padding = 0>
|
|
21
|
+
struct Conv2DInputBlockLoaderGeneral {
|
|
22
|
+
// Destination dimensions
|
|
23
|
+
STEEL_CONST short BROWS = BM;
|
|
24
|
+
STEEL_CONST short BCOLS = BK;
|
|
25
|
+
|
|
26
|
+
// Read dimensions
|
|
27
|
+
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
|
28
|
+
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
|
|
29
|
+
|
|
30
|
+
// Thread read shape
|
|
31
|
+
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
|
32
|
+
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
|
33
|
+
|
|
34
|
+
// Rows / strided reads within the block
|
|
35
|
+
STEEL_CONST short n_rows = BROWS / TROWS;
|
|
36
|
+
|
|
37
|
+
// Thread location indices
|
|
38
|
+
const short thread_idx;
|
|
39
|
+
const short bi;
|
|
40
|
+
const short bj;
|
|
41
|
+
|
|
42
|
+
// threadgroup and device memory
|
|
43
|
+
threadgroup T* dst;
|
|
44
|
+
|
|
45
|
+
const constant MLXConvParams<2>* params;
|
|
46
|
+
const constant Conv2DGeneralJumpParams* jump_params;
|
|
47
|
+
|
|
48
|
+
const short base_wh;
|
|
49
|
+
const short base_ww;
|
|
50
|
+
|
|
51
|
+
short weight_h;
|
|
52
|
+
short weight_w;
|
|
53
|
+
|
|
54
|
+
const device T* src[n_rows];
|
|
55
|
+
|
|
56
|
+
int read_n[n_rows];
|
|
57
|
+
int read_ih[n_rows];
|
|
58
|
+
int read_iw[n_rows];
|
|
59
|
+
|
|
60
|
+
/* Constructor */
|
|
61
|
+
METAL_FUNC Conv2DInputBlockLoaderGeneral(
|
|
62
|
+
const device T* src_,
|
|
63
|
+
threadgroup T* dst_,
|
|
64
|
+
const int4 offsets,
|
|
65
|
+
const constant MLXConvParams<2>* params_,
|
|
66
|
+
const constant Conv2DGeneralJumpParams* jump_params_,
|
|
67
|
+
const short base_wh_,
|
|
68
|
+
const short base_ww_,
|
|
69
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
70
|
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
|
71
|
+
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
72
|
+
bi(thread_idx / TCOLS),
|
|
73
|
+
bj(vec_size * (thread_idx % TCOLS)),
|
|
74
|
+
dst(dst_ + bi * dst_ld + bj),
|
|
75
|
+
params(params_),
|
|
76
|
+
jump_params(jump_params_),
|
|
77
|
+
base_wh(base_wh_),
|
|
78
|
+
base_ww(base_ww_),
|
|
79
|
+
weight_h(base_wh_),
|
|
80
|
+
weight_w(base_ww_) {
|
|
81
|
+
STEEL_PRAGMA_UNROLL
|
|
82
|
+
for (short i = 0; i < n_rows; ++i) {
|
|
83
|
+
int offset_nhw = offsets.y + bi + i * TROWS;
|
|
84
|
+
int n = offset_nhw / jump_params->adj_out_hw;
|
|
85
|
+
int hw = offset_nhw % jump_params->adj_out_hw;
|
|
86
|
+
int oh =
|
|
87
|
+
(hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z;
|
|
88
|
+
int ow =
|
|
89
|
+
(hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w;
|
|
90
|
+
|
|
91
|
+
int ih = oh * params->str[0] - params->pad[0];
|
|
92
|
+
int iw = ow * params->str[1] - params->pad[1];
|
|
93
|
+
|
|
94
|
+
read_n[i] = n;
|
|
95
|
+
read_ih[i] = ih;
|
|
96
|
+
read_iw[i] = iw;
|
|
97
|
+
|
|
98
|
+
// Read from input if in bounds
|
|
99
|
+
src[i] = src_ + n * params->in_strides[0] + bj;
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
/* Load from device memory into threadgroup memory - without bound checking */
|
|
104
|
+
METAL_FUNC void load_unsafe() const {
|
|
105
|
+
STEEL_PRAGMA_UNROLL
|
|
106
|
+
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
|
107
|
+
// Find bounds
|
|
108
|
+
int n = read_n[i];
|
|
109
|
+
|
|
110
|
+
int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
|
|
111
|
+
int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
|
|
112
|
+
|
|
113
|
+
int ih_dil = read_ih[i] + h_flip * params->kdil[0];
|
|
114
|
+
int iw_dil = read_iw[i] + w_flip * params->kdil[1];
|
|
115
|
+
|
|
116
|
+
int ih = ih_dil / params->idil[0];
|
|
117
|
+
int iw = iw_dil / params->idil[1];
|
|
118
|
+
|
|
119
|
+
size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
|
|
120
|
+
|
|
121
|
+
// Read from input if in bounds
|
|
122
|
+
if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
|
|
123
|
+
(iw_dil >= 0 && iw < params->iS[1])) {
|
|
124
|
+
STEEL_PRAGMA_UNROLL
|
|
125
|
+
for (short j = 0; j < vec_size; ++j) {
|
|
126
|
+
dst[is * dst_ld + j] = (src[i])[offset + j];
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
// Zero pad otherwise
|
|
131
|
+
else {
|
|
132
|
+
STEEL_PRAGMA_UNROLL
|
|
133
|
+
for (short j = 0; j < vec_size; ++j) {
|
|
134
|
+
dst[is * dst_ld + j] = T(0);
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
METAL_FUNC void load_safe(const short remaining_k) const {
|
|
141
|
+
STEEL_PRAGMA_UNROLL
|
|
142
|
+
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
|
143
|
+
// Find bounds
|
|
144
|
+
int n = read_n[i];
|
|
145
|
+
|
|
146
|
+
int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h;
|
|
147
|
+
int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w;
|
|
148
|
+
|
|
149
|
+
int ih_dil = read_ih[i] + h_flip * params->kdil[0];
|
|
150
|
+
int iw_dil = read_iw[i] + w_flip * params->kdil[1];
|
|
151
|
+
|
|
152
|
+
int ih = ih_dil / params->idil[0];
|
|
153
|
+
int iw = iw_dil / params->idil[1];
|
|
154
|
+
|
|
155
|
+
size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2];
|
|
156
|
+
|
|
157
|
+
// Read from input if in bounds
|
|
158
|
+
if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) &&
|
|
159
|
+
(iw_dil >= 0 && iw < params->iS[1])) {
|
|
160
|
+
if (bj + vec_size <= remaining_k) {
|
|
161
|
+
STEEL_PRAGMA_UNROLL
|
|
162
|
+
for (short j = 0; j < vec_size; ++j) {
|
|
163
|
+
dst[is * dst_ld + j] = (src[i])[offset + j];
|
|
164
|
+
}
|
|
165
|
+
} else {
|
|
166
|
+
for (short j = 0; j < vec_size; ++j) {
|
|
167
|
+
if (bj + j < remaining_k) {
|
|
168
|
+
dst[is * dst_ld + j] = (src[i])[offset + j];
|
|
169
|
+
} else {
|
|
170
|
+
dst[is * dst_ld + j] = T(0);
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
// Zero pad otherwise
|
|
177
|
+
else {
|
|
178
|
+
STEEL_PRAGMA_UNROLL
|
|
179
|
+
for (short j = 0; j < vec_size; ++j) {
|
|
180
|
+
dst[is * dst_ld + j] = T(0);
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
/* Iteration helper */
|
|
187
|
+
METAL_FUNC void next() {
|
|
188
|
+
weight_w += jump_params->f_wgt_jump_w;
|
|
189
|
+
if (weight_w < params->wS[1]) {
|
|
190
|
+
return;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
weight_w = base_ww;
|
|
194
|
+
|
|
195
|
+
weight_h += jump_params->f_wgt_jump_h;
|
|
196
|
+
if (weight_h < params->wS[0]) {
|
|
197
|
+
return;
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
weight_h = base_wh;
|
|
201
|
+
|
|
202
|
+
STEEL_PRAGMA_UNROLL
|
|
203
|
+
for (short i = 0; i < n_rows; i++) {
|
|
204
|
+
src[i] += BK;
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
};
|
|
208
|
+
|
|
209
|
+
template <
|
|
210
|
+
typename T,
|
|
211
|
+
short BM,
|
|
212
|
+
short BN,
|
|
213
|
+
short BK,
|
|
214
|
+
short tgp_size,
|
|
215
|
+
short tgp_padding = 0>
|
|
216
|
+
struct Conv2DWeightBlockLoaderGeneral {
|
|
217
|
+
// Destination dimensions
|
|
218
|
+
STEEL_CONST short BROWS = BN;
|
|
219
|
+
STEEL_CONST short BCOLS = BK;
|
|
220
|
+
|
|
221
|
+
// Read dimensions
|
|
222
|
+
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
|
223
|
+
STEEL_CONST short vec_size =
|
|
224
|
+
(BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
|
|
225
|
+
|
|
226
|
+
// Thread read shape
|
|
227
|
+
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
|
228
|
+
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
|
229
|
+
|
|
230
|
+
// Rows / strided reads within the block
|
|
231
|
+
STEEL_CONST short n_rows = BROWS / TROWS;
|
|
232
|
+
|
|
233
|
+
// Leading dimension for src
|
|
234
|
+
const int src_ld;
|
|
235
|
+
|
|
236
|
+
// Thread location indices
|
|
237
|
+
const short thread_idx;
|
|
238
|
+
const short bi;
|
|
239
|
+
const short bj;
|
|
240
|
+
|
|
241
|
+
// threadgroup and device memory
|
|
242
|
+
threadgroup T* dst;
|
|
243
|
+
const device T* src;
|
|
244
|
+
|
|
245
|
+
const constant MLXConvParams<2>* params;
|
|
246
|
+
const constant Conv2DGeneralJumpParams* jump_params;
|
|
247
|
+
|
|
248
|
+
const short base_wh;
|
|
249
|
+
const short base_ww;
|
|
250
|
+
|
|
251
|
+
short weight_h;
|
|
252
|
+
short weight_w;
|
|
253
|
+
|
|
254
|
+
const int start_row;
|
|
255
|
+
|
|
256
|
+
/* Constructor */
|
|
257
|
+
METAL_FUNC Conv2DWeightBlockLoaderGeneral(
|
|
258
|
+
const device T* src_,
|
|
259
|
+
threadgroup T* dst_,
|
|
260
|
+
const int2 offsets,
|
|
261
|
+
const constant MLXConvParams<2>* params_,
|
|
262
|
+
const constant Conv2DGeneralJumpParams* jump_params_,
|
|
263
|
+
const short base_wh_,
|
|
264
|
+
const short base_ww_,
|
|
265
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
266
|
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
|
267
|
+
: src_ld(params_->wt_strides[0]),
|
|
268
|
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
269
|
+
bi(thread_idx / TCOLS),
|
|
270
|
+
bj(vec_size * (thread_idx % TCOLS)),
|
|
271
|
+
dst(dst_ + bi * dst_ld + bj),
|
|
272
|
+
src(src_ + bi * src_ld + bj),
|
|
273
|
+
params(params_),
|
|
274
|
+
jump_params(jump_params_),
|
|
275
|
+
base_wh(base_wh_),
|
|
276
|
+
base_ww(base_ww_),
|
|
277
|
+
weight_h(base_wh_),
|
|
278
|
+
weight_w(base_ww_),
|
|
279
|
+
start_row(offsets.y + bi) {}
|
|
280
|
+
|
|
281
|
+
/* Load from device memory into threadgroup memory - without bound checking */
|
|
282
|
+
METAL_FUNC void load_unsafe() const {
|
|
283
|
+
const device T* curr_src = src + weight_h * params->wt_strides[1] +
|
|
284
|
+
weight_w * params->wt_strides[2];
|
|
285
|
+
|
|
286
|
+
if ((start_row + BN <= params->O)) {
|
|
287
|
+
STEEL_PRAGMA_UNROLL
|
|
288
|
+
for (short i = 0; i < BN; i += TROWS) {
|
|
289
|
+
STEEL_PRAGMA_UNROLL
|
|
290
|
+
for (short j = 0; j < vec_size; j++) {
|
|
291
|
+
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
|
292
|
+
}
|
|
293
|
+
}
|
|
294
|
+
} else {
|
|
295
|
+
for (short i = 0; i < BN; i += TROWS) {
|
|
296
|
+
if ((start_row + i) < params->O) {
|
|
297
|
+
STEEL_PRAGMA_UNROLL
|
|
298
|
+
for (short j = 0; j < vec_size; j++) {
|
|
299
|
+
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
|
300
|
+
}
|
|
301
|
+
} else {
|
|
302
|
+
STEEL_PRAGMA_UNROLL
|
|
303
|
+
for (short j = 0; j < vec_size; j++) {
|
|
304
|
+
dst[i * dst_ld + j] = T(0);
|
|
305
|
+
}
|
|
306
|
+
}
|
|
307
|
+
}
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
METAL_FUNC void load_safe(const short remaining_k) const {
|
|
312
|
+
const device T* curr_src = src + weight_h * params->wt_strides[1] +
|
|
313
|
+
weight_w * params->wt_strides[2];
|
|
314
|
+
|
|
315
|
+
if ((start_row + BN <= params->O)) {
|
|
316
|
+
STEEL_PRAGMA_UNROLL
|
|
317
|
+
for (short i = 0; i < BN; i += TROWS) {
|
|
318
|
+
if (bj + vec_size <= remaining_k) {
|
|
319
|
+
STEEL_PRAGMA_UNROLL
|
|
320
|
+
for (short j = 0; j < vec_size; j++) {
|
|
321
|
+
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
|
322
|
+
}
|
|
323
|
+
} else {
|
|
324
|
+
for (short j = 0; j < vec_size; j++) {
|
|
325
|
+
if (bj + j < remaining_k) {
|
|
326
|
+
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
|
327
|
+
} else {
|
|
328
|
+
dst[i * dst_ld + j] = T(0);
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
} else {
|
|
334
|
+
for (short i = 0; i < BN; i += TROWS) {
|
|
335
|
+
if ((start_row + i) < params->O) {
|
|
336
|
+
if (bj + vec_size <= remaining_k) {
|
|
337
|
+
STEEL_PRAGMA_UNROLL
|
|
338
|
+
for (short j = 0; j < vec_size; j++) {
|
|
339
|
+
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
|
340
|
+
}
|
|
341
|
+
} else {
|
|
342
|
+
for (short j = 0; j < vec_size; j++) {
|
|
343
|
+
if (bj + j < remaining_k) {
|
|
344
|
+
dst[i * dst_ld + j] = curr_src[i * src_ld + j];
|
|
345
|
+
} else {
|
|
346
|
+
dst[i * dst_ld + j] = T(0);
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
}
|
|
350
|
+
} else {
|
|
351
|
+
STEEL_PRAGMA_UNROLL
|
|
352
|
+
for (short j = 0; j < vec_size; j++) {
|
|
353
|
+
dst[i * dst_ld + j] = T(0);
|
|
354
|
+
}
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
/* Iteration helper */
|
|
361
|
+
METAL_FUNC void next() {
|
|
362
|
+
weight_w += jump_params->f_wgt_jump_w;
|
|
363
|
+
if (weight_w < params->wS[1]) {
|
|
364
|
+
return;
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
weight_w = base_ww;
|
|
368
|
+
|
|
369
|
+
weight_h += jump_params->f_wgt_jump_h;
|
|
370
|
+
if (weight_h < params->wS[0]) {
|
|
371
|
+
return;
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
weight_h = base_wh;
|
|
375
|
+
|
|
376
|
+
src += BK;
|
|
377
|
+
}
|
|
378
|
+
};
|
|
379
|
+
|
|
380
|
+
} // namespace steel
|
|
381
|
+
} // namespace mlx
|