mlx-cpu 0.30.1__py3-none-manylinux_2_35_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx/__main__.py +27 -0
- mlx/_reprlib_fix.py +16 -0
- mlx/extension.py +88 -0
- mlx/include/mlx/3rdparty/pocketfft.h +3581 -0
- mlx/include/mlx/allocator.h +73 -0
- mlx/include/mlx/array.h +645 -0
- mlx/include/mlx/backend/common/binary.h +97 -0
- mlx/include/mlx/backend/common/broadcasting.h +11 -0
- mlx/include/mlx/backend/common/buffer_cache.h +157 -0
- mlx/include/mlx/backend/common/compiled.h +77 -0
- mlx/include/mlx/backend/common/copy.h +50 -0
- mlx/include/mlx/backend/common/hadamard.h +109 -0
- mlx/include/mlx/backend/common/matmul.h +67 -0
- mlx/include/mlx/backend/common/reduce.h +59 -0
- mlx/include/mlx/backend/common/slicing.h +20 -0
- mlx/include/mlx/backend/common/ternary.h +85 -0
- mlx/include/mlx/backend/common/unary.h +29 -0
- mlx/include/mlx/backend/common/utils.h +205 -0
- mlx/include/mlx/backend/cpu/arange.h +28 -0
- mlx/include/mlx/backend/cpu/available.h +9 -0
- mlx/include/mlx/backend/cpu/binary.h +517 -0
- mlx/include/mlx/backend/cpu/binary_ops.h +98 -0
- mlx/include/mlx/backend/cpu/binary_two.h +166 -0
- mlx/include/mlx/backend/cpu/compiled_preamble.h +12 -0
- mlx/include/mlx/backend/cpu/copy.h +36 -0
- mlx/include/mlx/backend/cpu/encoder.h +67 -0
- mlx/include/mlx/backend/cpu/eval.h +12 -0
- mlx/include/mlx/backend/cpu/gemm.h +26 -0
- mlx/include/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
- mlx/include/mlx/backend/cpu/jit_compiler.h +20 -0
- mlx/include/mlx/backend/cpu/lapack.h +80 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
- mlx/include/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
- mlx/include/mlx/backend/cpu/simd/base_simd.h +295 -0
- mlx/include/mlx/backend/cpu/simd/math.h +193 -0
- mlx/include/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
- mlx/include/mlx/backend/cpu/simd/simd.h +4 -0
- mlx/include/mlx/backend/cpu/simd/type.h +11 -0
- mlx/include/mlx/backend/cpu/slicing.h +21 -0
- mlx/include/mlx/backend/cpu/ternary.h +154 -0
- mlx/include/mlx/backend/cpu/threefry.h +21 -0
- mlx/include/mlx/backend/cpu/unary.h +281 -0
- mlx/include/mlx/backend/cpu/unary_ops.h +180 -0
- mlx/include/mlx/backend/cuda/allocator.h +89 -0
- mlx/include/mlx/backend/cuda/conv/conv.h +126 -0
- mlx/include/mlx/backend/cuda/cublas_utils.h +96 -0
- mlx/include/mlx/backend/cuda/cuda.h +10 -0
- mlx/include/mlx/backend/cuda/cuda_utils.h +89 -0
- mlx/include/mlx/backend/cuda/cudnn_utils.h +171 -0
- mlx/include/mlx/backend/cuda/device/config.h +12 -0
- mlx/include/mlx/backend/cuda/device.h +189 -0
- mlx/include/mlx/backend/cuda/event.h +78 -0
- mlx/include/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
- mlx/include/mlx/backend/cuda/gemms/gemv.h +24 -0
- mlx/include/mlx/backend/cuda/jit_module.h +119 -0
- mlx/include/mlx/backend/cuda/lru_cache.h +189 -0
- mlx/include/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
- mlx/include/mlx/backend/cuda/quantized/cuda_fp4.h +83 -0
- mlx/include/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
- mlx/include/mlx/backend/cuda/quantized/quantized.h +45 -0
- mlx/include/mlx/backend/cuda/utils.h +46 -0
- mlx/include/mlx/backend/cuda/worker.h +55 -0
- mlx/include/mlx/backend/gpu/available.h +9 -0
- mlx/include/mlx/backend/gpu/copy.h +57 -0
- mlx/include/mlx/backend/gpu/eval.h +18 -0
- mlx/include/mlx/backend/gpu/slicing.h +36 -0
- mlx/include/mlx/backend/metal/allocator.h +79 -0
- mlx/include/mlx/backend/metal/binary.h +33 -0
- mlx/include/mlx/backend/metal/device.h +283 -0
- mlx/include/mlx/backend/metal/jit/includes.h +57 -0
- mlx/include/mlx/backend/metal/jit/indexing.h +76 -0
- mlx/include/mlx/backend/metal/kernels/arange.h +9 -0
- mlx/include/mlx/backend/metal/kernels/atomic.h +345 -0
- mlx/include/mlx/backend/metal/kernels/bf16.h +16 -0
- mlx/include/mlx/backend/metal/kernels/bf16_math.h +380 -0
- mlx/include/mlx/backend/metal/kernels/binary.h +199 -0
- mlx/include/mlx/backend/metal/kernels/binary_ops.h +326 -0
- mlx/include/mlx/backend/metal/kernels/binary_two.h +244 -0
- mlx/include/mlx/backend/metal/kernels/cexpf.h +134 -0
- mlx/include/mlx/backend/metal/kernels/complex.h +173 -0
- mlx/include/mlx/backend/metal/kernels/copy.h +276 -0
- mlx/include/mlx/backend/metal/kernels/defines.h +24 -0
- mlx/include/mlx/backend/metal/kernels/erf.h +69 -0
- mlx/include/mlx/backend/metal/kernels/expm1f.h +90 -0
- mlx/include/mlx/backend/metal/kernels/fft/radix.h +328 -0
- mlx/include/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
- mlx/include/mlx/backend/metal/kernels/fft.h +486 -0
- mlx/include/mlx/backend/metal/kernels/fp4.h +59 -0
- mlx/include/mlx/backend/metal/kernels/fp8.h +82 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized.h +1804 -0
- mlx/include/mlx/backend/metal/kernels/fp_quantized_nax.h +1059 -0
- mlx/include/mlx/backend/metal/kernels/gemv_masked.h +827 -0
- mlx/include/mlx/backend/metal/kernels/hadamard.h +182 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather.h +51 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
- mlx/include/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
- mlx/include/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
- mlx/include/mlx/backend/metal/kernels/indexing/masked_scatter.h +38 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
- mlx/include/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
- mlx/include/mlx/backend/metal/kernels/logsumexp.h +140 -0
- mlx/include/mlx/backend/metal/kernels/quantized.h +2502 -0
- mlx/include/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
- mlx/include/mlx/backend/metal/kernels/quantized_utils.h +90 -0
- mlx/include/mlx/backend/metal/kernels/reduce.h +5 -0
- mlx/include/mlx/backend/metal/kernels/reduce_utils.h +6 -0
- mlx/include/mlx/backend/metal/kernels/reduction/ops.h +275 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
- mlx/include/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
- mlx/include/mlx/backend/metal/kernels/scan.h +514 -0
- mlx/include/mlx/backend/metal/kernels/sdpa_vector.h +415 -0
- mlx/include/mlx/backend/metal/kernels/softmax.h +190 -0
- mlx/include/mlx/backend/metal/kernels/sort.h +715 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +476 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
- mlx/include/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
- mlx/include/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
- mlx/include/mlx/backend/metal/kernels/steel/defines.h +7 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +156 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +207 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +132 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/params.h +64 -0
- mlx/include/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
- mlx/include/mlx/backend/metal/kernels/steel/utils.h +42 -0
- mlx/include/mlx/backend/metal/kernels/ternary.h +145 -0
- mlx/include/mlx/backend/metal/kernels/ternary_ops.h +10 -0
- mlx/include/mlx/backend/metal/kernels/unary.h +63 -0
- mlx/include/mlx/backend/metal/kernels/unary_ops.h +454 -0
- mlx/include/mlx/backend/metal/kernels/utils.h +444 -0
- mlx/include/mlx/backend/metal/matmul.h +144 -0
- mlx/include/mlx/backend/metal/metal.h +22 -0
- mlx/include/mlx/backend/metal/reduce.h +41 -0
- mlx/include/mlx/backend/metal/resident.h +32 -0
- mlx/include/mlx/backend/metal/scan.h +17 -0
- mlx/include/mlx/backend/metal/ternary.h +21 -0
- mlx/include/mlx/backend/metal/unary.h +21 -0
- mlx/include/mlx/backend/metal/utils.h +84 -0
- mlx/include/mlx/backend/no_gpu/apple_memory.h +16 -0
- mlx/include/mlx/backend/no_gpu/linux_memory.h +22 -0
- mlx/include/mlx/compile.h +44 -0
- mlx/include/mlx/compile_impl.h +69 -0
- mlx/include/mlx/device.h +31 -0
- mlx/include/mlx/distributed/distributed.h +60 -0
- mlx/include/mlx/distributed/distributed_impl.h +59 -0
- mlx/include/mlx/distributed/jaccl/jaccl.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi.h +12 -0
- mlx/include/mlx/distributed/mpi/mpi_declarations.h +28 -0
- mlx/include/mlx/distributed/nccl/nccl.h +12 -0
- mlx/include/mlx/distributed/ops.h +56 -0
- mlx/include/mlx/distributed/primitives.h +156 -0
- mlx/include/mlx/distributed/reduction_ops.h +38 -0
- mlx/include/mlx/distributed/ring/ring.h +12 -0
- mlx/include/mlx/distributed/utils.h +67 -0
- mlx/include/mlx/dtype.h +115 -0
- mlx/include/mlx/dtype_utils.h +119 -0
- mlx/include/mlx/einsum.h +22 -0
- mlx/include/mlx/event.h +58 -0
- mlx/include/mlx/export.h +136 -0
- mlx/include/mlx/export_impl.h +98 -0
- mlx/include/mlx/fast.h +102 -0
- mlx/include/mlx/fast_primitives.h +427 -0
- mlx/include/mlx/fence.h +39 -0
- mlx/include/mlx/fft.h +167 -0
- mlx/include/mlx/graph_utils.h +66 -0
- mlx/include/mlx/io/gguf.h +20 -0
- mlx/include/mlx/io/load.h +175 -0
- mlx/include/mlx/io.h +61 -0
- mlx/include/mlx/linalg.h +111 -0
- mlx/include/mlx/memory.h +78 -0
- mlx/include/mlx/mlx.h +25 -0
- mlx/include/mlx/ops.h +1627 -0
- mlx/include/mlx/primitives.h +2524 -0
- mlx/include/mlx/random.h +282 -0
- mlx/include/mlx/scheduler.h +188 -0
- mlx/include/mlx/small_vector.h +540 -0
- mlx/include/mlx/stream.h +41 -0
- mlx/include/mlx/threadpool.h +133 -0
- mlx/include/mlx/transforms.h +229 -0
- mlx/include/mlx/transforms_impl.h +86 -0
- mlx/include/mlx/types/bf16.h +187 -0
- mlx/include/mlx/types/complex.h +113 -0
- mlx/include/mlx/types/fp16.h +234 -0
- mlx/include/mlx/types/half_types.h +58 -0
- mlx/include/mlx/types/limits.h +70 -0
- mlx/include/mlx/utils.h +175 -0
- mlx/include/mlx/version.h +20 -0
- mlx/lib/libmlx.so +0 -0
- mlx/py.typed +1 -0
- mlx/share/cmake/MLX/FindNCCL.cmake +54 -0
- mlx/share/cmake/MLX/Findnvpl.cmake +3 -0
- mlx/share/cmake/MLX/MLXConfig.cmake +66 -0
- mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- mlx/share/cmake/MLX/MLXTargets.cmake +106 -0
- mlx/share/cmake/MLX/extension.cmake +50 -0
- mlx/utils.py +325 -0
- mlx_cpu-0.30.1.dist-info/METADATA +142 -0
- mlx_cpu-0.30.1.dist-info/RECORD +231 -0
- mlx_cpu-0.30.1.dist-info/WHEEL +5 -0
- mlx_cpu-0.30.1.dist-info/licenses/LICENSE +21 -0
- mlx_cpu-0.30.1.dist-info/sboms/auditwheel.cdx.json +1 -0
- mlx_cpu-0.30.1.dist-info/top_level.txt +1 -0
- mlx_cpu.libs/libblas-bd8a282c.so.3.10.0 +0 -0
- mlx_cpu.libs/libgfortran-3ec47101.so.5.0.0 +0 -0
- mlx_cpu.libs/liblapack-86b2c207.so.3.10.0 +0 -0
- mlx_cpu.libs/libquadmath-67d31475.so.0.0.0 +0 -0
|
@@ -0,0 +1,459 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
using namespace mlx::steel;
|
|
4
|
+
|
|
5
|
+
constant bool has_batch [[function_constant(10)]];
|
|
6
|
+
constant bool align_M [[function_constant(200)]];
|
|
7
|
+
constant bool align_N [[function_constant(201)]];
|
|
8
|
+
constant bool align_K [[function_constant(202)]];
|
|
9
|
+
|
|
10
|
+
template <
|
|
11
|
+
typename T,
|
|
12
|
+
int BM,
|
|
13
|
+
int BN,
|
|
14
|
+
int BK,
|
|
15
|
+
int WM,
|
|
16
|
+
int WN,
|
|
17
|
+
bool transpose_a,
|
|
18
|
+
bool transpose_b,
|
|
19
|
+
typename AccumType = float>
|
|
20
|
+
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs(
|
|
21
|
+
const device T* A [[buffer(0)]],
|
|
22
|
+
const device T* B [[buffer(1)]],
|
|
23
|
+
const device uint32_t* rhs_indices [[buffer(2)]],
|
|
24
|
+
device T* C [[buffer(3)]],
|
|
25
|
+
const constant GEMMParams* params [[buffer(4)]],
|
|
26
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
27
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
28
|
+
uint3 tid [[threadgroup_position_in_grid]]) {
|
|
29
|
+
using gemm_kernel = GEMMKernel<
|
|
30
|
+
T,
|
|
31
|
+
T,
|
|
32
|
+
BM,
|
|
33
|
+
BN,
|
|
34
|
+
BK,
|
|
35
|
+
WM,
|
|
36
|
+
WN,
|
|
37
|
+
transpose_a,
|
|
38
|
+
transpose_b,
|
|
39
|
+
true,
|
|
40
|
+
true,
|
|
41
|
+
AccumType>;
|
|
42
|
+
|
|
43
|
+
using loader_a_t = typename gemm_kernel::loader_a_t;
|
|
44
|
+
using loader_b_t = typename gemm_kernel::loader_b_t;
|
|
45
|
+
using mma_t = typename gemm_kernel::mma_t;
|
|
46
|
+
|
|
47
|
+
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
|
48
|
+
params->tiles_m <= static_cast<int>(tid.y)) {
|
|
49
|
+
return;
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
// Prepare threadgroup memory
|
|
53
|
+
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
54
|
+
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
55
|
+
|
|
56
|
+
// Find the block in A, B, C
|
|
57
|
+
const int c_row = tid.y * BM;
|
|
58
|
+
const int c_col = tid.x * BN;
|
|
59
|
+
const size_t c_row_long = size_t(c_row);
|
|
60
|
+
const size_t c_col_long = size_t(c_col);
|
|
61
|
+
|
|
62
|
+
// Prepare threadgroup bounds
|
|
63
|
+
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
|
64
|
+
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
|
65
|
+
|
|
66
|
+
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
|
67
|
+
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
|
68
|
+
C += c_row_long * params->ldd + c_col_long;
|
|
69
|
+
|
|
70
|
+
// Do as many matmuls as necessary
|
|
71
|
+
uint32_t index;
|
|
72
|
+
short offset;
|
|
73
|
+
uint32_t index_next = rhs_indices[c_row];
|
|
74
|
+
short offset_next = 0;
|
|
75
|
+
int n = 0;
|
|
76
|
+
while (n < tgp_bm) {
|
|
77
|
+
n++;
|
|
78
|
+
offset = offset_next;
|
|
79
|
+
index = index_next;
|
|
80
|
+
offset_next = tgp_bm;
|
|
81
|
+
for (; n < tgp_bm; n++) {
|
|
82
|
+
if (rhs_indices[c_row + n] != index) {
|
|
83
|
+
offset_next = n;
|
|
84
|
+
index_next = rhs_indices[c_row + n];
|
|
85
|
+
break;
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
89
|
+
|
|
90
|
+
// Prepare threadgroup mma operation
|
|
91
|
+
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
92
|
+
|
|
93
|
+
// Prepare threadgroup loading operations
|
|
94
|
+
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
95
|
+
thread loader_b_t loader_b(
|
|
96
|
+
B + index * params->batch_stride_b,
|
|
97
|
+
params->ldb,
|
|
98
|
+
Bs,
|
|
99
|
+
simd_group_id,
|
|
100
|
+
simd_lane_id);
|
|
101
|
+
|
|
102
|
+
// Prepare iterations
|
|
103
|
+
const int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
104
|
+
|
|
105
|
+
// Do unaligned K iterations first
|
|
106
|
+
if (!align_K) {
|
|
107
|
+
const int k_last = params->gemm_k_iterations_aligned * BK;
|
|
108
|
+
const int k_remain = params->K - k_last;
|
|
109
|
+
const size_t k_jump_a =
|
|
110
|
+
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
|
111
|
+
const size_t k_jump_b =
|
|
112
|
+
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
|
113
|
+
|
|
114
|
+
// Move loader source ahead to end
|
|
115
|
+
loader_a.src += k_jump_a;
|
|
116
|
+
loader_b.src += k_jump_b;
|
|
117
|
+
|
|
118
|
+
// Load tile
|
|
119
|
+
const short2 tile_dims_A =
|
|
120
|
+
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
|
121
|
+
const short2 tile_dims_B =
|
|
122
|
+
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
123
|
+
|
|
124
|
+
loader_a.load_safe(tile_dims_A);
|
|
125
|
+
loader_b.load_safe(tile_dims_B);
|
|
126
|
+
|
|
127
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
128
|
+
|
|
129
|
+
// Do matmul
|
|
130
|
+
mma_op.mma(As, Bs);
|
|
131
|
+
|
|
132
|
+
// Reset source back to start
|
|
133
|
+
loader_a.src -= k_jump_a;
|
|
134
|
+
loader_b.src -= k_jump_b;
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
// Matrix level aligned never check
|
|
138
|
+
if (align_M && align_N) {
|
|
139
|
+
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
140
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
141
|
+
|
|
142
|
+
// Load elements into threadgroup
|
|
143
|
+
loader_a.load_unsafe();
|
|
144
|
+
loader_b.load_unsafe();
|
|
145
|
+
|
|
146
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
147
|
+
|
|
148
|
+
// Multiply and accumulate threadgroup elements
|
|
149
|
+
mma_op.mma(As, Bs);
|
|
150
|
+
|
|
151
|
+
// Prepare for next iteration
|
|
152
|
+
loader_a.next();
|
|
153
|
+
loader_b.next();
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
// Store results to device memory
|
|
157
|
+
if (offset_next - offset == BM) {
|
|
158
|
+
mma_op.store_result(C, params->ldd);
|
|
159
|
+
} else {
|
|
160
|
+
mma_op.store_result_slice(
|
|
161
|
+
C, params->ldd, short2(0, offset), short2(BN, offset_next));
|
|
162
|
+
}
|
|
163
|
+
} else {
|
|
164
|
+
const short lbk = 0;
|
|
165
|
+
|
|
166
|
+
// Tile aligned don't check
|
|
167
|
+
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
|
168
|
+
gemm_kernel::gemm_loop(
|
|
169
|
+
As,
|
|
170
|
+
Bs,
|
|
171
|
+
gemm_k_iterations,
|
|
172
|
+
loader_a,
|
|
173
|
+
loader_b,
|
|
174
|
+
mma_op,
|
|
175
|
+
tgp_bm,
|
|
176
|
+
tgp_bn,
|
|
177
|
+
lbk,
|
|
178
|
+
LoopAlignment<true, true, true>{});
|
|
179
|
+
if (offset_next - offset == BM) {
|
|
180
|
+
mma_op.store_result(C, params->ldd);
|
|
181
|
+
} else {
|
|
182
|
+
mma_op.store_result_slice(
|
|
183
|
+
C, params->ldd, short2(0, offset), short2(BN, offset_next));
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
// Tile partially aligned check rows
|
|
188
|
+
else if (align_N || tgp_bn == BN) {
|
|
189
|
+
gemm_kernel::gemm_loop(
|
|
190
|
+
As,
|
|
191
|
+
Bs,
|
|
192
|
+
gemm_k_iterations,
|
|
193
|
+
loader_a,
|
|
194
|
+
loader_b,
|
|
195
|
+
mma_op,
|
|
196
|
+
tgp_bm,
|
|
197
|
+
tgp_bn,
|
|
198
|
+
lbk,
|
|
199
|
+
LoopAlignment<false, true, true>{});
|
|
200
|
+
mma_op.store_result_slice(
|
|
201
|
+
C, params->ldd, short2(0, offset), short2(BN, offset_next));
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
// Tile partially aligned check cols
|
|
205
|
+
else if (align_M || tgp_bm == BM) {
|
|
206
|
+
gemm_kernel::gemm_loop(
|
|
207
|
+
As,
|
|
208
|
+
Bs,
|
|
209
|
+
gemm_k_iterations,
|
|
210
|
+
loader_a,
|
|
211
|
+
loader_b,
|
|
212
|
+
mma_op,
|
|
213
|
+
tgp_bm,
|
|
214
|
+
tgp_bn,
|
|
215
|
+
lbk,
|
|
216
|
+
LoopAlignment<true, false, true>{});
|
|
217
|
+
mma_op.store_result_slice(
|
|
218
|
+
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
// Nothing aligned so check both rows and cols
|
|
222
|
+
else {
|
|
223
|
+
gemm_kernel::gemm_loop(
|
|
224
|
+
As,
|
|
225
|
+
Bs,
|
|
226
|
+
gemm_k_iterations,
|
|
227
|
+
loader_a,
|
|
228
|
+
loader_b,
|
|
229
|
+
mma_op,
|
|
230
|
+
tgp_bm,
|
|
231
|
+
tgp_bn,
|
|
232
|
+
lbk,
|
|
233
|
+
LoopAlignment<false, false, true>{});
|
|
234
|
+
mma_op.store_result_slice(
|
|
235
|
+
C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next));
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
template <
|
|
242
|
+
typename T,
|
|
243
|
+
int BM,
|
|
244
|
+
int BN,
|
|
245
|
+
int BK,
|
|
246
|
+
int WM,
|
|
247
|
+
int WN,
|
|
248
|
+
bool transpose_a,
|
|
249
|
+
bool transpose_b,
|
|
250
|
+
typename AccumType = float>
|
|
251
|
+
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm(
|
|
252
|
+
const device T* A [[buffer(0)]],
|
|
253
|
+
const device T* B [[buffer(1)]],
|
|
254
|
+
const device uint32_t* lhs_indices [[buffer(2)]],
|
|
255
|
+
const device uint32_t* rhs_indices [[buffer(3)]],
|
|
256
|
+
device T* C [[buffer(4)]],
|
|
257
|
+
const constant GEMMParams* params [[buffer(5)]],
|
|
258
|
+
const constant int* indices_shape [[buffer(6)]],
|
|
259
|
+
const constant int64_t* lhs_strides [[buffer(7)]],
|
|
260
|
+
const constant int64_t* rhs_strides [[buffer(8)]],
|
|
261
|
+
const constant int& batch_ndim_a [[buffer(9)]],
|
|
262
|
+
const constant int* batch_shape_a [[buffer(10)]],
|
|
263
|
+
const constant int64_t* batch_strides_a [[buffer(11)]],
|
|
264
|
+
const constant int& batch_ndim_b [[buffer(12)]],
|
|
265
|
+
const constant int* batch_shape_b [[buffer(13)]],
|
|
266
|
+
const constant int64_t* batch_strides_b [[buffer(14)]],
|
|
267
|
+
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
268
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
269
|
+
uint3 tid [[threadgroup_position_in_grid]]) {
|
|
270
|
+
using gemm_kernel = GEMMKernel<
|
|
271
|
+
T,
|
|
272
|
+
T,
|
|
273
|
+
BM,
|
|
274
|
+
BN,
|
|
275
|
+
BK,
|
|
276
|
+
WM,
|
|
277
|
+
WN,
|
|
278
|
+
transpose_a,
|
|
279
|
+
transpose_b,
|
|
280
|
+
true,
|
|
281
|
+
true,
|
|
282
|
+
AccumType>;
|
|
283
|
+
|
|
284
|
+
using loader_a_t = typename gemm_kernel::loader_a_t;
|
|
285
|
+
using loader_b_t = typename gemm_kernel::loader_b_t;
|
|
286
|
+
using mma_t = typename gemm_kernel::mma_t;
|
|
287
|
+
|
|
288
|
+
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
|
289
|
+
params->tiles_m <= static_cast<int>(tid.y)) {
|
|
290
|
+
return;
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
// Move A and B to the locations pointed by lhs_indices and rhs_indices.
|
|
294
|
+
uint32_t indx_A, indx_B;
|
|
295
|
+
if (has_batch) {
|
|
296
|
+
ulong2 indices_offsets = elem_to_loc_broadcast(
|
|
297
|
+
tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim);
|
|
298
|
+
indx_A = lhs_indices[indices_offsets.x];
|
|
299
|
+
indx_B = rhs_indices[indices_offsets.y];
|
|
300
|
+
} else {
|
|
301
|
+
indx_A = lhs_indices[params->batch_stride_a * tid.z];
|
|
302
|
+
indx_B = rhs_indices[params->batch_stride_b * tid.z];
|
|
303
|
+
}
|
|
304
|
+
A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a);
|
|
305
|
+
B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b);
|
|
306
|
+
C += params->batch_stride_d * tid.z;
|
|
307
|
+
|
|
308
|
+
// Prepare threadgroup memory
|
|
309
|
+
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
310
|
+
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
311
|
+
|
|
312
|
+
// Just make sure everybody's finished with the indexing math above.
|
|
313
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
314
|
+
|
|
315
|
+
// Find block in A, B, C
|
|
316
|
+
const int c_row = tid.y * BM;
|
|
317
|
+
const int c_col = tid.x * BN;
|
|
318
|
+
const size_t c_row_long = size_t(c_row);
|
|
319
|
+
const size_t c_col_long = size_t(c_col);
|
|
320
|
+
|
|
321
|
+
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
|
322
|
+
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
|
323
|
+
C += c_row_long * params->ldd + c_col_long;
|
|
324
|
+
|
|
325
|
+
// Prepare threadgroup mma operation
|
|
326
|
+
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
327
|
+
|
|
328
|
+
// Prepare threadgroup loading operations
|
|
329
|
+
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
330
|
+
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
331
|
+
|
|
332
|
+
// Prepare threadgroup bounds
|
|
333
|
+
const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row));
|
|
334
|
+
const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col));
|
|
335
|
+
|
|
336
|
+
// Prepare iterations
|
|
337
|
+
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
338
|
+
|
|
339
|
+
// Do unaligned K iterations first
|
|
340
|
+
if (!align_K) {
|
|
341
|
+
const int k_last = params->gemm_k_iterations_aligned * BK;
|
|
342
|
+
const int k_remain = params->K - k_last;
|
|
343
|
+
const size_t k_jump_a =
|
|
344
|
+
transpose_a ? params->lda * size_t(k_last) : size_t(k_last);
|
|
345
|
+
const size_t k_jump_b =
|
|
346
|
+
transpose_b ? size_t(k_last) : params->ldb * size_t(k_last);
|
|
347
|
+
|
|
348
|
+
// Move loader source ahead to end
|
|
349
|
+
loader_a.src += k_jump_a;
|
|
350
|
+
loader_b.src += k_jump_b;
|
|
351
|
+
|
|
352
|
+
// Load tile
|
|
353
|
+
const short2 tile_dims_A =
|
|
354
|
+
transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm);
|
|
355
|
+
const short2 tile_dims_B =
|
|
356
|
+
transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain);
|
|
357
|
+
|
|
358
|
+
loader_a.load_safe(tile_dims_A);
|
|
359
|
+
loader_b.load_safe(tile_dims_B);
|
|
360
|
+
|
|
361
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
362
|
+
|
|
363
|
+
// Do matmul
|
|
364
|
+
mma_op.mma(As, Bs);
|
|
365
|
+
|
|
366
|
+
// Reset source back to start
|
|
367
|
+
loader_a.src -= k_jump_a;
|
|
368
|
+
loader_b.src -= k_jump_b;
|
|
369
|
+
}
|
|
370
|
+
|
|
371
|
+
// Matrix level aligned never check
|
|
372
|
+
if (align_M && align_N) {
|
|
373
|
+
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
374
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
375
|
+
|
|
376
|
+
// Load elements into threadgroup
|
|
377
|
+
loader_a.load_unsafe();
|
|
378
|
+
loader_b.load_unsafe();
|
|
379
|
+
|
|
380
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
381
|
+
|
|
382
|
+
// Multiply and accumulate threadgroup elements
|
|
383
|
+
mma_op.mma(As, Bs);
|
|
384
|
+
|
|
385
|
+
// Prepare for next iteration
|
|
386
|
+
loader_a.next();
|
|
387
|
+
loader_b.next();
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
// Store results to device memory
|
|
391
|
+
mma_op.store_result(C, params->ldd);
|
|
392
|
+
} else {
|
|
393
|
+
const short lbk = 0;
|
|
394
|
+
|
|
395
|
+
// Tile aligned don't check
|
|
396
|
+
if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) {
|
|
397
|
+
gemm_kernel::gemm_loop(
|
|
398
|
+
As,
|
|
399
|
+
Bs,
|
|
400
|
+
gemm_k_iterations,
|
|
401
|
+
loader_a,
|
|
402
|
+
loader_b,
|
|
403
|
+
mma_op,
|
|
404
|
+
tgp_bm,
|
|
405
|
+
tgp_bn,
|
|
406
|
+
lbk,
|
|
407
|
+
LoopAlignment<true, true, true>{});
|
|
408
|
+
mma_op.store_result(C, params->ldd);
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
// Tile partially aligned check rows
|
|
412
|
+
else if (align_N || tgp_bn == BN) {
|
|
413
|
+
gemm_kernel::gemm_loop(
|
|
414
|
+
As,
|
|
415
|
+
Bs,
|
|
416
|
+
gemm_k_iterations,
|
|
417
|
+
loader_a,
|
|
418
|
+
loader_b,
|
|
419
|
+
mma_op,
|
|
420
|
+
tgp_bm,
|
|
421
|
+
tgp_bn,
|
|
422
|
+
lbk,
|
|
423
|
+
LoopAlignment<false, true, true>{});
|
|
424
|
+
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
// Tile partially aligned check cols
|
|
428
|
+
else if (align_M || tgp_bm == BM) {
|
|
429
|
+
gemm_kernel::gemm_loop(
|
|
430
|
+
As,
|
|
431
|
+
Bs,
|
|
432
|
+
gemm_k_iterations,
|
|
433
|
+
loader_a,
|
|
434
|
+
loader_b,
|
|
435
|
+
mma_op,
|
|
436
|
+
tgp_bm,
|
|
437
|
+
tgp_bn,
|
|
438
|
+
lbk,
|
|
439
|
+
LoopAlignment<true, false, true>{});
|
|
440
|
+
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
// Nothing aligned so check both rows and cols
|
|
444
|
+
else {
|
|
445
|
+
gemm_kernel::gemm_loop(
|
|
446
|
+
As,
|
|
447
|
+
Bs,
|
|
448
|
+
gemm_k_iterations,
|
|
449
|
+
loader_a,
|
|
450
|
+
loader_b,
|
|
451
|
+
mma_op,
|
|
452
|
+
tgp_bm,
|
|
453
|
+
tgp_bn,
|
|
454
|
+
lbk,
|
|
455
|
+
LoopAlignment<false, false, true>{});
|
|
456
|
+
mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm));
|
|
457
|
+
}
|
|
458
|
+
}
|
|
459
|
+
}
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
// Copyright © 2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
using namespace mlx::steel;
|
|
4
|
+
|
|
5
|
+
constant bool align_M [[function_constant(200)]];
|
|
6
|
+
constant bool align_N [[function_constant(201)]];
|
|
7
|
+
constant bool align_K [[function_constant(202)]];
|
|
8
|
+
|
|
9
|
+
template <
|
|
10
|
+
typename T,
|
|
11
|
+
int BM,
|
|
12
|
+
int BN,
|
|
13
|
+
int BK,
|
|
14
|
+
int WM,
|
|
15
|
+
int WN,
|
|
16
|
+
bool transpose_a,
|
|
17
|
+
bool transpose_b,
|
|
18
|
+
typename AccumType = float>
|
|
19
|
+
[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void
|
|
20
|
+
gather_mm_rhs_nax(
|
|
21
|
+
const device T* A [[buffer(0)]],
|
|
22
|
+
const device T* B [[buffer(1)]],
|
|
23
|
+
const device uint32_t* rhs_indices [[buffer(2)]],
|
|
24
|
+
device T* C [[buffer(3)]],
|
|
25
|
+
const constant GEMMParams* params [[buffer(4)]],
|
|
26
|
+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
27
|
+
uint3 tid [[threadgroup_position_in_grid]]) {
|
|
28
|
+
constexpr short UM = 16;
|
|
29
|
+
constexpr short UN = 32;
|
|
30
|
+
constexpr short UK = 16;
|
|
31
|
+
constexpr short SM = BM / WM;
|
|
32
|
+
constexpr short SN = BN / WN;
|
|
33
|
+
constexpr short SK = 32;
|
|
34
|
+
constexpr short TM = SM / UM;
|
|
35
|
+
constexpr short TN = SN / UN;
|
|
36
|
+
|
|
37
|
+
if (params->tiles_n <= static_cast<int>(tid.x) ||
|
|
38
|
+
params->tiles_m <= static_cast<int>(tid.y)) {
|
|
39
|
+
return;
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
// Find the block in A, B, C
|
|
43
|
+
const int c_row = tid.y * BM;
|
|
44
|
+
const int c_col = tid.x * BN;
|
|
45
|
+
const size_t c_row_long = size_t(c_row);
|
|
46
|
+
const size_t c_col_long = size_t(c_col);
|
|
47
|
+
|
|
48
|
+
A += transpose_a ? c_row_long : c_row_long * params->lda;
|
|
49
|
+
B += transpose_b ? c_col_long * params->ldb : c_col_long;
|
|
50
|
+
C += c_row_long * params->ldd + c_col_long;
|
|
51
|
+
rhs_indices += c_row;
|
|
52
|
+
|
|
53
|
+
const short tm = SM * (simd_group_id / WN);
|
|
54
|
+
const short tn = SN * (simd_group_id % WN);
|
|
55
|
+
|
|
56
|
+
const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm)));
|
|
57
|
+
const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM);
|
|
58
|
+
|
|
59
|
+
const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn)));
|
|
60
|
+
const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN);
|
|
61
|
+
|
|
62
|
+
A += transpose_a ? tm : (tm * params->lda);
|
|
63
|
+
B += transpose_b ? (tn * params->ldb) : tn;
|
|
64
|
+
C += tm * params->ldd + tn;
|
|
65
|
+
rhs_indices += tm;
|
|
66
|
+
|
|
67
|
+
// Do as many matmuls as necessary
|
|
68
|
+
uint32_t index;
|
|
69
|
+
short offset;
|
|
70
|
+
uint32_t index_next = rhs_indices[0];
|
|
71
|
+
short offset_next = 0;
|
|
72
|
+
int n = 0;
|
|
73
|
+
while (n < sgp_sm) {
|
|
74
|
+
n++;
|
|
75
|
+
offset = offset_next;
|
|
76
|
+
index = index_next;
|
|
77
|
+
offset_next = sgp_sm;
|
|
78
|
+
for (; n < sgp_sm; n++) {
|
|
79
|
+
if (rhs_indices[n] != index) {
|
|
80
|
+
offset_next = n;
|
|
81
|
+
index_next = rhs_indices[n];
|
|
82
|
+
break;
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
threadgroup_barrier(mem_flags::mem_none);
|
|
86
|
+
|
|
87
|
+
using DSubTile = NAXSubTile<AccumType, UM, UN>;
|
|
88
|
+
NAXTile<AccumType, TM, TN, DSubTile> Ctile;
|
|
89
|
+
|
|
90
|
+
dispatch_bool(align_K, [&](auto kAlignedK) {
|
|
91
|
+
dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) {
|
|
92
|
+
dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) {
|
|
93
|
+
auto do_gemm = gemm_loop<
|
|
94
|
+
T,
|
|
95
|
+
SM,
|
|
96
|
+
SN,
|
|
97
|
+
SK,
|
|
98
|
+
BK,
|
|
99
|
+
transpose_a,
|
|
100
|
+
transpose_b,
|
|
101
|
+
kAlignedM.value,
|
|
102
|
+
kAlignedN.value,
|
|
103
|
+
kAlignedK.value,
|
|
104
|
+
UM,
|
|
105
|
+
UN,
|
|
106
|
+
UK,
|
|
107
|
+
AccumType>;
|
|
108
|
+
Ctile = do_gemm(
|
|
109
|
+
A, B + index * params->batch_stride_b, params, sgp_sm, sgp_sn);
|
|
110
|
+
|
|
111
|
+
if constexpr (kAlignedN.value) {
|
|
112
|
+
if (offset_next - offset == SM) {
|
|
113
|
+
Ctile.store(C, int(params->ldd));
|
|
114
|
+
} else {
|
|
115
|
+
Ctile.store_slice(
|
|
116
|
+
C,
|
|
117
|
+
int(params->ldd),
|
|
118
|
+
short2(0, offset),
|
|
119
|
+
short2(SN, offset_next));
|
|
120
|
+
}
|
|
121
|
+
} else {
|
|
122
|
+
Ctile.store_slice(
|
|
123
|
+
C,
|
|
124
|
+
int(params->ldd),
|
|
125
|
+
short2(0, offset),
|
|
126
|
+
short2(sgp_sn, offset_next));
|
|
127
|
+
}
|
|
128
|
+
});
|
|
129
|
+
});
|
|
130
|
+
});
|
|
131
|
+
}
|
|
132
|
+
}
|