mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,1059 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <metal_simdgroup>
|
|
4
|
+
#include <metal_stdlib>
|
|
5
|
+
|
|
6
|
+
#include "mlx/backend/metal/kernels/fp4.h"
|
|
7
|
+
#include "mlx/backend/metal/kernels/fp8.h"
|
|
8
|
+
|
|
9
|
+
constant bool align_M [[function_constant(200)]];
|
|
10
|
+
constant bool align_N [[function_constant(201)]];
|
|
11
|
+
constant bool align_K [[function_constant(202)]];
|
|
12
|
+
|
|
13
|
+
using namespace metal;
|
|
14
|
+
|
|
15
|
+
#define MLX_MTL_CONST static constant constexpr const
|
|
16
|
+
|
|
17
|
+
MLX_MTL_CONST int SIMD_SIZE = 32;
|
|
18
|
+
MLX_MTL_CONST int QUAD_SIZE = 4;
|
|
19
|
+
|
|
20
|
+
template <int wsize = 8>
|
|
21
|
+
inline constexpr short get_pack_factor() {
|
|
22
|
+
return wsize / 4;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
template <int wsize = 8>
|
|
26
|
+
inline constexpr short get_bytes_per_pack() {
|
|
27
|
+
return wsize / 8;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
template <typename T>
|
|
31
|
+
static inline T dequantize_scale(uint8_t s) {
|
|
32
|
+
return T(*(thread fp8_e8m0*)(&s));
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
template <int bits>
|
|
36
|
+
struct Quantize {
|
|
37
|
+
uint8_t operator()(float x) {
|
|
38
|
+
if constexpr (bits == 8) {
|
|
39
|
+
return fp8_e4m3(x).bits;
|
|
40
|
+
} else {
|
|
41
|
+
return fp4_e2m1(x).bits;
|
|
42
|
+
}
|
|
43
|
+
}
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
template <int bits>
|
|
47
|
+
struct Dequantize {
|
|
48
|
+
float operator()(uint8_t x) {
|
|
49
|
+
if constexpr (bits == 8) {
|
|
50
|
+
return float(*(thread fp8_e4m3*)(&x));
|
|
51
|
+
} else {
|
|
52
|
+
return float(*(thread fp4_e2m1*)(&x));
|
|
53
|
+
}
|
|
54
|
+
}
|
|
55
|
+
};
|
|
56
|
+
|
|
57
|
+
template <typename U, int N>
|
|
58
|
+
inline void dequantize(
|
|
59
|
+
const device uint8_t* w,
|
|
60
|
+
U scale,
|
|
61
|
+
threadgroup U* w_local,
|
|
62
|
+
const threadgroup U* lut) {
|
|
63
|
+
for (int i = 0; i < (N / 2); i++) {
|
|
64
|
+
w_local[2 * i] = scale * lut[w[i] & 0xf];
|
|
65
|
+
w_local[2 * i + 1] = scale * lut[(w[i] >> 4) & 0xf];
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
template <
|
|
70
|
+
typename T,
|
|
71
|
+
short BROWS,
|
|
72
|
+
short BCOLS,
|
|
73
|
+
short dst_ld,
|
|
74
|
+
short reduction_dim,
|
|
75
|
+
short tgp_size,
|
|
76
|
+
short group_size>
|
|
77
|
+
struct QuantizedBlockLoader {
|
|
78
|
+
static_assert(
|
|
79
|
+
BCOLS % group_size == 0,
|
|
80
|
+
"The group size should be divisible by the columns");
|
|
81
|
+
|
|
82
|
+
MLX_MTL_CONST short pack_factor = get_pack_factor<8>();
|
|
83
|
+
MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack();
|
|
84
|
+
MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor;
|
|
85
|
+
MLX_MTL_CONST short n_reads =
|
|
86
|
+
(BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size;
|
|
87
|
+
MLX_MTL_CONST short n_groups = BCOLS / group_size;
|
|
88
|
+
|
|
89
|
+
static_assert(
|
|
90
|
+
(BCOLS_PACKED / n_reads) == n_groups,
|
|
91
|
+
"Other configurations are not yet supported");
|
|
92
|
+
|
|
93
|
+
const int src_ld;
|
|
94
|
+
const int tile_stride;
|
|
95
|
+
const int group_stride;
|
|
96
|
+
|
|
97
|
+
const short thread_idx;
|
|
98
|
+
const short bi;
|
|
99
|
+
const short bj;
|
|
100
|
+
|
|
101
|
+
const short group_id;
|
|
102
|
+
|
|
103
|
+
threadgroup T* dst;
|
|
104
|
+
const device uint8_t* src;
|
|
105
|
+
const device uint8_t* scales;
|
|
106
|
+
threadgroup T* lut;
|
|
107
|
+
|
|
108
|
+
QuantizedBlockLoader(
|
|
109
|
+
const device uint8_t* src_,
|
|
110
|
+
const device uint8_t* scales_,
|
|
111
|
+
const int src_ld_,
|
|
112
|
+
threadgroup T* dst_,
|
|
113
|
+
threadgroup T* lut_,
|
|
114
|
+
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
115
|
+
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
|
116
|
+
: src_ld(src_ld_),
|
|
117
|
+
tile_stride(
|
|
118
|
+
reduction_dim ? BCOLS_PACKED * bytes_per_pack
|
|
119
|
+
: BROWS * src_ld * bytes_per_pack / pack_factor),
|
|
120
|
+
group_stride(BROWS * src_ld / group_size),
|
|
121
|
+
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
122
|
+
bi(n_reads * thread_idx / BCOLS_PACKED),
|
|
123
|
+
bj((n_reads * thread_idx) % BCOLS_PACKED),
|
|
124
|
+
group_id((bj * pack_factor) / group_size),
|
|
125
|
+
dst(dst_ + bi * dst_ld + bj * pack_factor),
|
|
126
|
+
src(src_ + bi * src_ld * bytes_per_pack / pack_factor +
|
|
127
|
+
bj * bytes_per_pack),
|
|
128
|
+
scales(scales_ + bi * src_ld / group_size + group_id),
|
|
129
|
+
lut(lut_) {
|
|
130
|
+
if (simd_group_id == 0 && simd_lane_id < 16) {
|
|
131
|
+
lut[simd_lane_id] = static_cast<T>(FP4_LUT[simd_lane_id]);
|
|
132
|
+
}
|
|
133
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
void load_unsafe() const {
|
|
137
|
+
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
|
138
|
+
return;
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
T scale = dequantize_scale<T>(*scales);
|
|
142
|
+
for (int i = 0; i < n_reads; i++) {
|
|
143
|
+
dequantize<T, pack_factor>(
|
|
144
|
+
src + i * bytes_per_pack, scale, dst + i * pack_factor, lut);
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
void load_safe(short2 src_tile_dim) const {
|
|
149
|
+
if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) {
|
|
150
|
+
return;
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
if (reduction_dim == 1 && bi >= src_tile_dim.x) {
|
|
154
|
+
for (int i = 0; i < n_reads * pack_factor; i++) {
|
|
155
|
+
dst[i] = T(0);
|
|
156
|
+
}
|
|
157
|
+
return;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
if (reduction_dim == 0 && bi >= src_tile_dim.y) {
|
|
161
|
+
for (int i = 0; i < n_reads * pack_factor; i++) {
|
|
162
|
+
dst[i] = T(0);
|
|
163
|
+
}
|
|
164
|
+
return;
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
T scale = dequantize_scale<T>(*scales);
|
|
168
|
+
for (int i = 0; i < n_reads; i++) {
|
|
169
|
+
dequantize<T, pack_factor>(
|
|
170
|
+
(device uint8_t*)(src + i * bytes_per_pack),
|
|
171
|
+
scale,
|
|
172
|
+
dst + i * pack_factor,
|
|
173
|
+
lut);
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
void next() {
|
|
178
|
+
src += tile_stride;
|
|
179
|
+
if (reduction_dim == 1) {
|
|
180
|
+
// if (group_steps > 1) {
|
|
181
|
+
// group_step_cnt++;
|
|
182
|
+
// if (group_step_cnt == group_steps) {
|
|
183
|
+
// group_step_cnt = 0;
|
|
184
|
+
// scales++;
|
|
185
|
+
// }
|
|
186
|
+
// } else {
|
|
187
|
+
scales += n_groups;
|
|
188
|
+
// }
|
|
189
|
+
} else {
|
|
190
|
+
scales += n_groups * group_stride;
|
|
191
|
+
}
|
|
192
|
+
}
|
|
193
|
+
};
|
|
194
|
+
|
|
195
|
+
using namespace mlx::steel;
|
|
196
|
+
|
|
197
|
+
template <
|
|
198
|
+
typename T,
|
|
199
|
+
const int group_size,
|
|
200
|
+
const int bits,
|
|
201
|
+
const bool aligned_N,
|
|
202
|
+
const int BM = 64,
|
|
203
|
+
const int BK = 64,
|
|
204
|
+
const int BN = 64,
|
|
205
|
+
const int WM = 2,
|
|
206
|
+
const int WN = 2,
|
|
207
|
+
typename Wtype = bfloat>
|
|
208
|
+
METAL_FUNC void fp_qmm_t_impl(
|
|
209
|
+
const device uint32_t* w,
|
|
210
|
+
const device uint8_t* scales,
|
|
211
|
+
const device T* x,
|
|
212
|
+
device T* y,
|
|
213
|
+
threadgroup Wtype* Ws,
|
|
214
|
+
const constant int& K,
|
|
215
|
+
const constant int& N,
|
|
216
|
+
const constant int& M,
|
|
217
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
218
|
+
uint lid [[thread_index_in_threadgroup]],
|
|
219
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
220
|
+
uint simd_lid [[thread_index_in_simdgroup]],
|
|
221
|
+
threadgroup Wtype* lut) {
|
|
222
|
+
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
|
223
|
+
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
|
224
|
+
|
|
225
|
+
(void)lid;
|
|
226
|
+
|
|
227
|
+
constexpr int pack_factor = get_pack_factor<8>();
|
|
228
|
+
constexpr int bytes_per_pack = get_bytes_per_pack();
|
|
229
|
+
|
|
230
|
+
constexpr int BK_padded = (BK + 16 / sizeof(Wtype));
|
|
231
|
+
|
|
232
|
+
// Instantiate Loader
|
|
233
|
+
using loader_w_t = QuantizedBlockLoader<
|
|
234
|
+
Wtype,
|
|
235
|
+
BN,
|
|
236
|
+
BK,
|
|
237
|
+
BK_padded,
|
|
238
|
+
1,
|
|
239
|
+
WM * WN * SIMD_SIZE,
|
|
240
|
+
group_size>;
|
|
241
|
+
|
|
242
|
+
// Set the block
|
|
243
|
+
const int K_w = K * bytes_per_pack / pack_factor;
|
|
244
|
+
const int K_g = K / group_size;
|
|
245
|
+
const int y_row = tid.y * BM;
|
|
246
|
+
const int y_col = tid.x * BN;
|
|
247
|
+
|
|
248
|
+
auto wl = (const device uint8_t*)w;
|
|
249
|
+
|
|
250
|
+
x += y_row * static_cast<int64_t>(K);
|
|
251
|
+
wl += y_col * K_w;
|
|
252
|
+
scales += y_col * K_g;
|
|
253
|
+
y += y_row * static_cast<int64_t>(N) + y_col;
|
|
254
|
+
|
|
255
|
+
// Make the weight loader
|
|
256
|
+
loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid);
|
|
257
|
+
|
|
258
|
+
constexpr short UM = 16;
|
|
259
|
+
constexpr short UN = 32;
|
|
260
|
+
constexpr short UK = 16;
|
|
261
|
+
constexpr short SM = BM / WM;
|
|
262
|
+
constexpr short SN = BN / WN;
|
|
263
|
+
constexpr short SK = 32;
|
|
264
|
+
|
|
265
|
+
constexpr short TM = SM / UM;
|
|
266
|
+
constexpr short TN = SN / UN;
|
|
267
|
+
constexpr short TK = SK / UK;
|
|
268
|
+
|
|
269
|
+
const short tm = SM * (simd_gid / WN);
|
|
270
|
+
const short tn = SN * (simd_gid % WN);
|
|
271
|
+
|
|
272
|
+
constexpr bool transpose_a = false;
|
|
273
|
+
constexpr bool transpose_b = true;
|
|
274
|
+
|
|
275
|
+
const short sgp_sm = min(SM, short(M - (y_row + tm)));
|
|
276
|
+
const bool is_unaligned_sm = (sgp_sm != SM);
|
|
277
|
+
|
|
278
|
+
const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn)));
|
|
279
|
+
|
|
280
|
+
const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col)));
|
|
281
|
+
const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN);
|
|
282
|
+
|
|
283
|
+
using AccumType = float;
|
|
284
|
+
|
|
285
|
+
using ASubTile = NAXSubTile<T, UM, UK>;
|
|
286
|
+
using BSubTile = NAXSubTile<Wtype, UN, UK>;
|
|
287
|
+
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
|
288
|
+
|
|
289
|
+
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
|
|
290
|
+
|
|
291
|
+
Dtile.clear();
|
|
292
|
+
|
|
293
|
+
x += tm * K;
|
|
294
|
+
|
|
295
|
+
dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) {
|
|
296
|
+
dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) {
|
|
297
|
+
for (int k = 0; k < K; k += BK) {
|
|
298
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
299
|
+
if constexpr (kAlignedN.value) {
|
|
300
|
+
loader_w.load_unsafe();
|
|
301
|
+
} else {
|
|
302
|
+
loader_w.load_safe(short2(BK, tgp_bn));
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
306
|
+
|
|
307
|
+
STEEL_PRAGMA_NO_UNROLL
|
|
308
|
+
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
|
|
309
|
+
NAXTile<T, TM, TK, ASubTile> Atile;
|
|
310
|
+
NAXTile<Wtype, TN, TK, BSubTile> Btile;
|
|
311
|
+
|
|
312
|
+
volatile int compiler_barrier;
|
|
313
|
+
|
|
314
|
+
if constexpr (kAlignedM.value) {
|
|
315
|
+
Atile.load(x + kk1, K);
|
|
316
|
+
} else {
|
|
317
|
+
Atile.load_safe(x + kk1, K, short2(SK, sgp_sm));
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
Btile.template load<Wtype, BK_padded, 1>(Ws + tn * BK_padded + kk1);
|
|
321
|
+
|
|
322
|
+
tile_matmad_nax(
|
|
323
|
+
Dtile,
|
|
324
|
+
Atile,
|
|
325
|
+
metal::bool_constant<transpose_a>{},
|
|
326
|
+
Btile,
|
|
327
|
+
metal::bool_constant<transpose_b>{});
|
|
328
|
+
|
|
329
|
+
(void)compiler_barrier;
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
x += BK;
|
|
333
|
+
loader_w.next();
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
// Store results to device memory
|
|
337
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
338
|
+
|
|
339
|
+
if constexpr (kAlignedM.value && kAlignedN.value) {
|
|
340
|
+
Dtile.store(y + tm * N + tn, N);
|
|
341
|
+
} else if (kAlignedM.value && sgp_sn == SN) {
|
|
342
|
+
Dtile.store(y + tm * N + tn, N);
|
|
343
|
+
} else {
|
|
344
|
+
Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm));
|
|
345
|
+
}
|
|
346
|
+
});
|
|
347
|
+
});
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
template <
|
|
351
|
+
typename T,
|
|
352
|
+
const int group_size,
|
|
353
|
+
const int bits,
|
|
354
|
+
const int BM = 64,
|
|
355
|
+
const int BK = 64,
|
|
356
|
+
const int BN = 64,
|
|
357
|
+
const int WM = 2,
|
|
358
|
+
const int WN = 2,
|
|
359
|
+
typename Wtype = bfloat>
|
|
360
|
+
METAL_FUNC void fp_qmm_n_impl(
|
|
361
|
+
const device uint32_t* w,
|
|
362
|
+
const device uint8_t* scales,
|
|
363
|
+
const device T* x,
|
|
364
|
+
device T* y,
|
|
365
|
+
threadgroup T* Ws,
|
|
366
|
+
const constant int& K,
|
|
367
|
+
const constant int& N,
|
|
368
|
+
const constant int& M,
|
|
369
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
370
|
+
uint lid [[thread_index_in_threadgroup]],
|
|
371
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
372
|
+
uint simd_lid [[thread_index_in_simdgroup]],
|
|
373
|
+
threadgroup Wtype* lut) {
|
|
374
|
+
static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE");
|
|
375
|
+
static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE");
|
|
376
|
+
|
|
377
|
+
(void)lid;
|
|
378
|
+
(void)M;
|
|
379
|
+
|
|
380
|
+
constexpr int pack_factor = get_pack_factor<8>();
|
|
381
|
+
constexpr int bytes_per_pack = get_bytes_per_pack();
|
|
382
|
+
|
|
383
|
+
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
384
|
+
|
|
385
|
+
using loader_w_t = QuantizedBlockLoader<
|
|
386
|
+
T,
|
|
387
|
+
BK,
|
|
388
|
+
BN,
|
|
389
|
+
BN_padded,
|
|
390
|
+
0,
|
|
391
|
+
WM * WN * SIMD_SIZE,
|
|
392
|
+
group_size>;
|
|
393
|
+
|
|
394
|
+
// Set the block
|
|
395
|
+
const int K_w = K * bytes_per_pack / pack_factor;
|
|
396
|
+
const int K_g = K / group_size;
|
|
397
|
+
const int y_row = tid.y * BM;
|
|
398
|
+
const int y_col = tid.x * BN;
|
|
399
|
+
|
|
400
|
+
auto wl = (const device uint8_t*)w;
|
|
401
|
+
|
|
402
|
+
x += y_row * static_cast<int64_t>(K);
|
|
403
|
+
wl += y_col * K_w;
|
|
404
|
+
scales += y_col * K_g;
|
|
405
|
+
y += y_row * static_cast<int64_t>(N) + y_col;
|
|
406
|
+
|
|
407
|
+
// Make the x loader and mma operation
|
|
408
|
+
// const short num_els = min(BM, M - y_row);
|
|
409
|
+
// const short num_outs = min(BN, N - y_col);
|
|
410
|
+
loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid);
|
|
411
|
+
|
|
412
|
+
constexpr short UM = 16;
|
|
413
|
+
constexpr short UN = 32;
|
|
414
|
+
constexpr short UK = 16;
|
|
415
|
+
constexpr short SM = BM / WM;
|
|
416
|
+
constexpr short SN = BN / WN;
|
|
417
|
+
constexpr short SK = 32;
|
|
418
|
+
|
|
419
|
+
constexpr short TM = SM / UM;
|
|
420
|
+
constexpr short TN = SN / UN;
|
|
421
|
+
constexpr short TK = SK / UK;
|
|
422
|
+
|
|
423
|
+
const short tm = SM * (simd_gid / WN);
|
|
424
|
+
const short tn = SN * (simd_gid % WN);
|
|
425
|
+
|
|
426
|
+
const short ldb_tgp = BN_padded;
|
|
427
|
+
|
|
428
|
+
constexpr bool transpose_a = false;
|
|
429
|
+
constexpr bool transpose_b = false;
|
|
430
|
+
|
|
431
|
+
using AccumType = float;
|
|
432
|
+
|
|
433
|
+
using ASubTile = NAXSubTile<T, UM, UK>;
|
|
434
|
+
using BSubTile = NAXSubTile<T, UK, UN>;
|
|
435
|
+
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
|
436
|
+
|
|
437
|
+
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
|
|
438
|
+
|
|
439
|
+
Dtile.clear();
|
|
440
|
+
|
|
441
|
+
x += tm * K;
|
|
442
|
+
|
|
443
|
+
for (int k = 0; k < K; k += BK) {
|
|
444
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
445
|
+
loader_w.load_unsafe();
|
|
446
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
447
|
+
|
|
448
|
+
STEEL_PRAGMA_NO_UNROLL
|
|
449
|
+
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
|
|
450
|
+
NAXTile<T, TM, TK, ASubTile> Atile;
|
|
451
|
+
NAXTile<Wtype, TK, TN, BSubTile> Btile;
|
|
452
|
+
|
|
453
|
+
volatile int compiler_barrier;
|
|
454
|
+
|
|
455
|
+
Atile.load(x + kk1, K);
|
|
456
|
+
Btile.template load<T, BN_padded, 1>(Ws + tn + kk1 * ldb_tgp);
|
|
457
|
+
|
|
458
|
+
tile_matmad_nax(
|
|
459
|
+
Dtile,
|
|
460
|
+
Atile,
|
|
461
|
+
metal::bool_constant<transpose_a>{},
|
|
462
|
+
Btile,
|
|
463
|
+
metal::bool_constant<transpose_b>{});
|
|
464
|
+
|
|
465
|
+
(void)compiler_barrier;
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
x += BK;
|
|
469
|
+
loader_w.next();
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
// Store results to device memory
|
|
473
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
474
|
+
|
|
475
|
+
Dtile.store(y + tm * N + tn, N);
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
template <typename T, typename S>
|
|
479
|
+
METAL_FUNC void adjust_matrix_offsets(
|
|
480
|
+
const device T*& x,
|
|
481
|
+
const device uint32_t*& w,
|
|
482
|
+
const device S*& scales,
|
|
483
|
+
device T*& y,
|
|
484
|
+
int output_stride,
|
|
485
|
+
const constant int& x_batch_ndims,
|
|
486
|
+
const constant int* x_shape,
|
|
487
|
+
const constant int64_t* x_strides,
|
|
488
|
+
const constant int& w_batch_ndims,
|
|
489
|
+
const constant int* w_shape,
|
|
490
|
+
const constant int64_t* w_strides,
|
|
491
|
+
const constant int64_t* s_strides,
|
|
492
|
+
uint3 tid [[threadgroup_position_in_grid]]) {
|
|
493
|
+
// Set the input/output matrices
|
|
494
|
+
uint32_t x_idx = tid.z;
|
|
495
|
+
uint32_t w_idx = tid.z;
|
|
496
|
+
if (x_batch_ndims == 1) {
|
|
497
|
+
x += x_idx * x_strides[0];
|
|
498
|
+
} else {
|
|
499
|
+
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
|
500
|
+
}
|
|
501
|
+
if (w_batch_ndims == 1) {
|
|
502
|
+
w += w_idx * w_strides[0];
|
|
503
|
+
scales += w_idx * s_strides[0];
|
|
504
|
+
} else {
|
|
505
|
+
ulong2 idx = elem_to_loc_broadcast(
|
|
506
|
+
w_idx, w_shape, w_strides, s_strides, w_batch_ndims);
|
|
507
|
+
w += idx.x;
|
|
508
|
+
scales += idx.y;
|
|
509
|
+
}
|
|
510
|
+
y += tid.z * output_stride;
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
template <typename T, typename S>
|
|
514
|
+
METAL_FUNC void adjust_matrix_offsets(
|
|
515
|
+
const device T*& x,
|
|
516
|
+
const device uint32_t*& w,
|
|
517
|
+
const device S*& scales,
|
|
518
|
+
const device uint32_t* lhs_indices,
|
|
519
|
+
const device uint32_t* rhs_indices,
|
|
520
|
+
device T*& y,
|
|
521
|
+
int output_stride,
|
|
522
|
+
const constant int& batch_ndims,
|
|
523
|
+
const constant int* batch_shape,
|
|
524
|
+
const constant int64_t* lhs_strides,
|
|
525
|
+
const constant int64_t* rhs_strides,
|
|
526
|
+
const constant int& x_batch_ndims,
|
|
527
|
+
const constant int* x_shape,
|
|
528
|
+
const constant int64_t* x_strides,
|
|
529
|
+
const constant int& w_batch_ndims,
|
|
530
|
+
const constant int* w_shape,
|
|
531
|
+
const constant int64_t* w_strides,
|
|
532
|
+
const constant int64_t* s_strides,
|
|
533
|
+
uint3 tid [[threadgroup_position_in_grid]]) {
|
|
534
|
+
// Set the input/output matrices
|
|
535
|
+
uint32_t x_idx;
|
|
536
|
+
uint32_t w_idx;
|
|
537
|
+
if (batch_ndims == 1) {
|
|
538
|
+
x_idx = lhs_indices[tid.z * lhs_strides[0]];
|
|
539
|
+
w_idx = rhs_indices[tid.z * rhs_strides[0]];
|
|
540
|
+
} else {
|
|
541
|
+
ulong2 idx = elem_to_loc_broadcast(
|
|
542
|
+
tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
|
|
543
|
+
x_idx = lhs_indices[idx.x];
|
|
544
|
+
w_idx = rhs_indices[idx.y];
|
|
545
|
+
}
|
|
546
|
+
if (x_batch_ndims == 1) {
|
|
547
|
+
x += x_idx * x_strides[0];
|
|
548
|
+
} else {
|
|
549
|
+
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
|
550
|
+
}
|
|
551
|
+
if (w_batch_ndims == 1) {
|
|
552
|
+
w += w_idx * w_strides[0];
|
|
553
|
+
scales += w_idx * s_strides[0];
|
|
554
|
+
} else {
|
|
555
|
+
ulong2 idx = elem_to_loc_broadcast(
|
|
556
|
+
w_idx, w_shape, w_strides, s_strides, w_batch_ndims);
|
|
557
|
+
w += idx.x;
|
|
558
|
+
scales += idx.y;
|
|
559
|
+
}
|
|
560
|
+
y += tid.z * output_stride;
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
template <
|
|
564
|
+
typename T,
|
|
565
|
+
const int group_size,
|
|
566
|
+
const int bits,
|
|
567
|
+
const bool aligned_N,
|
|
568
|
+
const bool batched,
|
|
569
|
+
const int BM = 64,
|
|
570
|
+
const int BK = 64,
|
|
571
|
+
const int BN = 64,
|
|
572
|
+
const int WM = 2,
|
|
573
|
+
const int WN = 2,
|
|
574
|
+
typename Wtype = bfloat>
|
|
575
|
+
[[kernel]] void fp_qmm_t_nax(
|
|
576
|
+
const device uint32_t* w,
|
|
577
|
+
const device uint8_t* scales,
|
|
578
|
+
const device T* x,
|
|
579
|
+
device T* y,
|
|
580
|
+
const constant int& K,
|
|
581
|
+
const constant int& N,
|
|
582
|
+
const constant int& M,
|
|
583
|
+
const constant int& x_batch_ndims,
|
|
584
|
+
const constant int* x_shape,
|
|
585
|
+
const constant int64_t* x_strides,
|
|
586
|
+
const constant int& w_batch_ndims,
|
|
587
|
+
const constant int* w_shape,
|
|
588
|
+
const constant int64_t* w_strides,
|
|
589
|
+
const constant int64_t* s_strides,
|
|
590
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
591
|
+
uint lid [[thread_index_in_threadgroup]],
|
|
592
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
593
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
594
|
+
(void)lid;
|
|
595
|
+
|
|
596
|
+
constexpr int BK_padded = (BK + 16 / sizeof(Wtype));
|
|
597
|
+
|
|
598
|
+
threadgroup Wtype Ws[BN * BK_padded];
|
|
599
|
+
threadgroup Wtype lut[16];
|
|
600
|
+
|
|
601
|
+
if (batched) {
|
|
602
|
+
adjust_matrix_offsets(
|
|
603
|
+
x,
|
|
604
|
+
w,
|
|
605
|
+
scales,
|
|
606
|
+
y,
|
|
607
|
+
M * N,
|
|
608
|
+
x_batch_ndims,
|
|
609
|
+
x_shape,
|
|
610
|
+
x_strides,
|
|
611
|
+
w_batch_ndims,
|
|
612
|
+
w_shape,
|
|
613
|
+
w_strides,
|
|
614
|
+
s_strides,
|
|
615
|
+
tid);
|
|
616
|
+
}
|
|
617
|
+
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN, WM, WN, Wtype>(
|
|
618
|
+
w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
template <
|
|
622
|
+
typename T,
|
|
623
|
+
const int group_size,
|
|
624
|
+
const int bits,
|
|
625
|
+
const bool batched,
|
|
626
|
+
const int BM = 64,
|
|
627
|
+
const int BK = 64,
|
|
628
|
+
const int BN = 64,
|
|
629
|
+
const int WM = 2,
|
|
630
|
+
const int WN = 2,
|
|
631
|
+
typename Wtype = bfloat>
|
|
632
|
+
[[kernel]] void fp_qmm_n_nax(
|
|
633
|
+
const device uint32_t* w,
|
|
634
|
+
const device uint8_t* scales,
|
|
635
|
+
const device T* x,
|
|
636
|
+
device T* y,
|
|
637
|
+
const constant int& K,
|
|
638
|
+
const constant int& N,
|
|
639
|
+
const constant int& M,
|
|
640
|
+
const constant int& x_batch_ndims,
|
|
641
|
+
const constant int* x_shape,
|
|
642
|
+
const constant int64_t* x_strides,
|
|
643
|
+
const constant int& w_batch_ndims,
|
|
644
|
+
const constant int* w_shape,
|
|
645
|
+
const constant int64_t* w_strides,
|
|
646
|
+
const constant int64_t* s_strides,
|
|
647
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
648
|
+
uint lid [[thread_index_in_threadgroup]],
|
|
649
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
650
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
651
|
+
(void)lid;
|
|
652
|
+
|
|
653
|
+
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
654
|
+
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
655
|
+
|
|
656
|
+
threadgroup T Xs[BM * BK_padded];
|
|
657
|
+
threadgroup T Ws[BK * BN_padded];
|
|
658
|
+
threadgroup T lut[16];
|
|
659
|
+
|
|
660
|
+
if (batched) {
|
|
661
|
+
adjust_matrix_offsets(
|
|
662
|
+
x,
|
|
663
|
+
w,
|
|
664
|
+
scales,
|
|
665
|
+
y,
|
|
666
|
+
M * N,
|
|
667
|
+
x_batch_ndims,
|
|
668
|
+
x_shape,
|
|
669
|
+
x_strides,
|
|
670
|
+
w_batch_ndims,
|
|
671
|
+
w_shape,
|
|
672
|
+
w_strides,
|
|
673
|
+
s_strides,
|
|
674
|
+
tid);
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN, WM, WN, Wtype>(
|
|
678
|
+
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
template <
|
|
682
|
+
typename T,
|
|
683
|
+
const int group_size,
|
|
684
|
+
const int bits,
|
|
685
|
+
const bool aligned_N,
|
|
686
|
+
const int BM = 64,
|
|
687
|
+
const int BK = 64,
|
|
688
|
+
const int BN = 64,
|
|
689
|
+
const int WM = 2,
|
|
690
|
+
const int WN = 2,
|
|
691
|
+
typename Wtype = bfloat>
|
|
692
|
+
[[kernel]] void fp_gather_qmm_t_nax(
|
|
693
|
+
const device uint32_t* w,
|
|
694
|
+
const device uint8_t* scales,
|
|
695
|
+
const device T* x,
|
|
696
|
+
const device uint32_t* lhs_indices,
|
|
697
|
+
const device uint32_t* rhs_indices,
|
|
698
|
+
device T* y,
|
|
699
|
+
const constant int& K,
|
|
700
|
+
const constant int& N,
|
|
701
|
+
const constant int& M,
|
|
702
|
+
const constant int& x_batch_ndims,
|
|
703
|
+
const constant int* x_shape,
|
|
704
|
+
const constant int64_t* x_strides,
|
|
705
|
+
const constant int& w_batch_ndims,
|
|
706
|
+
const constant int* w_shape,
|
|
707
|
+
const constant int64_t* w_strides,
|
|
708
|
+
const constant int64_t* s_strides,
|
|
709
|
+
const constant int& batch_ndims,
|
|
710
|
+
const constant int* batch_shape,
|
|
711
|
+
const constant int64_t* lhs_strides,
|
|
712
|
+
const constant int64_t* rhs_strides,
|
|
713
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
714
|
+
uint lid [[thread_index_in_threadgroup]],
|
|
715
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
716
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
717
|
+
(void)lid;
|
|
718
|
+
|
|
719
|
+
constexpr int BK_padded = (BK + 16 / sizeof(Wtype));
|
|
720
|
+
|
|
721
|
+
threadgroup Wtype Ws[BN * BK_padded];
|
|
722
|
+
threadgroup Wtype lut[16];
|
|
723
|
+
|
|
724
|
+
adjust_matrix_offsets(
|
|
725
|
+
x,
|
|
726
|
+
w,
|
|
727
|
+
scales,
|
|
728
|
+
lhs_indices,
|
|
729
|
+
rhs_indices,
|
|
730
|
+
y,
|
|
731
|
+
M * N,
|
|
732
|
+
batch_ndims,
|
|
733
|
+
batch_shape,
|
|
734
|
+
lhs_strides,
|
|
735
|
+
rhs_strides,
|
|
736
|
+
x_batch_ndims,
|
|
737
|
+
x_shape,
|
|
738
|
+
x_strides,
|
|
739
|
+
w_batch_ndims,
|
|
740
|
+
w_shape,
|
|
741
|
+
w_strides,
|
|
742
|
+
s_strides,
|
|
743
|
+
tid);
|
|
744
|
+
fp_qmm_t_impl<T, group_size, bits, aligned_N, BM, BK, BN, WM, WN, Wtype>(
|
|
745
|
+
w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
|
746
|
+
}
|
|
747
|
+
|
|
748
|
+
template <
|
|
749
|
+
typename T,
|
|
750
|
+
const int group_size,
|
|
751
|
+
const int bits,
|
|
752
|
+
const int BM = 64,
|
|
753
|
+
const int BK = 64,
|
|
754
|
+
const int BN = 64,
|
|
755
|
+
const int WM = 2,
|
|
756
|
+
const int WN = 2,
|
|
757
|
+
typename Wtype = bfloat>
|
|
758
|
+
[[kernel]] void fp_gather_qmm_n_nax(
|
|
759
|
+
const device uint32_t* w,
|
|
760
|
+
const device uint8_t* scales,
|
|
761
|
+
const device T* x,
|
|
762
|
+
const device uint32_t* lhs_indices,
|
|
763
|
+
const device uint32_t* rhs_indices,
|
|
764
|
+
device T* y,
|
|
765
|
+
const constant int& K,
|
|
766
|
+
const constant int& N,
|
|
767
|
+
const constant int& M,
|
|
768
|
+
const constant int& x_batch_ndims,
|
|
769
|
+
const constant int* x_shape,
|
|
770
|
+
const constant int64_t* x_strides,
|
|
771
|
+
const constant int& w_batch_ndims,
|
|
772
|
+
const constant int* w_shape,
|
|
773
|
+
const constant int64_t* w_strides,
|
|
774
|
+
const constant int64_t* s_strides,
|
|
775
|
+
const constant int& batch_ndims,
|
|
776
|
+
const constant int* batch_shape,
|
|
777
|
+
const constant int64_t* lhs_strides,
|
|
778
|
+
const constant int64_t* rhs_strides,
|
|
779
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
780
|
+
uint lid [[thread_index_in_threadgroup]],
|
|
781
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
782
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
783
|
+
(void)lid;
|
|
784
|
+
|
|
785
|
+
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
|
786
|
+
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
|
787
|
+
|
|
788
|
+
threadgroup T Xs[BM * BK_padded];
|
|
789
|
+
threadgroup T Ws[BK * BN_padded];
|
|
790
|
+
threadgroup T lut[16];
|
|
791
|
+
|
|
792
|
+
adjust_matrix_offsets(
|
|
793
|
+
x,
|
|
794
|
+
w,
|
|
795
|
+
scales,
|
|
796
|
+
lhs_indices,
|
|
797
|
+
rhs_indices,
|
|
798
|
+
y,
|
|
799
|
+
M * N,
|
|
800
|
+
batch_ndims,
|
|
801
|
+
batch_shape,
|
|
802
|
+
lhs_strides,
|
|
803
|
+
rhs_strides,
|
|
804
|
+
x_batch_ndims,
|
|
805
|
+
x_shape,
|
|
806
|
+
x_strides,
|
|
807
|
+
w_batch_ndims,
|
|
808
|
+
w_shape,
|
|
809
|
+
w_strides,
|
|
810
|
+
s_strides,
|
|
811
|
+
tid);
|
|
812
|
+
fp_qmm_n_impl<T, group_size, bits, BM, BK, BN, WM, WN, Wtype>(
|
|
813
|
+
w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut);
|
|
814
|
+
}
|
|
815
|
+
|
|
816
|
+
template <
|
|
817
|
+
typename T,
|
|
818
|
+
int group_size,
|
|
819
|
+
const int bits,
|
|
820
|
+
int BM,
|
|
821
|
+
int BN,
|
|
822
|
+
int BK,
|
|
823
|
+
int WM,
|
|
824
|
+
int WN,
|
|
825
|
+
bool transpose,
|
|
826
|
+
typename Wtype = bfloat>
|
|
827
|
+
[[kernel]] void fp_gather_qmm_rhs_nax(
|
|
828
|
+
const device T* x,
|
|
829
|
+
const device uint32_t* w,
|
|
830
|
+
const device uint8_t* scales,
|
|
831
|
+
const device uint32_t* indices,
|
|
832
|
+
device T* y,
|
|
833
|
+
const constant int& M,
|
|
834
|
+
const constant int& N,
|
|
835
|
+
const constant int& K,
|
|
836
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
837
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
838
|
+
uint simd_lane_id [[thread_index_in_simdgroup]]) {
|
|
839
|
+
constexpr int pack_factor = get_pack_factor<8>();
|
|
840
|
+
constexpr int bytes_per_pack = get_bytes_per_pack();
|
|
841
|
+
constexpr int BK_padded = (BK + 16 / sizeof(Wtype));
|
|
842
|
+
constexpr int BN_padded = (BN + 16 / sizeof(Wtype));
|
|
843
|
+
|
|
844
|
+
threadgroup Wtype lut[16];
|
|
845
|
+
|
|
846
|
+
using loader_w_t = QuantizedBlockLoader<
|
|
847
|
+
Wtype,
|
|
848
|
+
transpose ? BN : BK,
|
|
849
|
+
transpose ? BK : BN,
|
|
850
|
+
transpose ? BK_padded : BN_padded,
|
|
851
|
+
transpose,
|
|
852
|
+
WM * WN * SIMD_SIZE,
|
|
853
|
+
group_size>;
|
|
854
|
+
|
|
855
|
+
threadgroup Wtype Ws[transpose ? BN * BK_padded : BK * BN_padded];
|
|
856
|
+
|
|
857
|
+
// Compute the block
|
|
858
|
+
const int K_w = K * bytes_per_pack / pack_factor;
|
|
859
|
+
const int K_g = K / group_size;
|
|
860
|
+
const int N_w = N * bytes_per_pack / pack_factor;
|
|
861
|
+
const int N_g = N / group_size;
|
|
862
|
+
const int K_it = K / BK;
|
|
863
|
+
const size_t stride_w = transpose ? N * K_w : K * N_w;
|
|
864
|
+
const size_t stride_s = transpose ? N * K_g : K * N_g;
|
|
865
|
+
const int y_row = tid.y * BM;
|
|
866
|
+
const int y_col = tid.x * BN;
|
|
867
|
+
const size_t y_row_long = size_t(y_row);
|
|
868
|
+
const size_t y_col_long = size_t(y_col);
|
|
869
|
+
|
|
870
|
+
// Prepare threadgroup bounds
|
|
871
|
+
const short tgp_bm = align_M ? BM : short(min(BM, M - y_row));
|
|
872
|
+
const short tgp_bn = align_N ? BN : short(min(BN, N - y_col));
|
|
873
|
+
|
|
874
|
+
// Calculate the final tiles in the case that K is not aligned
|
|
875
|
+
const int k_remain = K - K_it * BK;
|
|
876
|
+
const short2 tile_w =
|
|
877
|
+
transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
878
|
+
|
|
879
|
+
// Move x and output to the correct block
|
|
880
|
+
auto wl = (const device uint8_t*)w;
|
|
881
|
+
x += y_row_long * K;
|
|
882
|
+
y += y_row_long * N + y_col_long;
|
|
883
|
+
wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor;
|
|
884
|
+
scales += transpose ? y_col_long * K_g : y_col / group_size;
|
|
885
|
+
|
|
886
|
+
constexpr short UM = 16;
|
|
887
|
+
constexpr short UN = 32;
|
|
888
|
+
constexpr short UK = 16;
|
|
889
|
+
constexpr short SM = BM / WM;
|
|
890
|
+
constexpr short SN = BN / WN;
|
|
891
|
+
constexpr short SK = 32;
|
|
892
|
+
|
|
893
|
+
constexpr short TM = SM / UM;
|
|
894
|
+
constexpr short TN = SN / UN;
|
|
895
|
+
constexpr short TK = SK / UK;
|
|
896
|
+
|
|
897
|
+
const short tm = SM * (simd_group_id / WN);
|
|
898
|
+
const short tn = SN * (simd_group_id % WN);
|
|
899
|
+
|
|
900
|
+
const short sgp_sm =
|
|
901
|
+
align_M ? SM : min(SM, short(max(0, (M - (y_row + tm)))));
|
|
902
|
+
const short sgp_sn =
|
|
903
|
+
align_N ? SN : min(SN, short(max(0, (N - (y_col + tn)))));
|
|
904
|
+
|
|
905
|
+
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
|
|
906
|
+
const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN);
|
|
907
|
+
|
|
908
|
+
constexpr short BR = transpose ? TN : TK;
|
|
909
|
+
constexpr short BC = transpose ? TK : TN;
|
|
910
|
+
|
|
911
|
+
using AccumType = float;
|
|
912
|
+
|
|
913
|
+
using ASubTile = NAXSubTile<T, UM, UK>;
|
|
914
|
+
using BSubTile = NAXSubTile<Wtype, transpose ? UN : UK, transpose ? UK : UN>;
|
|
915
|
+
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
|
916
|
+
|
|
917
|
+
// Do as many matmuls as necessary
|
|
918
|
+
uint32_t index;
|
|
919
|
+
short offset;
|
|
920
|
+
uint32_t index_next = indices[y_row];
|
|
921
|
+
short offset_next = 0;
|
|
922
|
+
int n = 0;
|
|
923
|
+
while (n < tgp_bm) {
|
|
924
|
+
n++;
|
|
925
|
+
offset = offset_next;
|
|
926
|
+
index = index_next;
|
|
927
|
+
offset_next = tgp_bm;
|
|
928
|
+
for (; n < tgp_bm; n++) {
|
|
929
|
+
if (indices[y_row + n] != index) {
|
|
930
|
+
offset_next = n;
|
|
931
|
+
index_next = indices[y_row + n];
|
|
932
|
+
break;
|
|
933
|
+
}
|
|
934
|
+
}
|
|
935
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
936
|
+
|
|
937
|
+
// Prepare threadgroup mma operation
|
|
938
|
+
NAXTile<AccumType, TM, TN, DSubTile> Dtile;
|
|
939
|
+
|
|
940
|
+
Dtile.clear();
|
|
941
|
+
|
|
942
|
+
const device T* xn = x + tm * K;
|
|
943
|
+
|
|
944
|
+
// Prepare threadgroup loading operations
|
|
945
|
+
thread loader_w_t loader_w(
|
|
946
|
+
wl + index * stride_w,
|
|
947
|
+
scales + index * stride_s,
|
|
948
|
+
transpose ? K : N,
|
|
949
|
+
Ws,
|
|
950
|
+
lut,
|
|
951
|
+
simd_group_id,
|
|
952
|
+
simd_lane_id);
|
|
953
|
+
|
|
954
|
+
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
|
|
955
|
+
dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) {
|
|
956
|
+
for (int k = 0; k < K_it; k++) {
|
|
957
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
958
|
+
if constexpr (kAlignedN.value) {
|
|
959
|
+
loader_w.load_unsafe();
|
|
960
|
+
} else {
|
|
961
|
+
loader_w.load_safe(
|
|
962
|
+
transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK));
|
|
963
|
+
}
|
|
964
|
+
|
|
965
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
966
|
+
|
|
967
|
+
STEEL_PRAGMA_NO_UNROLL
|
|
968
|
+
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
|
|
969
|
+
NAXTile<T, TM, TK, ASubTile> Atile;
|
|
970
|
+
NAXTile<Wtype, BR, BC, BSubTile> Btile;
|
|
971
|
+
|
|
972
|
+
volatile int compiler_barrier;
|
|
973
|
+
|
|
974
|
+
if constexpr (kAlignedM.value) {
|
|
975
|
+
Atile.load(xn + kk1, K);
|
|
976
|
+
} else {
|
|
977
|
+
Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm));
|
|
978
|
+
}
|
|
979
|
+
|
|
980
|
+
if constexpr (transpose) {
|
|
981
|
+
Btile.template load<Wtype, BK_padded, 1>(
|
|
982
|
+
Ws + tn * BK_padded + kk1);
|
|
983
|
+
} else {
|
|
984
|
+
Btile.template load<Wtype, BN_padded, 1>(
|
|
985
|
+
Ws + tn + kk1 * BN_padded);
|
|
986
|
+
}
|
|
987
|
+
|
|
988
|
+
tile_matmad_nax(
|
|
989
|
+
Dtile,
|
|
990
|
+
Atile,
|
|
991
|
+
metal::bool_constant<false>{},
|
|
992
|
+
Btile,
|
|
993
|
+
metal::bool_constant<transpose>{});
|
|
994
|
+
|
|
995
|
+
(void)compiler_barrier;
|
|
996
|
+
}
|
|
997
|
+
|
|
998
|
+
xn += BK;
|
|
999
|
+
loader_w.next();
|
|
1000
|
+
}
|
|
1001
|
+
|
|
1002
|
+
if (!align_K) {
|
|
1003
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1004
|
+
loader_w.load_safe(tile_w);
|
|
1005
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1006
|
+
|
|
1007
|
+
STEEL_PRAGMA_NO_UNROLL
|
|
1008
|
+
for (int kk1 = 0; kk1 < BK; kk1 += SK) {
|
|
1009
|
+
NAXTile<T, TM, TK, ASubTile> Atile;
|
|
1010
|
+
NAXTile<Wtype, BR, BC, BSubTile> Btile;
|
|
1011
|
+
|
|
1012
|
+
volatile int compiler_barrier;
|
|
1013
|
+
|
|
1014
|
+
const short psk = min(int(SK), max(0, (BK - kk1)));
|
|
1015
|
+
Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm));
|
|
1016
|
+
|
|
1017
|
+
if constexpr (transpose) {
|
|
1018
|
+
Btile.template load<Wtype, BK_padded, 1>(
|
|
1019
|
+
Ws + tn * BK_padded + kk1);
|
|
1020
|
+
} else {
|
|
1021
|
+
Btile.template load<Wtype, BN_padded, 1>(
|
|
1022
|
+
Ws + tn + kk1 * BN_padded);
|
|
1023
|
+
}
|
|
1024
|
+
|
|
1025
|
+
tile_matmad_nax(
|
|
1026
|
+
Dtile,
|
|
1027
|
+
Atile,
|
|
1028
|
+
metal::bool_constant<false>{},
|
|
1029
|
+
Btile,
|
|
1030
|
+
metal::bool_constant<transpose>{});
|
|
1031
|
+
|
|
1032
|
+
(void)compiler_barrier;
|
|
1033
|
+
}
|
|
1034
|
+
}
|
|
1035
|
+
|
|
1036
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
1037
|
+
|
|
1038
|
+
const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm));
|
|
1039
|
+
const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm));
|
|
1040
|
+
|
|
1041
|
+
// Store results to device memory
|
|
1042
|
+
if constexpr (kAlignedN.value) {
|
|
1043
|
+
if (m_lo_lim == 0 && m_hi_lim == SM) {
|
|
1044
|
+
Dtile.store(y + tm * N + tn, N);
|
|
1045
|
+
} else {
|
|
1046
|
+
Dtile.store_slice(
|
|
1047
|
+
y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim));
|
|
1048
|
+
}
|
|
1049
|
+
} else {
|
|
1050
|
+
Dtile.store_slice(
|
|
1051
|
+
y + tm * N + tn,
|
|
1052
|
+
N,
|
|
1053
|
+
short2(0, m_lo_lim),
|
|
1054
|
+
short2(sgp_sn, m_hi_lim));
|
|
1055
|
+
}
|
|
1056
|
+
});
|
|
1057
|
+
});
|
|
1058
|
+
}
|
|
1059
|
+
}
|