mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#include <metal_simdgroup>
|
|
4
|
+
|
|
5
|
+
using namespace metal;
|
|
6
|
+
|
|
7
|
+
constant bool has_mask [[function_constant(20)]];
|
|
8
|
+
constant bool query_transposed [[function_constant(21)]];
|
|
9
|
+
constant bool do_causal [[function_constant(22)]];
|
|
10
|
+
constant bool bool_mask [[function_constant(23)]];
|
|
11
|
+
constant bool float_mask [[function_constant(24)]];
|
|
12
|
+
constant bool has_sinks [[function_constant(25)]];
|
|
13
|
+
|
|
14
|
+
template <typename T, int D, int V = D>
|
|
15
|
+
[[kernel]] void sdpa_vector(
|
|
16
|
+
const device T* queries [[buffer(0)]],
|
|
17
|
+
const device T* keys [[buffer(1)]],
|
|
18
|
+
const device T* values [[buffer(2)]],
|
|
19
|
+
device T* out [[buffer(3)]],
|
|
20
|
+
const constant int& gqa_factor [[buffer(4)]],
|
|
21
|
+
const constant int& N [[buffer(5)]],
|
|
22
|
+
const constant size_t& k_head_stride [[buffer(6)]],
|
|
23
|
+
const constant size_t& k_seq_stride [[buffer(7)]],
|
|
24
|
+
const constant size_t& v_head_stride [[buffer(8)]],
|
|
25
|
+
const constant size_t& v_seq_stride [[buffer(9)]],
|
|
26
|
+
const constant float& scale [[buffer(10)]],
|
|
27
|
+
const device bool* bmask [[buffer(11), function_constant(bool_mask)]],
|
|
28
|
+
const device T* fmask [[buffer(12), function_constant(float_mask)]],
|
|
29
|
+
const constant int& mask_kv_seq_stride
|
|
30
|
+
[[buffer(13), function_constant(has_mask)]],
|
|
31
|
+
const constant int& mask_q_seq_stride
|
|
32
|
+
[[buffer(14), function_constant(has_mask)]],
|
|
33
|
+
const constant int& mask_head_stride
|
|
34
|
+
[[buffer(15), function_constant(has_mask)]],
|
|
35
|
+
const device T* sinks [[buffer(16), function_constant(has_sinks)]],
|
|
36
|
+
const constant int& num_q_heads
|
|
37
|
+
[[buffer(17), function_constant(has_sinks)]],
|
|
38
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
39
|
+
uint3 tpg [[threadgroups_per_grid]],
|
|
40
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
41
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
42
|
+
constexpr int BN = 32;
|
|
43
|
+
constexpr int BD = 32;
|
|
44
|
+
constexpr int qk_per_thread = D / BD;
|
|
45
|
+
constexpr int v_per_thread = V / BD;
|
|
46
|
+
int inner_k_stride = BN * int(k_seq_stride);
|
|
47
|
+
int inner_v_stride = BN * int(v_seq_stride);
|
|
48
|
+
|
|
49
|
+
typedef float U;
|
|
50
|
+
|
|
51
|
+
thread U q[qk_per_thread];
|
|
52
|
+
thread U k[qk_per_thread];
|
|
53
|
+
thread U o[v_per_thread];
|
|
54
|
+
|
|
55
|
+
threadgroup U outputs[BN * BD];
|
|
56
|
+
threadgroup U max_scores[BN];
|
|
57
|
+
threadgroup U sum_exp_scores[BN];
|
|
58
|
+
|
|
59
|
+
// Adjust positions
|
|
60
|
+
const int q_batch_head_idx = tid.x;
|
|
61
|
+
const int q_seq_idx = tid.y;
|
|
62
|
+
const int kv_head_idx = q_batch_head_idx / gqa_factor;
|
|
63
|
+
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
|
|
64
|
+
const int q_offset =
|
|
65
|
+
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
|
|
66
|
+
queries += q_offset * D + simd_lid * qk_per_thread;
|
|
67
|
+
keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride +
|
|
68
|
+
simd_lid * qk_per_thread;
|
|
69
|
+
values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride +
|
|
70
|
+
simd_lid * v_per_thread;
|
|
71
|
+
if (bool_mask) {
|
|
72
|
+
bmask += q_batch_head_idx * mask_head_stride +
|
|
73
|
+
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
|
|
74
|
+
}
|
|
75
|
+
if (float_mask) {
|
|
76
|
+
fmask += q_batch_head_idx * mask_head_stride +
|
|
77
|
+
simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
out += o_offset * V + simd_gid * v_per_thread;
|
|
81
|
+
|
|
82
|
+
// Read the query and 0 the output accumulator
|
|
83
|
+
for (int i = 0; i < qk_per_thread; i++) {
|
|
84
|
+
q[i] = static_cast<U>(scale) * queries[i];
|
|
85
|
+
}
|
|
86
|
+
for (int i = 0; i < v_per_thread; i++) {
|
|
87
|
+
o[i] = 0;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
U max_score = Limits<U>::finite_min;
|
|
91
|
+
U sum_exp_score = 0;
|
|
92
|
+
if (has_sinks && simd_gid == 0) {
|
|
93
|
+
max_score = static_cast<U>(sinks[q_batch_head_idx % num_q_heads]);
|
|
94
|
+
sum_exp_score = 1;
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
// For each key
|
|
98
|
+
for (int i = simd_gid; i < N; i += BN) {
|
|
99
|
+
bool use_key = true;
|
|
100
|
+
if (do_causal) {
|
|
101
|
+
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
|
102
|
+
} else if (bool_mask) {
|
|
103
|
+
use_key = bmask[0];
|
|
104
|
+
} else if (float_mask) {
|
|
105
|
+
use_key = (fmask[0] >= Limits<T>::finite_min);
|
|
106
|
+
}
|
|
107
|
+
if (use_key) {
|
|
108
|
+
// Read the key
|
|
109
|
+
for (int j = 0; j < qk_per_thread; j++) {
|
|
110
|
+
k[j] = keys[j];
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
// Compute the i-th score
|
|
114
|
+
U score = 0;
|
|
115
|
+
for (int j = 0; j < qk_per_thread; j++) {
|
|
116
|
+
score += q[j] * k[j];
|
|
117
|
+
}
|
|
118
|
+
score = simd_sum(score);
|
|
119
|
+
if (float_mask) {
|
|
120
|
+
score += static_cast<U>(fmask[0]);
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
// Update the accumulators
|
|
124
|
+
U new_max = max(max_score, score);
|
|
125
|
+
U factor = fast::exp(max_score - new_max);
|
|
126
|
+
U exp_score = fast::exp(score - new_max);
|
|
127
|
+
|
|
128
|
+
max_score = new_max;
|
|
129
|
+
sum_exp_score = sum_exp_score * factor + exp_score;
|
|
130
|
+
|
|
131
|
+
// Update the output accumulator
|
|
132
|
+
for (int j = 0; j < v_per_thread; j++) {
|
|
133
|
+
o[j] = o[j] * factor + exp_score * values[j];
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
// Move the pointers to the next kv
|
|
138
|
+
keys += inner_k_stride;
|
|
139
|
+
values += inner_v_stride;
|
|
140
|
+
if (bool_mask) {
|
|
141
|
+
bmask += BN * mask_kv_seq_stride;
|
|
142
|
+
}
|
|
143
|
+
if (float_mask) {
|
|
144
|
+
fmask += BN * mask_kv_seq_stride;
|
|
145
|
+
}
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
// Each thread has a partial part of the output so we need to combine them.
|
|
149
|
+
|
|
150
|
+
// First let's communicate the max and sum_exp
|
|
151
|
+
if (simd_lid == 0) {
|
|
152
|
+
max_scores[simd_gid] = max_score;
|
|
153
|
+
sum_exp_scores[simd_gid] = sum_exp_score;
|
|
154
|
+
}
|
|
155
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
156
|
+
max_score = max_scores[simd_lid];
|
|
157
|
+
U new_max = simd_max(max_score);
|
|
158
|
+
U factor = fast::exp(max_score - new_max);
|
|
159
|
+
sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor);
|
|
160
|
+
|
|
161
|
+
// Now we need to aggregate all the outputs
|
|
162
|
+
for (int i = 0; i < v_per_thread; i++) {
|
|
163
|
+
outputs[simd_lid * BD + simd_gid] = o[i];
|
|
164
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
165
|
+
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
|
|
166
|
+
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
|
|
167
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
// And write the output
|
|
171
|
+
if (simd_lid == 0) {
|
|
172
|
+
for (int i = 0; i < v_per_thread; i++) {
|
|
173
|
+
out[i] = static_cast<T>(o[i]);
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
template <typename T, int D, int V = D>
|
|
179
|
+
[[kernel]] void sdpa_vector_2pass_1(
|
|
180
|
+
const device T* queries [[buffer(0)]],
|
|
181
|
+
const device T* keys [[buffer(1)]],
|
|
182
|
+
const device T* values [[buffer(2)]],
|
|
183
|
+
device float* out [[buffer(3)]],
|
|
184
|
+
device float* sums [[buffer(4)]],
|
|
185
|
+
device float* maxs [[buffer(5)]],
|
|
186
|
+
const constant int& gqa_factor [[buffer(6)]],
|
|
187
|
+
const constant int& N [[buffer(7)]],
|
|
188
|
+
const constant size_t& k_head_stride [[buffer(8)]],
|
|
189
|
+
const constant size_t& k_seq_stride [[buffer(9)]],
|
|
190
|
+
const constant size_t& v_head_stride [[buffer(10)]],
|
|
191
|
+
const constant size_t& v_seq_stride [[buffer(11)]],
|
|
192
|
+
const constant float& scale [[buffer(12)]],
|
|
193
|
+
const device bool* bmask [[buffer(13), function_constant(bool_mask)]],
|
|
194
|
+
const device T* fmask [[buffer(14), function_constant(float_mask)]],
|
|
195
|
+
const constant int& mask_kv_seq_stride
|
|
196
|
+
[[buffer(15), function_constant(has_mask)]],
|
|
197
|
+
const constant int& mask_q_seq_stride
|
|
198
|
+
[[buffer(16), function_constant(has_mask)]],
|
|
199
|
+
const constant int& mask_head_stride
|
|
200
|
+
[[buffer(17), function_constant(has_mask)]],
|
|
201
|
+
const device T* sinks [[buffer(18), function_constant(has_sinks)]],
|
|
202
|
+
const constant int& num_q_heads
|
|
203
|
+
[[buffer(19), function_constant(has_sinks)]],
|
|
204
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
205
|
+
uint3 tpg [[threadgroups_per_grid]],
|
|
206
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
207
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
208
|
+
constexpr int BN = 8;
|
|
209
|
+
constexpr int BD = 32;
|
|
210
|
+
constexpr int qk_per_thread = D / BD;
|
|
211
|
+
constexpr int v_per_thread = V / BD;
|
|
212
|
+
int inner_k_stride = BN * int(k_seq_stride);
|
|
213
|
+
int inner_v_stride = BN * int(v_seq_stride);
|
|
214
|
+
constexpr int blocks = 32;
|
|
215
|
+
|
|
216
|
+
typedef float U;
|
|
217
|
+
|
|
218
|
+
thread U q[qk_per_thread];
|
|
219
|
+
thread U k[qk_per_thread];
|
|
220
|
+
thread U o[v_per_thread];
|
|
221
|
+
|
|
222
|
+
threadgroup U outputs[BN * BD];
|
|
223
|
+
threadgroup U max_scores[BN];
|
|
224
|
+
threadgroup U sum_exp_scores[BN];
|
|
225
|
+
|
|
226
|
+
// Adjust positions
|
|
227
|
+
const int block_idx = tid.z;
|
|
228
|
+
const int q_batch_head_idx = tid.x;
|
|
229
|
+
const int q_seq_idx = tid.y;
|
|
230
|
+
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
|
|
231
|
+
const int q_offset =
|
|
232
|
+
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
|
|
233
|
+
const int kv_head_idx = q_batch_head_idx / gqa_factor;
|
|
234
|
+
|
|
235
|
+
queries += q_offset * D + simd_lid * qk_per_thread;
|
|
236
|
+
keys += kv_head_idx * k_head_stride +
|
|
237
|
+
(block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread;
|
|
238
|
+
values += kv_head_idx * v_head_stride +
|
|
239
|
+
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;
|
|
240
|
+
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
|
|
241
|
+
if (bool_mask) {
|
|
242
|
+
bmask += q_batch_head_idx * mask_head_stride +
|
|
243
|
+
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
|
244
|
+
q_seq_idx * mask_q_seq_stride;
|
|
245
|
+
}
|
|
246
|
+
if (float_mask) {
|
|
247
|
+
fmask += q_batch_head_idx * mask_head_stride +
|
|
248
|
+
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
|
|
249
|
+
q_seq_idx * mask_q_seq_stride;
|
|
250
|
+
}
|
|
251
|
+
sums += o_offset * blocks + block_idx;
|
|
252
|
+
maxs += o_offset * blocks + block_idx;
|
|
253
|
+
|
|
254
|
+
// Read the query and 0 the output accumulator
|
|
255
|
+
for (int i = 0; i < qk_per_thread; i++) {
|
|
256
|
+
q[i] = static_cast<U>(scale) * queries[i];
|
|
257
|
+
}
|
|
258
|
+
for (int i = 0; i < v_per_thread; i++) {
|
|
259
|
+
o[i] = 0;
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
U max_score = Limits<U>::finite_min;
|
|
263
|
+
U sum_exp_score = 0;
|
|
264
|
+
if (has_sinks && block_idx == 0 && simd_gid == 0) {
|
|
265
|
+
int q_head_idx = q_batch_head_idx % num_q_heads;
|
|
266
|
+
max_score = static_cast<U>(sinks[q_head_idx]);
|
|
267
|
+
sum_exp_score = 1;
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
// For each key
|
|
271
|
+
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
|
|
272
|
+
bool use_key = true;
|
|
273
|
+
if (do_causal) {
|
|
274
|
+
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
|
|
275
|
+
} else if (bool_mask) {
|
|
276
|
+
use_key = bmask[0];
|
|
277
|
+
} else if (float_mask) {
|
|
278
|
+
use_key = (fmask[0] >= Limits<T>::finite_min);
|
|
279
|
+
}
|
|
280
|
+
if (use_key) {
|
|
281
|
+
// Read the key
|
|
282
|
+
for (int i = 0; i < qk_per_thread; i++) {
|
|
283
|
+
k[i] = keys[i];
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
// Compute the i-th score
|
|
287
|
+
U score = 0;
|
|
288
|
+
for (int i = 0; i < qk_per_thread; i++) {
|
|
289
|
+
score += q[i] * k[i];
|
|
290
|
+
}
|
|
291
|
+
score = simd_sum(score);
|
|
292
|
+
|
|
293
|
+
if (float_mask) {
|
|
294
|
+
score += fmask[0];
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
// Update the accumulators
|
|
298
|
+
U new_max = max(max_score, score);
|
|
299
|
+
U factor = fast::exp(max_score - new_max);
|
|
300
|
+
U exp_score = fast::exp(score - new_max);
|
|
301
|
+
|
|
302
|
+
max_score = new_max;
|
|
303
|
+
sum_exp_score = sum_exp_score * factor + exp_score;
|
|
304
|
+
|
|
305
|
+
// Update the output accumulator
|
|
306
|
+
for (int i = 0; i < v_per_thread; i++) {
|
|
307
|
+
o[i] = o[i] * factor + exp_score * values[i];
|
|
308
|
+
}
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
// Move the pointers to the next kv
|
|
312
|
+
keys += blocks * inner_k_stride;
|
|
313
|
+
values += blocks * inner_v_stride;
|
|
314
|
+
if (bool_mask) {
|
|
315
|
+
bmask += BN * blocks * mask_kv_seq_stride;
|
|
316
|
+
}
|
|
317
|
+
if (float_mask) {
|
|
318
|
+
fmask += BN * blocks * mask_kv_seq_stride;
|
|
319
|
+
}
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
// Each thread has a partial part of the output so we need to combine them.
|
|
323
|
+
|
|
324
|
+
// First let's communicate the max and sum_exp
|
|
325
|
+
if (simd_lid == 0) {
|
|
326
|
+
max_scores[simd_gid] = max_score;
|
|
327
|
+
sum_exp_scores[simd_gid] = sum_exp_score;
|
|
328
|
+
}
|
|
329
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
330
|
+
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
|
|
331
|
+
U new_max = simd_max(max_score);
|
|
332
|
+
U factor = fast::exp(max_score - new_max);
|
|
333
|
+
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
|
|
334
|
+
sum_exp_score = simd_sum(sum_exp_score * factor);
|
|
335
|
+
|
|
336
|
+
// Write the sum and new max
|
|
337
|
+
if (simd_gid == 0) {
|
|
338
|
+
sums[0] = sum_exp_score;
|
|
339
|
+
maxs[0] = new_max;
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
// Now we need to aggregate all the outputs
|
|
343
|
+
for (int i = 0; i < v_per_thread; i++) {
|
|
344
|
+
outputs[simd_lid * BN + simd_gid] =
|
|
345
|
+
o[i] * fast::exp(max_scores[simd_gid] - new_max);
|
|
346
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
347
|
+
|
|
348
|
+
// And write the output
|
|
349
|
+
if (simd_gid == 0) {
|
|
350
|
+
U output = outputs[simd_lid * BN];
|
|
351
|
+
for (int j = 1; j < BN; j++) {
|
|
352
|
+
output += outputs[simd_lid * BN + j];
|
|
353
|
+
}
|
|
354
|
+
out[i] = static_cast<T>(output);
|
|
355
|
+
}
|
|
356
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
357
|
+
}
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
template <typename T, int D>
|
|
361
|
+
[[kernel]] void sdpa_vector_2pass_2(
|
|
362
|
+
const device float* partials [[buffer(0)]],
|
|
363
|
+
const device float* sums [[buffer(1)]],
|
|
364
|
+
const device float* maxs [[buffer(2)]],
|
|
365
|
+
device T* out [[buffer(3)]],
|
|
366
|
+
uint3 tid [[threadgroup_position_in_grid]],
|
|
367
|
+
uint3 tpg [[threadgroups_per_grid]],
|
|
368
|
+
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
|
369
|
+
uint simd_lid [[thread_index_in_simdgroup]]) {
|
|
370
|
+
constexpr int BN = 32;
|
|
371
|
+
constexpr int BD = 32;
|
|
372
|
+
constexpr int elem_per_thread = D / BD;
|
|
373
|
+
constexpr int blocks = 32;
|
|
374
|
+
|
|
375
|
+
typedef float U;
|
|
376
|
+
|
|
377
|
+
thread U o[elem_per_thread];
|
|
378
|
+
threadgroup U outputs[BN * BD];
|
|
379
|
+
|
|
380
|
+
// Adjust positions
|
|
381
|
+
const int head_idx = tid.x;
|
|
382
|
+
const int q_seq_idx = tid.y;
|
|
383
|
+
const int q_offset = head_idx * tpg.y + q_seq_idx;
|
|
384
|
+
;
|
|
385
|
+
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
|
|
386
|
+
sums += q_offset * blocks;
|
|
387
|
+
maxs += q_offset * blocks;
|
|
388
|
+
out += q_offset * D + simd_gid * elem_per_thread;
|
|
389
|
+
|
|
390
|
+
// First everybody reads the max and sum_exp
|
|
391
|
+
U max_score = maxs[simd_lid];
|
|
392
|
+
U new_max = simd_max(max_score);
|
|
393
|
+
U factor = fast::exp(max_score - new_max);
|
|
394
|
+
U sum_exp_score = simd_sum(sums[simd_lid] * factor);
|
|
395
|
+
|
|
396
|
+
// Now read the block into registers and then use shared memory to transpose
|
|
397
|
+
// it
|
|
398
|
+
for (int i = 0; i < elem_per_thread; i++) {
|
|
399
|
+
o[i] = partials[i];
|
|
400
|
+
}
|
|
401
|
+
for (int i = 0; i < elem_per_thread; i++) {
|
|
402
|
+
outputs[simd_lid * BD + simd_gid] = o[i];
|
|
403
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
404
|
+
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
|
|
405
|
+
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
|
|
406
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
// And write the output
|
|
410
|
+
if (simd_lid == 0) {
|
|
411
|
+
for (int i = 0; i < elem_per_thread; i++) {
|
|
412
|
+
out[i] = static_cast<T>(o[i]);
|
|
413
|
+
}
|
|
414
|
+
}
|
|
415
|
+
}
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
template <typename T>
|
|
4
|
+
inline T softmax_exp(T x) {
|
|
5
|
+
// Softmax doesn't need high precision exponential cause x is gonna be in
|
|
6
|
+
// (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)).
|
|
7
|
+
return fast::exp(x);
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|
11
|
+
[[kernel]] void softmax_single_row(
|
|
12
|
+
const device T* in,
|
|
13
|
+
device T* out,
|
|
14
|
+
constant int& axis_size,
|
|
15
|
+
uint gid [[threadgroup_position_in_grid]],
|
|
16
|
+
uint _lid [[thread_position_in_threadgroup]],
|
|
17
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
18
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
19
|
+
int lid = _lid;
|
|
20
|
+
|
|
21
|
+
constexpr int SIMD_SIZE = 32;
|
|
22
|
+
|
|
23
|
+
threadgroup AccT local_max[SIMD_SIZE];
|
|
24
|
+
threadgroup AccT local_normalizer[SIMD_SIZE];
|
|
25
|
+
|
|
26
|
+
AccT ld[N_READS];
|
|
27
|
+
|
|
28
|
+
in += gid * size_t(axis_size) + lid * N_READS;
|
|
29
|
+
if (lid * N_READS + N_READS <= axis_size) {
|
|
30
|
+
for (int i = 0; i < N_READS; i++) {
|
|
31
|
+
ld[i] = AccT(in[i]);
|
|
32
|
+
}
|
|
33
|
+
} else {
|
|
34
|
+
for (int i = 0; i < N_READS; i++) {
|
|
35
|
+
ld[i] =
|
|
36
|
+
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
if (simd_group_id == 0) {
|
|
40
|
+
local_max[simd_lane_id] = Limits<AccT>::min;
|
|
41
|
+
local_normalizer[simd_lane_id] = 0;
|
|
42
|
+
}
|
|
43
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
44
|
+
|
|
45
|
+
// Get the max
|
|
46
|
+
AccT maxval = Limits<AccT>::finite_min;
|
|
47
|
+
for (int i = 0; i < N_READS; i++) {
|
|
48
|
+
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
|
49
|
+
}
|
|
50
|
+
maxval = simd_max(maxval);
|
|
51
|
+
if (simd_lane_id == 0) {
|
|
52
|
+
local_max[simd_group_id] = maxval;
|
|
53
|
+
}
|
|
54
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
55
|
+
if (simd_group_id == 0) {
|
|
56
|
+
maxval = simd_max(local_max[simd_lane_id]);
|
|
57
|
+
if (simd_lane_id == 0) {
|
|
58
|
+
local_max[0] = maxval;
|
|
59
|
+
}
|
|
60
|
+
}
|
|
61
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
62
|
+
maxval = local_max[0];
|
|
63
|
+
|
|
64
|
+
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
|
65
|
+
AccT normalizer = 0;
|
|
66
|
+
for (int i = 0; i < N_READS; i++) {
|
|
67
|
+
AccT exp_x = softmax_exp(ld[i] - maxval);
|
|
68
|
+
ld[i] = exp_x;
|
|
69
|
+
normalizer += exp_x;
|
|
70
|
+
}
|
|
71
|
+
normalizer = simd_sum(normalizer);
|
|
72
|
+
if (simd_lane_id == 0) {
|
|
73
|
+
local_normalizer[simd_group_id] = normalizer;
|
|
74
|
+
}
|
|
75
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
76
|
+
if (simd_group_id == 0) {
|
|
77
|
+
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
|
78
|
+
if (simd_lane_id == 0) {
|
|
79
|
+
local_normalizer[0] = normalizer;
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
83
|
+
normalizer = 1 / local_normalizer[0];
|
|
84
|
+
|
|
85
|
+
// Normalize and write to the output
|
|
86
|
+
out += gid * size_t(axis_size) + lid * N_READS;
|
|
87
|
+
if (lid * N_READS + N_READS <= axis_size) {
|
|
88
|
+
for (int i = 0; i < N_READS; i++) {
|
|
89
|
+
out[i] = T(ld[i] * normalizer);
|
|
90
|
+
}
|
|
91
|
+
} else {
|
|
92
|
+
for (int i = 0; i < N_READS; i++) {
|
|
93
|
+
if ((lid * N_READS + i) < axis_size) {
|
|
94
|
+
out[i] = T(ld[i] * normalizer);
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
template <typename T, typename AccT = T, int N_READS = SOFTMAX_N_READS>
|
|
101
|
+
[[kernel]] void softmax_looped(
|
|
102
|
+
const device T* in,
|
|
103
|
+
device T* out,
|
|
104
|
+
constant int& axis_size,
|
|
105
|
+
uint gid [[threadgroup_position_in_grid]],
|
|
106
|
+
uint lid [[thread_position_in_threadgroup]],
|
|
107
|
+
uint lsize [[threads_per_threadgroup]],
|
|
108
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
109
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
110
|
+
in += gid * size_t(axis_size);
|
|
111
|
+
|
|
112
|
+
constexpr int SIMD_SIZE = 32;
|
|
113
|
+
|
|
114
|
+
threadgroup AccT local_max[SIMD_SIZE];
|
|
115
|
+
threadgroup AccT local_normalizer[SIMD_SIZE];
|
|
116
|
+
|
|
117
|
+
// Get the max and the normalizer in one go
|
|
118
|
+
AccT prevmax;
|
|
119
|
+
AccT maxval = Limits<AccT>::finite_min;
|
|
120
|
+
AccT normalizer = 0;
|
|
121
|
+
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
|
122
|
+
r++) {
|
|
123
|
+
int offset = r * lsize * N_READS + lid * N_READS;
|
|
124
|
+
AccT vals[N_READS];
|
|
125
|
+
if (offset + N_READS <= axis_size) {
|
|
126
|
+
for (int i = 0; i < N_READS; i++) {
|
|
127
|
+
vals[i] = AccT(in[offset + i]);
|
|
128
|
+
}
|
|
129
|
+
} else {
|
|
130
|
+
for (int i = 0; i < N_READS; i++) {
|
|
131
|
+
vals[i] =
|
|
132
|
+
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
prevmax = maxval;
|
|
136
|
+
for (int i = 0; i < N_READS; i++) {
|
|
137
|
+
maxval = (maxval < vals[i]) ? vals[i] : maxval;
|
|
138
|
+
}
|
|
139
|
+
normalizer *= softmax_exp(prevmax - maxval);
|
|
140
|
+
for (int i = 0; i < N_READS; i++) {
|
|
141
|
+
normalizer += softmax_exp(vals[i] - maxval);
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
// Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS *
|
|
145
|
+
// lsize) parts. We need to combine them.
|
|
146
|
+
// 1. We start by finding the max across simd groups
|
|
147
|
+
// 2. We then change the partial normalizers to account for a possible
|
|
148
|
+
// change in max
|
|
149
|
+
// 3. We sum all normalizers
|
|
150
|
+
prevmax = maxval;
|
|
151
|
+
maxval = simd_max(maxval);
|
|
152
|
+
normalizer *= softmax_exp(prevmax - maxval);
|
|
153
|
+
normalizer = simd_sum(normalizer);
|
|
154
|
+
|
|
155
|
+
// Now the normalizer and max value is correct for each simdgroup. We write
|
|
156
|
+
// them shared memory and combine them.
|
|
157
|
+
prevmax = maxval;
|
|
158
|
+
if (simd_lane_id == 0) {
|
|
159
|
+
local_max[simd_group_id] = maxval;
|
|
160
|
+
}
|
|
161
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
162
|
+
maxval = simd_max(local_max[simd_lane_id]);
|
|
163
|
+
normalizer *= softmax_exp(prevmax - maxval);
|
|
164
|
+
if (simd_lane_id == 0) {
|
|
165
|
+
local_normalizer[simd_group_id] = normalizer;
|
|
166
|
+
}
|
|
167
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
168
|
+
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
|
169
|
+
normalizer = 1 / normalizer;
|
|
170
|
+
|
|
171
|
+
// Finally given the normalizer and max value we can directly write the
|
|
172
|
+
// softmax output
|
|
173
|
+
out += gid * size_t(axis_size);
|
|
174
|
+
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
|
175
|
+
r++) {
|
|
176
|
+
int offset = r * lsize * N_READS + lid * N_READS;
|
|
177
|
+
if (offset + N_READS <= axis_size) {
|
|
178
|
+
for (int i = 0; i < N_READS; i++) {
|
|
179
|
+
out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
|
180
|
+
}
|
|
181
|
+
} else {
|
|
182
|
+
for (int i = 0; i < N_READS; i++) {
|
|
183
|
+
if (offset + i < axis_size) {
|
|
184
|
+
out[offset + i] =
|
|
185
|
+
T(softmax_exp(in[offset + i] - maxval) * normalizer);
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
}
|
|
189
|
+
}
|
|
190
|
+
}
|