mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
|
4
|
+
[[kernel]] void col_reduce_small(
|
|
5
|
+
const device T* in [[buffer(0)]],
|
|
6
|
+
device U* out [[buffer(1)]],
|
|
7
|
+
const constant size_t& reduction_size [[buffer(2)]],
|
|
8
|
+
const constant int64_t& reduction_stride [[buffer(3)]],
|
|
9
|
+
const constant int* shape [[buffer(4)]],
|
|
10
|
+
const constant int64_t* strides [[buffer(5)]],
|
|
11
|
+
const constant int& ndim [[buffer(6)]],
|
|
12
|
+
const constant int* reduce_shape [[buffer(7)]],
|
|
13
|
+
const constant int64_t* reduce_strides [[buffer(8)]],
|
|
14
|
+
const constant int& reduce_ndim [[buffer(9)]],
|
|
15
|
+
const constant size_t& non_col_reductions [[buffer(10)]],
|
|
16
|
+
uint3 gid [[threadgroup_position_in_grid]],
|
|
17
|
+
uint3 gsize [[threadgroups_per_grid]],
|
|
18
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
19
|
+
uint3 lsize [[threads_per_threadgroup]]) {
|
|
20
|
+
constexpr int n_reads = 4;
|
|
21
|
+
Op op;
|
|
22
|
+
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
|
23
|
+
const device T* row;
|
|
24
|
+
|
|
25
|
+
U totals[n_reads];
|
|
26
|
+
for (int i = 0; i < n_reads; i++) {
|
|
27
|
+
totals[i] = Op::init;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads;
|
|
31
|
+
if (column >= reduction_stride) {
|
|
32
|
+
return;
|
|
33
|
+
}
|
|
34
|
+
bool safe = column + n_reads <= reduction_stride;
|
|
35
|
+
|
|
36
|
+
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
|
37
|
+
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
|
38
|
+
in += in_idx + column;
|
|
39
|
+
|
|
40
|
+
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
|
|
41
|
+
loop.next(lid.y, reduce_shape, reduce_strides);
|
|
42
|
+
for (IdxT r = lid.y; r < total_rows; r += lsize.y) {
|
|
43
|
+
row = in + loop.location();
|
|
44
|
+
if (safe) {
|
|
45
|
+
for (int i = 0; i < n_reads; i++) {
|
|
46
|
+
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
|
47
|
+
}
|
|
48
|
+
} else {
|
|
49
|
+
U vals[n_reads];
|
|
50
|
+
for (int i = 0; i < n_reads; i++) {
|
|
51
|
+
vals[i] =
|
|
52
|
+
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
|
|
53
|
+
}
|
|
54
|
+
for (int i = 0; i < n_reads; i++) {
|
|
55
|
+
totals[i] = op(vals[i], totals[i]);
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
loop.next(lsize.y, reduce_shape, reduce_strides);
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
if (lsize.y > 1) {
|
|
62
|
+
// lsize.y should be <= 8
|
|
63
|
+
threadgroup U shared_vals[32 * 8 * n_reads];
|
|
64
|
+
for (int i = 0; i < n_reads; i++) {
|
|
65
|
+
shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i];
|
|
66
|
+
}
|
|
67
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
68
|
+
if (lid.y == 0) {
|
|
69
|
+
for (int i = 0; i < n_reads; i++) {
|
|
70
|
+
totals[i] = shared_vals[lid.x * n_reads + i];
|
|
71
|
+
}
|
|
72
|
+
for (uint j = 1; j < lsize.y; j++) {
|
|
73
|
+
for (int i = 0; i < n_reads; i++) {
|
|
74
|
+
totals[i] =
|
|
75
|
+
op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i],
|
|
76
|
+
totals[i]);
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
if (lid.y == 0) {
|
|
83
|
+
out += out_idx * IdxT(reduction_stride) + column;
|
|
84
|
+
if (safe) {
|
|
85
|
+
for (int i = 0; i < n_reads; i++) {
|
|
86
|
+
out[i] = totals[i];
|
|
87
|
+
}
|
|
88
|
+
} else {
|
|
89
|
+
for (int i = 0; column + i < reduction_stride; i++) {
|
|
90
|
+
out[i] = totals[i];
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
template <typename T, typename U, typename Op, typename IdxT, int NDIMS>
|
|
97
|
+
[[kernel]] void col_reduce_longcolumn(
|
|
98
|
+
const device T* in [[buffer(0)]],
|
|
99
|
+
device U* out [[buffer(1)]],
|
|
100
|
+
const constant size_t& reduction_size [[buffer(2)]],
|
|
101
|
+
const constant size_t& reduction_stride [[buffer(3)]],
|
|
102
|
+
const constant int* shape [[buffer(4)]],
|
|
103
|
+
const constant int64_t* strides [[buffer(5)]],
|
|
104
|
+
const constant int& ndim [[buffer(6)]],
|
|
105
|
+
const constant int* reduce_shape [[buffer(7)]],
|
|
106
|
+
const constant int64_t* reduce_strides [[buffer(8)]],
|
|
107
|
+
const constant int& reduce_ndim [[buffer(9)]],
|
|
108
|
+
const constant size_t& non_col_reductions [[buffer(10)]],
|
|
109
|
+
const constant size_t& out_size [[buffer(11)]],
|
|
110
|
+
uint3 gid [[threadgroup_position_in_grid]],
|
|
111
|
+
uint3 gsize [[threadgroups_per_grid]],
|
|
112
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
113
|
+
uint3 lsize [[threads_per_threadgroup]]) {
|
|
114
|
+
Op op;
|
|
115
|
+
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
|
116
|
+
const device T* row;
|
|
117
|
+
|
|
118
|
+
IdxT out_idx = gid.x + gsize.x * IdxT(gid.y);
|
|
119
|
+
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
|
120
|
+
in += in_idx + lid.x;
|
|
121
|
+
|
|
122
|
+
U total = Op::init;
|
|
123
|
+
IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size);
|
|
124
|
+
loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides);
|
|
125
|
+
for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows;
|
|
126
|
+
r += lsize.y * gsize.z) {
|
|
127
|
+
row = in + loop.location();
|
|
128
|
+
total = op(static_cast<U>(*row), total);
|
|
129
|
+
loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides);
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
threadgroup U shared_vals[32 * 32];
|
|
133
|
+
shared_vals[lid.y * lsize.x + lid.x] = total;
|
|
134
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
135
|
+
if (lid.y == 0) {
|
|
136
|
+
for (uint i = 1; i < lsize.y; i++) {
|
|
137
|
+
total = op(total, shared_vals[i * lsize.x + lid.x]);
|
|
138
|
+
}
|
|
139
|
+
out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] =
|
|
140
|
+
total;
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
/**
|
|
145
|
+
* Our approach is the following simple looped approach:
|
|
146
|
+
* 1. Each thread keeps running totals for BN / n_simdgroups outputs.
|
|
147
|
+
* 2. Load a tile BM, BN in registers and accumulate in the running totals
|
|
148
|
+
* 3. Move ahead by BM steps until the column axis and the non column
|
|
149
|
+
* reductions are exhausted.
|
|
150
|
+
* 6. If BM == 32 then transpose in SM and simd reduce the running totals.
|
|
151
|
+
* Otherwise write in shared memory and BN threads accumulate the running
|
|
152
|
+
* totals with a loop.
|
|
153
|
+
* 7. Write them to the output
|
|
154
|
+
*/
|
|
155
|
+
template <
|
|
156
|
+
typename T,
|
|
157
|
+
typename U,
|
|
158
|
+
typename Op,
|
|
159
|
+
typename IdxT,
|
|
160
|
+
int NDIMS,
|
|
161
|
+
int BM,
|
|
162
|
+
int BN>
|
|
163
|
+
[[kernel]] void col_reduce_looped(
|
|
164
|
+
const device T* in [[buffer(0)]],
|
|
165
|
+
device U* out [[buffer(1)]],
|
|
166
|
+
const constant size_t& reduction_size [[buffer(2)]],
|
|
167
|
+
const constant int64_t& reduction_stride [[buffer(3)]],
|
|
168
|
+
const constant int* shape [[buffer(4)]],
|
|
169
|
+
const constant int64_t* strides [[buffer(5)]],
|
|
170
|
+
const constant int& ndim [[buffer(6)]],
|
|
171
|
+
const constant int* reduce_shape [[buffer(7)]],
|
|
172
|
+
const constant int64_t* reduce_strides [[buffer(8)]],
|
|
173
|
+
const constant int& reduce_ndim [[buffer(9)]],
|
|
174
|
+
const constant size_t& non_col_reductions [[buffer(10)]],
|
|
175
|
+
uint3 gid [[threadgroup_position_in_grid]],
|
|
176
|
+
uint3 gsize [[threadgroups_per_grid]],
|
|
177
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
178
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
179
|
+
Op op;
|
|
180
|
+
constexpr int n_simdgroups = 8;
|
|
181
|
+
constexpr short tgp_size = n_simdgroups * simd_size;
|
|
182
|
+
constexpr short n_reads = (BM * BN) / tgp_size;
|
|
183
|
+
constexpr short n_read_blocks = BN / n_reads;
|
|
184
|
+
|
|
185
|
+
threadgroup U shared_vals[BN * BM];
|
|
186
|
+
U totals[n_reads];
|
|
187
|
+
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
|
188
|
+
const device T* row;
|
|
189
|
+
|
|
190
|
+
for (int i = 0; i < n_reads; i++) {
|
|
191
|
+
totals[i] = Op::init;
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
short lid = simd_group_id * simd_size + simd_lane_id;
|
|
195
|
+
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
|
|
196
|
+
IdxT column = BN * gid.x + offset.x;
|
|
197
|
+
bool safe = column + n_reads <= reduction_stride;
|
|
198
|
+
|
|
199
|
+
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
|
200
|
+
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
|
201
|
+
in += in_idx + column;
|
|
202
|
+
|
|
203
|
+
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
|
|
204
|
+
loop.next(offset.y, reduce_shape, reduce_strides);
|
|
205
|
+
for (IdxT r = offset.y; r < total; r += BM) {
|
|
206
|
+
row = in + loop.location();
|
|
207
|
+
|
|
208
|
+
if (safe) {
|
|
209
|
+
for (int i = 0; i < n_reads; i++) {
|
|
210
|
+
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
|
211
|
+
}
|
|
212
|
+
} else {
|
|
213
|
+
U vals[n_reads];
|
|
214
|
+
for (int i = 0; i < n_reads; i++) {
|
|
215
|
+
vals[i] =
|
|
216
|
+
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
|
|
217
|
+
}
|
|
218
|
+
for (int i = 0; i < n_reads; i++) {
|
|
219
|
+
totals[i] = op(vals[i], totals[i]);
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
loop.next(BM, reduce_shape, reduce_strides);
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
// We can use a simd reduction to accumulate across BM so each thread writes
|
|
227
|
+
// the partial output to SM and then each simdgroup does BN / n_simdgroups
|
|
228
|
+
// accumulations.
|
|
229
|
+
if (BM == 32) {
|
|
230
|
+
constexpr int n_outputs = BN / n_simdgroups;
|
|
231
|
+
static_assert(
|
|
232
|
+
BM != 32 || n_outputs == n_reads,
|
|
233
|
+
"The tile should be selected such that n_outputs == n_reads");
|
|
234
|
+
for (int i = 0; i < n_reads; i++) {
|
|
235
|
+
shared_vals[offset.y * BN + offset.x + i] = totals[i];
|
|
236
|
+
}
|
|
237
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
238
|
+
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
|
|
239
|
+
for (int i = 0; i < n_outputs; i++) {
|
|
240
|
+
totals[i] =
|
|
241
|
+
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
// Write the output.
|
|
245
|
+
if (simd_lane_id == 0) {
|
|
246
|
+
IdxT out_column = BN * gid.x + out_offset.x;
|
|
247
|
+
out += out_idx * IdxT(reduction_stride) + out_column;
|
|
248
|
+
if (out_column + n_outputs <= reduction_stride) {
|
|
249
|
+
for (int i = 0; i < n_outputs; i++) {
|
|
250
|
+
out[i] = totals[i];
|
|
251
|
+
}
|
|
252
|
+
} else {
|
|
253
|
+
for (int i = 0; out_column + i < reduction_stride; i++) {
|
|
254
|
+
out[i] = totals[i];
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
// Each thread holds n_reads partial results. We write them all out to shared
|
|
261
|
+
// memory and threads with offset.y == 0 aggregate the columns and write the
|
|
262
|
+
// outputs.
|
|
263
|
+
else {
|
|
264
|
+
short x_block = offset.x / n_reads;
|
|
265
|
+
for (int i = 0; i < n_reads; i++) {
|
|
266
|
+
shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i];
|
|
267
|
+
}
|
|
268
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
269
|
+
if (offset.y == 0) {
|
|
270
|
+
for (int i = 0; i < n_reads; i++) {
|
|
271
|
+
for (int j = 1; j < BM; j++) {
|
|
272
|
+
totals[i] =
|
|
273
|
+
op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]);
|
|
274
|
+
}
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
// Write the output.
|
|
279
|
+
if (offset.y == 0) {
|
|
280
|
+
out += out_idx * IdxT(reduction_stride) + column;
|
|
281
|
+
if (safe) {
|
|
282
|
+
for (int i = 0; i < n_reads; i++) {
|
|
283
|
+
out[i] = totals[i];
|
|
284
|
+
}
|
|
285
|
+
} else {
|
|
286
|
+
for (int i = 0; column + i < reduction_stride; i++) {
|
|
287
|
+
out[i] = totals[i];
|
|
288
|
+
}
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
template <
|
|
295
|
+
typename T,
|
|
296
|
+
typename U,
|
|
297
|
+
typename Op,
|
|
298
|
+
typename IdxT,
|
|
299
|
+
int NDIMS,
|
|
300
|
+
int BM,
|
|
301
|
+
int BN>
|
|
302
|
+
[[kernel]] void col_reduce_2pass(
|
|
303
|
+
const device T* in [[buffer(0)]],
|
|
304
|
+
device U* out [[buffer(1)]],
|
|
305
|
+
const constant size_t& reduction_size [[buffer(2)]],
|
|
306
|
+
const constant int64_t& reduction_stride [[buffer(3)]],
|
|
307
|
+
const constant int* shape [[buffer(4)]],
|
|
308
|
+
const constant int64_t* strides [[buffer(5)]],
|
|
309
|
+
const constant int& ndim [[buffer(6)]],
|
|
310
|
+
const constant int* reduce_shape [[buffer(7)]],
|
|
311
|
+
const constant int64_t* reduce_strides [[buffer(8)]],
|
|
312
|
+
const constant int& reduce_ndim [[buffer(9)]],
|
|
313
|
+
const constant size_t& non_col_reductions [[buffer(10)]],
|
|
314
|
+
const constant size_t& out_size [[buffer(11)]],
|
|
315
|
+
uint3 gid [[threadgroup_position_in_grid]],
|
|
316
|
+
uint3 gsize [[threadgroups_per_grid]],
|
|
317
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
318
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
319
|
+
Op op;
|
|
320
|
+
constexpr int n_simdgroups = 8;
|
|
321
|
+
constexpr short tgp_size = n_simdgroups * simd_size;
|
|
322
|
+
constexpr short n_reads = (BM * BN) / tgp_size;
|
|
323
|
+
constexpr short n_read_blocks = BN / n_reads;
|
|
324
|
+
constexpr int n_outputs = BN / n_simdgroups;
|
|
325
|
+
constexpr short outer_blocks = 32;
|
|
326
|
+
static_assert(BM == 32, "BM should be equal to 32");
|
|
327
|
+
|
|
328
|
+
threadgroup U shared_vals[BN * BM];
|
|
329
|
+
U totals[n_reads];
|
|
330
|
+
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
|
331
|
+
const device T* row;
|
|
332
|
+
|
|
333
|
+
for (int i = 0; i < n_reads; i++) {
|
|
334
|
+
totals[i] = Op::init;
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
short lid = simd_group_id * simd_size + simd_lane_id;
|
|
338
|
+
short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks);
|
|
339
|
+
IdxT column = BN * gid.x + offset.x;
|
|
340
|
+
bool safe = column + n_reads <= reduction_stride;
|
|
341
|
+
|
|
342
|
+
IdxT full_idx = gid.y + gsize.y * IdxT(gid.z);
|
|
343
|
+
IdxT block_idx = full_idx / IdxT(out_size);
|
|
344
|
+
IdxT out_idx = full_idx % IdxT(out_size);
|
|
345
|
+
IdxT in_idx = elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
|
346
|
+
in += in_idx + column;
|
|
347
|
+
|
|
348
|
+
IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size);
|
|
349
|
+
loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides);
|
|
350
|
+
for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) {
|
|
351
|
+
row = in + loop.location();
|
|
352
|
+
|
|
353
|
+
if (safe) {
|
|
354
|
+
for (int i = 0; i < n_reads; i++) {
|
|
355
|
+
totals[i] = op(static_cast<U>(row[i]), totals[i]);
|
|
356
|
+
}
|
|
357
|
+
} else {
|
|
358
|
+
U vals[n_reads];
|
|
359
|
+
for (int i = 0; i < n_reads; i++) {
|
|
360
|
+
vals[i] =
|
|
361
|
+
(column + i < reduction_stride) ? static_cast<U>(row[i]) : op.init;
|
|
362
|
+
}
|
|
363
|
+
for (int i = 0; i < n_reads; i++) {
|
|
364
|
+
totals[i] = op(vals[i], totals[i]);
|
|
365
|
+
}
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
loop.next(outer_blocks * BM, reduce_shape, reduce_strides);
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
// We can use a simd reduction to accumulate across BM so each thread writes
|
|
372
|
+
// the partial output to SM and then each simdgroup does BN / n_simdgroups
|
|
373
|
+
// accumulations.
|
|
374
|
+
for (int i = 0; i < n_reads; i++) {
|
|
375
|
+
shared_vals[offset.y * BN + offset.x + i] = totals[i];
|
|
376
|
+
}
|
|
377
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
378
|
+
short2 out_offset(simd_group_id * n_outputs, simd_lane_id);
|
|
379
|
+
for (int i = 0; i < n_outputs; i++) {
|
|
380
|
+
totals[i] =
|
|
381
|
+
op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]);
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
// Write the output.
|
|
385
|
+
if (simd_lane_id == 0) {
|
|
386
|
+
IdxT out_column = BN * gid.x + out_offset.x;
|
|
387
|
+
out += full_idx * IdxT(reduction_stride) + out_column;
|
|
388
|
+
if (out_column + n_outputs <= reduction_stride) {
|
|
389
|
+
for (int i = 0; i < n_outputs; i++) {
|
|
390
|
+
out[i] = totals[i];
|
|
391
|
+
}
|
|
392
|
+
} else {
|
|
393
|
+
for (int i = 0; out_column + i < reduction_stride; i++) {
|
|
394
|
+
out[i] = totals[i];
|
|
395
|
+
}
|
|
396
|
+
}
|
|
397
|
+
}
|
|
398
|
+
}
|