mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,369 @@
|
|
|
1
|
+
// Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
// Row reduction utilities
|
|
4
|
+
// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup
|
|
5
|
+
// - `threadgroup_reduce` collaborative reduction in the threadgroup such that
|
|
6
|
+
// lid.x == 0 holds the reduced value
|
|
7
|
+
// - `thread_reduce` simple loop and reduce the row
|
|
8
|
+
|
|
9
|
+
/**
|
|
10
|
+
* The thread group collaboratively reduces across the rows with bounds
|
|
11
|
+
* checking. In the end each thread holds a part of the reduction.
|
|
12
|
+
*/
|
|
13
|
+
template <
|
|
14
|
+
typename T,
|
|
15
|
+
typename U,
|
|
16
|
+
typename Op,
|
|
17
|
+
int N_READS = REDUCE_N_READS,
|
|
18
|
+
int N_WRITES = REDUCE_N_WRITES>
|
|
19
|
+
METAL_FUNC void per_thread_row_reduce(
|
|
20
|
+
thread U totals[N_WRITES],
|
|
21
|
+
const device T* inputs[N_WRITES],
|
|
22
|
+
int blocks,
|
|
23
|
+
int extra,
|
|
24
|
+
uint lsize_x,
|
|
25
|
+
uint lid_x) {
|
|
26
|
+
Op op;
|
|
27
|
+
|
|
28
|
+
// Set up the accumulator registers
|
|
29
|
+
for (int i = 0; i < N_WRITES; i++) {
|
|
30
|
+
totals[i] = Op::init;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
// Loop over the reduction size within thread group
|
|
34
|
+
for (int i = 0; i < blocks; i++) {
|
|
35
|
+
for (int j = 0; j < N_WRITES; j++) {
|
|
36
|
+
for (int i = 0; i < N_READS; i++) {
|
|
37
|
+
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
inputs[j] += lsize_x * N_READS;
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
// Separate case for the last set as we close the reduction size
|
|
45
|
+
int index = lid_x * N_READS;
|
|
46
|
+
if (index + N_READS <= extra) {
|
|
47
|
+
for (int j = 0; j < N_WRITES; j++) {
|
|
48
|
+
for (int i = 0; i < N_READS; i++) {
|
|
49
|
+
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
} else {
|
|
53
|
+
for (int j = 0; j < N_WRITES; j++) {
|
|
54
|
+
for (int i = 0; index + i < extra; i++) {
|
|
55
|
+
totals[j] = op(static_cast<U>(inputs[j][i]), totals[j]);
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
/**
|
|
62
|
+
* Consecutive rows in a contiguous array.
|
|
63
|
+
*/
|
|
64
|
+
template <
|
|
65
|
+
typename T,
|
|
66
|
+
typename U,
|
|
67
|
+
typename Op,
|
|
68
|
+
int N_READS = REDUCE_N_READS,
|
|
69
|
+
int N_WRITES = REDUCE_N_WRITES>
|
|
70
|
+
METAL_FUNC void per_thread_row_reduce(
|
|
71
|
+
thread U totals[N_WRITES],
|
|
72
|
+
const device T* in,
|
|
73
|
+
const constant size_t& reduction_size,
|
|
74
|
+
int blocks,
|
|
75
|
+
int extra,
|
|
76
|
+
uint lsize_x,
|
|
77
|
+
uint lid_x) {
|
|
78
|
+
// Set up the input pointers
|
|
79
|
+
const device T* inputs[N_WRITES];
|
|
80
|
+
inputs[0] = in + lid_x * N_READS;
|
|
81
|
+
for (int i = 1; i < N_READS; i++) {
|
|
82
|
+
inputs[i] = inputs[i - 1] + reduction_size;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
|
|
86
|
+
totals, inputs, blocks, extra, lsize_x, lid_x);
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
/**
|
|
90
|
+
* Consecutive rows in an arbitrarily ordered array.
|
|
91
|
+
*/
|
|
92
|
+
template <
|
|
93
|
+
typename T,
|
|
94
|
+
typename U,
|
|
95
|
+
typename Op,
|
|
96
|
+
int N_READS = REDUCE_N_READS,
|
|
97
|
+
int N_WRITES = REDUCE_N_WRITES>
|
|
98
|
+
METAL_FUNC void per_thread_row_reduce(
|
|
99
|
+
thread U totals[N_WRITES],
|
|
100
|
+
const device T* in,
|
|
101
|
+
const int64_t row_idx,
|
|
102
|
+
int blocks,
|
|
103
|
+
int extra,
|
|
104
|
+
const constant int* shape,
|
|
105
|
+
const constant int64_t* strides,
|
|
106
|
+
const constant int& ndim,
|
|
107
|
+
uint lsize_x,
|
|
108
|
+
uint lid_x) {
|
|
109
|
+
// Set up the input pointers
|
|
110
|
+
const device T* inputs[N_WRITES];
|
|
111
|
+
in += lid_x * N_READS;
|
|
112
|
+
for (int i = 0; i < N_READS; i++) {
|
|
113
|
+
inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
|
|
117
|
+
totals, inputs, blocks, extra, lsize_x, lid_x);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
/**
|
|
121
|
+
* Reduce within the threadgroup.
|
|
122
|
+
*/
|
|
123
|
+
template <
|
|
124
|
+
typename T,
|
|
125
|
+
typename U,
|
|
126
|
+
typename Op,
|
|
127
|
+
int N_READS = REDUCE_N_READS,
|
|
128
|
+
int N_WRITES = REDUCE_N_WRITES>
|
|
129
|
+
METAL_FUNC void threadgroup_reduce(
|
|
130
|
+
thread U totals[N_WRITES],
|
|
131
|
+
threadgroup U* shared_vals,
|
|
132
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
133
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
134
|
+
uint simd_per_group [[simdgroups_per_threadgroup]],
|
|
135
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
136
|
+
Op op;
|
|
137
|
+
|
|
138
|
+
// Simdgroup first
|
|
139
|
+
for (int i = 0; i < N_WRITES; i++) {
|
|
140
|
+
totals[i] = op.simd_reduce(totals[i]);
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
// Across simdgroups
|
|
144
|
+
if (simd_per_group > 1) {
|
|
145
|
+
if (simd_lane_id == 0) {
|
|
146
|
+
for (int i = 0; i < N_WRITES; i++) {
|
|
147
|
+
shared_vals[simd_group_id * N_WRITES + i] = totals[i];
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
151
|
+
|
|
152
|
+
U values[N_WRITES];
|
|
153
|
+
for (int i = 0; i < N_WRITES; i++) {
|
|
154
|
+
values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i]
|
|
155
|
+
: op.init;
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
for (int i = 0; i < N_WRITES; i++) {
|
|
159
|
+
totals[i] = op.simd_reduce(values[i]);
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
template <typename T, typename U, typename Op, int N_READS = REDUCE_N_READS>
|
|
165
|
+
METAL_FUNC void
|
|
166
|
+
thread_reduce(thread U& total, const device T* row, int blocks, int extra) {
|
|
167
|
+
Op op;
|
|
168
|
+
for (int i = 0; i < blocks; i++) {
|
|
169
|
+
U vals[N_READS];
|
|
170
|
+
for (int j = 0; j < N_READS; j++) {
|
|
171
|
+
vals[j] = row[j];
|
|
172
|
+
}
|
|
173
|
+
for (int j = 0; j < N_READS; j++) {
|
|
174
|
+
total = op(vals[j], total);
|
|
175
|
+
}
|
|
176
|
+
row += N_READS;
|
|
177
|
+
}
|
|
178
|
+
for (int i = 0; i < extra; i++) {
|
|
179
|
+
total = op(*row++, total);
|
|
180
|
+
}
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
// Reduction kernels
|
|
184
|
+
// - `row_reduce_small` depending on the non-row reductions and row size it
|
|
185
|
+
// either just loops over everything or a simd collaboratively reduces the
|
|
186
|
+
// non_row reductions. In the first case one thread is responsible for one
|
|
187
|
+
// output on the 2nd one simd is responsible for one output.
|
|
188
|
+
// - `row_reduce_simple` simple contiguous row reduction
|
|
189
|
+
// - `row_reduce_looped` simply loop and reduce each row for each non-row
|
|
190
|
+
// reduction. One threadgroup is responsible for one output.
|
|
191
|
+
|
|
192
|
+
template <
|
|
193
|
+
typename T,
|
|
194
|
+
typename U,
|
|
195
|
+
typename Op,
|
|
196
|
+
typename IdxT,
|
|
197
|
+
int NDIMS,
|
|
198
|
+
int N_READS = REDUCE_N_READS>
|
|
199
|
+
[[kernel]] void row_reduce_small(
|
|
200
|
+
const device T* in [[buffer(0)]],
|
|
201
|
+
device U* out [[buffer(1)]],
|
|
202
|
+
const constant int64_t& row_size [[buffer(2)]],
|
|
203
|
+
const constant int64_t& non_row_reductions [[buffer(3)]],
|
|
204
|
+
const constant int* shape [[buffer(4)]],
|
|
205
|
+
const constant int64_t* strides [[buffer(5)]],
|
|
206
|
+
const constant int& ndim [[buffer(6)]],
|
|
207
|
+
const constant int* reduce_shape [[buffer(7)]],
|
|
208
|
+
const constant int64_t* reduce_strides [[buffer(8)]],
|
|
209
|
+
const constant int& reduce_ndim [[buffer(9)]],
|
|
210
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
211
|
+
uint3 gid [[threadgroup_position_in_grid]],
|
|
212
|
+
uint3 gsize [[threadgroups_per_grid]],
|
|
213
|
+
uint3 tid [[thread_position_in_grid]],
|
|
214
|
+
uint3 tsize [[threads_per_grid]]) {
|
|
215
|
+
Op op;
|
|
216
|
+
|
|
217
|
+
U total_val = Op::init;
|
|
218
|
+
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
|
219
|
+
|
|
220
|
+
// Precompute some row reduction numbers
|
|
221
|
+
const device T* row;
|
|
222
|
+
int blocks = IdxT(row_size) / N_READS;
|
|
223
|
+
int extra = IdxT(row_size) % N_READS;
|
|
224
|
+
|
|
225
|
+
if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) {
|
|
226
|
+
// Simple loop over non_row_reductions and reduce the row in the thread.
|
|
227
|
+
IdxT out_idx = tid.x + tsize.x * IdxT(tid.y);
|
|
228
|
+
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
|
229
|
+
|
|
230
|
+
for (uint r = 0; r < non_row_reductions; r++) {
|
|
231
|
+
row = in + loop.location();
|
|
232
|
+
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
|
|
233
|
+
loop.next(reduce_shape, reduce_strides);
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
out[out_idx] = total_val;
|
|
237
|
+
} else {
|
|
238
|
+
// Collaboratively reduce over non_row_reductions in the simdgroup. Each
|
|
239
|
+
// thread reduces every 32nd row and then a simple simd reduce.
|
|
240
|
+
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
|
241
|
+
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim);
|
|
242
|
+
|
|
243
|
+
loop.next(simd_lane_id, reduce_shape, reduce_strides);
|
|
244
|
+
|
|
245
|
+
for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) {
|
|
246
|
+
row = in + loop.location();
|
|
247
|
+
thread_reduce<T, U, Op, N_READS>(total_val, row, blocks, extra);
|
|
248
|
+
loop.next(simd_size, reduce_shape, reduce_strides);
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
total_val = op.simd_reduce(total_val);
|
|
252
|
+
|
|
253
|
+
if (simd_lane_id == 0) {
|
|
254
|
+
out[out_idx] = total_val;
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
template <
|
|
260
|
+
typename T,
|
|
261
|
+
typename U,
|
|
262
|
+
typename Op,
|
|
263
|
+
typename IdxT = int64_t,
|
|
264
|
+
int N_READS = REDUCE_N_READS,
|
|
265
|
+
int N_WRITES = REDUCE_N_WRITES>
|
|
266
|
+
[[kernel]] void row_reduce_simple(
|
|
267
|
+
const device T* in [[buffer(0)]],
|
|
268
|
+
device U* out [[buffer(1)]],
|
|
269
|
+
const constant size_t& reduction_size [[buffer(2)]],
|
|
270
|
+
const constant int64_t& out_size [[buffer(3)]],
|
|
271
|
+
uint3 gid [[threadgroup_position_in_grid]],
|
|
272
|
+
uint3 gsize [[threadgroups_per_grid]],
|
|
273
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
274
|
+
uint3 lsize [[threads_per_threadgroup]],
|
|
275
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
276
|
+
uint simd_per_group [[simdgroups_per_threadgroup]],
|
|
277
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
278
|
+
threadgroup U shared_vals[simd_size * N_WRITES];
|
|
279
|
+
U totals[N_WRITES];
|
|
280
|
+
|
|
281
|
+
// Move to the row
|
|
282
|
+
IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z));
|
|
283
|
+
if (out_idx + N_WRITES > out_size) {
|
|
284
|
+
out_idx = out_size - N_WRITES;
|
|
285
|
+
}
|
|
286
|
+
in += out_idx * IdxT(reduction_size);
|
|
287
|
+
out += out_idx;
|
|
288
|
+
|
|
289
|
+
// Each thread reduces across the row
|
|
290
|
+
int blocks = IdxT(reduction_size) / (lsize.x * N_READS);
|
|
291
|
+
int extra = reduction_size - blocks * (lsize.x * N_READS);
|
|
292
|
+
per_thread_row_reduce<T, U, Op, N_READS, N_WRITES>(
|
|
293
|
+
totals, in, reduction_size, blocks, extra, lsize.x, lid.x);
|
|
294
|
+
|
|
295
|
+
// Reduce across the threadgroup
|
|
296
|
+
threadgroup_reduce<T, U, Op, N_READS, N_WRITES>(
|
|
297
|
+
totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
|
|
298
|
+
|
|
299
|
+
// Write the output
|
|
300
|
+
if (lid.x == 0) {
|
|
301
|
+
for (int i = 0; i < N_WRITES; i++) {
|
|
302
|
+
out[i] = totals[i];
|
|
303
|
+
}
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
template <
|
|
308
|
+
typename T,
|
|
309
|
+
typename U,
|
|
310
|
+
typename Op,
|
|
311
|
+
typename IdxT,
|
|
312
|
+
int NDIMS,
|
|
313
|
+
int N_READS = REDUCE_N_READS>
|
|
314
|
+
[[kernel]] void row_reduce_looped(
|
|
315
|
+
const device T* in [[buffer(0)]],
|
|
316
|
+
device U* out [[buffer(1)]],
|
|
317
|
+
const constant int64_t& row_size [[buffer(2)]],
|
|
318
|
+
const constant int64_t& non_row_reductions [[buffer(3)]],
|
|
319
|
+
const constant int* shape [[buffer(4)]],
|
|
320
|
+
const constant int64_t* strides [[buffer(5)]],
|
|
321
|
+
const constant int& ndim [[buffer(6)]],
|
|
322
|
+
const constant int* reduce_shape [[buffer(7)]],
|
|
323
|
+
const constant int64_t* reduce_strides [[buffer(8)]],
|
|
324
|
+
const constant int& reduce_ndim [[buffer(9)]],
|
|
325
|
+
uint3 gid [[threadgroup_position_in_grid]],
|
|
326
|
+
uint3 gsize [[threadgroups_per_grid]],
|
|
327
|
+
uint3 lid [[thread_position_in_threadgroup]],
|
|
328
|
+
uint3 lsize [[threads_per_threadgroup]],
|
|
329
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
330
|
+
uint simd_per_group [[simdgroups_per_threadgroup]],
|
|
331
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]]) {
|
|
332
|
+
Op op;
|
|
333
|
+
threadgroup U shared_vals[simd_size];
|
|
334
|
+
U total = Op::init;
|
|
335
|
+
|
|
336
|
+
IdxT out_idx = gid.y + gsize.y * IdxT(gid.z);
|
|
337
|
+
|
|
338
|
+
// lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it
|
|
339
|
+
// needs a small refactor.
|
|
340
|
+
in += elem_to_loc<IdxT>(out_idx, shape, strides, ndim) + lid.x * N_READS;
|
|
341
|
+
|
|
342
|
+
LoopedElemToLoc<NDIMS, IdxT, (NDIMS > 2)> loop(reduce_ndim);
|
|
343
|
+
const device T* row;
|
|
344
|
+
int blocks = IdxT(row_size) / (lsize.x * N_READS);
|
|
345
|
+
int extra = row_size - blocks * (lsize.x * N_READS);
|
|
346
|
+
|
|
347
|
+
for (IdxT i = 0; i < non_row_reductions; i++) {
|
|
348
|
+
row = in + loop.location();
|
|
349
|
+
|
|
350
|
+
// Each thread reduces across the row
|
|
351
|
+
U row_total;
|
|
352
|
+
per_thread_row_reduce<T, U, Op, N_READS, 1>(
|
|
353
|
+
&row_total, &row, blocks, extra, lsize.x, lid.x);
|
|
354
|
+
|
|
355
|
+
// Aggregate across rows
|
|
356
|
+
total = op(total, row_total);
|
|
357
|
+
|
|
358
|
+
loop.next(reduce_shape, reduce_strides);
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
// Reduce across the threadgroup
|
|
362
|
+
threadgroup_reduce<T, U, Op, N_READS, 1>(
|
|
363
|
+
&total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id);
|
|
364
|
+
|
|
365
|
+
// Write the output
|
|
366
|
+
if (lid.x == 0) {
|
|
367
|
+
out[out_idx] = total;
|
|
368
|
+
}
|
|
369
|
+
}
|