mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/steel/utils.h"
|
|
6
|
+
|
|
7
|
+
#include "mlx/backend/metal/kernels/steel/conv/params.h"
|
|
8
|
+
|
|
9
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
10
|
+
// Loading helper
|
|
11
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
12
|
+
|
|
13
|
+
namespace mlx {
|
|
14
|
+
namespace steel {
|
|
15
|
+
|
|
16
|
+
template <
|
|
17
|
+
typename T,
|
|
18
|
+
short BM,
|
|
19
|
+
short BN,
|
|
20
|
+
short BK,
|
|
21
|
+
short tgp_size,
|
|
22
|
+
short tgp_padding = 0>
|
|
23
|
+
struct Conv2DInputBlockLoaderLargeFilter {
|
|
24
|
+
// Destination dimensions
|
|
25
|
+
STEEL_CONST short BROWS = BM;
|
|
26
|
+
STEEL_CONST short BCOLS = BK;
|
|
27
|
+
|
|
28
|
+
// Read dimensions
|
|
29
|
+
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
|
30
|
+
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
|
|
31
|
+
|
|
32
|
+
// Thread read shape
|
|
33
|
+
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
|
34
|
+
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
|
35
|
+
|
|
36
|
+
// Rows / strided reads within the block
|
|
37
|
+
STEEL_CONST short n_rows = BROWS / TROWS;
|
|
38
|
+
|
|
39
|
+
// Thread location indices
|
|
40
|
+
const short thread_idx;
|
|
41
|
+
const short bi;
|
|
42
|
+
const short bj;
|
|
43
|
+
|
|
44
|
+
// threadgroup and device memory
|
|
45
|
+
threadgroup T* dst;
|
|
46
|
+
|
|
47
|
+
const constant MLXConvParams<2>* params;
|
|
48
|
+
const constant ImplicitGemmConv2DParams* gemm_params;
|
|
49
|
+
|
|
50
|
+
short weight_h;
|
|
51
|
+
short weight_w;
|
|
52
|
+
|
|
53
|
+
const device T* src[n_rows];
|
|
54
|
+
|
|
55
|
+
int read_n[n_rows];
|
|
56
|
+
int read_ih[n_rows];
|
|
57
|
+
int read_iw[n_rows];
|
|
58
|
+
|
|
59
|
+
/* Constructor */
|
|
60
|
+
METAL_FUNC Conv2DInputBlockLoaderLargeFilter(
|
|
61
|
+
const device T* src_,
|
|
62
|
+
threadgroup T* dst_,
|
|
63
|
+
const int2 offsets,
|
|
64
|
+
const constant MLXConvParams<2>* params_,
|
|
65
|
+
const constant ImplicitGemmConv2DParams* gemm_params_,
|
|
66
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
67
|
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
|
68
|
+
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
69
|
+
bi(thread_idx / TCOLS),
|
|
70
|
+
bj(vec_size * (thread_idx % TCOLS)),
|
|
71
|
+
dst(dst_ + bi * dst_ld + bj),
|
|
72
|
+
params(params_),
|
|
73
|
+
gemm_params(gemm_params_),
|
|
74
|
+
weight_h(0),
|
|
75
|
+
weight_w(0) {
|
|
76
|
+
int out_n_pixels = params->oS[0] * params->oS[1];
|
|
77
|
+
|
|
78
|
+
STEEL_PRAGMA_UNROLL
|
|
79
|
+
for (short i = 0; i < n_rows; ++i) {
|
|
80
|
+
int offset_nhw = offsets.y + bi + i * TROWS;
|
|
81
|
+
int n = offset_nhw / out_n_pixels;
|
|
82
|
+
int hw = offset_nhw % out_n_pixels;
|
|
83
|
+
int oh = hw / params->oS[1];
|
|
84
|
+
int ow = hw % params->oS[1];
|
|
85
|
+
|
|
86
|
+
int ih = oh * params->str[0] - params->pad[0];
|
|
87
|
+
int iw = ow * params->str[1] - params->pad[1];
|
|
88
|
+
|
|
89
|
+
read_n[i] = n;
|
|
90
|
+
read_ih[i] = ih;
|
|
91
|
+
read_iw[i] = iw;
|
|
92
|
+
|
|
93
|
+
// Adjust for flip
|
|
94
|
+
if (params->flip) {
|
|
95
|
+
ih += (params->wS[0] - 1) * params->kdil[0];
|
|
96
|
+
iw += (params->wS[1] - 1) * params->kdil[1];
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
// Read from input if in bounds
|
|
100
|
+
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
|
|
101
|
+
iw * params->in_strides[2] + bj;
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
/* Load from device memory into threadgroup memory - without bound checking */
|
|
106
|
+
METAL_FUNC void load_unsafe() const {
|
|
107
|
+
STEEL_PRAGMA_UNROLL
|
|
108
|
+
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
|
109
|
+
// Find bounds
|
|
110
|
+
int n = read_n[i];
|
|
111
|
+
int ih = read_ih[i] + weight_h * params->kdil[0];
|
|
112
|
+
int iw = read_iw[i] + weight_w * params->kdil[1];
|
|
113
|
+
|
|
114
|
+
// Read from input if in bounds
|
|
115
|
+
if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) &&
|
|
116
|
+
(iw >= 0 && iw < params->iS[1])) {
|
|
117
|
+
STEEL_PRAGMA_UNROLL
|
|
118
|
+
for (short j = 0; j < vec_size; ++j) {
|
|
119
|
+
dst[is * dst_ld + j] = src[i][j];
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
// Zero pad otherwise
|
|
124
|
+
else {
|
|
125
|
+
STEEL_PRAGMA_UNROLL
|
|
126
|
+
for (short j = 0; j < vec_size; ++j) {
|
|
127
|
+
dst[is * dst_ld + j] = T(0);
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
/* Iteration helper */
|
|
134
|
+
METAL_FUNC void next() {
|
|
135
|
+
if (++weight_w < params->wS[1]) {
|
|
136
|
+
STEEL_PRAGMA_UNROLL
|
|
137
|
+
for (short i = 0; i < n_rows; i++) {
|
|
138
|
+
src[i] += gemm_params->inp_jump_w;
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
return;
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
weight_w = 0;
|
|
145
|
+
|
|
146
|
+
if (++weight_h < params->wS[0]) {
|
|
147
|
+
STEEL_PRAGMA_UNROLL
|
|
148
|
+
for (short i = 0; i < n_rows; i++) {
|
|
149
|
+
src[i] += gemm_params->inp_jump_h;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
return;
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
weight_h = 0;
|
|
156
|
+
|
|
157
|
+
STEEL_PRAGMA_UNROLL
|
|
158
|
+
for (short i = 0; i < n_rows; i++) {
|
|
159
|
+
src[i] += gemm_params->inp_jump_c;
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
};
|
|
163
|
+
|
|
164
|
+
template <
|
|
165
|
+
typename T,
|
|
166
|
+
short BM,
|
|
167
|
+
short BN,
|
|
168
|
+
short BK,
|
|
169
|
+
short tgp_size,
|
|
170
|
+
short tgp_padding = 0>
|
|
171
|
+
struct Conv2DInputBlockLoaderSmallFilter {
|
|
172
|
+
// Destination dimensions
|
|
173
|
+
STEEL_CONST short BROWS = BM;
|
|
174
|
+
STEEL_CONST short BCOLS = BK;
|
|
175
|
+
|
|
176
|
+
// Read dimensions
|
|
177
|
+
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
|
178
|
+
STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4;
|
|
179
|
+
|
|
180
|
+
// Thread read shape
|
|
181
|
+
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
|
182
|
+
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
|
183
|
+
|
|
184
|
+
// Rows / strided reads within the block
|
|
185
|
+
STEEL_CONST short n_rows = BROWS / TROWS;
|
|
186
|
+
|
|
187
|
+
using mask_t = short;
|
|
188
|
+
|
|
189
|
+
// Thread location indices
|
|
190
|
+
const short thread_idx;
|
|
191
|
+
const short bi;
|
|
192
|
+
const short bj;
|
|
193
|
+
|
|
194
|
+
// threadgroup and device memory
|
|
195
|
+
threadgroup T* dst;
|
|
196
|
+
|
|
197
|
+
const constant MLXConvParams<2>* params;
|
|
198
|
+
const constant ImplicitGemmConv2DParams* gemm_params;
|
|
199
|
+
|
|
200
|
+
short weight_h;
|
|
201
|
+
short weight_w;
|
|
202
|
+
|
|
203
|
+
const device T* src[n_rows];
|
|
204
|
+
|
|
205
|
+
mask_t mask_h[n_rows];
|
|
206
|
+
mask_t mask_w[n_rows];
|
|
207
|
+
|
|
208
|
+
/* Constructor */
|
|
209
|
+
METAL_FUNC Conv2DInputBlockLoaderSmallFilter(
|
|
210
|
+
const device T* src_,
|
|
211
|
+
threadgroup T* dst_,
|
|
212
|
+
const int2 offsets,
|
|
213
|
+
const constant MLXConvParams<2>* params_,
|
|
214
|
+
const constant ImplicitGemmConv2DParams* gemm_params_,
|
|
215
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
216
|
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
|
217
|
+
: thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
218
|
+
bi(thread_idx / TCOLS),
|
|
219
|
+
bj(vec_size * (thread_idx % TCOLS)),
|
|
220
|
+
dst(dst_ + bi * dst_ld + bj),
|
|
221
|
+
params(params_),
|
|
222
|
+
gemm_params(gemm_params_),
|
|
223
|
+
weight_h(0),
|
|
224
|
+
weight_w(0) {
|
|
225
|
+
int out_n_pixels = params->oS[0] * params->oS[1];
|
|
226
|
+
|
|
227
|
+
int read_n[n_rows];
|
|
228
|
+
int read_ih[n_rows];
|
|
229
|
+
int read_iw[n_rows];
|
|
230
|
+
|
|
231
|
+
STEEL_PRAGMA_UNROLL
|
|
232
|
+
for (short i = 0; i < n_rows; ++i) {
|
|
233
|
+
int offset_nhw = offsets.y + bi + i * TROWS;
|
|
234
|
+
int n = offset_nhw / out_n_pixels;
|
|
235
|
+
int hw = offset_nhw % out_n_pixels;
|
|
236
|
+
int oh = hw / params->oS[1];
|
|
237
|
+
int ow = hw % params->oS[1];
|
|
238
|
+
|
|
239
|
+
int ih = oh * params->str[0] - params->pad[0];
|
|
240
|
+
int iw = ow * params->str[1] - params->pad[1];
|
|
241
|
+
|
|
242
|
+
read_n[i] = n;
|
|
243
|
+
read_ih[i] = ih;
|
|
244
|
+
read_iw[i] = iw;
|
|
245
|
+
|
|
246
|
+
// Adjust for flip
|
|
247
|
+
if (params->flip) {
|
|
248
|
+
ih += (params->wS[0] - 1) * params->kdil[0];
|
|
249
|
+
iw += (params->wS[1] - 1) * params->kdil[1];
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
// Read from input if in bounds
|
|
253
|
+
src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] +
|
|
254
|
+
iw * params->in_strides[2] + bj;
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
STEEL_PRAGMA_UNROLL
|
|
258
|
+
for (short i = 0; i < n_rows; ++i) {
|
|
259
|
+
mask_h[i] = 0;
|
|
260
|
+
mask_w[i] = 0;
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
for (short kh = 0; kh < params->wS[0]; kh++) {
|
|
264
|
+
short flip_h = params->flip ? params->wS[0] - kh - 1 : kh;
|
|
265
|
+
STEEL_PRAGMA_UNROLL
|
|
266
|
+
for (short i = 0; i < n_rows; ++i) {
|
|
267
|
+
int n = read_n[i];
|
|
268
|
+
int ih = read_ih[i] + flip_h * params->kdil[0];
|
|
269
|
+
|
|
270
|
+
bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0];
|
|
271
|
+
|
|
272
|
+
mask_h[i] |= (in_bounds << kh);
|
|
273
|
+
}
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
for (short kw = 0; kw < params->wS[1]; kw++) {
|
|
277
|
+
short flip_w = params->flip ? params->wS[1] - kw - 1 : kw;
|
|
278
|
+
STEEL_PRAGMA_UNROLL
|
|
279
|
+
for (short i = 0; i < n_rows; ++i) {
|
|
280
|
+
int iw = read_iw[i] + flip_w * params->kdil[1];
|
|
281
|
+
|
|
282
|
+
bool in_bounds = iw >= 0 && iw < params->iS[1];
|
|
283
|
+
|
|
284
|
+
mask_w[i] |= (in_bounds << kw);
|
|
285
|
+
}
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
/* Load from device memory into threadgroup memory - without bound checking */
|
|
290
|
+
METAL_FUNC void load_unsafe() const {
|
|
291
|
+
mask_t h_mask = mask_t(1) << weight_h;
|
|
292
|
+
mask_t w_mask = mask_t(1) << weight_w;
|
|
293
|
+
|
|
294
|
+
STEEL_PRAGMA_UNROLL
|
|
295
|
+
for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) {
|
|
296
|
+
// Read from input if in bounds
|
|
297
|
+
if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) {
|
|
298
|
+
STEEL_PRAGMA_UNROLL
|
|
299
|
+
for (short j = 0; j < vec_size; ++j) {
|
|
300
|
+
dst[is * dst_ld + j] = src[i][j];
|
|
301
|
+
}
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
// Zero pad otherwise
|
|
305
|
+
else {
|
|
306
|
+
STEEL_PRAGMA_UNROLL
|
|
307
|
+
for (short j = 0; j < vec_size; ++j) {
|
|
308
|
+
dst[is * dst_ld + j] = T(0);
|
|
309
|
+
}
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
/* Iteration helper */
|
|
315
|
+
METAL_FUNC void next() {
|
|
316
|
+
if (++weight_w < params->wS[1]) {
|
|
317
|
+
STEEL_PRAGMA_UNROLL
|
|
318
|
+
for (short i = 0; i < n_rows; i++) {
|
|
319
|
+
src[i] += gemm_params->inp_jump_w;
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
return;
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
weight_w = 0;
|
|
326
|
+
|
|
327
|
+
if (++weight_h < params->wS[0]) {
|
|
328
|
+
STEEL_PRAGMA_UNROLL
|
|
329
|
+
for (short i = 0; i < n_rows; i++) {
|
|
330
|
+
src[i] += gemm_params->inp_jump_h;
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
return;
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
weight_h = 0;
|
|
337
|
+
|
|
338
|
+
STEEL_PRAGMA_UNROLL
|
|
339
|
+
for (short i = 0; i < n_rows; i++) {
|
|
340
|
+
src[i] += gemm_params->inp_jump_c;
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
};
|
|
344
|
+
|
|
345
|
+
template <
|
|
346
|
+
typename T,
|
|
347
|
+
short BM,
|
|
348
|
+
short BN,
|
|
349
|
+
short BK,
|
|
350
|
+
short tgp_size,
|
|
351
|
+
short tgp_padding = 0>
|
|
352
|
+
struct Conv2DWeightBlockLoader {
|
|
353
|
+
// Destination dimensions
|
|
354
|
+
STEEL_CONST short BROWS = BN;
|
|
355
|
+
STEEL_CONST short BCOLS = BK;
|
|
356
|
+
|
|
357
|
+
// Read dimensions
|
|
358
|
+
STEEL_CONST short dst_ld = BCOLS + tgp_padding;
|
|
359
|
+
STEEL_CONST short vec_size =
|
|
360
|
+
(BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4);
|
|
361
|
+
|
|
362
|
+
// Thread read shape
|
|
363
|
+
STEEL_CONST short TCOLS = BCOLS / vec_size;
|
|
364
|
+
STEEL_CONST short TROWS = tgp_size / TCOLS;
|
|
365
|
+
|
|
366
|
+
// Rows / strided reads within the block
|
|
367
|
+
STEEL_CONST short n_rows = BROWS / TROWS;
|
|
368
|
+
|
|
369
|
+
// Leading dimension for src
|
|
370
|
+
const int src_ld;
|
|
371
|
+
|
|
372
|
+
// Thread location indices
|
|
373
|
+
const short thread_idx;
|
|
374
|
+
const short bi;
|
|
375
|
+
const short bj;
|
|
376
|
+
|
|
377
|
+
// threadgroup and device memory
|
|
378
|
+
threadgroup T* dst;
|
|
379
|
+
const device T* src;
|
|
380
|
+
|
|
381
|
+
const constant MLXConvParams<2>* params;
|
|
382
|
+
|
|
383
|
+
int weight_hw;
|
|
384
|
+
int weight_step;
|
|
385
|
+
|
|
386
|
+
const int read_n;
|
|
387
|
+
const bool do_read;
|
|
388
|
+
|
|
389
|
+
/* Constructor */
|
|
390
|
+
METAL_FUNC Conv2DWeightBlockLoader(
|
|
391
|
+
const device T* src_,
|
|
392
|
+
threadgroup T* dst_,
|
|
393
|
+
const int2 offsets,
|
|
394
|
+
const constant MLXConvParams<2>* params_,
|
|
395
|
+
const constant ImplicitGemmConv2DParams* gemm_params_,
|
|
396
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
397
|
+
uint simd_lane_id [[thread_index_in_simdgroup]])
|
|
398
|
+
: src_ld(params_->wt_strides[0]),
|
|
399
|
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
400
|
+
bi(thread_idx / TCOLS),
|
|
401
|
+
bj(vec_size * (thread_idx % TCOLS)),
|
|
402
|
+
dst(dst_ + bi * dst_ld + bj),
|
|
403
|
+
src(src_ + bi * src_ld + bj),
|
|
404
|
+
params(params_),
|
|
405
|
+
weight_hw(0),
|
|
406
|
+
weight_step(params->C / params->groups),
|
|
407
|
+
read_n(offsets.y + bi),
|
|
408
|
+
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
|
|
409
|
+
|
|
410
|
+
/* Load from device memory into threadgroup memory - without bound checking */
|
|
411
|
+
METAL_FUNC void load_unsafe() const {
|
|
412
|
+
if (BN != 8 || do_read) {
|
|
413
|
+
STEEL_PRAGMA_UNROLL
|
|
414
|
+
for (short i = 0; i < BN; i += TROWS) {
|
|
415
|
+
STEEL_PRAGMA_UNROLL
|
|
416
|
+
for (short j = 0; j < vec_size; j++) {
|
|
417
|
+
dst[i * dst_ld + j] = src[i * src_ld + j];
|
|
418
|
+
}
|
|
419
|
+
}
|
|
420
|
+
} else {
|
|
421
|
+
for (short i = 0; i < BN; i += TROWS) {
|
|
422
|
+
if ((read_n + i) < params->O) {
|
|
423
|
+
STEEL_PRAGMA_UNROLL
|
|
424
|
+
for (short j = 0; j < vec_size; j++) {
|
|
425
|
+
dst[i * dst_ld + j] = src[i * src_ld + j];
|
|
426
|
+
}
|
|
427
|
+
} else {
|
|
428
|
+
STEEL_PRAGMA_UNROLL
|
|
429
|
+
for (short j = 0; j < vec_size; j++) {
|
|
430
|
+
dst[i * dst_ld + j] = T(0);
|
|
431
|
+
}
|
|
432
|
+
}
|
|
433
|
+
}
|
|
434
|
+
}
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
/* Iteration helper */
|
|
438
|
+
METAL_FUNC void next() {
|
|
439
|
+
if (++weight_hw < (params->wS[1] * params->wS[0])) {
|
|
440
|
+
src += weight_step;
|
|
441
|
+
return;
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
weight_hw = 0;
|
|
445
|
+
|
|
446
|
+
src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step;
|
|
447
|
+
}
|
|
448
|
+
};
|
|
449
|
+
|
|
450
|
+
} // namespace steel
|
|
451
|
+
} // namespace mlx
|