mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,719 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include "mlx/backend/metal/kernels/steel/defines.h"
|
|
4
|
+
using namespace metal;
|
|
5
|
+
using namespace mlx::steel;
|
|
6
|
+
|
|
7
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
8
|
+
// GEMM kernels
|
|
9
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
10
|
+
|
|
11
|
+
struct _NoMask {
|
|
12
|
+
char x;
|
|
13
|
+
|
|
14
|
+
constexpr METAL_FUNC operator bool() {
|
|
15
|
+
return true;
|
|
16
|
+
}
|
|
17
|
+
constexpr METAL_FUNC operator bool() const threadgroup {
|
|
18
|
+
return true;
|
|
19
|
+
}
|
|
20
|
+
constexpr METAL_FUNC operator bool() const device {
|
|
21
|
+
return true;
|
|
22
|
+
}
|
|
23
|
+
constexpr METAL_FUNC operator bool() const constant {
|
|
24
|
+
return true;
|
|
25
|
+
}
|
|
26
|
+
};
|
|
27
|
+
|
|
28
|
+
template <typename OutT, typename InT = OutT>
|
|
29
|
+
struct ScaleOp {
|
|
30
|
+
OutT scale;
|
|
31
|
+
|
|
32
|
+
METAL_FUNC OutT apply(InT x) const {
|
|
33
|
+
return static_cast<OutT>(x) * scale;
|
|
34
|
+
}
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
typedef struct _NoMask nomask_t;
|
|
38
|
+
|
|
39
|
+
template <
|
|
40
|
+
typename T,
|
|
41
|
+
typename out_mask_t,
|
|
42
|
+
typename op_mask_t,
|
|
43
|
+
int BM,
|
|
44
|
+
int BN,
|
|
45
|
+
int BK,
|
|
46
|
+
int WM,
|
|
47
|
+
int WN,
|
|
48
|
+
bool transpose_a,
|
|
49
|
+
bool transpose_b,
|
|
50
|
+
bool MN_aligned,
|
|
51
|
+
bool K_aligned>
|
|
52
|
+
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
|
53
|
+
block_masked_gemm(
|
|
54
|
+
const device T* A [[buffer(0)]],
|
|
55
|
+
const device T* B [[buffer(1)]],
|
|
56
|
+
device T* D [[buffer(3)]],
|
|
57
|
+
const constant GEMMParams* params [[buffer(4)]],
|
|
58
|
+
const constant int* batch_shape [[buffer(6)]],
|
|
59
|
+
const constant int64_t* batch_strides [[buffer(7)]],
|
|
60
|
+
const device out_mask_t* out_mask [[buffer(10)]],
|
|
61
|
+
const device op_mask_t* lhs_mask [[buffer(11)]],
|
|
62
|
+
const device op_mask_t* rhs_mask [[buffer(12)]],
|
|
63
|
+
const constant int* mask_strides [[buffer(13)]],
|
|
64
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
65
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
66
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
67
|
+
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
68
|
+
// Appease the compiler
|
|
69
|
+
(void)lid;
|
|
70
|
+
|
|
71
|
+
static_assert(
|
|
72
|
+
BM == BN,
|
|
73
|
+
"block_masked_gemm must have the same block M and block N size");
|
|
74
|
+
static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0");
|
|
75
|
+
|
|
76
|
+
constexpr bool has_operand_mask = !metal::is_same_v<op_mask_t, nomask_t>;
|
|
77
|
+
constexpr bool has_output_mask = !metal::is_same_v<out_mask_t, nomask_t>;
|
|
78
|
+
|
|
79
|
+
constexpr bool has_mul_operand_mask =
|
|
80
|
+
has_operand_mask && !metal::is_same_v<op_mask_t, bool>;
|
|
81
|
+
constexpr bool has_mul_output_mask =
|
|
82
|
+
has_output_mask && !metal::is_same_v<out_mask_t, bool>;
|
|
83
|
+
|
|
84
|
+
constexpr short k_mask_factor = short(BM / BK);
|
|
85
|
+
|
|
86
|
+
using gemm_kernel = GEMMKernel<
|
|
87
|
+
T,
|
|
88
|
+
T,
|
|
89
|
+
BM,
|
|
90
|
+
BN,
|
|
91
|
+
BK,
|
|
92
|
+
WM,
|
|
93
|
+
WN,
|
|
94
|
+
transpose_a,
|
|
95
|
+
transpose_b,
|
|
96
|
+
MN_aligned,
|
|
97
|
+
K_aligned>;
|
|
98
|
+
|
|
99
|
+
const int tid_y = ((tid.y) << params->swizzle_log) +
|
|
100
|
+
((tid.x) & ((1 << params->swizzle_log) - 1));
|
|
101
|
+
const int tid_x = (tid.x) >> params->swizzle_log;
|
|
102
|
+
|
|
103
|
+
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
104
|
+
return;
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
const constant auto* mask_batch_strides =
|
|
108
|
+
batch_strides + 2 * params->batch_ndim;
|
|
109
|
+
|
|
110
|
+
if (params->batch_ndim > 1) {
|
|
111
|
+
if (has_output_mask) {
|
|
112
|
+
out_mask += elem_to_loc(
|
|
113
|
+
tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
|
114
|
+
|
|
115
|
+
mask_batch_strides += params->batch_ndim;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
if (has_operand_mask) {
|
|
119
|
+
const constant auto* mask_strides_lhs = mask_batch_strides;
|
|
120
|
+
const constant auto* mask_strides_rhs =
|
|
121
|
+
mask_strides_lhs + params->batch_ndim;
|
|
122
|
+
|
|
123
|
+
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
124
|
+
tid.z,
|
|
125
|
+
batch_shape,
|
|
126
|
+
mask_strides_lhs,
|
|
127
|
+
mask_strides_rhs,
|
|
128
|
+
params->batch_ndim);
|
|
129
|
+
|
|
130
|
+
lhs_mask += batch_offsets.x;
|
|
131
|
+
rhs_mask += batch_offsets.y;
|
|
132
|
+
}
|
|
133
|
+
} else {
|
|
134
|
+
if (has_output_mask) {
|
|
135
|
+
out_mask += tid.z * mask_batch_strides[0];
|
|
136
|
+
mask_batch_strides += params->batch_ndim;
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
if (has_operand_mask) {
|
|
140
|
+
lhs_mask += tid.z * mask_batch_strides[0];
|
|
141
|
+
rhs_mask += tid.z * mask_batch_strides[params->batch_ndim];
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
// Adjust for batch
|
|
146
|
+
if (params->batch_ndim > 1) {
|
|
147
|
+
const constant auto* A_bstrides = batch_strides;
|
|
148
|
+
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
|
149
|
+
|
|
150
|
+
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
151
|
+
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
|
152
|
+
|
|
153
|
+
A += batch_offsets.x;
|
|
154
|
+
B += batch_offsets.y;
|
|
155
|
+
|
|
156
|
+
} else {
|
|
157
|
+
A += params->batch_stride_a * tid.z;
|
|
158
|
+
B += params->batch_stride_b * tid.z;
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
D += params->batch_stride_d * tid.z;
|
|
162
|
+
|
|
163
|
+
// Find block in A, B, C
|
|
164
|
+
const int c_row = tid_y * BM;
|
|
165
|
+
const int c_col = tid_x * BN;
|
|
166
|
+
const size_t c_row_long = size_t(c_row);
|
|
167
|
+
const size_t c_col_long = size_t(c_col);
|
|
168
|
+
|
|
169
|
+
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
|
170
|
+
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
|
171
|
+
D += c_row_long * params->ldd + c_col_long;
|
|
172
|
+
|
|
173
|
+
const constant int* out_mask_strides = mask_strides;
|
|
174
|
+
const constant int* lhs_mask_strides =
|
|
175
|
+
mask_strides + (has_output_mask ? 2 : 0);
|
|
176
|
+
const constant int* rhs_mask_strides =
|
|
177
|
+
lhs_mask_strides + (has_operand_mask ? 2 : 0);
|
|
178
|
+
|
|
179
|
+
const int out_mask_offset = !has_output_mask
|
|
180
|
+
? 0
|
|
181
|
+
: tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0];
|
|
182
|
+
int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1];
|
|
183
|
+
int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0];
|
|
184
|
+
const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0];
|
|
185
|
+
const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1];
|
|
186
|
+
short k_factor_cnt = k_mask_factor;
|
|
187
|
+
|
|
188
|
+
ScaleOp<float> out_mask_op;
|
|
189
|
+
ScaleOp<T> lhs_mask_op;
|
|
190
|
+
ScaleOp<T> rhs_mask_op;
|
|
191
|
+
|
|
192
|
+
if (has_output_mask) {
|
|
193
|
+
auto mask_out = out_mask[out_mask_offset];
|
|
194
|
+
|
|
195
|
+
if (has_mul_output_mask) {
|
|
196
|
+
out_mask_op.scale = float(mask_out);
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
// Write zeros and return
|
|
200
|
+
if (!mask_out) {
|
|
201
|
+
constexpr short tgp_size = WM * WN * 32;
|
|
202
|
+
constexpr short vec_size = 4;
|
|
203
|
+
|
|
204
|
+
// Tile threads in threadgroup
|
|
205
|
+
constexpr short TN = BN / vec_size;
|
|
206
|
+
constexpr short TM = tgp_size / TN;
|
|
207
|
+
|
|
208
|
+
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
|
209
|
+
const short bi = thread_idx / TN;
|
|
210
|
+
const short bj = vec_size * (thread_idx % TN);
|
|
211
|
+
|
|
212
|
+
D += bi * params->ldd + bj;
|
|
213
|
+
|
|
214
|
+
short tgp_bm = min(BM, params->M - c_row);
|
|
215
|
+
short tgp_bn = min(BN, params->N - c_col);
|
|
216
|
+
|
|
217
|
+
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
|
218
|
+
for (short ti = 0; ti < BM; ti += TM) {
|
|
219
|
+
STEEL_PRAGMA_UNROLL
|
|
220
|
+
for (short j = 0; j < vec_size; j++) {
|
|
221
|
+
D[ti * params->ldd + j] = T(0.);
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
} else {
|
|
225
|
+
short jmax = tgp_bn - bj;
|
|
226
|
+
jmax = jmax < vec_size ? jmax : vec_size;
|
|
227
|
+
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
|
228
|
+
for (short j = 0; j < jmax; j++) {
|
|
229
|
+
D[ti * params->ldd + j] = T(0.);
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
return;
|
|
235
|
+
}
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
239
|
+
|
|
240
|
+
// Prepare threadgroup mma operation
|
|
241
|
+
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
|
242
|
+
|
|
243
|
+
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
244
|
+
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
245
|
+
|
|
246
|
+
// Prepare threadgroup loading operations
|
|
247
|
+
thread typename gemm_kernel::loader_a_t loader_a(
|
|
248
|
+
A, params->lda, As, simd_group_id, simd_lane_id);
|
|
249
|
+
thread typename gemm_kernel::loader_b_t loader_b(
|
|
250
|
+
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
251
|
+
|
|
252
|
+
// Prepare threadgroup bounds
|
|
253
|
+
const short tgp_bm =
|
|
254
|
+
MN_aligned ? short(BM) : short(min(BM, params->M - c_row));
|
|
255
|
+
const short tgp_bn =
|
|
256
|
+
MN_aligned ? short(BN) : short(min(BN, params->N - c_col));
|
|
257
|
+
|
|
258
|
+
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
259
|
+
|
|
260
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
261
|
+
// Do unaligned K iterations first
|
|
262
|
+
if (!K_aligned) {
|
|
263
|
+
const int k_last = params->gemm_k_iterations_aligned * BK;
|
|
264
|
+
const int mask_idx_last = k_last / BM;
|
|
265
|
+
|
|
266
|
+
if (!has_operand_mask ||
|
|
267
|
+
(bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) &&
|
|
268
|
+
bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) {
|
|
269
|
+
if (has_mul_operand_mask) {
|
|
270
|
+
lhs_mask_op.scale =
|
|
271
|
+
lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step];
|
|
272
|
+
rhs_mask_op.scale =
|
|
273
|
+
rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step];
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
// Move loader source ahead to end
|
|
277
|
+
const int k_remain = params->K - k_last;
|
|
278
|
+
const size_t k_jump_a =
|
|
279
|
+
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
|
280
|
+
const size_t k_jump_b =
|
|
281
|
+
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
|
282
|
+
|
|
283
|
+
loader_a.src += k_jump_a;
|
|
284
|
+
loader_b.src += k_jump_b;
|
|
285
|
+
|
|
286
|
+
// Load tile
|
|
287
|
+
const short2 tile_dims_A =
|
|
288
|
+
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
|
289
|
+
const short2 tile_dims_B =
|
|
290
|
+
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
291
|
+
|
|
292
|
+
loader_a.load_safe(tile_dims_A);
|
|
293
|
+
loader_b.load_safe(tile_dims_B);
|
|
294
|
+
|
|
295
|
+
if (has_mul_operand_mask) {
|
|
296
|
+
loader_a.apply_inplace_op(lhs_mask_op);
|
|
297
|
+
loader_b.apply_inplace_op(rhs_mask_op);
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
301
|
+
|
|
302
|
+
// Do matmul
|
|
303
|
+
mma_op.mma(As, Bs);
|
|
304
|
+
|
|
305
|
+
// Reset source back to start
|
|
306
|
+
loader_a.src -= k_jump_a;
|
|
307
|
+
loader_b.src -= k_jump_b;
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
312
|
+
// MNK aligned loop
|
|
313
|
+
if (MN_aligned) {
|
|
314
|
+
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
|
|
315
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
316
|
+
|
|
317
|
+
if (!has_operand_mask ||
|
|
318
|
+
(bool(lhs_mask[lhs_mask_offset]) &&
|
|
319
|
+
bool(rhs_mask[rhs_mask_offset]))) {
|
|
320
|
+
if (has_mul_operand_mask) {
|
|
321
|
+
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
|
|
322
|
+
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
// Load elements into threadgroup
|
|
326
|
+
loader_a.load_unsafe();
|
|
327
|
+
loader_b.load_unsafe();
|
|
328
|
+
|
|
329
|
+
if (has_mul_operand_mask) {
|
|
330
|
+
loader_a.apply_inplace_op(lhs_mask_op);
|
|
331
|
+
loader_b.apply_inplace_op(rhs_mask_op);
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
335
|
+
|
|
336
|
+
// Multiply and accumulate threadgroup elements
|
|
337
|
+
mma_op.mma(As, Bs);
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
// Prepare for next iteration
|
|
341
|
+
loader_a.next();
|
|
342
|
+
loader_b.next();
|
|
343
|
+
|
|
344
|
+
k_factor_cnt--;
|
|
345
|
+
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
|
|
346
|
+
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
|
|
347
|
+
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
if (has_mul_output_mask) {
|
|
351
|
+
mma_op.apply_epilogue(out_mask_op);
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
// Store results to device memory
|
|
355
|
+
mma_op.store_result(D, params->ldd);
|
|
356
|
+
return;
|
|
357
|
+
|
|
358
|
+
}
|
|
359
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
360
|
+
// MN unaligned loop
|
|
361
|
+
else {
|
|
362
|
+
const bool M_aligned = (tgp_bm == BM);
|
|
363
|
+
const bool N_aligned = (tgp_bn == BN);
|
|
364
|
+
|
|
365
|
+
const short2 tile_dims_A =
|
|
366
|
+
transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
|
367
|
+
const short2 tile_dims_B =
|
|
368
|
+
transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
|
369
|
+
|
|
370
|
+
for (; gemm_k_iterations > 0; gemm_k_iterations--) {
|
|
371
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
372
|
+
if (!has_operand_mask ||
|
|
373
|
+
(bool(lhs_mask[lhs_mask_offset]) &&
|
|
374
|
+
bool(rhs_mask[rhs_mask_offset]))) {
|
|
375
|
+
if (has_mul_operand_mask) {
|
|
376
|
+
lhs_mask_op.scale = lhs_mask[lhs_mask_offset];
|
|
377
|
+
rhs_mask_op.scale = rhs_mask[rhs_mask_offset];
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
// Load elements into threadgroup
|
|
381
|
+
if (M_aligned) {
|
|
382
|
+
loader_a.load_unsafe();
|
|
383
|
+
} else {
|
|
384
|
+
loader_a.load_safe(tile_dims_A);
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
if (N_aligned) {
|
|
388
|
+
loader_b.load_unsafe();
|
|
389
|
+
} else {
|
|
390
|
+
loader_b.load_safe(tile_dims_B);
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
if (has_mul_operand_mask) {
|
|
394
|
+
loader_a.apply_inplace_op(lhs_mask_op);
|
|
395
|
+
loader_b.apply_inplace_op(rhs_mask_op);
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
399
|
+
|
|
400
|
+
// Multiply and accumulate threadgroup elements
|
|
401
|
+
mma_op.mma(As, Bs);
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
// Prepare for next iteration
|
|
405
|
+
loader_a.next();
|
|
406
|
+
loader_b.next();
|
|
407
|
+
|
|
408
|
+
k_factor_cnt--;
|
|
409
|
+
lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0;
|
|
410
|
+
rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0;
|
|
411
|
+
k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt;
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
if (has_mul_output_mask) {
|
|
415
|
+
mma_op.apply_epilogue(out_mask_op);
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
if (M_aligned && N_aligned) {
|
|
419
|
+
mma_op.store_result(D, params->ldd);
|
|
420
|
+
} else {
|
|
421
|
+
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
template <
|
|
427
|
+
typename T,
|
|
428
|
+
int BM,
|
|
429
|
+
int BN,
|
|
430
|
+
int BK,
|
|
431
|
+
int WM,
|
|
432
|
+
int WN,
|
|
433
|
+
bool transpose_a,
|
|
434
|
+
bool transpose_b,
|
|
435
|
+
bool MN_aligned,
|
|
436
|
+
bool K_aligned,
|
|
437
|
+
bool has_operand_mask = false>
|
|
438
|
+
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
|
439
|
+
block_masked_gemm(
|
|
440
|
+
const device T* A [[buffer(0)]],
|
|
441
|
+
const device T* B [[buffer(1)]],
|
|
442
|
+
device T* D [[buffer(3)]],
|
|
443
|
+
const constant GEMMParams* params [[buffer(4)]],
|
|
444
|
+
const constant int* batch_shape [[buffer(6)]],
|
|
445
|
+
const constant int64_t* batch_strides [[buffer(7)]],
|
|
446
|
+
const device bool* out_mask [[buffer(10)]],
|
|
447
|
+
const device bool* lhs_mask [[buffer(11)]],
|
|
448
|
+
const device bool* rhs_mask [[buffer(12)]],
|
|
449
|
+
const constant int* mask_strides [[buffer(13)]],
|
|
450
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
451
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
452
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
453
|
+
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
454
|
+
// Appease the compiler
|
|
455
|
+
(void)lid;
|
|
456
|
+
|
|
457
|
+
using gemm_kernel = GEMMKernel<
|
|
458
|
+
T,
|
|
459
|
+
T,
|
|
460
|
+
BM,
|
|
461
|
+
BN,
|
|
462
|
+
BK,
|
|
463
|
+
WM,
|
|
464
|
+
WN,
|
|
465
|
+
transpose_a,
|
|
466
|
+
transpose_b,
|
|
467
|
+
MN_aligned,
|
|
468
|
+
K_aligned>;
|
|
469
|
+
|
|
470
|
+
const int tid_y = ((tid.y) << params->swizzle_log) +
|
|
471
|
+
((tid.x) & ((1 << params->swizzle_log) - 1));
|
|
472
|
+
const int tid_x = (tid.x) >> params->swizzle_log;
|
|
473
|
+
|
|
474
|
+
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
475
|
+
return;
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
if (params->batch_ndim > 1) {
|
|
479
|
+
const constant auto* mask_batch_strides =
|
|
480
|
+
batch_strides + 2 * params->batch_ndim;
|
|
481
|
+
out_mask +=
|
|
482
|
+
elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim);
|
|
483
|
+
|
|
484
|
+
if (has_operand_mask) {
|
|
485
|
+
const constant auto* mask_strides_lhs =
|
|
486
|
+
mask_batch_strides + params->batch_ndim;
|
|
487
|
+
const constant auto* mask_strides_rhs =
|
|
488
|
+
mask_strides_lhs + params->batch_ndim;
|
|
489
|
+
|
|
490
|
+
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
491
|
+
tid.z,
|
|
492
|
+
batch_shape,
|
|
493
|
+
mask_strides_lhs,
|
|
494
|
+
mask_strides_rhs,
|
|
495
|
+
params->batch_ndim);
|
|
496
|
+
|
|
497
|
+
lhs_mask += batch_offsets.x;
|
|
498
|
+
rhs_mask += batch_offsets.y;
|
|
499
|
+
}
|
|
500
|
+
} else {
|
|
501
|
+
out_mask += tid.z * batch_strides[2 * params->batch_ndim];
|
|
502
|
+
if (has_operand_mask) {
|
|
503
|
+
lhs_mask += tid.z * batch_strides[3 * params->batch_ndim];
|
|
504
|
+
rhs_mask += tid.z * batch_strides[4 * params->batch_ndim];
|
|
505
|
+
}
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
// Adjust for batch
|
|
509
|
+
if (params->batch_ndim > 1) {
|
|
510
|
+
const constant auto* A_bstrides = batch_strides;
|
|
511
|
+
const constant auto* B_bstrides = batch_strides + params->batch_ndim;
|
|
512
|
+
|
|
513
|
+
ulong2 batch_offsets = elem_to_loc_broadcast(
|
|
514
|
+
tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim);
|
|
515
|
+
|
|
516
|
+
A += batch_offsets.x;
|
|
517
|
+
B += batch_offsets.y;
|
|
518
|
+
|
|
519
|
+
} else {
|
|
520
|
+
A += params->batch_stride_a * tid.z;
|
|
521
|
+
B += params->batch_stride_b * tid.z;
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
D += params->batch_stride_d * tid.z;
|
|
525
|
+
|
|
526
|
+
// Find block in A, B, C
|
|
527
|
+
const int c_row = tid_y * BM;
|
|
528
|
+
const int c_col = tid_x * BN;
|
|
529
|
+
const size_t c_row_long = size_t(c_row);
|
|
530
|
+
const size_t c_col_long = size_t(c_col);
|
|
531
|
+
|
|
532
|
+
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
|
533
|
+
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
|
534
|
+
D += c_row_long * params->ldd + c_col_long;
|
|
535
|
+
|
|
536
|
+
bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]];
|
|
537
|
+
|
|
538
|
+
// Write zeros and return
|
|
539
|
+
if (!mask_out) {
|
|
540
|
+
constexpr short tgp_size = WM * WN * 32;
|
|
541
|
+
constexpr short vec_size = 4;
|
|
542
|
+
|
|
543
|
+
// Tile threads in threadgroup
|
|
544
|
+
constexpr short TN = BN / vec_size;
|
|
545
|
+
constexpr short TM = tgp_size / TN;
|
|
546
|
+
|
|
547
|
+
const short thread_idx = simd_group_id * 32 + simd_lane_id;
|
|
548
|
+
const short bi = thread_idx / TN;
|
|
549
|
+
const short bj = vec_size * (thread_idx % TN);
|
|
550
|
+
|
|
551
|
+
D += bi * params->ldd + bj;
|
|
552
|
+
|
|
553
|
+
short tgp_bm = min(BM, params->M - c_row);
|
|
554
|
+
short tgp_bn = min(BN, params->N - c_col);
|
|
555
|
+
|
|
556
|
+
if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
|
557
|
+
for (short ti = 0; ti < BM; ti += TM) {
|
|
558
|
+
STEEL_PRAGMA_UNROLL
|
|
559
|
+
for (short j = 0; j < vec_size; j++) {
|
|
560
|
+
D[ti * params->ldd + j] = T(0.);
|
|
561
|
+
}
|
|
562
|
+
}
|
|
563
|
+
} else {
|
|
564
|
+
short jmax = tgp_bn - bj;
|
|
565
|
+
jmax = jmax < vec_size ? jmax : vec_size;
|
|
566
|
+
for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) {
|
|
567
|
+
for (short j = 0; j < jmax; j++) {
|
|
568
|
+
D[ti * params->ldd + j] = T(0.);
|
|
569
|
+
}
|
|
570
|
+
}
|
|
571
|
+
}
|
|
572
|
+
|
|
573
|
+
return;
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
577
|
+
|
|
578
|
+
// Prepare threadgroup mma operation
|
|
579
|
+
thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id);
|
|
580
|
+
|
|
581
|
+
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
582
|
+
|
|
583
|
+
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
584
|
+
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
585
|
+
|
|
586
|
+
// Prepare threadgroup loading operations
|
|
587
|
+
thread typename gemm_kernel::loader_a_t loader_a(
|
|
588
|
+
A, params->lda, As, simd_group_id, simd_lane_id);
|
|
589
|
+
thread typename gemm_kernel::loader_b_t loader_b(
|
|
590
|
+
B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
591
|
+
|
|
592
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
593
|
+
// MNK aligned loop
|
|
594
|
+
if (MN_aligned) {
|
|
595
|
+
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
596
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
597
|
+
|
|
598
|
+
if (!has_operand_mask ||
|
|
599
|
+
(lhs_mask
|
|
600
|
+
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
|
601
|
+
rhs_mask
|
|
602
|
+
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
|
603
|
+
// Load elements into threadgroup
|
|
604
|
+
loader_a.load_unsafe();
|
|
605
|
+
loader_b.load_unsafe();
|
|
606
|
+
|
|
607
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
608
|
+
|
|
609
|
+
// Multiply and accumulate threadgroup elements
|
|
610
|
+
mma_op.mma(As, Bs);
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
// Prepare for next iteration
|
|
614
|
+
loader_a.next();
|
|
615
|
+
loader_b.next();
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
619
|
+
|
|
620
|
+
// Loop tail
|
|
621
|
+
if (!K_aligned) {
|
|
622
|
+
if (!has_operand_mask ||
|
|
623
|
+
(lhs_mask
|
|
624
|
+
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
|
625
|
+
rhs_mask
|
|
626
|
+
[(params->K / BM) * mask_strides[5] +
|
|
627
|
+
tid_x * mask_strides[4]])) {
|
|
628
|
+
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
629
|
+
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
|
630
|
+
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
|
631
|
+
|
|
632
|
+
loader_a.load_safe(tile_dims_A);
|
|
633
|
+
loader_b.load_safe(tile_dims_B);
|
|
634
|
+
|
|
635
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
636
|
+
|
|
637
|
+
mma_op.mma(As, Bs);
|
|
638
|
+
}
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
// Store results to device memory
|
|
642
|
+
mma_op.store_result(D, params->ldd);
|
|
643
|
+
return;
|
|
644
|
+
|
|
645
|
+
}
|
|
646
|
+
///////////////////////////////////////////////////////////////////////////////
|
|
647
|
+
// MN unaligned loop
|
|
648
|
+
else { // Loop over K - unaligned case
|
|
649
|
+
short tgp_bm = min(BM, params->M - c_row);
|
|
650
|
+
short tgp_bn = min(BN, params->N - c_col);
|
|
651
|
+
short lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
652
|
+
|
|
653
|
+
bool M_aligned = (tgp_bm == BM);
|
|
654
|
+
bool N_aligned = (tgp_bn == BN);
|
|
655
|
+
|
|
656
|
+
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
|
657
|
+
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
|
658
|
+
|
|
659
|
+
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
660
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
661
|
+
if (!has_operand_mask ||
|
|
662
|
+
(lhs_mask
|
|
663
|
+
[tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] &&
|
|
664
|
+
rhs_mask
|
|
665
|
+
[((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) {
|
|
666
|
+
// Load elements into threadgroup
|
|
667
|
+
if (M_aligned) {
|
|
668
|
+
loader_a.load_unsafe();
|
|
669
|
+
} else {
|
|
670
|
+
loader_a.load_safe(tile_dims_A);
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
if (N_aligned) {
|
|
674
|
+
loader_b.load_unsafe();
|
|
675
|
+
} else {
|
|
676
|
+
loader_b.load_safe(tile_dims_B);
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
680
|
+
|
|
681
|
+
// Multiply and accumulate threadgroup elements
|
|
682
|
+
mma_op.mma(As, Bs);
|
|
683
|
+
}
|
|
684
|
+
|
|
685
|
+
// Prepare for next iteration
|
|
686
|
+
loader_a.next();
|
|
687
|
+
loader_b.next();
|
|
688
|
+
}
|
|
689
|
+
|
|
690
|
+
if (!K_aligned) {
|
|
691
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
692
|
+
|
|
693
|
+
if (!has_operand_mask ||
|
|
694
|
+
(lhs_mask
|
|
695
|
+
[tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] &&
|
|
696
|
+
rhs_mask
|
|
697
|
+
[(params->K / BM) * mask_strides[5] +
|
|
698
|
+
tid_x * mask_strides[4]])) {
|
|
699
|
+
short2 tile_dims_A_last =
|
|
700
|
+
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
|
701
|
+
short2 tile_dims_B_last =
|
|
702
|
+
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
|
703
|
+
|
|
704
|
+
loader_a.load_safe(tile_dims_A_last);
|
|
705
|
+
loader_b.load_safe(tile_dims_B_last);
|
|
706
|
+
|
|
707
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
708
|
+
|
|
709
|
+
mma_op.mma(As, Bs);
|
|
710
|
+
}
|
|
711
|
+
}
|
|
712
|
+
|
|
713
|
+
if (M_aligned && N_aligned) {
|
|
714
|
+
mma_op.store_result(D, params->ldd);
|
|
715
|
+
} else {
|
|
716
|
+
mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm));
|
|
717
|
+
}
|
|
718
|
+
}
|
|
719
|
+
}
|