mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,2502 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <metal_simdgroup>
|
|
4
|
+
#include <metal_stdlib>
|
|
5
|
+
|
|
6
|
+
constant bool align_M [[function_constant(200)]];
|
|
7
|
+
constant bool align_N [[function_constant(201)]];
|
|
8
|
+
constant bool align_K [[function_constant(202)]];
|
|
9
|
+
|
|
10
|
+
using namespace metal;
|
|
11
|
+
|
|
12
|
+
#define MLX_MTL_CONST static constant constexpr const
|
|
13
|
+
|
|
14
|
+
MLX_MTL_CONST int SIMD_SIZE = 32;
|
|
15
|
+
MLX_MTL_CONST int QUAD_SIZE = 4;
|
|
16
|
+
|
|
17
|
+
template <int bits, int wsize = 8>
|
|
18
|
+
inline constexpr short get_pack_factor() {
|
|
19
|
+
return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits);
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
template <int bits, int wsize = 8>
|
|
23
|
+
inline constexpr short get_bytes_per_pack() {
|
|
24
|
+
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
25
|
+
return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3);
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
template <typename T, typename U, int values_per_thread, int bits>
|
|
29
|
+
inline U load_vector(const device T* x, thread U* x_thread) {
|
|
30
|
+
static_assert(
|
|
31
|
+
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
32
|
+
bits == 8,
|
|
33
|
+
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
34
|
+
|
|
35
|
+
U sum = 0;
|
|
36
|
+
|
|
37
|
+
if (bits == 2) {
|
|
38
|
+
for (int i = 0; i < values_per_thread; i += 4) {
|
|
39
|
+
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
40
|
+
x_thread[i] = x[i];
|
|
41
|
+
x_thread[i + 1] = x[i + 1] / 4.0f;
|
|
42
|
+
x_thread[i + 2] = x[i + 2] / 16.0f;
|
|
43
|
+
x_thread[i + 3] = x[i + 3] / 64.0f;
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
else if (bits == 3) {
|
|
48
|
+
for (int i = 0; i < values_per_thread; i += 8) {
|
|
49
|
+
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
|
50
|
+
x[i + 6] + x[i + 7];
|
|
51
|
+
x_thread[i] = x[i];
|
|
52
|
+
x_thread[i + 1] = x[i + 1] / 8.0f;
|
|
53
|
+
x_thread[i + 2] = x[i + 2] / 64.0f;
|
|
54
|
+
x_thread[i + 3] = x[i + 3] / 2.0f;
|
|
55
|
+
x_thread[i + 4] = x[i + 4] / 16.0f;
|
|
56
|
+
x_thread[i + 5] = x[i + 5] / 128.0f;
|
|
57
|
+
x_thread[i + 6] = x[i + 6] / 4.0f;
|
|
58
|
+
x_thread[i + 7] = x[i + 7] / 32.0f;
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
else if (bits == 4) {
|
|
63
|
+
for (int i = 0; i < values_per_thread; i += 4) {
|
|
64
|
+
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
65
|
+
x_thread[i] = x[i];
|
|
66
|
+
x_thread[i + 1] = x[i + 1] / 16.0f;
|
|
67
|
+
x_thread[i + 2] = x[i + 2] / 256.0f;
|
|
68
|
+
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
else if (bits == 5) {
|
|
73
|
+
for (int i = 0; i < values_per_thread; i += 8) {
|
|
74
|
+
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
|
75
|
+
x[i + 6] + x[i + 7];
|
|
76
|
+
x_thread[i] = x[i];
|
|
77
|
+
x_thread[i + 1] = x[i + 1] / 32.0f;
|
|
78
|
+
x_thread[i + 2] = x[i + 2] / 4.0f;
|
|
79
|
+
x_thread[i + 3] = x[i + 3] / 128.0f;
|
|
80
|
+
x_thread[i + 4] = x[i + 4] / 16.0f;
|
|
81
|
+
x_thread[i + 5] = x[i + 5] / 2.0f;
|
|
82
|
+
x_thread[i + 6] = x[i + 6] / 64.0f;
|
|
83
|
+
x_thread[i + 7] = x[i + 7] / 8.0f;
|
|
84
|
+
}
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
else if (bits == 6) {
|
|
88
|
+
for (int i = 0; i < values_per_thread; i += 4) {
|
|
89
|
+
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
90
|
+
x_thread[i] = x[i];
|
|
91
|
+
x_thread[i + 1] = x[i + 1] / 64.0f;
|
|
92
|
+
x_thread[i + 2] = x[i + 2] / 16.0f;
|
|
93
|
+
x_thread[i + 3] = x[i + 3] / 4.0f;
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
else if (bits == 8) {
|
|
98
|
+
for (int i = 0; i < values_per_thread; i++) {
|
|
99
|
+
sum += x[i];
|
|
100
|
+
x_thread[i] = x[i];
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
return sum;
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
template <typename T, typename U, int values_per_thread, int bits>
|
|
108
|
+
inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|
109
|
+
static_assert(
|
|
110
|
+
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
111
|
+
bits == 8,
|
|
112
|
+
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
113
|
+
|
|
114
|
+
U sum = 0;
|
|
115
|
+
|
|
116
|
+
if (bits == 2) {
|
|
117
|
+
for (int i = 0; i < N; i += 4) {
|
|
118
|
+
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
119
|
+
x_thread[i] = x[i];
|
|
120
|
+
x_thread[i + 1] = x[i + 1] / 4.0f;
|
|
121
|
+
x_thread[i + 2] = x[i + 2] / 16.0f;
|
|
122
|
+
x_thread[i + 3] = x[i + 3] / 64.0f;
|
|
123
|
+
}
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
else if (bits == 3) {
|
|
127
|
+
for (int i = 0; i < N; i += 8) {
|
|
128
|
+
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
|
129
|
+
x[i + 6] + x[i + 7];
|
|
130
|
+
|
|
131
|
+
x_thread[i] = x[i];
|
|
132
|
+
x_thread[i + 1] = x[i + 1] / 8.0f;
|
|
133
|
+
x_thread[i + 2] = x[i + 2] / 64.0f;
|
|
134
|
+
x_thread[i + 3] = x[i + 3] / 2.0f;
|
|
135
|
+
x_thread[i + 4] = x[i + 4] / 16.0f;
|
|
136
|
+
x_thread[i + 5] = x[i + 5] / 128.0f;
|
|
137
|
+
x_thread[i + 6] = x[i + 6] / 4.0f;
|
|
138
|
+
x_thread[i + 7] = x[i + 7] / 32.0f;
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
else if (bits == 4) {
|
|
143
|
+
for (int i = 0; i < N; i += 4) {
|
|
144
|
+
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
145
|
+
x_thread[i] = x[i];
|
|
146
|
+
x_thread[i + 1] = x[i + 1] / 16.0f;
|
|
147
|
+
x_thread[i + 2] = x[i + 2] / 256.0f;
|
|
148
|
+
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
|
149
|
+
}
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
else if (bits == 5) {
|
|
153
|
+
for (int i = 0; i < N; i += 8) {
|
|
154
|
+
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] +
|
|
155
|
+
x[i + 6] + x[i + 7];
|
|
156
|
+
x_thread[i] = x[i];
|
|
157
|
+
x_thread[i + 1] = x[i + 1] / 32.0f;
|
|
158
|
+
x_thread[i + 2] = x[i + 2] / 4.0f;
|
|
159
|
+
x_thread[i + 3] = x[i + 3] / 128.0f;
|
|
160
|
+
x_thread[i + 4] = x[i + 4] / 16.0f;
|
|
161
|
+
x_thread[i + 5] = x[i + 5] / 2.0f;
|
|
162
|
+
x_thread[i + 6] = x[i + 6] / 64.0f;
|
|
163
|
+
x_thread[i + 7] = x[i + 7] / 8.0f;
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
else if (bits == 6) {
|
|
168
|
+
for (int i = 0; i < N; i += 4) {
|
|
169
|
+
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
|
170
|
+
x_thread[i] = x[i];
|
|
171
|
+
x_thread[i + 1] = x[i + 1] / 64.0f;
|
|
172
|
+
x_thread[i + 2] = x[i + 2] / 16.0f;
|
|
173
|
+
x_thread[i + 3] = x[i + 3] / 4.0f;
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
else if (bits == 8) {
|
|
178
|
+
for (int i = 0; i < N; i++) {
|
|
179
|
+
sum += x[i];
|
|
180
|
+
x_thread[i] = x[i];
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
for (int i = N; i < values_per_thread; i++) {
|
|
185
|
+
x_thread[i] = 0;
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
return sum;
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
template <typename U, int values_per_thread, int bits>
|
|
192
|
+
inline U qdot(
|
|
193
|
+
const device uint8_t* w,
|
|
194
|
+
const thread U* x_thread,
|
|
195
|
+
U scale,
|
|
196
|
+
U bias,
|
|
197
|
+
U sum) {
|
|
198
|
+
static_assert(
|
|
199
|
+
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
200
|
+
bits == 8,
|
|
201
|
+
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
202
|
+
|
|
203
|
+
U accum = 0;
|
|
204
|
+
|
|
205
|
+
if (bits == 2) {
|
|
206
|
+
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
207
|
+
accum +=
|
|
208
|
+
(x_thread[4 * i] * (w[i] & 0x03) +
|
|
209
|
+
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
|
210
|
+
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
|
211
|
+
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
|
212
|
+
}
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
else if (bits == 3) {
|
|
216
|
+
for (int i = 0; i < (values_per_thread / 8); i++) {
|
|
217
|
+
x_thread += 8 * i;
|
|
218
|
+
w += 3 * i;
|
|
219
|
+
|
|
220
|
+
accum += (w[0] & 0x07) * x_thread[0];
|
|
221
|
+
accum += (w[0] & 0x38) * x_thread[1];
|
|
222
|
+
accum += (w[0] & 0xc0) * x_thread[2];
|
|
223
|
+
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
|
|
224
|
+
|
|
225
|
+
accum += (w[1] & 0x0e) * x_thread[3];
|
|
226
|
+
accum += (w[1] & 0x70) * x_thread[4];
|
|
227
|
+
accum += (w[1] & 0x80) * x_thread[5];
|
|
228
|
+
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
|
|
229
|
+
|
|
230
|
+
accum += (w[2] & 0x1c) * x_thread[6];
|
|
231
|
+
accum += (w[2] & 0xe0) * x_thread[7];
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
else if (bits == 4) {
|
|
236
|
+
const device uint16_t* ws = (const device uint16_t*)w;
|
|
237
|
+
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
238
|
+
accum +=
|
|
239
|
+
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
|
240
|
+
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
|
241
|
+
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
|
242
|
+
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
|
243
|
+
}
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
else if (bits == 5) {
|
|
247
|
+
for (int i = 0; i < (values_per_thread / 8); i++) {
|
|
248
|
+
x_thread += 8 * i;
|
|
249
|
+
w += 5 * i;
|
|
250
|
+
|
|
251
|
+
accum += (w[0] & 0x1f) * x_thread[0];
|
|
252
|
+
accum += (w[0] & 0xe0) * x_thread[1];
|
|
253
|
+
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
|
|
254
|
+
accum += (w[1] & 0x7c) * x_thread[2];
|
|
255
|
+
accum += (w[1] & 0x80) * x_thread[3];
|
|
256
|
+
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
|
|
257
|
+
accum += (w[2] & 0xf0) * x_thread[4];
|
|
258
|
+
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
|
|
259
|
+
accum += (w[3] & 0x3e) * x_thread[5];
|
|
260
|
+
accum += (w[3] & 0xc0) * x_thread[6];
|
|
261
|
+
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
|
|
262
|
+
accum += (w[4] & 0xf8) * x_thread[7];
|
|
263
|
+
}
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
else if (bits == 6) {
|
|
267
|
+
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
268
|
+
x_thread += 4 * i;
|
|
269
|
+
w += 3 * i;
|
|
270
|
+
|
|
271
|
+
accum += (w[0] & 0x3f) * x_thread[0];
|
|
272
|
+
|
|
273
|
+
accum += (w[0] & 0xc0) * x_thread[1];
|
|
274
|
+
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
|
|
275
|
+
|
|
276
|
+
accum += (w[1] & 0xf0) * x_thread[2];
|
|
277
|
+
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
|
|
278
|
+
|
|
279
|
+
accum += (w[2] & 0xfc) * x_thread[3];
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
else if (bits == 8) {
|
|
284
|
+
for (int i = 0; i < values_per_thread; i++) {
|
|
285
|
+
accum += x_thread[i] * w[i];
|
|
286
|
+
}
|
|
287
|
+
}
|
|
288
|
+
|
|
289
|
+
return scale * accum + sum * bias;
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
template <typename U, int values_per_thread, int bits>
|
|
293
|
+
inline U qdot_safe(
|
|
294
|
+
const device uint8_t* w,
|
|
295
|
+
const thread U* x_thread,
|
|
296
|
+
U scale,
|
|
297
|
+
U bias,
|
|
298
|
+
U sum,
|
|
299
|
+
int N) {
|
|
300
|
+
static_assert(
|
|
301
|
+
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
302
|
+
bits == 8,
|
|
303
|
+
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
304
|
+
|
|
305
|
+
U accum = 0;
|
|
306
|
+
|
|
307
|
+
if (bits == 2) {
|
|
308
|
+
for (int i = 0; i < (N / 4); i++) {
|
|
309
|
+
accum +=
|
|
310
|
+
(x_thread[4 * i] * (w[i] & 0x03) +
|
|
311
|
+
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
|
312
|
+
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
|
313
|
+
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
else if (bits == 3) {
|
|
318
|
+
for (int i = 0; i < (N / 8); i++) {
|
|
319
|
+
x_thread += 8 * i;
|
|
320
|
+
w += 3 * i;
|
|
321
|
+
|
|
322
|
+
accum += (w[0] & 0x07) * x_thread[0];
|
|
323
|
+
accum += (w[0] & 0x38) * x_thread[1];
|
|
324
|
+
accum += (w[0] & 0xc0) * x_thread[2];
|
|
325
|
+
accum += (w[1] & 0x01) * (x_thread[2] * 256.0f);
|
|
326
|
+
|
|
327
|
+
accum += (w[1] & 0x0e) * x_thread[3];
|
|
328
|
+
accum += (w[1] & 0x70) * x_thread[4];
|
|
329
|
+
accum += (w[1] & 0x80) * x_thread[5];
|
|
330
|
+
accum += (w[2] & 0x03) * (x_thread[5] * 256.0f);
|
|
331
|
+
|
|
332
|
+
accum += (w[2] & 0x1c) * x_thread[6];
|
|
333
|
+
accum += (w[2] & 0xe0) * x_thread[7];
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
else if (bits == 4) {
|
|
338
|
+
const device uint16_t* ws = (const device uint16_t*)w;
|
|
339
|
+
for (int i = 0; i < (N / 4); i++) {
|
|
340
|
+
accum +=
|
|
341
|
+
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
|
342
|
+
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
|
343
|
+
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
|
344
|
+
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
|
345
|
+
}
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
else if (bits == 5) {
|
|
349
|
+
for (int i = 0; i < (N / 8); i++) {
|
|
350
|
+
x_thread += 8 * i;
|
|
351
|
+
w += 5 * i;
|
|
352
|
+
|
|
353
|
+
accum += (w[0] & 0x1f) * x_thread[0];
|
|
354
|
+
accum += (w[0] & 0xe0) * x_thread[1];
|
|
355
|
+
accum += (w[1] & 0x3) * (x_thread[1] * 256.0f);
|
|
356
|
+
accum += (w[1] & 0x7c) * x_thread[2];
|
|
357
|
+
accum += (w[1] & 0x80) * x_thread[3];
|
|
358
|
+
accum += (w[2] & 0xf) * (x_thread[3] * 256.0f);
|
|
359
|
+
accum += (w[2] & 0xf0) * x_thread[4];
|
|
360
|
+
accum += (w[3] & 0x1) * (x_thread[4] * 256.0f);
|
|
361
|
+
accum += (w[3] & 0x3e) * x_thread[5];
|
|
362
|
+
accum += (w[3] & 0xc0) * x_thread[6];
|
|
363
|
+
accum += (w[4] & 0x7) * (x_thread[6] * 256.0f);
|
|
364
|
+
accum += (w[4] & 0xf8) * x_thread[7];
|
|
365
|
+
}
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
else if (bits == 6) {
|
|
369
|
+
for (int i = 0; i < (N / 4); i++) {
|
|
370
|
+
x_thread += 4 * i;
|
|
371
|
+
w += 3 * i;
|
|
372
|
+
|
|
373
|
+
accum += (w[0] & 0x3f) * x_thread[0];
|
|
374
|
+
|
|
375
|
+
accum += (w[0] & 0xc0) * x_thread[1];
|
|
376
|
+
accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f);
|
|
377
|
+
|
|
378
|
+
accum += (w[1] & 0xf0) * x_thread[2];
|
|
379
|
+
accum += (w[2] & 0x03) * (x_thread[2] * 256.0f);
|
|
380
|
+
|
|
381
|
+
accum += (w[2] & 0xfc) * x_thread[3];
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
else if (bits == 8) {
|
|
386
|
+
for (int i = 0; i < N; i++) {
|
|
387
|
+
accum += x_thread[i] * w[i];
|
|
388
|
+
}
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
return scale * accum + sum * bias;
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
template <typename U, int values_per_thread, int bits>
|
|
395
|
+
inline void
|
|
396
|
+
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
|
397
|
+
static_assert(
|
|
398
|
+
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
399
|
+
bits == 8,
|
|
400
|
+
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
401
|
+
|
|
402
|
+
if (bits == 2) {
|
|
403
|
+
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
|
404
|
+
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
405
|
+
result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias);
|
|
406
|
+
result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias);
|
|
407
|
+
result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias);
|
|
408
|
+
result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias);
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
else if (bits == 3) {
|
|
413
|
+
for (int i = 0; i < (values_per_thread / 8); i++) {
|
|
414
|
+
uint8_t w0 = w[3 * i];
|
|
415
|
+
uint8_t w1 = w[3 * i + 1];
|
|
416
|
+
uint8_t w2 = w[3 * i + 2];
|
|
417
|
+
|
|
418
|
+
result[8 * i] += x * ((w0 & 0x7) * scale + bias);
|
|
419
|
+
result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias);
|
|
420
|
+
result[8 * i + 2] +=
|
|
421
|
+
x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias);
|
|
422
|
+
result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias);
|
|
423
|
+
result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias);
|
|
424
|
+
result[8 * i + 5] +=
|
|
425
|
+
x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias);
|
|
426
|
+
result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias);
|
|
427
|
+
result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias);
|
|
428
|
+
}
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
else if (bits == 4) {
|
|
432
|
+
U s[2] = {scale, scale / 16.0f};
|
|
433
|
+
for (int i = 0; i < (values_per_thread / 2); i++) {
|
|
434
|
+
result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias);
|
|
435
|
+
result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias);
|
|
436
|
+
}
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
else if (bits == 5) {
|
|
440
|
+
for (int i = 0; i < (values_per_thread / 8); i++) {
|
|
441
|
+
uint8_t w0 = w[5 * i];
|
|
442
|
+
uint8_t w1 = w[5 * i + 1];
|
|
443
|
+
uint8_t w2 = w[5 * i + 2];
|
|
444
|
+
uint8_t w3 = w[5 * i + 3];
|
|
445
|
+
uint8_t w4 = w[5 * i + 4];
|
|
446
|
+
result[8 * i] += x * ((w0 & 0x1f) * scale + bias);
|
|
447
|
+
result[8 * i + 1] +=
|
|
448
|
+
x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias);
|
|
449
|
+
result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias);
|
|
450
|
+
result[8 * i + 3] +=
|
|
451
|
+
x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias);
|
|
452
|
+
result[8 * i + 4] +=
|
|
453
|
+
x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias);
|
|
454
|
+
result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias);
|
|
455
|
+
result[8 * i + 6] +=
|
|
456
|
+
x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias);
|
|
457
|
+
result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias);
|
|
458
|
+
}
|
|
459
|
+
}
|
|
460
|
+
|
|
461
|
+
else if (bits == 6) {
|
|
462
|
+
for (int i = 0; i < (values_per_thread / 4); i++) {
|
|
463
|
+
uint8_t w0 = w[3 * i];
|
|
464
|
+
uint8_t w1 = w[3 * i + 1];
|
|
465
|
+
uint8_t w2 = w[3 * i + 2];
|
|
466
|
+
|
|
467
|
+
result[4 * i] += x * ((w0 & 0x3f) * scale + bias);
|
|
468
|
+
result[4 * i + 1] +=
|
|
469
|
+
x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias);
|
|
470
|
+
result[4 * i + 2] +=
|
|
471
|
+
x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias);
|
|
472
|
+
result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias);
|
|
473
|
+
}
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
else if (bits == 8) {
|
|
477
|
+
for (int i = 0; i < values_per_thread; i++) {
|
|
478
|
+
result[i] += x * (scale * w[i] + bias);
|
|
479
|
+
}
|
|
480
|
+
}
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
template <typename U, int N, int bits>
|
|
484
|
+
inline void
|
|
485
|
+
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
|
486
|
+
static_assert(
|
|
487
|
+
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
488
|
+
bits == 8,
|
|
489
|
+
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
490
|
+
|
|
491
|
+
if (bits == 2) {
|
|
492
|
+
U s[4] = {
|
|
493
|
+
scale,
|
|
494
|
+
scale / static_cast<U>(4.0f),
|
|
495
|
+
scale / static_cast<U>(16.0f),
|
|
496
|
+
scale / static_cast<U>(64.0f)};
|
|
497
|
+
for (int i = 0; i < (N / 4); i++) {
|
|
498
|
+
w_local[4 * i] = s[0] * (w[i] & 0x03) + bias;
|
|
499
|
+
w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias;
|
|
500
|
+
w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias;
|
|
501
|
+
w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias;
|
|
502
|
+
}
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
else if (bits == 3) {
|
|
506
|
+
for (int i = 0; i < (N / 8); i++) {
|
|
507
|
+
w_local += 8 * i;
|
|
508
|
+
w += 3 * i;
|
|
509
|
+
|
|
510
|
+
w_local[0] = (w[0] & 0x7) * scale + bias;
|
|
511
|
+
w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias;
|
|
512
|
+
w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
|
|
513
|
+
w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias;
|
|
514
|
+
w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias;
|
|
515
|
+
w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
|
|
516
|
+
w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
|
517
|
+
w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
|
518
|
+
}
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
else if (bits == 4) {
|
|
522
|
+
U s[2] = {scale, scale / static_cast<U>(16.0f)};
|
|
523
|
+
for (int i = 0; i < (N / 2); i++) {
|
|
524
|
+
w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias;
|
|
525
|
+
w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias;
|
|
526
|
+
}
|
|
527
|
+
}
|
|
528
|
+
|
|
529
|
+
else if (bits == 5) {
|
|
530
|
+
for (int i = 0; i < (N / 8); i++) {
|
|
531
|
+
w_local += 8 * i;
|
|
532
|
+
w += 5 * i;
|
|
533
|
+
|
|
534
|
+
w_local[0] = (w[0] & 0x1f) * scale + bias;
|
|
535
|
+
w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
|
536
|
+
w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
|
537
|
+
w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
|
|
538
|
+
w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
|
|
539
|
+
w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
|
|
540
|
+
w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
|
|
541
|
+
w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
|
|
542
|
+
}
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
else if (bits == 6) {
|
|
546
|
+
for (int i = 0; i < (N / 4); i++) {
|
|
547
|
+
w_local += 4 * i;
|
|
548
|
+
w += 3 * i;
|
|
549
|
+
w_local[0] = (w[0] & 0x3f) * scale + bias;
|
|
550
|
+
w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
|
|
551
|
+
w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
|
|
552
|
+
w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
|
|
553
|
+
}
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
else if (bits == 8) {
|
|
557
|
+
for (int i = 0; i < N; i++) {
|
|
558
|
+
w_local[i] = scale * w[i] + bias;
|
|
559
|
+
}
|
|
560
|
+
}
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
template <
|
|
564
|
+
typename T,
|
|
565
|
+
short BROWS,
|
|
566
|
+
short BCOLS,
|
|
567
|
+
short dst_ld,
|
|
568
|
+
short reduction_dim,
|
|
569
|
+
short tgp_size,
|
|
570
|
+
short group_size,
|
|
571
|
+
short bits>
|
|
572
|
+
struct QuantizedBlockLoader {
|
|
573
|
+
static_assert(
|
|
574
|
+
BCOLS <= group_size,
|
|
575
|
+
"The group size should be larger than the columns");
|
|
576
|
+
static_assert(
|
|
577
|
+
group_size % BCOLS == 0,
|
|
578
|
+
"The group size should be divisible by the columns");
|
|
579
|
+
static_assert(
|
|
580
|
+
bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 ||
|
|
581
|
+
bits == 8,
|
|
582
|
+
"Template undefined for bits not in {2, 3, 4, 5, 6, 8}");
|
|
583
|
+
|
|
584
|
+
MLX_MTL_CONST short pack_factor = get_pack_factor<bits, 8>();
|
|
585
|
+
MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack<bits>();
|
|
586
|
+
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
|
587
|
+
MLX_MTL_CONST short n_reads =
|
|
588
|
+
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
|
589
|
+
MLX_MTL_CONST short group_steps = group_size / BCOLS;
|
|
590
|
+
|
|
591
|
+
const int src_ld;
|
|
592
|
+
const int tile_stride;
|
|
593
|
+
short group_step_cnt;
|
|
594
|
+
const int group_stride;
|
|
595
|
+
|
|
596
|
+
const short thread_idx;
|
|
597
|
+
const short bi;
|
|
598
|
+
const short bj;
|
|
599
|
+
|
|
600
|
+
threadgroup T* dst;
|
|
601
|
+
const device uint8_t* src;
|
|
602
|
+
const device T* scales;
|
|
603
|
+
const device T* biases;
|
|
604
|
+
|
|
605
|
+
QuantizedBlockLoader(
|
|
606
|
+
const device uint8_t* src_,
|
|
607
|
+
const device T* scales_,
|
|
608
|
+
const device T* biases_,
|
|
609
|
+
const int src_ld_,
|
|
610
|
+
threadgroup T* dst_,
|
|
611
|
+
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
612
|
+
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
|
613
|
+
: src_ld(src_ld_),
|
|
614
|
+
tile_stride(
|
|
615
|
+
reduction_dim ? BCOLS_PACKED * bytes_per_pack
|
|
616
|
+
: BROWS * src_ld * bytes_per_pack / pack_factor),
|
|
617
|
+
group_step_cnt(0),
|
|
618
|
+
group_stride(BROWS * src_ld / group_size),
|
|
619
|
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
620
|
+
bi(n_reads * thread_idx / BCOLS_PACKED),
|
|
621
|
+
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
|
622
|
+
dst(dst_ + bi * dst_ld + bj * pack_factor),
|
|
623
|
+
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
|
|
624
|
+
bj * bytes_per_pack),
|
|
625
|
+
scales(scales_ + bi * src_ld / group_size),
|
|
626
|
+
biases(biases_ + bi * src_ld / group_size) {}
|
|
627
|
+
|
|
628
|
+
void load_unsafe() const {
|
|
629
|
+
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
|
630
|
+
return;
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
T scale = *scales;
|
|
634
|
+
T bias = *biases;
|
|
635
|
+
for (int i = 0; i < n_reads; i++) {
|
|
636
|
+
dequantize<T, pack_factor, bits>(
|
|
637
|
+
src + i * bytes_per_pack, scale, bias, dst + i * pack_factor);
|
|
638
|
+
}
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
void load_safe(short2 src_tile_dim) const {
|
|
642
|
+
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
|
643
|
+
return;
|
|
644
|
+
}
|
|
645
|
+
|
|
646
|
+
if (reduction_dim == 1 && bi >= src_tile_dim.x) {
|
|
647
|
+
for (int i = 0; i < n_reads * pack_factor; i++) {
|
|
648
|
+
dst[i] = T(0);
|
|
649
|
+
}
|
|
650
|
+
return;
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
if (reduction_dim == 0 && bi >= src_tile_dim.y) {
|
|
654
|
+
for (int i = 0; i < n_reads * pack_factor; i++) {
|
|
655
|
+
dst[i] = T(0);
|
|
656
|
+
}
|
|
657
|
+
return;
|
|
658
|
+
}
|
|
659
|
+
|
|
660
|
+
T scale = *scales;
|
|
661
|
+
T bias = *biases;
|
|
662
|
+
for (int i = 0; i < n_reads; i++) {
|
|
663
|
+
dequantize<T, pack_factor, bits>(
|
|
664
|
+
(device uint8_t*)(src + i * bytes_per_pack),
|
|
665
|
+
scale,
|
|
666
|
+
bias,
|
|
667
|
+
dst + i * pack_factor);
|
|
668
|
+
}
|
|
669
|
+
}
|
|
670
|
+
|
|
671
|
+
void next() {
|
|
672
|
+
src += tile_stride;
|
|
673
|
+
if (reduction_dim == 1) {
|
|
674
|
+
if (group_steps > 1) {
|
|
675
|
+
group_step_cnt++;
|
|
676
|
+
if (group_step_cnt == group_steps) {
|
|
677
|
+
group_step_cnt = 0;
|
|
678
|
+
scales++;
|
|
679
|
+
biases++;
|
|
680
|
+
}
|
|
681
|
+
} else {
|
|
682
|
+
scales++;
|
|
683
|
+
biases++;
|
|
684
|
+
}
|
|
685
|
+
} else {
|
|
686
|
+
scales += group_stride;
|
|
687
|
+
biases += group_stride;
|
|
688
|
+
}
|
|
689
|
+
}
|
|
690
|
+
};
|
|
691
|
+
|
|
692
|
+
template <typename T, int group_size, int bits, int D>
|
|
693
|
+
METAL_FUNC void qmv_quad_impl(
|
|
694
|
+
const device uint32_t* w,
|
|
695
|
+
const device T* scales,
|
|
696
|
+
const device T* biases,
|
|
697
|
+
const device T* x,
|
|
698
|
+
device T* y,
|
|
699
|
+
constant int& in_vec_size,
|
|
700
|
+
const constant int& out_vec_size,
|
|
701
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
702
|
+
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
|
703
|
+
uint quad_lid [[thread_index_in_quadgroup]]) {
|
|
704
|
+
constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE;
|
|
705
|
+
constexpr int pack_factor = 32 / bits;
|
|
706
|
+
constexpr int values_per_thread = D / QUAD_SIZE;
|
|
707
|
+
constexpr int packs_per_thread = values_per_thread / pack_factor;
|
|
708
|
+
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
|
709
|
+
constexpr int results_per_quadgroup = 8;
|
|
710
|
+
|
|
711
|
+
typedef float U;
|
|
712
|
+
|
|
713
|
+
thread U x_thread[values_per_thread];
|
|
714
|
+
thread U result[results_per_quadgroup] = {0};
|
|
715
|
+
|
|
716
|
+
// Adjust positions
|
|
717
|
+
const int in_vec_size_w = in_vec_size / pack_factor;
|
|
718
|
+
const int in_vec_size_g = in_vec_size / group_size;
|
|
719
|
+
const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid;
|
|
720
|
+
|
|
721
|
+
w += out_row * in_vec_size_w + quad_lid * packs_per_thread;
|
|
722
|
+
scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
|
723
|
+
biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread;
|
|
724
|
+
x += tid.x * in_vec_size + quad_lid * values_per_thread;
|
|
725
|
+
y += tid.x * out_vec_size + out_row;
|
|
726
|
+
|
|
727
|
+
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
|
728
|
+
|
|
729
|
+
for (int row = 0; row < results_per_quadgroup; row++) {
|
|
730
|
+
auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd);
|
|
731
|
+
const device T* sl = scales + row * in_vec_size_g * quads_per_simd;
|
|
732
|
+
const device T* bl = biases + row * in_vec_size_g * quads_per_simd;
|
|
733
|
+
|
|
734
|
+
U s = sl[0];
|
|
735
|
+
U b = bl[0];
|
|
736
|
+
if (row * quads_per_simd + out_row < out_vec_size) {
|
|
737
|
+
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
738
|
+
}
|
|
739
|
+
}
|
|
740
|
+
|
|
741
|
+
for (int row = 0; row < results_per_quadgroup; row++) {
|
|
742
|
+
result[row] = quad_sum(result[row]);
|
|
743
|
+
if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) {
|
|
744
|
+
y[row * quads_per_simd] = static_cast<T>(result[row]);
|
|
745
|
+
}
|
|
746
|
+
}
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
template <typename T, int group_size, int bits>
|
|
750
|
+
METAL_FUNC void qmv_fast_impl(
|
|
751
|
+
const device uint32_t* w,
|
|
752
|
+
const device T* scales,
|
|
753
|
+
const device T* biases,
|
|
754
|
+
const device T* x,
|
|
755
|
+
device T* y,
|
|
756
|
+
const constant int& in_vec_size,
|
|
757
|
+
const constant int& out_vec_size,
|
|
758
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
759
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
760
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
761
|
+
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
|
|
762
|
+
constexpr int num_simdgroups = 2;
|
|
763
|
+
constexpr int results_per_simdgroup = 4;
|
|
764
|
+
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
|
765
|
+
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
|
|
766
|
+
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
|
767
|
+
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
|
768
|
+
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
|
769
|
+
|
|
770
|
+
const device uint8_t* ws = (const device uint8_t*)w;
|
|
771
|
+
|
|
772
|
+
typedef float U;
|
|
773
|
+
|
|
774
|
+
thread U x_thread[values_per_thread];
|
|
775
|
+
thread U result[results_per_simdgroup] = {0};
|
|
776
|
+
|
|
777
|
+
// Adjust positions
|
|
778
|
+
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
|
779
|
+
const int in_vec_size_g = in_vec_size / group_size;
|
|
780
|
+
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
|
781
|
+
simd_gid * results_per_simdgroup;
|
|
782
|
+
|
|
783
|
+
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
|
784
|
+
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
785
|
+
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
786
|
+
x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
|
787
|
+
y += tid.x * out_vec_size + out_row;
|
|
788
|
+
|
|
789
|
+
for (int k = 0; k < in_vec_size; k += block_size) {
|
|
790
|
+
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
|
791
|
+
|
|
792
|
+
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
793
|
+
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
794
|
+
const device T* sl = scales + row * in_vec_size_g;
|
|
795
|
+
const device T* bl = biases + row * in_vec_size_g;
|
|
796
|
+
|
|
797
|
+
U s = sl[0];
|
|
798
|
+
U b = bl[0];
|
|
799
|
+
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
ws += block_size * bytes_per_pack / pack_factor;
|
|
803
|
+
scales += block_size / group_size;
|
|
804
|
+
biases += block_size / group_size;
|
|
805
|
+
x += block_size;
|
|
806
|
+
}
|
|
807
|
+
|
|
808
|
+
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
809
|
+
result[row] = simd_sum(result[row]);
|
|
810
|
+
if (simd_lid == 0) {
|
|
811
|
+
y[row] = static_cast<T>(result[row]);
|
|
812
|
+
}
|
|
813
|
+
}
|
|
814
|
+
}
|
|
815
|
+
|
|
816
|
+
template <typename T, int group_size, int bits>
|
|
817
|
+
METAL_FUNC void qmv_impl(
|
|
818
|
+
const device uint32_t* w,
|
|
819
|
+
const device T* scales,
|
|
820
|
+
const device T* biases,
|
|
821
|
+
const device T* x,
|
|
822
|
+
device T* y,
|
|
823
|
+
const constant int& in_vec_size,
|
|
824
|
+
const constant int& out_vec_size,
|
|
825
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
826
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
827
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
828
|
+
constexpr int num_simdgroups = 2;
|
|
829
|
+
constexpr int results_per_simdgroup = 4;
|
|
830
|
+
constexpr int packs_per_thread = 1;
|
|
831
|
+
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
|
832
|
+
constexpr int bytes_per_pack = get_bytes_per_pack<bits, 32>();
|
|
833
|
+
|
|
834
|
+
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
|
835
|
+
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
|
836
|
+
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
|
837
|
+
|
|
838
|
+
const device uint8_t* ws = (const device uint8_t*)w;
|
|
839
|
+
|
|
840
|
+
typedef float U;
|
|
841
|
+
|
|
842
|
+
thread U x_thread[values_per_thread];
|
|
843
|
+
thread U result[results_per_simdgroup] = {0};
|
|
844
|
+
|
|
845
|
+
// Adjust positions
|
|
846
|
+
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
|
847
|
+
const int in_vec_size_g = in_vec_size / group_size;
|
|
848
|
+
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
|
849
|
+
simd_gid * results_per_simdgroup;
|
|
850
|
+
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
|
851
|
+
|
|
852
|
+
if (out_row >= out_vec_size) {
|
|
853
|
+
return;
|
|
854
|
+
}
|
|
855
|
+
|
|
856
|
+
// In this case we need to properly guard all our reads because there isn't
|
|
857
|
+
// even 1 tile in the matrix
|
|
858
|
+
if (out_vec_size < (num_simdgroups * results_per_simdgroup)) {
|
|
859
|
+
ws +=
|
|
860
|
+
out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
|
861
|
+
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
862
|
+
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
863
|
+
x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
|
864
|
+
y += tid.x * out_vec_size + out_row;
|
|
865
|
+
|
|
866
|
+
int k = 0;
|
|
867
|
+
for (; k < in_vec_size - block_size; k += block_size) {
|
|
868
|
+
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
|
869
|
+
|
|
870
|
+
for (int row = 0; out_row + row < out_vec_size; row++) {
|
|
871
|
+
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
872
|
+
const device T* sl = scales + row * in_vec_size_g;
|
|
873
|
+
const device T* bl = biases + row * in_vec_size_g;
|
|
874
|
+
|
|
875
|
+
U s = sl[0];
|
|
876
|
+
U b = bl[0];
|
|
877
|
+
result[row] +=
|
|
878
|
+
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
879
|
+
}
|
|
880
|
+
|
|
881
|
+
ws += block_size * bytes_per_pack / pack_factor;
|
|
882
|
+
scales += block_size / group_size;
|
|
883
|
+
biases += block_size / group_size;
|
|
884
|
+
x += block_size;
|
|
885
|
+
}
|
|
886
|
+
const int remaining = clamp(
|
|
887
|
+
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
|
888
|
+
0,
|
|
889
|
+
values_per_thread);
|
|
890
|
+
if (remaining > 0) {
|
|
891
|
+
U sum = load_vector_safe<T, U, values_per_thread, bits>(
|
|
892
|
+
x, x_thread, remaining);
|
|
893
|
+
|
|
894
|
+
for (int row = 0; out_row + row < out_vec_size; row++) {
|
|
895
|
+
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
896
|
+
const device T* sl = scales + row * in_vec_size_g;
|
|
897
|
+
const device T* bl = biases + row * in_vec_size_g;
|
|
898
|
+
|
|
899
|
+
U s = sl[0];
|
|
900
|
+
U b = bl[0];
|
|
901
|
+
result[row] +=
|
|
902
|
+
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
903
|
+
}
|
|
904
|
+
}
|
|
905
|
+
|
|
906
|
+
for (int row = 0; out_row + row < out_vec_size; row++) {
|
|
907
|
+
result[row] = simd_sum(result[row]);
|
|
908
|
+
if (simd_lid == 0) {
|
|
909
|
+
y[row] = static_cast<T>(result[row]);
|
|
910
|
+
}
|
|
911
|
+
}
|
|
912
|
+
}
|
|
913
|
+
|
|
914
|
+
// In this case the last tile is moved back to redo some output values
|
|
915
|
+
else {
|
|
916
|
+
ws += used_out_row * in_vec_size_w +
|
|
917
|
+
simd_lid * packs_per_thread * bytes_per_pack;
|
|
918
|
+
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
919
|
+
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
|
920
|
+
x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
|
921
|
+
y += tid.x * out_vec_size + used_out_row;
|
|
922
|
+
|
|
923
|
+
int k = 0;
|
|
924
|
+
for (; k < in_vec_size - block_size; k += block_size) {
|
|
925
|
+
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
|
926
|
+
|
|
927
|
+
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
928
|
+
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
929
|
+
const device T* sl = scales + row * in_vec_size_g;
|
|
930
|
+
const device T* bl = biases + row * in_vec_size_g;
|
|
931
|
+
|
|
932
|
+
U s = sl[0];
|
|
933
|
+
U b = bl[0];
|
|
934
|
+
result[row] +=
|
|
935
|
+
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
|
936
|
+
}
|
|
937
|
+
|
|
938
|
+
ws += block_size * bytes_per_pack / pack_factor;
|
|
939
|
+
scales += block_size / group_size;
|
|
940
|
+
biases += block_size / group_size;
|
|
941
|
+
x += block_size;
|
|
942
|
+
}
|
|
943
|
+
const int remaining = clamp(
|
|
944
|
+
static_cast<int>(in_vec_size - k - simd_lid * values_per_thread),
|
|
945
|
+
0,
|
|
946
|
+
values_per_thread);
|
|
947
|
+
if (remaining > 0) {
|
|
948
|
+
U sum = load_vector_safe<T, U, values_per_thread, bits>(
|
|
949
|
+
x, x_thread, remaining);
|
|
950
|
+
|
|
951
|
+
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
952
|
+
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
|
953
|
+
const device T* sl = scales + row * in_vec_size_g;
|
|
954
|
+
const device T* bl = biases + row * in_vec_size_g;
|
|
955
|
+
|
|
956
|
+
U s = sl[0];
|
|
957
|
+
U b = bl[0];
|
|
958
|
+
result[row] += qdot_safe<U, values_per_thread, bits>(
|
|
959
|
+
wl, x_thread, s, b, sum, remaining);
|
|
960
|
+
}
|
|
961
|
+
}
|
|
962
|
+
for (int row = 0; row < results_per_simdgroup; row++) {
|
|
963
|
+
result[row] = simd_sum(result[row]);
|
|
964
|
+
if (simd_lid == 0) {
|
|
965
|
+
y[row] = static_cast<T>(result[row]);
|
|
966
|
+
}
|
|
967
|
+
}
|
|
968
|
+
}
|
|
969
|
+
}
|
|
970
|
+
|
|
971
|
+
template <typename T, const int group_size, const int bits>
|
|
972
|
+
METAL_FUNC void qvm_impl(
|
|
973
|
+
const device uint32_t* w,
|
|
974
|
+
const device T* scales,
|
|
975
|
+
const device T* biases,
|
|
976
|
+
const device T* x,
|
|
977
|
+
device T* y,
|
|
978
|
+
const int in_vec_size,
|
|
979
|
+
const int out_vec_size,
|
|
980
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
981
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
982
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
983
|
+
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
984
|
+
constexpr int num_simdgroups = 2;
|
|
985
|
+
constexpr int pack_factor = get_pack_factor<bits, 32>();
|
|
986
|
+
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
987
|
+
|
|
988
|
+
constexpr int tn = 32 / pack_factor;
|
|
989
|
+
constexpr int block_size = SIMD_SIZE;
|
|
990
|
+
|
|
991
|
+
using W_T =
|
|
992
|
+
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
|
993
|
+
const device W_T* ws = (const device W_T*)w;
|
|
994
|
+
|
|
995
|
+
typedef float U;
|
|
996
|
+
typedef struct {
|
|
997
|
+
W_T wi[tn * bytes_per_pack];
|
|
998
|
+
} vec_w;
|
|
999
|
+
|
|
1000
|
+
thread vec_w w_local;
|
|
1001
|
+
thread U result[tn * pack_factor] = {0};
|
|
1002
|
+
thread U scale = 1;
|
|
1003
|
+
thread U bias = 0;
|
|
1004
|
+
thread U x_local = 0;
|
|
1005
|
+
|
|
1006
|
+
// Adjust positions
|
|
1007
|
+
const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor;
|
|
1008
|
+
const int out_vec_size_g = out_vec_size / group_size;
|
|
1009
|
+
int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid);
|
|
1010
|
+
ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w;
|
|
1011
|
+
scales += out_col / group_size + simd_lid * out_vec_size_g;
|
|
1012
|
+
biases += out_col / group_size + simd_lid * out_vec_size_g;
|
|
1013
|
+
x += tid.x * in_vec_size + simd_lid;
|
|
1014
|
+
y += tid.x * out_vec_size + out_col;
|
|
1015
|
+
|
|
1016
|
+
if (out_col >= out_vec_size) {
|
|
1017
|
+
return;
|
|
1018
|
+
}
|
|
1019
|
+
|
|
1020
|
+
// Loop over in_vec in blocks of block_size
|
|
1021
|
+
int remaining = in_vec_size % block_size;
|
|
1022
|
+
if (remaining == 0) {
|
|
1023
|
+
for (int i = 0; i < in_vec_size; i += block_size) {
|
|
1024
|
+
x_local = *x;
|
|
1025
|
+
scale = *scales;
|
|
1026
|
+
bias = *biases;
|
|
1027
|
+
w_local = *((device vec_w*)ws);
|
|
1028
|
+
qouter<U, tn * pack_factor, bits>(
|
|
1029
|
+
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
|
1030
|
+
|
|
1031
|
+
x += block_size;
|
|
1032
|
+
scales += block_size * out_vec_size_g;
|
|
1033
|
+
biases += block_size * out_vec_size_g;
|
|
1034
|
+
ws += block_size * out_vec_size_w;
|
|
1035
|
+
}
|
|
1036
|
+
} else {
|
|
1037
|
+
for (int i = block_size; i < in_vec_size; i += block_size) {
|
|
1038
|
+
x_local = *x;
|
|
1039
|
+
scale = *scales;
|
|
1040
|
+
bias = *biases;
|
|
1041
|
+
w_local = *((device vec_w*)ws);
|
|
1042
|
+
|
|
1043
|
+
qouter<U, tn * pack_factor, bits>(
|
|
1044
|
+
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
|
1045
|
+
|
|
1046
|
+
x += block_size;
|
|
1047
|
+
scales += block_size * out_vec_size_g;
|
|
1048
|
+
biases += block_size * out_vec_size_g;
|
|
1049
|
+
ws += block_size * out_vec_size_w;
|
|
1050
|
+
}
|
|
1051
|
+
if (static_cast<int>(simd_lid) < remaining) {
|
|
1052
|
+
x_local = *x;
|
|
1053
|
+
scale = *scales;
|
|
1054
|
+
bias = *biases;
|
|
1055
|
+
w_local = *((device vec_w*)ws);
|
|
1056
|
+
} else {
|
|
1057
|
+
x_local = 0;
|
|
1058
|
+
scale = 0;
|
|
1059
|
+
bias = 0;
|
|
1060
|
+
}
|
|
1061
|
+
qouter<U, tn * pack_factor, bits>(
|
|
1062
|
+
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
|
1063
|
+
}
|
|
1064
|
+
|
|
1065
|
+
// Accumulate in the simdgroup
|
|
1066
|
+
#pragma clang loop unroll(full)
|
|
1067
|
+
for (int k = 0; k < tn * pack_factor; k++) {
|
|
1068
|
+
result[k] = simd_sum(result[k]);
|
|
1069
|
+
}
|
|
1070
|
+
|
|
1071
|
+
// Store the result
|
|
1072
|
+
if (simd_lid == 0) {
|
|
1073
|
+
#pragma clang loop unroll(full)
|
|
1074
|
+
for (int k = 0; k < tn * pack_factor; k++) {
|
|
1075
|
+
y[k] = static_cast<T>(result[k]);
|
|
1076
|
+
}
|
|
1077
|
+
}
|
|
1078
|
+
}
|
|
1079
|
+
|
|
1080
|
+
template <
|
|
1081
|
+
typename T,
|
|
1082
|
+
const int group_size,
|
|
1083
|
+
const int bits,
|
|
1084
|
+
const bool aligned_N,
|
|
1085
|
+
const int BM = 32,
|
|
1086
|
+
const int BK = 32,
|
|
1087
|
+
const int BN = 32>
|
|
1088
|
+
METAL_FUNC void qmm_t_impl(
|
|
1089
|
+
const device uint32_t* w,
|
|
1090
|
+
const device T* scales,
|
|
1091
|
+
const device T* biases,
|
|
1092
|
+
const device T* x,
|
|
1093
|
+
device T* y,
|
|
1094
|
+
threadgroup T* Xs,
|
|
1095
|
+
threadgroup T* Ws,
|
|
1096
|
+
const constant int& K,
|
|
1097
|
+
const constant int& N,
|
|
1098
|
+
const constant int& M,
|
|
1099
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
1100
|
+
uint lid [[thread_index_in_threadgroup]],
|
|
1101
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
1102
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
1103
|
+
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
|
1104
|
+
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
|
1105
|
+
|
|
1106
|
+
(void)lid;
|
|
1107
|
+
|
|
1108
|
+
constexpr int WM = 2;
|
|
1109
|
+
constexpr int WN = 2;
|
|
1110
|
+
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
1111
|
+
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
1112
|
+
|
|
1113
|
+
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
1114
|
+
|
|
1115
|
+
// Instantiate the appropriate BlockMMA and Loader
|
|
1116
|
+
using mma_t = mlx::steel::
|
|
1117
|
+
BlockMMA<T, T, BM, BN, BK, WM, WN, false, true, BK_padded, BK_padded>;
|
|
1118
|
+
using loader_x_t =
|
|
1119
|
+
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
|
1120
|
+
using loader_w_t = QuantizedBlockLoader<
|
|
1121
|
+
T,
|
|
1122
|
+
BN,
|
|
1123
|
+
BK,
|
|
1124
|
+
BK_padded,
|
|
1125
|
+
1,
|
|
1126
|
+
WM * WN * SIMD_SIZE,
|
|
1127
|
+
group_size,
|
|
1128
|
+
bits>;
|
|
1129
|
+
|
|
1130
|
+
// Set the block
|
|
1131
|
+
const int K_w = K * bytes_per_pack / pack_factor;
|
|
1132
|
+
const int K_g = K / group_size;
|
|
1133
|
+
const int y_row = tid.y * BM;
|
|
1134
|
+
const int y_col = tid.x * BN;
|
|
1135
|
+
|
|
1136
|
+
auto wl = (const device uint8_t*)w;
|
|
1137
|
+
|
|
1138
|
+
x += y_row * static_cast<int64_t>(K);
|
|
1139
|
+
wl += y_col * K_w;
|
|
1140
|
+
scales += y_col * K_g;
|
|
1141
|
+
biases += y_col * K_g;
|
|
1142
|
+
y += y_row * static_cast<int64_t>(N) + y_col;
|
|
1143
|
+
|
|
1144
|
+
// Make the x loader and mma operation
|
|
1145
|
+
const short num_els = min(BM, M - y_row);
|
|
1146
|
+
const short num_outs = min(BN, N - y_col);
|
|
1147
|
+
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
|
1148
|
+
loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid);
|
|
1149
|
+
mma_t mma_op(simd_gid, simd_lid);
|
|
1150
|
+
|
|
1151
|
+
if (num_els < BM) {
|
|
1152
|
+
if (!aligned_N && num_outs < BN) {
|
|
1153
|
+
for (int k = 0; k < K; k += BK) {
|
|
1154
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1155
|
+
loader_x.load_safe(short2(BK, num_els));
|
|
1156
|
+
loader_w.load_safe(short2(BK, num_outs));
|
|
1157
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1158
|
+
mma_op.mma(Xs, Ws);
|
|
1159
|
+
loader_x.next();
|
|
1160
|
+
loader_w.next();
|
|
1161
|
+
}
|
|
1162
|
+
} else {
|
|
1163
|
+
for (int k = 0; k < K; k += BK) {
|
|
1164
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1165
|
+
loader_x.load_safe(short2(BK, num_els));
|
|
1166
|
+
loader_w.load_unsafe();
|
|
1167
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1168
|
+
mma_op.mma(Xs, Ws);
|
|
1169
|
+
loader_x.next();
|
|
1170
|
+
loader_w.next();
|
|
1171
|
+
}
|
|
1172
|
+
}
|
|
1173
|
+
} else {
|
|
1174
|
+
if (!aligned_N && num_outs < BN) {
|
|
1175
|
+
for (int k = 0; k < K; k += BK) {
|
|
1176
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1177
|
+
loader_x.load_unsafe();
|
|
1178
|
+
loader_w.load_safe(short2(BK, num_outs));
|
|
1179
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1180
|
+
mma_op.mma(Xs, Ws);
|
|
1181
|
+
loader_x.next();
|
|
1182
|
+
loader_w.next();
|
|
1183
|
+
}
|
|
1184
|
+
} else {
|
|
1185
|
+
for (int k = 0; k < K; k += BK) {
|
|
1186
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1187
|
+
loader_x.load_unsafe();
|
|
1188
|
+
loader_w.load_unsafe();
|
|
1189
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1190
|
+
|
|
1191
|
+
mma_op.mma(Xs, Ws);
|
|
1192
|
+
loader_x.next();
|
|
1193
|
+
loader_w.next();
|
|
1194
|
+
}
|
|
1195
|
+
}
|
|
1196
|
+
}
|
|
1197
|
+
|
|
1198
|
+
// Store results to device memory
|
|
1199
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1200
|
+
if (num_els < BM || num_outs < BN) {
|
|
1201
|
+
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
|
|
1202
|
+
} else {
|
|
1203
|
+
mma_op.store_result(y, N);
|
|
1204
|
+
}
|
|
1205
|
+
}
|
|
1206
|
+
|
|
1207
|
+
template <
|
|
1208
|
+
typename T,
|
|
1209
|
+
const int group_size,
|
|
1210
|
+
const int bits,
|
|
1211
|
+
const int BM = 32,
|
|
1212
|
+
const int BK = 32,
|
|
1213
|
+
const int BN = 32>
|
|
1214
|
+
METAL_FUNC void qmm_n_impl(
|
|
1215
|
+
const device uint32_t* w,
|
|
1216
|
+
const device T* scales,
|
|
1217
|
+
const device T* biases,
|
|
1218
|
+
const device T* x,
|
|
1219
|
+
device T* y,
|
|
1220
|
+
threadgroup T* Xs,
|
|
1221
|
+
threadgroup T* Ws,
|
|
1222
|
+
const constant int& K,
|
|
1223
|
+
const constant int& N,
|
|
1224
|
+
const constant int& M,
|
|
1225
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
1226
|
+
uint lid [[thread_index_in_threadgroup]],
|
|
1227
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
1228
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
1229
|
+
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
|
1230
|
+
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
|
1231
|
+
|
|
1232
|
+
(void)lid;
|
|
1233
|
+
|
|
1234
|
+
constexpr int WM = 2;
|
|
1235
|
+
constexpr int WN = 2;
|
|
1236
|
+
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
1237
|
+
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
1238
|
+
|
|
1239
|
+
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
1240
|
+
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
1241
|
+
|
|
1242
|
+
// Instantiate the appropriate BlockMMA and Loader
|
|
1243
|
+
using mma_t = mlx::steel::
|
|
1244
|
+
BlockMMA<T, T, BM, BN, BK, WM, WN, false, false, BK_padded, BN_padded>;
|
|
1245
|
+
using loader_x_t = mlx::steel::
|
|
1246
|
+
BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE, 1, 4>;
|
|
1247
|
+
using loader_w_t = QuantizedBlockLoader<
|
|
1248
|
+
T,
|
|
1249
|
+
BK,
|
|
1250
|
+
BN,
|
|
1251
|
+
BN_padded,
|
|
1252
|
+
0,
|
|
1253
|
+
WM * WN * SIMD_SIZE,
|
|
1254
|
+
group_size,
|
|
1255
|
+
bits>;
|
|
1256
|
+
|
|
1257
|
+
auto wl = (const device uint8_t*)w;
|
|
1258
|
+
|
|
1259
|
+
// Set the block
|
|
1260
|
+
const int y_row = tid.y * BM;
|
|
1261
|
+
const int y_col = tid.x * BN;
|
|
1262
|
+
x += y_row * static_cast<int64_t>(K);
|
|
1263
|
+
wl += y_col * bytes_per_pack / pack_factor;
|
|
1264
|
+
scales += y_col / group_size;
|
|
1265
|
+
biases += y_col / group_size;
|
|
1266
|
+
y += y_row * static_cast<int64_t>(N) + y_col;
|
|
1267
|
+
|
|
1268
|
+
// Make the x loader and mma operation
|
|
1269
|
+
const short num_els = min(BM, M - y_row);
|
|
1270
|
+
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
|
1271
|
+
loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid);
|
|
1272
|
+
mma_t mma_op(simd_gid, simd_lid);
|
|
1273
|
+
|
|
1274
|
+
if (num_els < BM) {
|
|
1275
|
+
if ((K % BK) != 0) {
|
|
1276
|
+
const int k_blocks = K / BK;
|
|
1277
|
+
for (int k = 0; k < k_blocks; k++) {
|
|
1278
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1279
|
+
loader_x.load_safe(short2(BK, num_els));
|
|
1280
|
+
loader_w.load_unsafe();
|
|
1281
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1282
|
+
mma_op.mma(Xs, Ws);
|
|
1283
|
+
loader_x.next();
|
|
1284
|
+
loader_w.next();
|
|
1285
|
+
}
|
|
1286
|
+
const short num_k = K - k_blocks * BK;
|
|
1287
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1288
|
+
loader_x.load_safe(short2(num_k, num_els));
|
|
1289
|
+
loader_w.load_safe(short2(BN, num_k));
|
|
1290
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1291
|
+
mma_op.mma(Xs, Ws);
|
|
1292
|
+
} else {
|
|
1293
|
+
for (int k = 0; k < K; k += BK) {
|
|
1294
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1295
|
+
loader_x.load_safe(short2(BK, num_els));
|
|
1296
|
+
loader_w.load_unsafe();
|
|
1297
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1298
|
+
mma_op.mma(Xs, Ws);
|
|
1299
|
+
loader_x.next();
|
|
1300
|
+
loader_w.next();
|
|
1301
|
+
}
|
|
1302
|
+
}
|
|
1303
|
+
} else {
|
|
1304
|
+
if ((K % BK) != 0) {
|
|
1305
|
+
const int k_blocks = K / BK;
|
|
1306
|
+
for (int k = 0; k < k_blocks; k++) {
|
|
1307
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1308
|
+
loader_x.load_unsafe();
|
|
1309
|
+
loader_w.load_unsafe();
|
|
1310
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1311
|
+
mma_op.mma(Xs, Ws);
|
|
1312
|
+
loader_x.next();
|
|
1313
|
+
loader_w.next();
|
|
1314
|
+
}
|
|
1315
|
+
const short num_k = K - k_blocks * BK;
|
|
1316
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1317
|
+
loader_x.load_safe(short2(num_k, BM));
|
|
1318
|
+
loader_w.load_safe(short2(BN, num_k));
|
|
1319
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1320
|
+
mma_op.mma(Xs, Ws);
|
|
1321
|
+
} else {
|
|
1322
|
+
for (int k = 0; k < K; k += BK) {
|
|
1323
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1324
|
+
loader_x.load_unsafe();
|
|
1325
|
+
loader_w.load_unsafe();
|
|
1326
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1327
|
+
mma_op.mma(Xs, Ws);
|
|
1328
|
+
loader_x.next();
|
|
1329
|
+
loader_w.next();
|
|
1330
|
+
}
|
|
1331
|
+
}
|
|
1332
|
+
}
|
|
1333
|
+
|
|
1334
|
+
// Store results to device memory
|
|
1335
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1336
|
+
if (num_els < BM) {
|
|
1337
|
+
mma_op.store_result_safe(y, N, short2(BN, num_els));
|
|
1338
|
+
} else {
|
|
1339
|
+
mma_op.store_result(y, N);
|
|
1340
|
+
}
|
|
1341
|
+
}
|
|
1342
|
+
|
|
1343
|
+
template <typename T>
|
|
1344
|
+
METAL_FUNC void adjust_matrix_offsets(
|
|
1345
|
+
const device T*& x,
|
|
1346
|
+
const device uint32_t*& w,
|
|
1347
|
+
const device T*& scales,
|
|
1348
|
+
const device T*& biases,
|
|
1349
|
+
device T*& y,
|
|
1350
|
+
int output_stride,
|
|
1351
|
+
const constant int& x_batch_ndims,
|
|
1352
|
+
const constant int* x_shape,
|
|
1353
|
+
const constant int64_t* x_strides,
|
|
1354
|
+
const constant int& w_batch_ndims,
|
|
1355
|
+
const constant int* w_shape,
|
|
1356
|
+
const constant int64_t* w_strides,
|
|
1357
|
+
const constant int64_t* s_strides,
|
|
1358
|
+
const constant int64_t* b_strides,
|
|
1359
|
+
uint3 tid [[threadgroup_position_in_grid]]) {
|
|
1360
|
+
// Set the input/output matrices
|
|
1361
|
+
uint32_t x_idx = tid.z;
|
|
1362
|
+
uint32_t w_idx = tid.z;
|
|
1363
|
+
if (x_batch_ndims == 1) {
|
|
1364
|
+
x += x_idx * x_strides[0];
|
|
1365
|
+
} else {
|
|
1366
|
+
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
|
1367
|
+
}
|
|
1368
|
+
if (w_batch_ndims == 1) {
|
|
1369
|
+
w += w_idx * w_strides[0];
|
|
1370
|
+
scales += w_idx * s_strides[0];
|
|
1371
|
+
biases += w_idx * b_strides[0];
|
|
1372
|
+
} else {
|
|
1373
|
+
ulong3 idx = elem_to_loc_broadcast(
|
|
1374
|
+
w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
|
|
1375
|
+
w += idx.x;
|
|
1376
|
+
scales += idx.y;
|
|
1377
|
+
biases += idx.z;
|
|
1378
|
+
}
|
|
1379
|
+
y += tid.z * output_stride;
|
|
1380
|
+
}
|
|
1381
|
+
|
|
1382
|
+
template <typename T>
|
|
1383
|
+
METAL_FUNC void adjust_matrix_offsets(
|
|
1384
|
+
const device T*& x,
|
|
1385
|
+
const device uint32_t*& w,
|
|
1386
|
+
const device T*& scales,
|
|
1387
|
+
const device T*& biases,
|
|
1388
|
+
const device uint32_t* lhs_indices,
|
|
1389
|
+
const device uint32_t* rhs_indices,
|
|
1390
|
+
device T*& y,
|
|
1391
|
+
int output_stride,
|
|
1392
|
+
const constant int& batch_ndims,
|
|
1393
|
+
const constant int* batch_shape,
|
|
1394
|
+
const constant int64_t* lhs_strides,
|
|
1395
|
+
const constant int64_t* rhs_strides,
|
|
1396
|
+
const constant int& x_batch_ndims,
|
|
1397
|
+
const constant int* x_shape,
|
|
1398
|
+
const constant int64_t* x_strides,
|
|
1399
|
+
const constant int& w_batch_ndims,
|
|
1400
|
+
const constant int* w_shape,
|
|
1401
|
+
const constant int64_t* w_strides,
|
|
1402
|
+
const constant int64_t* s_strides,
|
|
1403
|
+
const constant int64_t* b_strides,
|
|
1404
|
+
uint3 tid [[threadgroup_position_in_grid]]) {
|
|
1405
|
+
// Set the input/output matrices
|
|
1406
|
+
uint32_t x_idx;
|
|
1407
|
+
uint32_t w_idx;
|
|
1408
|
+
if (batch_ndims == 1) {
|
|
1409
|
+
x_idx = lhs_indices[tid.z * lhs_strides[0]];
|
|
1410
|
+
w_idx = rhs_indices[tid.z * rhs_strides[0]];
|
|
1411
|
+
} else {
|
|
1412
|
+
ulong2 idx = elem_to_loc_broadcast(
|
|
1413
|
+
tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
|
|
1414
|
+
x_idx = lhs_indices[idx.x];
|
|
1415
|
+
w_idx = rhs_indices[idx.y];
|
|
1416
|
+
}
|
|
1417
|
+
if (x_batch_ndims == 1) {
|
|
1418
|
+
x += x_idx * x_strides[0];
|
|
1419
|
+
} else {
|
|
1420
|
+
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
|
1421
|
+
}
|
|
1422
|
+
if (w_batch_ndims == 1) {
|
|
1423
|
+
w += w_idx * w_strides[0];
|
|
1424
|
+
scales += w_idx * s_strides[0];
|
|
1425
|
+
biases += w_idx * b_strides[0];
|
|
1426
|
+
} else {
|
|
1427
|
+
ulong3 idx = elem_to_loc_broadcast(
|
|
1428
|
+
w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
|
|
1429
|
+
w += idx.x;
|
|
1430
|
+
scales += idx.y;
|
|
1431
|
+
biases += idx.z;
|
|
1432
|
+
}
|
|
1433
|
+
y += tid.z * output_stride;
|
|
1434
|
+
}
|
|
1435
|
+
|
|
1436
|
+
template <typename T, int group_size, int bits, int D, bool batched>
|
|
1437
|
+
[[kernel]] void affine_qmv_quad(
|
|
1438
|
+
const device uint32_t* w [[buffer(0)]],
|
|
1439
|
+
const device T* scales [[buffer(1)]],
|
|
1440
|
+
const device T* biases [[buffer(2)]],
|
|
1441
|
+
const device T* x [[buffer(3)]],
|
|
1442
|
+
device T* y [[buffer(4)]],
|
|
1443
|
+
const constant int& in_vec_size [[buffer(5)]],
|
|
1444
|
+
const constant int& out_vec_size [[buffer(6)]],
|
|
1445
|
+
const constant int& x_batch_ndims [[buffer(7)]],
|
|
1446
|
+
const constant int* x_shape [[buffer(8)]],
|
|
1447
|
+
const constant int64_t* x_strides [[buffer(9)]],
|
|
1448
|
+
const constant int& w_batch_ndims [[buffer(10)]],
|
|
1449
|
+
const constant int* w_shape [[buffer(11)]],
|
|
1450
|
+
const constant int64_t* w_strides [[buffer(12)]],
|
|
1451
|
+
const constant int64_t* s_strides [[buffer(13)]],
|
|
1452
|
+
const constant int64_t* b_strides [[buffer(14)]],
|
|
1453
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
1454
|
+
uint quad_gid [[quadgroup_index_in_threadgroup]],
|
|
1455
|
+
uint quad_lid [[thread_index_in_quadgroup]]) {
|
|
1456
|
+
if (batched) {
|
|
1457
|
+
int M = x_shape[x_batch_ndims];
|
|
1458
|
+
adjust_matrix_offsets<T>(
|
|
1459
|
+
x,
|
|
1460
|
+
w,
|
|
1461
|
+
scales,
|
|
1462
|
+
biases,
|
|
1463
|
+
y,
|
|
1464
|
+
out_vec_size * M,
|
|
1465
|
+
x_batch_ndims,
|
|
1466
|
+
x_shape,
|
|
1467
|
+
x_strides,
|
|
1468
|
+
w_batch_ndims,
|
|
1469
|
+
w_shape,
|
|
1470
|
+
w_strides,
|
|
1471
|
+
s_strides,
|
|
1472
|
+
b_strides,
|
|
1473
|
+
tid);
|
|
1474
|
+
}
|
|
1475
|
+
qmv_quad_impl<T, group_size, bits, D>(
|
|
1476
|
+
w,
|
|
1477
|
+
scales,
|
|
1478
|
+
biases,
|
|
1479
|
+
x,
|
|
1480
|
+
y,
|
|
1481
|
+
in_vec_size,
|
|
1482
|
+
out_vec_size,
|
|
1483
|
+
tid,
|
|
1484
|
+
quad_gid,
|
|
1485
|
+
quad_lid);
|
|
1486
|
+
}
|
|
1487
|
+
|
|
1488
|
+
template <typename T, int group_size, int bits, bool batched>
|
|
1489
|
+
[[kernel]] void affine_qmv_fast(
|
|
1490
|
+
const device uint32_t* w [[buffer(0)]],
|
|
1491
|
+
const device T* scales [[buffer(1)]],
|
|
1492
|
+
const device T* biases [[buffer(2)]],
|
|
1493
|
+
const device T* x [[buffer(3)]],
|
|
1494
|
+
device T* y [[buffer(4)]],
|
|
1495
|
+
const constant int& in_vec_size [[buffer(5)]],
|
|
1496
|
+
const constant int& out_vec_size [[buffer(6)]],
|
|
1497
|
+
const constant int& x_batch_ndims [[buffer(7)]],
|
|
1498
|
+
const constant int* x_shape [[buffer(8)]],
|
|
1499
|
+
const constant int64_t* x_strides [[buffer(9)]],
|
|
1500
|
+
const constant int& w_batch_ndims [[buffer(10)]],
|
|
1501
|
+
const constant int* w_shape [[buffer(11)]],
|
|
1502
|
+
const constant int64_t* w_strides [[buffer(12)]],
|
|
1503
|
+
const constant int64_t* s_strides [[buffer(13)]],
|
|
1504
|
+
const constant int64_t* b_strides [[buffer(14)]],
|
|
1505
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
1506
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
1507
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
1508
|
+
if (batched) {
|
|
1509
|
+
int M = x_shape[x_batch_ndims];
|
|
1510
|
+
adjust_matrix_offsets<T>(
|
|
1511
|
+
x,
|
|
1512
|
+
w,
|
|
1513
|
+
scales,
|
|
1514
|
+
biases,
|
|
1515
|
+
y,
|
|
1516
|
+
out_vec_size * M,
|
|
1517
|
+
x_batch_ndims,
|
|
1518
|
+
x_shape,
|
|
1519
|
+
x_strides,
|
|
1520
|
+
w_batch_ndims,
|
|
1521
|
+
w_shape,
|
|
1522
|
+
w_strides,
|
|
1523
|
+
s_strides,
|
|
1524
|
+
b_strides,
|
|
1525
|
+
tid);
|
|
1526
|
+
}
|
|
1527
|
+
qmv_fast_impl<T, group_size, bits>(
|
|
1528
|
+
w,
|
|
1529
|
+
scales,
|
|
1530
|
+
biases,
|
|
1531
|
+
x,
|
|
1532
|
+
y,
|
|
1533
|
+
in_vec_size,
|
|
1534
|
+
out_vec_size,
|
|
1535
|
+
tid,
|
|
1536
|
+
simd_gid,
|
|
1537
|
+
simd_lid);
|
|
1538
|
+
}
|
|
1539
|
+
|
|
1540
|
+
template <typename T, const int group_size, const int bits, bool batched>
|
|
1541
|
+
[[kernel]] void affine_qmv(
|
|
1542
|
+
const device uint32_t* w [[buffer(0)]],
|
|
1543
|
+
const device T* scales [[buffer(1)]],
|
|
1544
|
+
const device T* biases [[buffer(2)]],
|
|
1545
|
+
const device T* x [[buffer(3)]],
|
|
1546
|
+
device T* y [[buffer(4)]],
|
|
1547
|
+
const constant int& in_vec_size [[buffer(5)]],
|
|
1548
|
+
const constant int& out_vec_size [[buffer(6)]],
|
|
1549
|
+
const constant int& x_batch_ndims [[buffer(7)]],
|
|
1550
|
+
const constant int* x_shape [[buffer(8)]],
|
|
1551
|
+
const constant int64_t* x_strides [[buffer(9)]],
|
|
1552
|
+
const constant int& w_batch_ndims [[buffer(10)]],
|
|
1553
|
+
const constant int* w_shape [[buffer(11)]],
|
|
1554
|
+
const constant int64_t* w_strides [[buffer(12)]],
|
|
1555
|
+
const constant int64_t* s_strides [[buffer(13)]],
|
|
1556
|
+
const constant int64_t* b_strides [[buffer(14)]],
|
|
1557
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
1558
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
1559
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
1560
|
+
if (batched) {
|
|
1561
|
+
int M = x_shape[x_batch_ndims];
|
|
1562
|
+
adjust_matrix_offsets<T>(
|
|
1563
|
+
x,
|
|
1564
|
+
w,
|
|
1565
|
+
scales,
|
|
1566
|
+
biases,
|
|
1567
|
+
y,
|
|
1568
|
+
out_vec_size * M,
|
|
1569
|
+
x_batch_ndims,
|
|
1570
|
+
x_shape,
|
|
1571
|
+
x_strides,
|
|
1572
|
+
w_batch_ndims,
|
|
1573
|
+
w_shape,
|
|
1574
|
+
w_strides,
|
|
1575
|
+
s_strides,
|
|
1576
|
+
b_strides,
|
|
1577
|
+
tid);
|
|
1578
|
+
}
|
|
1579
|
+
qmv_impl<T, group_size, bits>(
|
|
1580
|
+
w,
|
|
1581
|
+
scales,
|
|
1582
|
+
biases,
|
|
1583
|
+
x,
|
|
1584
|
+
y,
|
|
1585
|
+
in_vec_size,
|
|
1586
|
+
out_vec_size,
|
|
1587
|
+
tid,
|
|
1588
|
+
simd_gid,
|
|
1589
|
+
simd_lid);
|
|
1590
|
+
}
|
|
1591
|
+
|
|
1592
|
+
template <typename T, const int group_size, const int bits, bool batched>
|
|
1593
|
+
[[kernel]] void affine_qvm(
|
|
1594
|
+
const device uint32_t* w [[buffer(0)]],
|
|
1595
|
+
const device T* scales [[buffer(1)]],
|
|
1596
|
+
const device T* biases [[buffer(2)]],
|
|
1597
|
+
const device T* x [[buffer(3)]],
|
|
1598
|
+
device T* y [[buffer(4)]],
|
|
1599
|
+
const constant int& in_vec_size [[buffer(5)]],
|
|
1600
|
+
const constant int& out_vec_size [[buffer(6)]],
|
|
1601
|
+
const constant int& x_batch_ndims [[buffer(7)]],
|
|
1602
|
+
const constant int* x_shape [[buffer(8)]],
|
|
1603
|
+
const constant int64_t* x_strides [[buffer(9)]],
|
|
1604
|
+
const constant int& w_batch_ndims [[buffer(10)]],
|
|
1605
|
+
const constant int* w_shape [[buffer(11)]],
|
|
1606
|
+
const constant int64_t* w_strides [[buffer(12)]],
|
|
1607
|
+
const constant int64_t* s_strides [[buffer(13)]],
|
|
1608
|
+
const constant int64_t* b_strides [[buffer(14)]],
|
|
1609
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
1610
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
1611
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
1612
|
+
if (batched) {
|
|
1613
|
+
int M = x_shape[x_batch_ndims];
|
|
1614
|
+
adjust_matrix_offsets<T>(
|
|
1615
|
+
x,
|
|
1616
|
+
w,
|
|
1617
|
+
scales,
|
|
1618
|
+
biases,
|
|
1619
|
+
y,
|
|
1620
|
+
out_vec_size * M,
|
|
1621
|
+
x_batch_ndims,
|
|
1622
|
+
x_shape,
|
|
1623
|
+
x_strides,
|
|
1624
|
+
w_batch_ndims,
|
|
1625
|
+
w_shape,
|
|
1626
|
+
w_strides,
|
|
1627
|
+
s_strides,
|
|
1628
|
+
b_strides,
|
|
1629
|
+
tid);
|
|
1630
|
+
}
|
|
1631
|
+
qvm_impl<T, group_size, bits>(
|
|
1632
|
+
w,
|
|
1633
|
+
scales,
|
|
1634
|
+
biases,
|
|
1635
|
+
x,
|
|
1636
|
+
y,
|
|
1637
|
+
in_vec_size,
|
|
1638
|
+
out_vec_size,
|
|
1639
|
+
tid,
|
|
1640
|
+
simd_gid,
|
|
1641
|
+
simd_lid);
|
|
1642
|
+
}
|
|
1643
|
+
|
|
1644
|
+
template <typename T, const int group_size, const int bits, int split_k = 32>
|
|
1645
|
+
[[kernel]] void affine_qvm_split_k(
|
|
1646
|
+
const device uint32_t* w [[buffer(0)]],
|
|
1647
|
+
const device T* scales [[buffer(1)]],
|
|
1648
|
+
const device T* biases [[buffer(2)]],
|
|
1649
|
+
const device T* x [[buffer(3)]],
|
|
1650
|
+
device T* y [[buffer(4)]],
|
|
1651
|
+
const constant int& in_vec_size [[buffer(5)]],
|
|
1652
|
+
const constant int& out_vec_size [[buffer(6)]],
|
|
1653
|
+
const constant int& x_batch_ndims [[buffer(7)]],
|
|
1654
|
+
const constant int* x_shape [[buffer(8)]],
|
|
1655
|
+
const constant int64_t* x_strides [[buffer(9)]],
|
|
1656
|
+
const constant int& w_batch_ndims [[buffer(10)]],
|
|
1657
|
+
const constant int* w_shape [[buffer(11)]],
|
|
1658
|
+
const constant int64_t* w_strides [[buffer(12)]],
|
|
1659
|
+
const constant int64_t* s_strides [[buffer(13)]],
|
|
1660
|
+
const constant int64_t* b_strides [[buffer(14)]],
|
|
1661
|
+
const constant int& final_block_size [[buffer(15)]],
|
|
1662
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
1663
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
1664
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
1665
|
+
int M = x_shape[x_batch_ndims];
|
|
1666
|
+
adjust_matrix_offsets<T>(
|
|
1667
|
+
x,
|
|
1668
|
+
w,
|
|
1669
|
+
scales,
|
|
1670
|
+
biases,
|
|
1671
|
+
y,
|
|
1672
|
+
out_vec_size * M,
|
|
1673
|
+
x_batch_ndims,
|
|
1674
|
+
x_shape,
|
|
1675
|
+
x_strides,
|
|
1676
|
+
w_batch_ndims,
|
|
1677
|
+
w_shape,
|
|
1678
|
+
w_strides,
|
|
1679
|
+
s_strides,
|
|
1680
|
+
b_strides,
|
|
1681
|
+
tid);
|
|
1682
|
+
|
|
1683
|
+
// When (in_vec_size % split_k != 0) the final block needs to be smaller
|
|
1684
|
+
int in_vec_size_adj =
|
|
1685
|
+
tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size;
|
|
1686
|
+
|
|
1687
|
+
qvm_impl<T, group_size, bits>(
|
|
1688
|
+
w,
|
|
1689
|
+
scales,
|
|
1690
|
+
biases,
|
|
1691
|
+
x,
|
|
1692
|
+
y,
|
|
1693
|
+
in_vec_size_adj,
|
|
1694
|
+
out_vec_size,
|
|
1695
|
+
tid,
|
|
1696
|
+
simd_gid,
|
|
1697
|
+
simd_lid);
|
|
1698
|
+
}
|
|
1699
|
+
|
|
1700
|
+
template <
|
|
1701
|
+
typename T,
|
|
1702
|
+
const int group_size,
|
|
1703
|
+
const int bits,
|
|
1704
|
+
const bool aligned_N,
|
|
1705
|
+
const bool batched,
|
|
1706
|
+
const int BM = 32,
|
|
1707
|
+
const int BK = 32,
|
|
1708
|
+
const int BN = 32>
|
|
1709
|
+
[[kernel]] void affine_qmm_t(
|
|
1710
|
+
const device uint32_t* w [[buffer(0)]],
|
|
1711
|
+
const device T* scales [[buffer(1)]],
|
|
1712
|
+
const device T* biases [[buffer(2)]],
|
|
1713
|
+
const device T* x [[buffer(3)]],
|
|
1714
|
+
device T* y [[buffer(4)]],
|
|
1715
|
+
const constant int& K [[buffer(5)]],
|
|
1716
|
+
const constant int& N [[buffer(6)]],
|
|
1717
|
+
const constant int& M [[buffer(7)]],
|
|
1718
|
+
const constant int& x_batch_ndims [[buffer(8)]],
|
|
1719
|
+
const constant int* x_shape [[buffer(9)]],
|
|
1720
|
+
const constant int64_t* x_strides [[buffer(10)]],
|
|
1721
|
+
const constant int& w_batch_ndims [[buffer(11)]],
|
|
1722
|
+
const constant int* w_shape [[buffer(12)]],
|
|
1723
|
+
const constant int64_t* w_strides [[buffer(13)]],
|
|
1724
|
+
const constant int64_t* s_strides [[buffer(14)]],
|
|
1725
|
+
const constant int64_t* b_strides [[buffer(15)]],
|
|
1726
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
1727
|
+
uint lid [[thread_index_in_threadgroup]],
|
|
1728
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
1729
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
1730
|
+
(void)lid;
|
|
1731
|
+
|
|
1732
|
+
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
1733
|
+
|
|
1734
|
+
threadgroup T Xs[BM * BK_padded];
|
|
1735
|
+
threadgroup T Ws[BN * BK_padded];
|
|
1736
|
+
|
|
1737
|
+
if (batched) {
|
|
1738
|
+
adjust_matrix_offsets<T>(
|
|
1739
|
+
x,
|
|
1740
|
+
w,
|
|
1741
|
+
scales,
|
|
1742
|
+
biases,
|
|
1743
|
+
y,
|
|
1744
|
+
M * N,
|
|
1745
|
+
x_batch_ndims,
|
|
1746
|
+
x_shape,
|
|
1747
|
+
x_strides,
|
|
1748
|
+
w_batch_ndims,
|
|
1749
|
+
w_shape,
|
|
1750
|
+
w_strides,
|
|
1751
|
+
s_strides,
|
|
1752
|
+
b_strides,
|
|
1753
|
+
tid);
|
|
1754
|
+
}
|
|
1755
|
+
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
|
1756
|
+
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
|
1757
|
+
}
|
|
1758
|
+
|
|
1759
|
+
template <
|
|
1760
|
+
typename T,
|
|
1761
|
+
const int group_size,
|
|
1762
|
+
const int bits,
|
|
1763
|
+
const bool batched,
|
|
1764
|
+
const int BM = 32,
|
|
1765
|
+
const int BK = 32,
|
|
1766
|
+
const int BN = 32>
|
|
1767
|
+
[[kernel]] void affine_qmm_n(
|
|
1768
|
+
const device uint32_t* w [[buffer(0)]],
|
|
1769
|
+
const device T* scales [[buffer(1)]],
|
|
1770
|
+
const device T* biases [[buffer(2)]],
|
|
1771
|
+
const device T* x [[buffer(3)]],
|
|
1772
|
+
device T* y [[buffer(4)]],
|
|
1773
|
+
const constant int& K [[buffer(5)]],
|
|
1774
|
+
const constant int& N [[buffer(6)]],
|
|
1775
|
+
const constant int& M [[buffer(7)]],
|
|
1776
|
+
const constant int& x_batch_ndims [[buffer(8)]],
|
|
1777
|
+
const constant int* x_shape [[buffer(9)]],
|
|
1778
|
+
const constant int64_t* x_strides [[buffer(10)]],
|
|
1779
|
+
const constant int& w_batch_ndims [[buffer(11)]],
|
|
1780
|
+
const constant int* w_shape [[buffer(12)]],
|
|
1781
|
+
const constant int64_t* w_strides [[buffer(13)]],
|
|
1782
|
+
const constant int64_t* s_strides [[buffer(14)]],
|
|
1783
|
+
const constant int64_t* b_strides [[buffer(15)]],
|
|
1784
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
1785
|
+
uint lid [[thread_index_in_threadgroup]],
|
|
1786
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
1787
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
1788
|
+
(void)lid;
|
|
1789
|
+
|
|
1790
|
+
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
1791
|
+
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
1792
|
+
|
|
1793
|
+
threadgroup T Xs[BM * BK_padded];
|
|
1794
|
+
threadgroup T Ws[BK * BN_padded];
|
|
1795
|
+
|
|
1796
|
+
if (batched) {
|
|
1797
|
+
adjust_matrix_offsets<T>(
|
|
1798
|
+
x,
|
|
1799
|
+
w,
|
|
1800
|
+
scales,
|
|
1801
|
+
biases,
|
|
1802
|
+
y,
|
|
1803
|
+
M * N,
|
|
1804
|
+
x_batch_ndims,
|
|
1805
|
+
x_shape,
|
|
1806
|
+
x_strides,
|
|
1807
|
+
w_batch_ndims,
|
|
1808
|
+
w_shape,
|
|
1809
|
+
w_strides,
|
|
1810
|
+
s_strides,
|
|
1811
|
+
b_strides,
|
|
1812
|
+
tid);
|
|
1813
|
+
}
|
|
1814
|
+
|
|
1815
|
+
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
|
1816
|
+
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
|
1817
|
+
}
|
|
1818
|
+
|
|
1819
|
+
template <typename T, int group_size, int bits>
|
|
1820
|
+
[[kernel]] void affine_gather_qmv_fast(
|
|
1821
|
+
const device uint32_t* w [[buffer(0)]],
|
|
1822
|
+
const device T* scales [[buffer(1)]],
|
|
1823
|
+
const device T* biases [[buffer(2)]],
|
|
1824
|
+
const device T* x [[buffer(3)]],
|
|
1825
|
+
const device uint32_t* lhs_indices [[buffer(4)]],
|
|
1826
|
+
const device uint32_t* rhs_indices [[buffer(5)]],
|
|
1827
|
+
device T* y [[buffer(6)]],
|
|
1828
|
+
const constant int& in_vec_size [[buffer(7)]],
|
|
1829
|
+
const constant int& out_vec_size [[buffer(8)]],
|
|
1830
|
+
const constant int& x_batch_ndims [[buffer(9)]],
|
|
1831
|
+
const constant int* x_shape [[buffer(10)]],
|
|
1832
|
+
const constant int64_t* x_strides [[buffer(11)]],
|
|
1833
|
+
const constant int& w_batch_ndims [[buffer(12)]],
|
|
1834
|
+
const constant int* w_shape [[buffer(13)]],
|
|
1835
|
+
const constant int64_t* w_strides [[buffer(14)]],
|
|
1836
|
+
const constant int64_t* s_strides [[buffer(15)]],
|
|
1837
|
+
const constant int64_t* b_strides [[buffer(16)]],
|
|
1838
|
+
const constant int& batch_ndims [[buffer(17)]],
|
|
1839
|
+
const constant int* batch_shape [[buffer(18)]],
|
|
1840
|
+
const constant int64_t* lhs_strides [[buffer(19)]],
|
|
1841
|
+
const constant int64_t* rhs_strides [[buffer(20)]],
|
|
1842
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
1843
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
1844
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
1845
|
+
int M = x_shape[x_batch_ndims];
|
|
1846
|
+
adjust_matrix_offsets<T>(
|
|
1847
|
+
x,
|
|
1848
|
+
w,
|
|
1849
|
+
scales,
|
|
1850
|
+
biases,
|
|
1851
|
+
lhs_indices,
|
|
1852
|
+
rhs_indices,
|
|
1853
|
+
y,
|
|
1854
|
+
out_vec_size * M,
|
|
1855
|
+
batch_ndims,
|
|
1856
|
+
batch_shape,
|
|
1857
|
+
lhs_strides,
|
|
1858
|
+
rhs_strides,
|
|
1859
|
+
x_batch_ndims,
|
|
1860
|
+
x_shape,
|
|
1861
|
+
x_strides,
|
|
1862
|
+
w_batch_ndims,
|
|
1863
|
+
w_shape,
|
|
1864
|
+
w_strides,
|
|
1865
|
+
s_strides,
|
|
1866
|
+
b_strides,
|
|
1867
|
+
tid);
|
|
1868
|
+
qmv_fast_impl<T, group_size, bits>(
|
|
1869
|
+
w,
|
|
1870
|
+
scales,
|
|
1871
|
+
biases,
|
|
1872
|
+
x,
|
|
1873
|
+
y,
|
|
1874
|
+
in_vec_size,
|
|
1875
|
+
out_vec_size,
|
|
1876
|
+
tid,
|
|
1877
|
+
simd_gid,
|
|
1878
|
+
simd_lid);
|
|
1879
|
+
}
|
|
1880
|
+
|
|
1881
|
+
template <typename T, int group_size, int bits>
|
|
1882
|
+
[[kernel]] void affine_gather_qmv(
|
|
1883
|
+
const device uint32_t* w [[buffer(0)]],
|
|
1884
|
+
const device T* scales [[buffer(1)]],
|
|
1885
|
+
const device T* biases [[buffer(2)]],
|
|
1886
|
+
const device T* x [[buffer(3)]],
|
|
1887
|
+
const device uint32_t* lhs_indices [[buffer(4)]],
|
|
1888
|
+
const device uint32_t* rhs_indices [[buffer(5)]],
|
|
1889
|
+
device T* y [[buffer(6)]],
|
|
1890
|
+
const constant int& in_vec_size [[buffer(7)]],
|
|
1891
|
+
const constant int& out_vec_size [[buffer(8)]],
|
|
1892
|
+
const constant int& x_batch_ndims [[buffer(9)]],
|
|
1893
|
+
const constant int* x_shape [[buffer(10)]],
|
|
1894
|
+
const constant int64_t* x_strides [[buffer(11)]],
|
|
1895
|
+
const constant int& w_batch_ndims [[buffer(12)]],
|
|
1896
|
+
const constant int* w_shape [[buffer(13)]],
|
|
1897
|
+
const constant int64_t* w_strides [[buffer(14)]],
|
|
1898
|
+
const constant int64_t* s_strides [[buffer(15)]],
|
|
1899
|
+
const constant int64_t* b_strides [[buffer(16)]],
|
|
1900
|
+
const constant int& batch_ndims [[buffer(17)]],
|
|
1901
|
+
const constant int* batch_shape [[buffer(18)]],
|
|
1902
|
+
const constant int64_t* lhs_strides [[buffer(19)]],
|
|
1903
|
+
const constant int64_t* rhs_strides [[buffer(20)]],
|
|
1904
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
1905
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
1906
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
1907
|
+
int M = x_shape[x_batch_ndims];
|
|
1908
|
+
adjust_matrix_offsets<T>(
|
|
1909
|
+
x,
|
|
1910
|
+
w,
|
|
1911
|
+
scales,
|
|
1912
|
+
biases,
|
|
1913
|
+
lhs_indices,
|
|
1914
|
+
rhs_indices,
|
|
1915
|
+
y,
|
|
1916
|
+
out_vec_size * M,
|
|
1917
|
+
batch_ndims,
|
|
1918
|
+
batch_shape,
|
|
1919
|
+
lhs_strides,
|
|
1920
|
+
rhs_strides,
|
|
1921
|
+
x_batch_ndims,
|
|
1922
|
+
x_shape,
|
|
1923
|
+
x_strides,
|
|
1924
|
+
w_batch_ndims,
|
|
1925
|
+
w_shape,
|
|
1926
|
+
w_strides,
|
|
1927
|
+
s_strides,
|
|
1928
|
+
b_strides,
|
|
1929
|
+
tid);
|
|
1930
|
+
qmv_impl<T, group_size, bits>(
|
|
1931
|
+
w,
|
|
1932
|
+
scales,
|
|
1933
|
+
biases,
|
|
1934
|
+
x,
|
|
1935
|
+
y,
|
|
1936
|
+
in_vec_size,
|
|
1937
|
+
out_vec_size,
|
|
1938
|
+
tid,
|
|
1939
|
+
simd_gid,
|
|
1940
|
+
simd_lid);
|
|
1941
|
+
}
|
|
1942
|
+
|
|
1943
|
+
template <typename T, int group_size, int bits>
|
|
1944
|
+
[[kernel]] void affine_gather_qvm(
|
|
1945
|
+
const device uint32_t* w [[buffer(0)]],
|
|
1946
|
+
const device T* scales [[buffer(1)]],
|
|
1947
|
+
const device T* biases [[buffer(2)]],
|
|
1948
|
+
const device T* x [[buffer(3)]],
|
|
1949
|
+
const device uint32_t* lhs_indices [[buffer(4)]],
|
|
1950
|
+
const device uint32_t* rhs_indices [[buffer(5)]],
|
|
1951
|
+
device T* y [[buffer(6)]],
|
|
1952
|
+
const constant int& in_vec_size [[buffer(7)]],
|
|
1953
|
+
const constant int& out_vec_size [[buffer(8)]],
|
|
1954
|
+
const constant int& x_batch_ndims [[buffer(9)]],
|
|
1955
|
+
const constant int* x_shape [[buffer(10)]],
|
|
1956
|
+
const constant int64_t* x_strides [[buffer(11)]],
|
|
1957
|
+
const constant int& w_batch_ndims [[buffer(12)]],
|
|
1958
|
+
const constant int* w_shape [[buffer(13)]],
|
|
1959
|
+
const constant int64_t* w_strides [[buffer(14)]],
|
|
1960
|
+
const constant int64_t* s_strides [[buffer(15)]],
|
|
1961
|
+
const constant int64_t* b_strides [[buffer(16)]],
|
|
1962
|
+
const constant int& batch_ndims [[buffer(17)]],
|
|
1963
|
+
const constant int* batch_shape [[buffer(18)]],
|
|
1964
|
+
const constant int64_t* lhs_strides [[buffer(19)]],
|
|
1965
|
+
const constant int64_t* rhs_strides [[buffer(20)]],
|
|
1966
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
1967
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
1968
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
1969
|
+
int M = x_shape[x_batch_ndims];
|
|
1970
|
+
adjust_matrix_offsets<T>(
|
|
1971
|
+
x,
|
|
1972
|
+
w,
|
|
1973
|
+
scales,
|
|
1974
|
+
biases,
|
|
1975
|
+
lhs_indices,
|
|
1976
|
+
rhs_indices,
|
|
1977
|
+
y,
|
|
1978
|
+
out_vec_size * M,
|
|
1979
|
+
batch_ndims,
|
|
1980
|
+
batch_shape,
|
|
1981
|
+
lhs_strides,
|
|
1982
|
+
rhs_strides,
|
|
1983
|
+
x_batch_ndims,
|
|
1984
|
+
x_shape,
|
|
1985
|
+
x_strides,
|
|
1986
|
+
w_batch_ndims,
|
|
1987
|
+
w_shape,
|
|
1988
|
+
w_strides,
|
|
1989
|
+
s_strides,
|
|
1990
|
+
b_strides,
|
|
1991
|
+
tid);
|
|
1992
|
+
qvm_impl<T, group_size, bits>(
|
|
1993
|
+
w,
|
|
1994
|
+
scales,
|
|
1995
|
+
biases,
|
|
1996
|
+
x,
|
|
1997
|
+
y,
|
|
1998
|
+
in_vec_size,
|
|
1999
|
+
out_vec_size,
|
|
2000
|
+
tid,
|
|
2001
|
+
simd_gid,
|
|
2002
|
+
simd_lid);
|
|
2003
|
+
}
|
|
2004
|
+
|
|
2005
|
+
template <
|
|
2006
|
+
typename T,
|
|
2007
|
+
const int group_size,
|
|
2008
|
+
const int bits,
|
|
2009
|
+
const bool aligned_N,
|
|
2010
|
+
const int BM = 32,
|
|
2011
|
+
const int BK = 32,
|
|
2012
|
+
const int BN = 32>
|
|
2013
|
+
[[kernel]] void affine_gather_qmm_t(
|
|
2014
|
+
const device uint32_t* w [[buffer(0)]],
|
|
2015
|
+
const device T* scales [[buffer(1)]],
|
|
2016
|
+
const device T* biases [[buffer(2)]],
|
|
2017
|
+
const device T* x [[buffer(3)]],
|
|
2018
|
+
const device uint32_t* lhs_indices [[buffer(4)]],
|
|
2019
|
+
const device uint32_t* rhs_indices [[buffer(5)]],
|
|
2020
|
+
device T* y [[buffer(6)]],
|
|
2021
|
+
const constant int& K [[buffer(7)]],
|
|
2022
|
+
const constant int& N [[buffer(8)]],
|
|
2023
|
+
const constant int& M [[buffer(9)]],
|
|
2024
|
+
const constant int& x_batch_ndims [[buffer(10)]],
|
|
2025
|
+
const constant int* x_shape [[buffer(11)]],
|
|
2026
|
+
const constant int64_t* x_strides [[buffer(12)]],
|
|
2027
|
+
const constant int& w_batch_ndims [[buffer(13)]],
|
|
2028
|
+
const constant int* w_shape [[buffer(14)]],
|
|
2029
|
+
const constant int64_t* w_strides [[buffer(15)]],
|
|
2030
|
+
const constant int64_t* s_strides [[buffer(16)]],
|
|
2031
|
+
const constant int64_t* b_strides [[buffer(17)]],
|
|
2032
|
+
const constant int& batch_ndims [[buffer(18)]],
|
|
2033
|
+
const constant int* batch_shape [[buffer(19)]],
|
|
2034
|
+
const constant int64_t* lhs_strides [[buffer(20)]],
|
|
2035
|
+
const constant int64_t* rhs_strides [[buffer(21)]],
|
|
2036
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
2037
|
+
uint lid [[thread_index_in_threadgroup]],
|
|
2038
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
2039
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
2040
|
+
(void)lid;
|
|
2041
|
+
|
|
2042
|
+
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
2043
|
+
|
|
2044
|
+
threadgroup T Xs[BM * BK_padded];
|
|
2045
|
+
threadgroup T Ws[BN * BK_padded];
|
|
2046
|
+
|
|
2047
|
+
adjust_matrix_offsets<T>(
|
|
2048
|
+
x,
|
|
2049
|
+
w,
|
|
2050
|
+
scales,
|
|
2051
|
+
biases,
|
|
2052
|
+
lhs_indices,
|
|
2053
|
+
rhs_indices,
|
|
2054
|
+
y,
|
|
2055
|
+
M * N,
|
|
2056
|
+
batch_ndims,
|
|
2057
|
+
batch_shape,
|
|
2058
|
+
lhs_strides,
|
|
2059
|
+
rhs_strides,
|
|
2060
|
+
x_batch_ndims,
|
|
2061
|
+
x_shape,
|
|
2062
|
+
x_strides,
|
|
2063
|
+
w_batch_ndims,
|
|
2064
|
+
w_shape,
|
|
2065
|
+
w_strides,
|
|
2066
|
+
s_strides,
|
|
2067
|
+
b_strides,
|
|
2068
|
+
tid);
|
|
2069
|
+
qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN>(
|
|
2070
|
+
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
|
2071
|
+
}
|
|
2072
|
+
|
|
2073
|
+
template <
|
|
2074
|
+
typename T,
|
|
2075
|
+
const int group_size,
|
|
2076
|
+
const int bits,
|
|
2077
|
+
const int BM = 32,
|
|
2078
|
+
const int BK = 32,
|
|
2079
|
+
const int BN = 32>
|
|
2080
|
+
[[kernel]] void affine_gather_qmm_n(
|
|
2081
|
+
const device uint32_t* w [[buffer(0)]],
|
|
2082
|
+
const device T* scales [[buffer(1)]],
|
|
2083
|
+
const device T* biases [[buffer(2)]],
|
|
2084
|
+
const device T* x [[buffer(3)]],
|
|
2085
|
+
const device uint32_t* lhs_indices [[buffer(4)]],
|
|
2086
|
+
const device uint32_t* rhs_indices [[buffer(5)]],
|
|
2087
|
+
device T* y [[buffer(6)]],
|
|
2088
|
+
const constant int& K [[buffer(7)]],
|
|
2089
|
+
const constant int& N [[buffer(8)]],
|
|
2090
|
+
const constant int& M [[buffer(9)]],
|
|
2091
|
+
const constant int& x_batch_ndims [[buffer(10)]],
|
|
2092
|
+
const constant int* x_shape [[buffer(11)]],
|
|
2093
|
+
const constant int64_t* x_strides [[buffer(12)]],
|
|
2094
|
+
const constant int& w_batch_ndims [[buffer(13)]],
|
|
2095
|
+
const constant int* w_shape [[buffer(14)]],
|
|
2096
|
+
const constant int64_t* w_strides [[buffer(15)]],
|
|
2097
|
+
const constant int64_t* s_strides [[buffer(16)]],
|
|
2098
|
+
const constant int64_t* b_strides [[buffer(17)]],
|
|
2099
|
+
const constant int& batch_ndims [[buffer(18)]],
|
|
2100
|
+
const constant int* batch_shape [[buffer(19)]],
|
|
2101
|
+
const constant int64_t* lhs_strides [[buffer(20)]],
|
|
2102
|
+
const constant int64_t* rhs_strides [[buffer(21)]],
|
|
2103
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
2104
|
+
uint lid [[thread_index_in_threadgroup]],
|
|
2105
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
2106
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
2107
|
+
(void)lid;
|
|
2108
|
+
|
|
2109
|
+
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
2110
|
+
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
2111
|
+
|
|
2112
|
+
threadgroup T Xs[BM * BK_padded];
|
|
2113
|
+
threadgroup T Ws[BK * BN_padded];
|
|
2114
|
+
|
|
2115
|
+
adjust_matrix_offsets<T>(
|
|
2116
|
+
x,
|
|
2117
|
+
w,
|
|
2118
|
+
scales,
|
|
2119
|
+
biases,
|
|
2120
|
+
lhs_indices,
|
|
2121
|
+
rhs_indices,
|
|
2122
|
+
y,
|
|
2123
|
+
M * N,
|
|
2124
|
+
batch_ndims,
|
|
2125
|
+
batch_shape,
|
|
2126
|
+
lhs_strides,
|
|
2127
|
+
rhs_strides,
|
|
2128
|
+
x_batch_ndims,
|
|
2129
|
+
x_shape,
|
|
2130
|
+
x_strides,
|
|
2131
|
+
w_batch_ndims,
|
|
2132
|
+
w_shape,
|
|
2133
|
+
w_strides,
|
|
2134
|
+
s_strides,
|
|
2135
|
+
b_strides,
|
|
2136
|
+
tid);
|
|
2137
|
+
qmm_n_impl<T, group_size, bits, BM, BK, BN>(
|
|
2138
|
+
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
|
2139
|
+
}
|
|
2140
|
+
|
|
2141
|
+
template <
|
|
2142
|
+
typename T,
|
|
2143
|
+
int group_size,
|
|
2144
|
+
int bits,
|
|
2145
|
+
int BM,
|
|
2146
|
+
int BN,
|
|
2147
|
+
int BK,
|
|
2148
|
+
int WM,
|
|
2149
|
+
int WN,
|
|
2150
|
+
bool transpose>
|
|
2151
|
+
[[kernel]] void affine_gather_qmm_rhs(
|
|
2152
|
+
const device T* x [[buffer(0)]],
|
|
2153
|
+
const device uint32_t* w [[buffer(1)]],
|
|
2154
|
+
const device T* scales [[buffer(2)]],
|
|
2155
|
+
const device T* biases [[buffer(3)]],
|
|
2156
|
+
const device uint32_t* indices [[buffer(4)]],
|
|
2157
|
+
device T* y [[buffer(5)]],
|
|
2158
|
+
const constant int& M [[buffer(6)]],
|
|
2159
|
+
const constant int& N [[buffer(7)]],
|
|
2160
|
+
const constant int& K [[buffer(8)]],
|
|
2161
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
2162
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
2163
|
+
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
|
2164
|
+
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
2165
|
+
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
2166
|
+
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
2167
|
+
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
2168
|
+
|
|
2169
|
+
using mma_t = mlx::steel::BlockMMA<
|
|
2170
|
+
T,
|
|
2171
|
+
T,
|
|
2172
|
+
BM,
|
|
2173
|
+
BN,
|
|
2174
|
+
BK,
|
|
2175
|
+
WM,
|
|
2176
|
+
WN,
|
|
2177
|
+
false,
|
|
2178
|
+
transpose,
|
|
2179
|
+
BK_padded,
|
|
2180
|
+
transpose ? BK_padded : BN_padded>;
|
|
2181
|
+
using loader_x_t =
|
|
2182
|
+
mlx::steel::BlockLoader<T, BM, BK, BK_padded, 1, WM * WN * SIMD_SIZE>;
|
|
2183
|
+
using loader_w_t = QuantizedBlockLoader<
|
|
2184
|
+
T,
|
|
2185
|
+
transpose ? BN : BK,
|
|
2186
|
+
transpose ? BK : BN,
|
|
2187
|
+
transpose ? BK_padded : BN_padded,
|
|
2188
|
+
transpose,
|
|
2189
|
+
WM * WN * SIMD_SIZE,
|
|
2190
|
+
group_size,
|
|
2191
|
+
bits>;
|
|
2192
|
+
|
|
2193
|
+
threadgroup T Xs[BM * BK_padded];
|
|
2194
|
+
threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded];
|
|
2195
|
+
|
|
2196
|
+
// Compute the block
|
|
2197
|
+
const int K_w = K * bytes_per_pack / pack_factor;
|
|
2198
|
+
const int K_g = K / group_size;
|
|
2199
|
+
const int N_w = N * bytes_per_pack / pack_factor;
|
|
2200
|
+
const int N_g = N / group_size;
|
|
2201
|
+
const int K_it = K / BK;
|
|
2202
|
+
const size_t stride_w = transpose ? N * K_w : K * N_w;
|
|
2203
|
+
const size_t stride_s = transpose ? N * K_g : K * N_g;
|
|
2204
|
+
const int y_row = tid.y * BM;
|
|
2205
|
+
const int y_col = tid.x * BN;
|
|
2206
|
+
const size_t y_row_long = size_t(y_row);
|
|
2207
|
+
const size_t y_col_long = size_t(y_col);
|
|
2208
|
+
|
|
2209
|
+
// Prepare threadgroup bounds
|
|
2210
|
+
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
|
|
2211
|
+
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
|
|
2212
|
+
|
|
2213
|
+
// Calculate the final tiles in the case that K is not aligned
|
|
2214
|
+
const int k_remain = K - K_it * BK;
|
|
2215
|
+
const short2 tile_x = short2(k_remain, tgp_bm);
|
|
2216
|
+
const short2 tile_w =
|
|
2217
|
+
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
2218
|
+
|
|
2219
|
+
// Move x and output to the correct block
|
|
2220
|
+
auto wl = (const device uint8_t*)w;
|
|
2221
|
+
x += y_row_long * K;
|
|
2222
|
+
y += y_row_long * N + y_col_long;
|
|
2223
|
+
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
|
|
2224
|
+
scales += transpose ? y_col_long * K_g : y_col / group_size;
|
|
2225
|
+
biases += transpose ? y_col_long * K_g : y_col / group_size;
|
|
2226
|
+
|
|
2227
|
+
// Do as many matmuls as necessary
|
|
2228
|
+
uint32_t index;
|
|
2229
|
+
short offset;
|
|
2230
|
+
uint32_t index_next = indices[y_row];
|
|
2231
|
+
short offset_next = 0;
|
|
2232
|
+
int n = 0;
|
|
2233
|
+
while (n < tgp_bm) {
|
|
2234
|
+
n++;
|
|
2235
|
+
offset = offset_next;
|
|
2236
|
+
index = index_next;
|
|
2237
|
+
offset_next = tgp_bm;
|
|
2238
|
+
for (; n < tgp_bm; n++) {
|
|
2239
|
+
if (indices[y_row + n] != index) {
|
|
2240
|
+
offset_next = n;
|
|
2241
|
+
index_next = indices[y_row + n];
|
|
2242
|
+
break;
|
|
2243
|
+
}
|
|
2244
|
+
}
|
|
2245
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
2246
|
+
|
|
2247
|
+
// Prepare threadgroup mma operation
|
|
2248
|
+
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
2249
|
+
|
|
2250
|
+
// Prepare threadgroup loading operations
|
|
2251
|
+
thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id);
|
|
2252
|
+
thread loader_w_t loader_w(
|
|
2253
|
+
wl + index * stride_w,
|
|
2254
|
+
scales + index * stride_s,
|
|
2255
|
+
biases + index * stride_s,
|
|
2256
|
+
transpose ? K : N,
|
|
2257
|
+
Ws,
|
|
2258
|
+
simd_group_id,
|
|
2259
|
+
simd_lane_id);
|
|
2260
|
+
|
|
2261
|
+
// Matrices are all aligned check nothing
|
|
2262
|
+
if (align_M && align_N) {
|
|
2263
|
+
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
|
|
2264
|
+
if (!align_K) {
|
|
2265
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2266
|
+
gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
|
2267
|
+
}
|
|
2268
|
+
|
|
2269
|
+
// Store results to device memory
|
|
2270
|
+
if (offset_next - offset == BM) {
|
|
2271
|
+
mma_op.store_result(y, N);
|
|
2272
|
+
} else {
|
|
2273
|
+
mma_op.store_result_slice(
|
|
2274
|
+
y, N, short2(0, offset), short2(BN, offset_next));
|
|
2275
|
+
}
|
|
2276
|
+
} else {
|
|
2277
|
+
// Tile aligned so check outside of the hot loop
|
|
2278
|
+
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
|
2279
|
+
gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it);
|
|
2280
|
+
if (!align_K) {
|
|
2281
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2282
|
+
gemm_loop_finalize(
|
|
2283
|
+
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
|
2284
|
+
}
|
|
2285
|
+
|
|
2286
|
+
// Store results to device memory
|
|
2287
|
+
if (offset_next - offset == BM) {
|
|
2288
|
+
mma_op.store_result(y, N);
|
|
2289
|
+
} else {
|
|
2290
|
+
mma_op.store_result_slice(
|
|
2291
|
+
y, N, short2(0, offset), short2(BN, offset_next));
|
|
2292
|
+
}
|
|
2293
|
+
}
|
|
2294
|
+
|
|
2295
|
+
// Tile partially aligned check rows
|
|
2296
|
+
else if (align_N || tgp_bn == BN) {
|
|
2297
|
+
gemm_loop_unaligned<false, true, transpose>(
|
|
2298
|
+
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
|
2299
|
+
if (!align_K) {
|
|
2300
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2301
|
+
gemm_loop_finalize(
|
|
2302
|
+
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
|
2303
|
+
}
|
|
2304
|
+
mma_op.store_result_slice(
|
|
2305
|
+
y, N, short2(0, offset), short2(BN, offset_next));
|
|
2306
|
+
}
|
|
2307
|
+
|
|
2308
|
+
// Tile partially aligned check cols
|
|
2309
|
+
else if (align_M || tgp_bm == BM) {
|
|
2310
|
+
gemm_loop_unaligned<true, false, transpose>(
|
|
2311
|
+
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
|
2312
|
+
if (!align_K) {
|
|
2313
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2314
|
+
gemm_loop_finalize(
|
|
2315
|
+
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
|
2316
|
+
}
|
|
2317
|
+
mma_op.store_result_slice(
|
|
2318
|
+
y, N, short2(0, offset), short2(tgp_bn, offset_next));
|
|
2319
|
+
}
|
|
2320
|
+
|
|
2321
|
+
// Nothing aligned so check both rows and cols
|
|
2322
|
+
else {
|
|
2323
|
+
gemm_loop_unaligned<false, false, transpose>(
|
|
2324
|
+
Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK);
|
|
2325
|
+
if (!align_K) {
|
|
2326
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
2327
|
+
gemm_loop_finalize(
|
|
2328
|
+
Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w);
|
|
2329
|
+
}
|
|
2330
|
+
mma_op.store_result_slice(
|
|
2331
|
+
y, N, short2(0, offset), short2(tgp_bn, offset_next));
|
|
2332
|
+
}
|
|
2333
|
+
}
|
|
2334
|
+
}
|
|
2335
|
+
}
|
|
2336
|
+
|
|
2337
|
+
template <typename T, const int group_size, const int bits>
|
|
2338
|
+
[[kernel]] void affine_quantize(
|
|
2339
|
+
const device T* w [[buffer(0)]],
|
|
2340
|
+
device uint8_t* out [[buffer(1)]],
|
|
2341
|
+
device T* scales [[buffer(2)]],
|
|
2342
|
+
device T* biases [[buffer(3)]],
|
|
2343
|
+
uint2 index [[thread_position_in_grid]],
|
|
2344
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
2345
|
+
constexpr float eps = 1e-7;
|
|
2346
|
+
constexpr int simd_size = 32;
|
|
2347
|
+
constexpr float n_bins = (1 << bits) - 1;
|
|
2348
|
+
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
2349
|
+
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
2350
|
+
constexpr int values_per_reduce = group_size / simd_size;
|
|
2351
|
+
constexpr int writes_per_reduce = pack_factor / values_per_reduce;
|
|
2352
|
+
constexpr int writes_per_pack =
|
|
2353
|
+
writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor;
|
|
2354
|
+
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
|
2355
|
+
|
|
2356
|
+
static_assert(
|
|
2357
|
+
group_size % simd_size == 0,
|
|
2358
|
+
"Group size must be divisible by simd size.");
|
|
2359
|
+
|
|
2360
|
+
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
|
2361
|
+
size_t in_index = offset * values_per_reduce;
|
|
2362
|
+
size_t out_index = power_of_2_bits
|
|
2363
|
+
? offset * writes_per_pack
|
|
2364
|
+
: offset * bytes_per_pack / writes_per_reduce;
|
|
2365
|
+
|
|
2366
|
+
float w_thread[values_per_reduce];
|
|
2367
|
+
float w_min = Limits<T>::max;
|
|
2368
|
+
float w_max = 0;
|
|
2369
|
+
|
|
2370
|
+
#pragma clang loop unroll(full)
|
|
2371
|
+
for (int i = 0; i < values_per_reduce; i++) {
|
|
2372
|
+
float val = w[in_index + i];
|
|
2373
|
+
w_thread[i] = val;
|
|
2374
|
+
w_min = min(w_min, val);
|
|
2375
|
+
w_max = max(w_max, val);
|
|
2376
|
+
}
|
|
2377
|
+
|
|
2378
|
+
w_min = simd_min(w_min);
|
|
2379
|
+
w_max = simd_max(w_max);
|
|
2380
|
+
|
|
2381
|
+
float scale = max((w_max - w_min) / n_bins, eps);
|
|
2382
|
+
bool side = abs(w_min) > abs(w_max);
|
|
2383
|
+
scale = side ? scale : -scale;
|
|
2384
|
+
float edge = side ? w_min : w_max;
|
|
2385
|
+
float q0 = round(edge / scale);
|
|
2386
|
+
bool at_zero = q0 == 0.0f;
|
|
2387
|
+
scale = at_zero ? scale : edge / q0;
|
|
2388
|
+
float bias = at_zero ? 0 : edge;
|
|
2389
|
+
|
|
2390
|
+
// Write out the scales and biases
|
|
2391
|
+
size_t gindex = in_index / group_size;
|
|
2392
|
+
if (in_index % group_size == 0) {
|
|
2393
|
+
scales[gindex] = static_cast<T>(scale);
|
|
2394
|
+
biases[gindex] = static_cast<T>(bias);
|
|
2395
|
+
}
|
|
2396
|
+
|
|
2397
|
+
using OutType = metal::conditional_t<bits == 5, uint64_t, uint32_t>;
|
|
2398
|
+
OutType output = 0;
|
|
2399
|
+
|
|
2400
|
+
#pragma clang loop unroll(full)
|
|
2401
|
+
for (int i = 0; i < values_per_reduce; i++) {
|
|
2402
|
+
uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins);
|
|
2403
|
+
if (bits == 8) {
|
|
2404
|
+
output = val;
|
|
2405
|
+
} else {
|
|
2406
|
+
output |= val << (bits * (i % pack_factor));
|
|
2407
|
+
}
|
|
2408
|
+
|
|
2409
|
+
if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) {
|
|
2410
|
+
out[out_index + i / pack_factor] = output;
|
|
2411
|
+
output = 0;
|
|
2412
|
+
} else {
|
|
2413
|
+
#pragma clang loop unroll(full)
|
|
2414
|
+
for (int j = 1; j < writes_per_reduce; j++) {
|
|
2415
|
+
uint8_t sval = simd_shuffle_down(val, j);
|
|
2416
|
+
output |= static_cast<OutType>(sval)
|
|
2417
|
+
<< (bits * (j * values_per_reduce + i));
|
|
2418
|
+
}
|
|
2419
|
+
}
|
|
2420
|
+
}
|
|
2421
|
+
if (bits == 3 || bits == 6) {
|
|
2422
|
+
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
|
|
2423
|
+
out[out_index] = output & 0xff;
|
|
2424
|
+
out[out_index + 1] = (output & 0xff00) >> 8;
|
|
2425
|
+
out[out_index + 2] = (output & 0xff0000) >> 16;
|
|
2426
|
+
}
|
|
2427
|
+
} else if (bits == 5) {
|
|
2428
|
+
if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) {
|
|
2429
|
+
out[out_index] = output & 0xff;
|
|
2430
|
+
out[out_index + 1] = (output & 0xff00) >> 8;
|
|
2431
|
+
out[out_index + 2] = (output & 0xff0000) >> 16;
|
|
2432
|
+
out[out_index + 3] = (output & 0xff000000) >> 24;
|
|
2433
|
+
out[out_index + 4] = (output & 0xff00000000) >> 32;
|
|
2434
|
+
}
|
|
2435
|
+
} else {
|
|
2436
|
+
if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) {
|
|
2437
|
+
out[out_index / writes_per_reduce] = output;
|
|
2438
|
+
}
|
|
2439
|
+
}
|
|
2440
|
+
}
|
|
2441
|
+
|
|
2442
|
+
template <typename T, const int group_size, const int bits>
|
|
2443
|
+
[[kernel]] void affine_dequantize(
|
|
2444
|
+
const device uint8_t* w [[buffer(0)]],
|
|
2445
|
+
const device T* scales [[buffer(1)]],
|
|
2446
|
+
const device T* biases [[buffer(2)]],
|
|
2447
|
+
device T* out [[buffer(3)]],
|
|
2448
|
+
uint2 index [[thread_position_in_grid]],
|
|
2449
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
2450
|
+
constexpr int pack_factor = get_pack_factor<bits, 8>();
|
|
2451
|
+
constexpr int bytes_per_pack = get_bytes_per_pack<bits>();
|
|
2452
|
+
|
|
2453
|
+
size_t offset = index.x + grid_dim.x * size_t(index.y);
|
|
2454
|
+
size_t oindex = offset * pack_factor;
|
|
2455
|
+
size_t gindex = oindex / group_size;
|
|
2456
|
+
T scale = scales[gindex];
|
|
2457
|
+
T bias = biases[gindex];
|
|
2458
|
+
|
|
2459
|
+
out += oindex;
|
|
2460
|
+
|
|
2461
|
+
if (bits == 3) {
|
|
2462
|
+
w += offset * bytes_per_pack;
|
|
2463
|
+
out[0] = (w[0] & 0x7) * scale + bias;
|
|
2464
|
+
out[1] = ((w[0] & 0x38) >> 3) * scale + bias;
|
|
2465
|
+
out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias;
|
|
2466
|
+
out[3] = ((w[1] & 0xe) >> 1) * scale + bias;
|
|
2467
|
+
out[4] = ((w[1] & 0x70) >> 4) * scale + bias;
|
|
2468
|
+
out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias;
|
|
2469
|
+
out[6] = ((w[2] & 0x1c) >> 2) * scale + bias;
|
|
2470
|
+
out[7] = ((w[2] & 0xe0) >> 5) * scale + bias;
|
|
2471
|
+
} else if (bits == 5) {
|
|
2472
|
+
w += offset * bytes_per_pack;
|
|
2473
|
+
out[0] = (w[0] & 0x1f) * scale + bias;
|
|
2474
|
+
out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias;
|
|
2475
|
+
out[2] = ((w[1] & 0x7c) >> 2) * scale + bias;
|
|
2476
|
+
out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias;
|
|
2477
|
+
out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias;
|
|
2478
|
+
out[5] = ((w[3] & 0x3e) >> 1) * scale + bias;
|
|
2479
|
+
out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias;
|
|
2480
|
+
out[7] = ((w[4] & 0xf8) >> 3) * scale + bias;
|
|
2481
|
+
} else if (bits == 6) {
|
|
2482
|
+
w += offset * bytes_per_pack;
|
|
2483
|
+
out[0] = (w[0] & 0x3f) * scale + bias;
|
|
2484
|
+
out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias;
|
|
2485
|
+
out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias;
|
|
2486
|
+
out[3] = ((w[2] >> 2) & 0x3f) * scale + bias;
|
|
2487
|
+
} else {
|
|
2488
|
+
uint val = w[offset];
|
|
2489
|
+
#pragma clang loop unroll(full)
|
|
2490
|
+
for (int i = 0; i < pack_factor; i++) {
|
|
2491
|
+
uint8_t d;
|
|
2492
|
+
if (bits == 2) {
|
|
2493
|
+
d = (val >> (bits * i)) & 0x03;
|
|
2494
|
+
} else if (bits == 4) {
|
|
2495
|
+
d = (val >> (bits * i)) & 0x0f;
|
|
2496
|
+
} else if (bits == 8) {
|
|
2497
|
+
d = val;
|
|
2498
|
+
}
|
|
2499
|
+
out[i] = scale * d + bias;
|
|
2500
|
+
}
|
|
2501
|
+
}
|
|
2502
|
+
}
|