mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
#include <metal_common>
|
|
3
|
+
#include <metal_compute>
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/steel/defines.h"
|
|
6
|
+
|
|
7
|
+
using namespace metal;
|
|
8
|
+
|
|
9
|
+
// Thread local Hadamard transform for 2^R
|
|
10
|
+
template <short R>
|
|
11
|
+
METAL_FUNC void radix_func(thread float* x) {
|
|
12
|
+
constexpr short logR = __builtin_ctz(R);
|
|
13
|
+
short h = 1;
|
|
14
|
+
STEEL_PRAGMA_UNROLL
|
|
15
|
+
for (short s = 0; s < logR; s++) {
|
|
16
|
+
STEEL_PRAGMA_UNROLL
|
|
17
|
+
for (short i = 0; i < R / 2; i++) {
|
|
18
|
+
short k = i & (h - 1);
|
|
19
|
+
short j = ((i - k) << 1) + k;
|
|
20
|
+
float a = x[j];
|
|
21
|
+
float b = x[j + h];
|
|
22
|
+
x[j] = a + b;
|
|
23
|
+
x[j + h] = a - b;
|
|
24
|
+
}
|
|
25
|
+
h <<= 1;
|
|
26
|
+
}
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
template <typename T, int N, int max_radix, int read_width, int stride = 1>
|
|
30
|
+
[[kernel]] void hadamard_n(
|
|
31
|
+
const device T* in [[buffer(0)]],
|
|
32
|
+
device T* out [[buffer(1)]],
|
|
33
|
+
constant const float& scale,
|
|
34
|
+
uint3 elem [[thread_position_in_grid]],
|
|
35
|
+
uint3 grid [[threads_per_grid]]) {
|
|
36
|
+
// Compute a Hadamard transform of size N = 2^k
|
|
37
|
+
//
|
|
38
|
+
// Equivalent to:
|
|
39
|
+
// from scipy.linalg import hadamard
|
|
40
|
+
// y = hadamard(len(x)) @ x
|
|
41
|
+
|
|
42
|
+
constexpr short num_threads = N / max_radix;
|
|
43
|
+
constexpr short logN = __builtin_ctz(N);
|
|
44
|
+
constexpr short logR = __builtin_ctz(max_radix);
|
|
45
|
+
constexpr short num_steps = logN / logR;
|
|
46
|
+
constexpr short logFinal = logN % logR;
|
|
47
|
+
constexpr short final_radix = 1 << (logFinal);
|
|
48
|
+
|
|
49
|
+
int batch_idx = elem.y * N * stride + elem.z;
|
|
50
|
+
short i = elem.x;
|
|
51
|
+
|
|
52
|
+
threadgroup T buf[N];
|
|
53
|
+
|
|
54
|
+
// Read values from device
|
|
55
|
+
if (stride == 1) {
|
|
56
|
+
STEEL_PRAGMA_UNROLL
|
|
57
|
+
for (short j = 0; j < max_radix / read_width; j++) {
|
|
58
|
+
short index = j * read_width * num_threads + i * read_width;
|
|
59
|
+
STEEL_PRAGMA_UNROLL
|
|
60
|
+
for (short r = 0; r < read_width; r++) {
|
|
61
|
+
buf[index + r] = in[batch_idx + index + r];
|
|
62
|
+
}
|
|
63
|
+
}
|
|
64
|
+
} else {
|
|
65
|
+
STEEL_PRAGMA_UNROLL
|
|
66
|
+
for (short j = 0; j < max_radix; j++) {
|
|
67
|
+
buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride];
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
72
|
+
|
|
73
|
+
float x[max_radix];
|
|
74
|
+
short h = 1;
|
|
75
|
+
|
|
76
|
+
STEEL_PRAGMA_UNROLL
|
|
77
|
+
for (short s = 0; s < num_steps; s++) {
|
|
78
|
+
short k = i & (h - 1);
|
|
79
|
+
short j = ((i - k) << logR) + k;
|
|
80
|
+
|
|
81
|
+
STEEL_PRAGMA_UNROLL
|
|
82
|
+
for (short r = 0; r < max_radix; r++) {
|
|
83
|
+
x[r] = buf[j + h * r];
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
radix_func<max_radix>(x);
|
|
87
|
+
|
|
88
|
+
STEEL_PRAGMA_UNROLL
|
|
89
|
+
for (short r = 0; r < max_radix; r++) {
|
|
90
|
+
buf[j + h * r] = T(x[r]);
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
h <<= logR;
|
|
94
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
// Do the final radix
|
|
98
|
+
// e.g. max_radix = 16
|
|
99
|
+
// N = 1024 = 16 * 16 * 4
|
|
100
|
+
if (final_radix > 1) {
|
|
101
|
+
// Each thread does multiple butterflies
|
|
102
|
+
STEEL_PRAGMA_UNROLL
|
|
103
|
+
for (int t = 0; t < max_radix / final_radix; t++) {
|
|
104
|
+
short index = i + t * num_threads;
|
|
105
|
+
short k = index & (h - 1);
|
|
106
|
+
short j = ((index - k) << logFinal) + k;
|
|
107
|
+
STEEL_PRAGMA_UNROLL
|
|
108
|
+
for (short r = 0; r < final_radix; r++) {
|
|
109
|
+
x[r] = buf[j + h * r];
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
radix_func<final_radix>(x);
|
|
113
|
+
|
|
114
|
+
STEEL_PRAGMA_UNROLL
|
|
115
|
+
for (short r = 0; r < final_radix; r++) {
|
|
116
|
+
buf[j + h * r] = T(x[r]);
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
// Write values to device
|
|
123
|
+
if (stride == 1) {
|
|
124
|
+
STEEL_PRAGMA_UNROLL
|
|
125
|
+
for (short j = 0; j < max_radix / read_width; j++) {
|
|
126
|
+
short index = j * read_width * num_threads + i * read_width;
|
|
127
|
+
STEEL_PRAGMA_UNROLL
|
|
128
|
+
for (short r = 0; r < read_width; r++) {
|
|
129
|
+
out[batch_idx + index + r] = T(buf[index + r] * scale);
|
|
130
|
+
}
|
|
131
|
+
}
|
|
132
|
+
} else {
|
|
133
|
+
STEEL_PRAGMA_UNROLL
|
|
134
|
+
for (short j = 0; j < max_radix; j++) {
|
|
135
|
+
out[batch_idx + (j * num_threads + i) * stride] =
|
|
136
|
+
buf[j * num_threads + i];
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
template <typename T, int N, int M, int read_width>
|
|
142
|
+
[[kernel]] void hadamard_m(
|
|
143
|
+
const device T* in [[buffer(0)]],
|
|
144
|
+
device T* out [[buffer(1)]],
|
|
145
|
+
constant const float& scale,
|
|
146
|
+
uint3 elem [[thread_position_in_grid]],
|
|
147
|
+
uint3 grid [[threads_per_grid]]) {
|
|
148
|
+
// Compute a Hadamard transform of size M
|
|
149
|
+
// using a naive O(M^2) codelet.
|
|
150
|
+
//
|
|
151
|
+
// This kernel is the second stage in the computation
|
|
152
|
+
// of a Hadamard transform of size M*N where N = 2^k.
|
|
153
|
+
|
|
154
|
+
int index = elem.x * grid.y + elem.y;
|
|
155
|
+
short i = index % (N / read_width);
|
|
156
|
+
int batch_idx = index / (N / read_width) * M * N;
|
|
157
|
+
|
|
158
|
+
float x[read_width][M];
|
|
159
|
+
STEEL_PRAGMA_UNROLL
|
|
160
|
+
for (short c = 0; c < M; c++) {
|
|
161
|
+
STEEL_PRAGMA_UNROLL
|
|
162
|
+
for (short r = 0; r < read_width; r++) {
|
|
163
|
+
x[r][c] = in[batch_idx + c * N + i * read_width + r];
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
STEEL_PRAGMA_UNROLL
|
|
168
|
+
for (short r = 0; r < read_width; r++) {
|
|
169
|
+
// This function is JIT compiled for M
|
|
170
|
+
// using the Hadamard matrix strings in `metal/hadamard.cpp`
|
|
171
|
+
hadamard_radix_m(x[r]);
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
// Write back to device
|
|
175
|
+
STEEL_PRAGMA_UNROLL
|
|
176
|
+
for (short c = 0; c < M; c++) {
|
|
177
|
+
STEEL_PRAGMA_UNROLL
|
|
178
|
+
for (short r = 0; r < read_width; r++) {
|
|
179
|
+
out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale);
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
}
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/indexing/indexing.h"
|
|
6
|
+
|
|
7
|
+
template <typename T, typename IdxT, int NIDX, int IDX_NDIM, typename LocT>
|
|
8
|
+
METAL_FUNC void gather_impl(
|
|
9
|
+
const device T* src [[buffer(0)]],
|
|
10
|
+
device T* out [[buffer(1)]],
|
|
11
|
+
const constant int* src_shape [[buffer(2)]],
|
|
12
|
+
const constant int64_t* src_strides [[buffer(3)]],
|
|
13
|
+
const constant size_t& src_ndim [[buffer(4)]],
|
|
14
|
+
const constant int* slice_sizes [[buffer(5)]],
|
|
15
|
+
const constant int* axes [[buffer(6)]],
|
|
16
|
+
const thread Indices<IdxT, NIDX>& indices,
|
|
17
|
+
uint3 index [[thread_position_in_grid]],
|
|
18
|
+
uint3 grid_dim [[threads_per_grid]]) {
|
|
19
|
+
LocT src_idx = 0;
|
|
20
|
+
for (int i = 0; i < NIDX; ++i) {
|
|
21
|
+
LocT idx_loc;
|
|
22
|
+
if (IDX_NDIM == 0) {
|
|
23
|
+
idx_loc = 0;
|
|
24
|
+
} else if (IDX_NDIM == 1) {
|
|
25
|
+
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
|
26
|
+
} else {
|
|
27
|
+
idx_loc = index.x * static_cast<LocT>(indices.strides[indices.ndim * i]);
|
|
28
|
+
idx_loc += indices.row_contiguous[i]
|
|
29
|
+
? index.y
|
|
30
|
+
: elem_to_loc<LocT>(
|
|
31
|
+
index.y,
|
|
32
|
+
&indices.shapes[indices.ndim * i + 1],
|
|
33
|
+
&indices.strides[indices.ndim * i + 1],
|
|
34
|
+
indices.ndim - 1);
|
|
35
|
+
}
|
|
36
|
+
auto ax = axes[i];
|
|
37
|
+
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);
|
|
38
|
+
src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
auto src_offset =
|
|
42
|
+
elem_to_loc<LocT>(index.z, slice_sizes, src_strides, src_ndim);
|
|
43
|
+
|
|
44
|
+
LocT out_idx = index.z;
|
|
45
|
+
if (IDX_NDIM == 1) {
|
|
46
|
+
out_idx += static_cast<LocT>(grid_dim.z) * index.x;
|
|
47
|
+
} else if (IDX_NDIM >= 2) {
|
|
48
|
+
out_idx += grid_dim.z * (index.x * static_cast<LocT>(grid_dim.y) + index.y);
|
|
49
|
+
}
|
|
50
|
+
out[out_idx] = src[src_offset + src_idx];
|
|
51
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
template <typename T, typename IdxT, typename LocT, bool SrcC, bool IdxC>
|
|
6
|
+
[[kernel]] void gather_axis(
|
|
7
|
+
const device T* src [[buffer(0)]],
|
|
8
|
+
const device IdxT* indices [[buffer(1)]],
|
|
9
|
+
device T* out [[buffer(2)]],
|
|
10
|
+
const constant int* shape [[buffer(3)]],
|
|
11
|
+
const constant int64_t* src_strides [[buffer(4)]],
|
|
12
|
+
const constant int64_t* idx_strides [[buffer(5)]],
|
|
13
|
+
const constant size_t& ndim [[buffer(6)]],
|
|
14
|
+
const constant int& axis [[buffer(7)]],
|
|
15
|
+
const constant int& axis_size [[buffer(8)]],
|
|
16
|
+
const constant size_t& src_ax_stride [[buffer(9)]],
|
|
17
|
+
const constant size_t& idx_ax_stride [[buffer(10)]],
|
|
18
|
+
uint3 index [[thread_position_in_grid]],
|
|
19
|
+
uint3 grid_dim [[threads_per_grid]]) {
|
|
20
|
+
LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
|
|
21
|
+
LocT out_idx = elem_idx * grid_dim.y + index.x;
|
|
22
|
+
|
|
23
|
+
LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
|
|
24
|
+
if (IdxC) {
|
|
25
|
+
idx_loc += out_idx;
|
|
26
|
+
} else {
|
|
27
|
+
idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
auto idx_val = indices[idx_loc];
|
|
31
|
+
if (is_signed_v<IdxT>) {
|
|
32
|
+
idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val;
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
LocT src_idx = idx_val * static_cast<LocT>(src_ax_stride);
|
|
36
|
+
if (SrcC) {
|
|
37
|
+
src_idx += elem_idx * axis_size + index.x;
|
|
38
|
+
} else {
|
|
39
|
+
src_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, src_strides, ndim);
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
out_idx += index.y * static_cast<LocT>(grid_dim.x);
|
|
43
|
+
out[out_idx] = src[src_idx];
|
|
44
|
+
}
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/indexing/indexing.h"
|
|
6
|
+
|
|
7
|
+
template <typename T, typename IdxT, typename LocT, int N>
|
|
8
|
+
[[kernel]] void gather_front(
|
|
9
|
+
const device T* src,
|
|
10
|
+
const device IdxT* indices,
|
|
11
|
+
device T* out,
|
|
12
|
+
const constant int64_t& stride,
|
|
13
|
+
const constant int& size,
|
|
14
|
+
uint2 index [[thread_position_in_grid]],
|
|
15
|
+
uint2 grid_dim [[threads_per_grid]]) {
|
|
16
|
+
auto idx = offset_neg_idx(indices[index.y], size);
|
|
17
|
+
LocT src_idx = static_cast<LocT>(stride) * idx;
|
|
18
|
+
LocT out_idx = static_cast<LocT>(stride) * index.y;
|
|
19
|
+
|
|
20
|
+
int s_idx = N * index.x;
|
|
21
|
+
for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) {
|
|
22
|
+
out[out_idx + s_idx] = src[src_idx + s_idx];
|
|
23
|
+
}
|
|
24
|
+
}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include <metal_stdlib>
|
|
6
|
+
|
|
7
|
+
template <typename IdxT, int NIDX>
|
|
8
|
+
struct Indices {
|
|
9
|
+
const array<const device IdxT*, NIDX> buffers;
|
|
10
|
+
const constant int* shapes;
|
|
11
|
+
const constant int64_t* strides;
|
|
12
|
+
const constant bool* row_contiguous;
|
|
13
|
+
const int ndim;
|
|
14
|
+
};
|
|
15
|
+
|
|
16
|
+
template <typename IdxT>
|
|
17
|
+
METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) {
|
|
18
|
+
if (is_unsigned_v<IdxT>) {
|
|
19
|
+
return idx;
|
|
20
|
+
} else {
|
|
21
|
+
return (idx < 0) ? idx + size : idx;
|
|
22
|
+
}
|
|
23
|
+
}
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
template <typename T, bool src_contiguous>
|
|
6
|
+
[[kernel]] void masked_assign_impl(
|
|
7
|
+
const device bool* mask [[buffer(0)]],
|
|
8
|
+
const device uint* scatter_offsets [[buffer(1)]],
|
|
9
|
+
const device T* src [[buffer(2)]],
|
|
10
|
+
device T* out [[buffer(3)]],
|
|
11
|
+
const constant int* src_shapes [[buffer(4)]],
|
|
12
|
+
const constant int64_t* src_strides [[buffer(5)]],
|
|
13
|
+
const constant int& src_ndim [[buffer(6)]],
|
|
14
|
+
const constant int64_t& src_batch_size [[buffer(7)]],
|
|
15
|
+
const constant int64_t& mask_batch_size [[buffer(8)]],
|
|
16
|
+
uint idx [[thread_position_in_grid]]) {
|
|
17
|
+
const bool mask_value = mask[idx];
|
|
18
|
+
if (!mask_value) {
|
|
19
|
+
return;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
const uint src_index = scatter_offsets[idx];
|
|
23
|
+
if (src_index >= src_batch_size) {
|
|
24
|
+
return;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
const uint batch_idx = idx / mask_batch_size;
|
|
28
|
+
|
|
29
|
+
if (src_contiguous) {
|
|
30
|
+
out[idx] = src[batch_idx * src_batch_size + src_index];
|
|
31
|
+
} else {
|
|
32
|
+
out[idx] = src[elem_to_loc<uint>(
|
|
33
|
+
batch_idx * src_batch_size + src_index,
|
|
34
|
+
src_shapes,
|
|
35
|
+
src_strides,
|
|
36
|
+
src_ndim)];
|
|
37
|
+
}
|
|
38
|
+
}
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
#include "mlx/backend/metal/kernels/indexing/indexing.h"
|
|
6
|
+
|
|
7
|
+
template <
|
|
8
|
+
typename T,
|
|
9
|
+
typename IdxT,
|
|
10
|
+
typename Op,
|
|
11
|
+
int NIDX,
|
|
12
|
+
bool UPD_ROW_CONTIG,
|
|
13
|
+
int NWORK,
|
|
14
|
+
typename LocT>
|
|
15
|
+
METAL_FUNC void scatter_impl(
|
|
16
|
+
const device T* updates,
|
|
17
|
+
device mlx_atomic<T>* out,
|
|
18
|
+
const constant int* upd_shape,
|
|
19
|
+
const constant int64_t* upd_strides,
|
|
20
|
+
const constant size_t& upd_ndim,
|
|
21
|
+
const constant size_t& upd_size,
|
|
22
|
+
const constant int* out_shape,
|
|
23
|
+
const constant int64_t* out_strides,
|
|
24
|
+
const constant size_t& out_ndim,
|
|
25
|
+
const constant int* axes,
|
|
26
|
+
const constant size_t& idx_size,
|
|
27
|
+
const thread Indices<IdxT, NIDX>& indices,
|
|
28
|
+
uint2 gid [[thread_position_in_grid]]) {
|
|
29
|
+
Op op;
|
|
30
|
+
|
|
31
|
+
auto ind_idx = gid.y * NWORK;
|
|
32
|
+
LocT out_offset = 0;
|
|
33
|
+
if (upd_size > 1) {
|
|
34
|
+
out_offset = elem_to_loc<LocT>(
|
|
35
|
+
gid.x, upd_shape + indices.ndim, out_strides, out_ndim);
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) {
|
|
39
|
+
LocT out_idx = out_offset;
|
|
40
|
+
for (int i = 0; i < NIDX; ++i) {
|
|
41
|
+
auto idx_loc = indices.row_contiguous[i]
|
|
42
|
+
? ind_idx
|
|
43
|
+
: elem_to_loc<LocT>(
|
|
44
|
+
ind_idx,
|
|
45
|
+
&indices.shapes[indices.ndim * i],
|
|
46
|
+
&indices.strides[indices.ndim * i],
|
|
47
|
+
indices.ndim);
|
|
48
|
+
auto ax = axes[i];
|
|
49
|
+
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]);
|
|
50
|
+
out_idx +=
|
|
51
|
+
static_cast<LocT>(idx_val) * static_cast<LocT>(out_strides[ax]);
|
|
52
|
+
}
|
|
53
|
+
auto upd_idx = ind_idx * static_cast<LocT>(upd_size) + gid.x;
|
|
54
|
+
if constexpr (!UPD_ROW_CONTIG) {
|
|
55
|
+
upd_idx = elem_to_loc<LocT>(upd_idx, upd_shape, upd_strides, upd_ndim);
|
|
56
|
+
}
|
|
57
|
+
op.atomic_update(out, updates[upd_idx], out_idx);
|
|
58
|
+
}
|
|
59
|
+
}
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
#pragma once
|
|
4
|
+
|
|
5
|
+
template <
|
|
6
|
+
typename T,
|
|
7
|
+
typename IdxT,
|
|
8
|
+
typename LocT,
|
|
9
|
+
typename Op,
|
|
10
|
+
bool UpdC,
|
|
11
|
+
bool IdxC>
|
|
12
|
+
[[kernel]] void scatter_axis(
|
|
13
|
+
const device T* upd [[buffer(0)]],
|
|
14
|
+
const device IdxT* indices [[buffer(1)]],
|
|
15
|
+
device mlx_atomic<T>* out [[buffer(2)]],
|
|
16
|
+
const constant int* shape [[buffer(3)]],
|
|
17
|
+
const constant int64_t* upd_strides [[buffer(4)]],
|
|
18
|
+
const constant int64_t* idx_strides [[buffer(5)]],
|
|
19
|
+
const constant size_t& ndim [[buffer(6)]],
|
|
20
|
+
const constant int& axis [[buffer(7)]],
|
|
21
|
+
const constant int& out_axis_size [[buffer(8)]],
|
|
22
|
+
const constant size_t& upd_ax_stride [[buffer(9)]],
|
|
23
|
+
const constant size_t& idx_ax_stride [[buffer(10)]],
|
|
24
|
+
uint3 index [[thread_position_in_grid]],
|
|
25
|
+
uint3 grid_dim [[threads_per_grid]]) {
|
|
26
|
+
Op op;
|
|
27
|
+
|
|
28
|
+
LocT elem_idx = index.z * static_cast<LocT>(grid_dim.x);
|
|
29
|
+
|
|
30
|
+
LocT idx_loc = index.y * static_cast<LocT>(idx_ax_stride);
|
|
31
|
+
if (IdxC) {
|
|
32
|
+
idx_loc += elem_idx * grid_dim.y + index.x;
|
|
33
|
+
} else {
|
|
34
|
+
idx_loc += elem_to_loc<LocT>(elem_idx + index.x, shape, idx_strides, ndim);
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
auto idx_val = indices[idx_loc];
|
|
38
|
+
if (is_signed_v<IdxT>) {
|
|
39
|
+
idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
LocT upd_idx = index.y * static_cast<LocT>(upd_ax_stride);
|
|
43
|
+
if (UpdC) {
|
|
44
|
+
upd_idx += elem_idx * grid_dim.y + index.x;
|
|
45
|
+
} else {
|
|
46
|
+
upd_idx += elem_to_loc<LocT>(elem_idx + index.x, shape, upd_strides, ndim);
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
LocT out_idx = elem_idx * static_cast<LocT>(out_axis_size) +
|
|
50
|
+
idx_val * grid_dim.x + index.x;
|
|
51
|
+
op.atomic_update(out, upd[upd_idx], out_idx);
|
|
52
|
+
}
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
// Copyright © 2025 Apple Inc.
|
|
2
|
+
|
|
3
|
+
template <typename T, typename AccT = float, int N_READS = 4>
|
|
4
|
+
[[kernel]] void logsumexp(
|
|
5
|
+
const device T* in,
|
|
6
|
+
device T* out,
|
|
7
|
+
constant int& axis_size,
|
|
8
|
+
uint gid [[threadgroup_position_in_grid]],
|
|
9
|
+
uint _lid [[thread_position_in_threadgroup]],
|
|
10
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
11
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
12
|
+
int lid = _lid;
|
|
13
|
+
|
|
14
|
+
constexpr int SIMD_SIZE = 32;
|
|
15
|
+
|
|
16
|
+
threadgroup AccT local_max[SIMD_SIZE];
|
|
17
|
+
threadgroup AccT local_normalizer[SIMD_SIZE];
|
|
18
|
+
|
|
19
|
+
AccT ld[N_READS];
|
|
20
|
+
|
|
21
|
+
in += gid * size_t(axis_size) + lid * N_READS;
|
|
22
|
+
if (lid * N_READS + N_READS <= axis_size) {
|
|
23
|
+
for (int i = 0; i < N_READS; i++) {
|
|
24
|
+
ld[i] = AccT(in[i]);
|
|
25
|
+
}
|
|
26
|
+
} else {
|
|
27
|
+
for (int i = 0; i < N_READS; i++) {
|
|
28
|
+
ld[i] =
|
|
29
|
+
((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits<AccT>::min;
|
|
30
|
+
}
|
|
31
|
+
}
|
|
32
|
+
if (simd_group_id == 0) {
|
|
33
|
+
local_max[simd_lane_id] = Limits<AccT>::min;
|
|
34
|
+
local_normalizer[simd_lane_id] = 0;
|
|
35
|
+
}
|
|
36
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
37
|
+
|
|
38
|
+
// Get the max
|
|
39
|
+
AccT maxval = Limits<AccT>::finite_min;
|
|
40
|
+
for (int i = 0; i < N_READS; i++) {
|
|
41
|
+
maxval = (maxval < ld[i]) ? ld[i] : maxval;
|
|
42
|
+
}
|
|
43
|
+
maxval = simd_max(maxval);
|
|
44
|
+
if (simd_lane_id == 0) {
|
|
45
|
+
local_max[simd_group_id] = maxval;
|
|
46
|
+
}
|
|
47
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
48
|
+
if (simd_group_id == 0) {
|
|
49
|
+
maxval = simd_max(local_max[simd_lane_id]);
|
|
50
|
+
if (simd_lane_id == 0) {
|
|
51
|
+
local_max[0] = maxval;
|
|
52
|
+
}
|
|
53
|
+
}
|
|
54
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
55
|
+
maxval = local_max[0];
|
|
56
|
+
|
|
57
|
+
// Compute exp(x_i - maxval) and store the partial sums in local_normalizer
|
|
58
|
+
AccT normalizer = 0;
|
|
59
|
+
for (int i = 0; i < N_READS; i++) {
|
|
60
|
+
normalizer += fast::exp(ld[i] - maxval);
|
|
61
|
+
}
|
|
62
|
+
normalizer = simd_sum(normalizer);
|
|
63
|
+
if (simd_lane_id == 0) {
|
|
64
|
+
local_normalizer[simd_group_id] = normalizer;
|
|
65
|
+
}
|
|
66
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
67
|
+
if (simd_group_id == 0) {
|
|
68
|
+
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
|
69
|
+
if (simd_lane_id == 0) {
|
|
70
|
+
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
template <typename T, typename AccT = float, int N_READS = 4>
|
|
76
|
+
[[kernel]] void logsumexp_looped(
|
|
77
|
+
const device T* in,
|
|
78
|
+
device T* out,
|
|
79
|
+
constant int& axis_size,
|
|
80
|
+
uint gid [[threadgroup_position_in_grid]],
|
|
81
|
+
uint lid [[thread_position_in_threadgroup]],
|
|
82
|
+
uint lsize [[threads_per_threadgroup]],
|
|
83
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
84
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
85
|
+
in += gid * size_t(axis_size);
|
|
86
|
+
|
|
87
|
+
constexpr int SIMD_SIZE = 32;
|
|
88
|
+
|
|
89
|
+
threadgroup AccT local_max[SIMD_SIZE];
|
|
90
|
+
threadgroup AccT local_normalizer[SIMD_SIZE];
|
|
91
|
+
|
|
92
|
+
// Get the max and the normalizer in one go
|
|
93
|
+
AccT prevmax;
|
|
94
|
+
AccT maxval = Limits<AccT>::finite_min;
|
|
95
|
+
AccT normalizer = 0;
|
|
96
|
+
for (int r = 0; r < static_cast<int>(ceildiv(axis_size, N_READS * lsize));
|
|
97
|
+
r++) {
|
|
98
|
+
int offset = r * lsize * N_READS + lid * N_READS;
|
|
99
|
+
AccT vals[N_READS];
|
|
100
|
+
if (offset + N_READS <= axis_size) {
|
|
101
|
+
for (int i = 0; i < N_READS; i++) {
|
|
102
|
+
vals[i] = AccT(in[offset + i]);
|
|
103
|
+
}
|
|
104
|
+
} else {
|
|
105
|
+
for (int i = 0; i < N_READS; i++) {
|
|
106
|
+
vals[i] =
|
|
107
|
+
(offset + i < axis_size) ? AccT(in[offset + i]) : Limits<AccT>::min;
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
prevmax = maxval;
|
|
111
|
+
for (int i = 0; i < N_READS; i++) {
|
|
112
|
+
maxval = (maxval < vals[i]) ? vals[i] : maxval;
|
|
113
|
+
}
|
|
114
|
+
normalizer *= fast::exp(prevmax - maxval);
|
|
115
|
+
for (int i = 0; i < N_READS; i++) {
|
|
116
|
+
normalizer += fast::exp(vals[i] - maxval);
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
prevmax = maxval;
|
|
120
|
+
maxval = simd_max(maxval);
|
|
121
|
+
normalizer *= fast::exp(prevmax - maxval);
|
|
122
|
+
normalizer = simd_sum(normalizer);
|
|
123
|
+
|
|
124
|
+
prevmax = maxval;
|
|
125
|
+
if (simd_lane_id == 0) {
|
|
126
|
+
local_max[simd_group_id] = maxval;
|
|
127
|
+
}
|
|
128
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
129
|
+
maxval = simd_max(local_max[simd_lane_id]);
|
|
130
|
+
normalizer *= fast::exp(prevmax - maxval);
|
|
131
|
+
if (simd_lane_id == 0) {
|
|
132
|
+
local_normalizer[simd_group_id] = normalizer;
|
|
133
|
+
}
|
|
134
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
135
|
+
normalizer = simd_sum(local_normalizer[simd_lane_id]);
|
|
136
|
+
|
|
137
|
+
if (lid == 0) {
|
|
138
|
+
out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval);
|
|
139
|
+
}
|
|
140
|
+
}
|