numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,3973 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SIMD-accelerated Batched Dot Products for Sapphire Rapids.
|
|
3
|
+
* @file include/numkong/dots/sapphireamx.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dots.h
|
|
8
|
+
*
|
|
9
|
+
* This file contains tiled matrix-multiplication kernels optimized for Intel AMX instructions,
|
|
10
|
+
* leveraging the new TMM registers on Intel Sapphire Rapids CPUs. Those are much larger than ZMM:
|
|
11
|
+
*
|
|
12
|
+
* - BF16 tiles: 16 rows × 32 elements = 512 BF16 values = 1KB per tile
|
|
13
|
+
* - INT8 tiles: 16 rows × 64 elements = 1024 INT8 values = 1KB per tile
|
|
14
|
+
*
|
|
15
|
+
* We typically use 4 registers for the 2 × 2 tile output for the matrix C accumulators, leaving
|
|
16
|
+
* 4 other registers for parts of A and B matrices:
|
|
17
|
+
*
|
|
18
|
+
* - TMM0, TMM1: A matrix tiles (row blocks i and i+16)
|
|
19
|
+
* - TMM2, TMM3: B matrix tiles (column blocks j and j+16)
|
|
20
|
+
* - TMM4-7: C accumulator tiles (2 × 2 output grid)
|
|
21
|
+
*
|
|
22
|
+
* In most synthetic benchmarks there seems to be no mahor difference between aggregating into 1 or 4
|
|
23
|
+
* output tiles, implying the CPU's ability to internally pipeline the accumulation; so using 2 × 2 for
|
|
24
|
+
* ouputs is more of memory-bandwidth saving measure.
|
|
25
|
+
*
|
|
26
|
+
* Lacking High Bandwidth Mememory, the performance in GEMM-like BLAS workloads is dominated by memory
|
|
27
|
+
* bandwidth. Latency hiding is also extremely hard, heavily affecting performance numbers. For reference,
|
|
28
|
+
* Intel MKL SGEMM for FP32 inputs yeilds arounf 250 GigaOPS per core on Intel Sapphire Rapids, leveraging
|
|
29
|
+
* AVX-512. At the same time, for AMX:
|
|
30
|
+
*
|
|
31
|
+
* - BF16 peak: ≈ 3 TeraOPS per core in theory, ≈ 500 GigaOPS per core in practice
|
|
32
|
+
* - INT8 peak: ≈ 6 TeraOPS per core in theory, ≈ 1000 GigaOPS per core in practice
|
|
33
|
+
*
|
|
34
|
+
* Several optimizations are used across file:
|
|
35
|
+
*
|
|
36
|
+
* - Pre-pack B matrix once for repeated inference (avoids runtime reordering)
|
|
37
|
+
* - Morton Z-curve tile ordering improves L2 cache hit rate by 5-25%
|
|
38
|
+
* - Use streaming stores for large C matrices to avoid cache pollution
|
|
39
|
+
*
|
|
40
|
+
* @section amx_instructions Intel AMX Instructions (Sapphire Rapids+)
|
|
41
|
+
*
|
|
42
|
+
* Tile configuration and data movement:
|
|
43
|
+
*
|
|
44
|
+
* Intrinsic Instruction Notes
|
|
45
|
+
* _tile_loadconfig LDTILECFG (mem64) Configure tile palette
|
|
46
|
+
* _tile_loadd TILELOADD (TMM, mem, stride) Load tile from memory
|
|
47
|
+
* _tile_stored TILESTORED (mem, TMM, stride) Store tile to memory
|
|
48
|
+
* _tile_zero TILEZERO (TMM) Zero a tile register
|
|
49
|
+
*
|
|
50
|
+
* BF16 matrix multiply (AMX-BF16):
|
|
51
|
+
*
|
|
52
|
+
* Intrinsic Instruction Operation
|
|
53
|
+
* _tile_dpbf16ps TDPBF16PS (TMM, TMM, TMM) C += A × B (bf16 → f32)
|
|
54
|
+
*
|
|
55
|
+
* INT8 matrix multiply (AMX-INT8):
|
|
56
|
+
*
|
|
57
|
+
* Intrinsic Instruction Operation
|
|
58
|
+
* _tile_dpbssd TDPBSSD (TMM, TMM, TMM) C += A × B (i8 × i8 → i32)
|
|
59
|
+
* _tile_dpbsud TDPBSUD (TMM, TMM, TMM) C += A × B (i8 × u8 → i32)
|
|
60
|
+
* _tile_dpbusd TDPBUSD (TMM, TMM, TMM) C += A × B (u8 × i8 → i32)
|
|
61
|
+
* _tile_dpbuud TDPBUUD (TMM, TMM, TMM) C += A × B (u8 × u8 → u32)
|
|
62
|
+
*
|
|
63
|
+
* AMX performance characteristics:
|
|
64
|
+
* - TDPBF16PS: 16 × 16 × 32 = 8192 BF16 MACs per instruction
|
|
65
|
+
* - TDPBSSD: 16 × 16 × 64 = 16384 INT8 MACs per instruction
|
|
66
|
+
* - Tile load latency is ~20-30 cycles; software pipelining essential
|
|
67
|
+
* - PDEP/PEXT used for Morton Z-curve encoding (BMI2): 2-3cy @ p1
|
|
68
|
+
*/
|
|
69
|
+
#ifndef NK_DOTS_SAPPHIREAMX_H
|
|
70
|
+
#define NK_DOTS_SAPPHIREAMX_H
|
|
71
|
+
|
|
72
|
+
#if NK_TARGET_X86_
|
|
73
|
+
#if NK_TARGET_SAPPHIREAMX
|
|
74
|
+
|
|
75
|
+
#include "numkong/cast/icelake.h" // For FP8 ↔ BF16 conversions
|
|
76
|
+
#include "numkong/dots/serial.h" // For nk_dots_reduce_sumsq_bf16_
|
|
77
|
+
|
|
78
|
+
#if defined(__cplusplus)
|
|
79
|
+
extern "C" {
|
|
80
|
+
#endif
|
|
81
|
+
|
|
82
|
+
#if defined(__clang__)
|
|
83
|
+
#pragma clang attribute push( \
|
|
84
|
+
__attribute__((target( \
|
|
85
|
+
"avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx512vbmi,f16c,fma,bmi,bmi2,amx-tile,amx-bf16,amx-int8"))), \
|
|
86
|
+
apply_to = function)
|
|
87
|
+
#elif defined(__GNUC__)
|
|
88
|
+
#pragma GCC push_options
|
|
89
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx512vbmi", "f16c", "fma", \
|
|
90
|
+
"bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8")
|
|
91
|
+
#endif
|
|
92
|
+
|
|
93
|
+
/* AMX-specific packed buffer header (64-byte aligned).
|
|
94
|
+
* Different from nk_dots_amx_packed_header_t as AMX uses tile-based layout.
|
|
95
|
+
*/
|
|
96
|
+
typedef struct {
|
|
97
|
+
nk_u32_t full_column_tiles; // Number of full column tiles (16 rows each)
|
|
98
|
+
nk_u32_t full_depth_tiles; // Number of depth tiles (32 columns for BF16, 64 for I8)
|
|
99
|
+
nk_u32_t column_remainder_count; // Remaining rows after full tiles (0-15)
|
|
100
|
+
nk_u32_t column_edge_offset; // Byte offset to edge data region
|
|
101
|
+
nk_u32_t norms_byte_offset; // Byte offset to per-column norms (for angular/euclidean)
|
|
102
|
+
nk_u32_t reserved[11]; // Padding to 64 bytes
|
|
103
|
+
} nk_dots_amx_packed_header_t;
|
|
104
|
+
|
|
105
|
+
/* Composable tile structures for AMX operations.
|
|
106
|
+
* These enable reusable primitives and cross-correlation (A × Aᵀ) use cases.
|
|
107
|
+
*/
|
|
108
|
+
|
|
109
|
+
/* BF16 A tile: 16 rows × 32 depth-elements, row-major layout.
|
|
110
|
+
* Loaded from source matrix, used as left operand in AMX multiply.
|
|
111
|
+
*/
|
|
112
|
+
typedef struct {
|
|
113
|
+
NK_ALIGN64 nk_bf16_t data[16][32]; // 16 rows × 32 columns = 1KB
|
|
114
|
+
} nk_dots_bf16_a16x32_sapphireamx_t;
|
|
115
|
+
|
|
116
|
+
/* BF16 B tile: 32 depth × 16 columns, pair-interleaved for TDPBF16PS.
|
|
117
|
+
* Access pattern: data[depth/2][column][depth%2] for logical B[depth, column].
|
|
118
|
+
* Pre-packed from column-major or transposed source.
|
|
119
|
+
*/
|
|
120
|
+
typedef struct {
|
|
121
|
+
NK_ALIGN64 nk_bf16_t data[16][16][2]; // 16 depth-groups × 16 columns × 2 = 1KB
|
|
122
|
+
} nk_dots_bf16_b32x16_sapphireamx_t;
|
|
123
|
+
|
|
124
|
+
/* BF16 output state: 16 × 16 F32 accumulator tile.
|
|
125
|
+
* Holds partial sums during depth-dimension accumulation.
|
|
126
|
+
*/
|
|
127
|
+
typedef struct {
|
|
128
|
+
NK_ALIGN64 nk_f32_t data[16][16]; // 16 × 16 = 1KB
|
|
129
|
+
} nk_dots_bf16_state_sapphireamx_t;
|
|
130
|
+
|
|
131
|
+
/* INT8 A tile: 16 rows × 64 depth-elements, row-major layout.
|
|
132
|
+
*/
|
|
133
|
+
typedef struct {
|
|
134
|
+
NK_ALIGN64 nk_i8_t data[16][64]; // 16 rows × 64 columns = 1KB
|
|
135
|
+
} nk_dots_i8_a16x64_sapphireamx_t;
|
|
136
|
+
|
|
137
|
+
/* INT8 B tile: 64 depth × 16 columns, quad-interleaved for TDPBSSD.
|
|
138
|
+
* Access pattern: data[depth/4][column][depth%4] for logical B[depth, column].
|
|
139
|
+
*/
|
|
140
|
+
typedef struct {
|
|
141
|
+
NK_ALIGN64 nk_i8_t data[16][16][4]; // 16 depth-groups × 16 columns × 4 = 1KB
|
|
142
|
+
} nk_dots_i8_b64x16_sapphireamx_t;
|
|
143
|
+
|
|
144
|
+
/* INT8 output state: 16 × 16 I32 accumulator tile.
|
|
145
|
+
*/
|
|
146
|
+
typedef struct {
|
|
147
|
+
NK_ALIGN64 nk_i32_t data[16][16]; // 16 × 16 = 1KB
|
|
148
|
+
} nk_dots_i8_state_sapphireamx_t;
|
|
149
|
+
|
|
150
|
+
/* BF16 2 × 2 output state: 32 × 32 F32 output (4 accumulator tiles).
|
|
151
|
+
* Used for GEMM's 2 × 2 output blocking pattern.
|
|
152
|
+
*/
|
|
153
|
+
typedef struct {
|
|
154
|
+
nk_dots_bf16_state_sapphireamx_t c[2][2]; // 4KB total
|
|
155
|
+
} nk_dots_bf16_state2x2_sapphireamx_t;
|
|
156
|
+
|
|
157
|
+
/* INT8 2 × 2 output state: 32 × 32 I32 output (4 accumulator tiles).
|
|
158
|
+
*/
|
|
159
|
+
typedef struct {
|
|
160
|
+
nk_dots_i8_state_sapphireamx_t c[2][2]; // 4KB total
|
|
161
|
+
} nk_dots_i8_state2x2_sapphireamx_t;
|
|
162
|
+
|
|
163
|
+
/* UINT8 A tile: 16 rows × 64 depth-elements, row-major layout.
|
|
164
|
+
* Same layout as I8, different interpretation of signed vs unsigned.
|
|
165
|
+
*/
|
|
166
|
+
typedef struct {
|
|
167
|
+
NK_ALIGN64 nk_u8_t data[16][64]; // 16 rows × 64 columns = 1KB
|
|
168
|
+
} nk_dots_u8_a16x64_sapphireamx_t;
|
|
169
|
+
|
|
170
|
+
/* UINT8 B tile: 64 depth × 16 columns, quad-interleaved for TDPBUUD.
|
|
171
|
+
*/
|
|
172
|
+
typedef struct {
|
|
173
|
+
NK_ALIGN64 nk_u8_t data[16][16][4]; // 16 depth-groups × 16 columns × 4 = 1KB
|
|
174
|
+
} nk_dots_u8_b64x16_sapphireamx_t;
|
|
175
|
+
|
|
176
|
+
/* UINT8 output state: 16 × 16 U32 accumulator tile.
|
|
177
|
+
*/
|
|
178
|
+
typedef struct {
|
|
179
|
+
NK_ALIGN64 nk_u32_t data[16][16]; // 16 × 16 = 1KB
|
|
180
|
+
} nk_dots_u8_state_sapphireamx_t;
|
|
181
|
+
|
|
182
|
+
/* UINT8 2 × 2 output state: 32 × 32 U32 output (4 accumulator tiles).
|
|
183
|
+
*/
|
|
184
|
+
typedef struct {
|
|
185
|
+
nk_dots_u8_state_sapphireamx_t c[2][2]; // 4KB total
|
|
186
|
+
} nk_dots_u8_state2x2_sapphireamx_t;
|
|
187
|
+
|
|
188
|
+
/* Morton Z-curve encoding for cache-friendly tile traversal */
|
|
189
|
+
NK_INTERNAL nk_u64_t nk_morton_encode_sapphireamx_(nk_u32_t tile_row, nk_u32_t tile_col) {
|
|
190
|
+
return _pdep_u64(tile_row, 0x5555555555555555ULL) | _pdep_u64(tile_col, 0xAAAAAAAAAAAAAAAAULL);
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
/* Configure AMX tile registers */
|
|
194
|
+
NK_INTERNAL void nk_amx_tile_configure_sapphireamx_(void) {
|
|
195
|
+
NK_ALIGN64 nk_u8_t tile_config[64] = {0};
|
|
196
|
+
tile_config[0] = 1; // palette 1 (standard tile configuration)
|
|
197
|
+
|
|
198
|
+
nk_u16_t *bytes_per_row = (nk_u16_t *)&tile_config[16];
|
|
199
|
+
nk_u8_t *rows_per_tile = &tile_config[48];
|
|
200
|
+
|
|
201
|
+
for (int tile_idx = 0; tile_idx < 8; tile_idx++) {
|
|
202
|
+
rows_per_tile[tile_idx] = 16; // 16 rows per tile
|
|
203
|
+
bytes_per_row[tile_idx] = 64; // 64 bytes per row (1KB total)
|
|
204
|
+
}
|
|
205
|
+
_tile_loadconfig(tile_config);
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
/** @brief Compiler memory barrier to ensure stores complete before AMX tile loads */
|
|
209
|
+
#if defined(_MSC_VER)
|
|
210
|
+
NK_INTERNAL void nk_compiler_barrier_sapphireamx_(void) { _ReadWriteBarrier(); }
|
|
211
|
+
#else
|
|
212
|
+
NK_INTERNAL void nk_compiler_barrier_sapphireamx_(void) { __asm__ volatile("" ::: "memory"); }
|
|
213
|
+
#endif
|
|
214
|
+
|
|
215
|
+
/* Initialize BF16 output state to zero */
|
|
216
|
+
NK_INTERNAL void nk_dots_bf16_init_sapphireamx_(nk_dots_bf16_state_sapphireamx_t *state) {
|
|
217
|
+
__m512 zero = _mm512_setzero_ps();
|
|
218
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) { _mm512_store_ps(state->data[row_idx], zero); }
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
/* Load A tile from row-major source with masking for edge tiles */
|
|
222
|
+
NK_INTERNAL void nk_dots_bf16_load_a_sapphireamx_( //
|
|
223
|
+
nk_dots_bf16_a16x32_sapphireamx_t *a_tile, //
|
|
224
|
+
nk_bf16_t const *src, nk_size_t src_stride_elements, //
|
|
225
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
226
|
+
|
|
227
|
+
__mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
|
|
228
|
+
__m512i zero = _mm512_setzero_si512();
|
|
229
|
+
|
|
230
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
231
|
+
if (row_idx < valid_rows) {
|
|
232
|
+
__m512i row = _mm512_maskz_loadu_epi16(column_mask, src + row_idx * src_stride_elements);
|
|
233
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], row);
|
|
234
|
+
}
|
|
235
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
|
|
236
|
+
}
|
|
237
|
+
nk_compiler_barrier_sapphireamx_();
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
/* Store state to output matrix with masking for edge tiles */
|
|
241
|
+
NK_INTERNAL void nk_dots_bf16_store_sapphireamx_( //
|
|
242
|
+
nk_dots_bf16_state_sapphireamx_t const *state, //
|
|
243
|
+
nk_f32_t *dst, nk_size_t dst_stride_elements, //
|
|
244
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
245
|
+
|
|
246
|
+
__mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
|
|
247
|
+
|
|
248
|
+
for (nk_size_t row_idx = 0; row_idx < valid_rows; row_idx++) {
|
|
249
|
+
__m512 row = _mm512_load_ps(state->data[row_idx]);
|
|
250
|
+
_mm512_mask_storeu_ps(dst + row_idx * dst_stride_elements, column_mask, row);
|
|
251
|
+
}
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
/* Accumulate 3 A x B tile pairs into state using AMX TDPBF16PS */
|
|
255
|
+
NK_INTERNAL void nk_dots_bf16_update_sapphireamx_( //
|
|
256
|
+
nk_dots_bf16_state_sapphireamx_t *state, //
|
|
257
|
+
nk_dots_bf16_a16x32_sapphireamx_t const *a_tile_0, //
|
|
258
|
+
nk_dots_bf16_a16x32_sapphireamx_t const *a_tile_1, //
|
|
259
|
+
nk_dots_bf16_a16x32_sapphireamx_t const *a_tile_2, //
|
|
260
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_0, //
|
|
261
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_1, //
|
|
262
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_2) {
|
|
263
|
+
|
|
264
|
+
// Load all tiles into registers
|
|
265
|
+
_tile_loadd(0, state->data, 64); // C accumulator
|
|
266
|
+
_tile_loadd(1, a_tile_0->data, 64); // A0
|
|
267
|
+
_tile_loadd(2, a_tile_1->data, 64); // A1
|
|
268
|
+
_tile_loadd(3, a_tile_2->data, 64); // A2
|
|
269
|
+
_tile_loadd(4, b_tile_0->data, 64); // B0
|
|
270
|
+
_tile_loadd(5, b_tile_1->data, 64); // B1
|
|
271
|
+
_tile_loadd(6, b_tile_2->data, 64); // B2
|
|
272
|
+
|
|
273
|
+
// Accumulate: C += A0 × B0 + A1 × B1 + A2 × B2
|
|
274
|
+
_tile_dpbf16ps(0, 1, 4); // C += A0 × B0
|
|
275
|
+
_tile_dpbf16ps(0, 2, 5); // C += A1 × B1
|
|
276
|
+
_tile_dpbf16ps(0, 3, 6); // C += A2 × B2
|
|
277
|
+
|
|
278
|
+
// Store result
|
|
279
|
+
_tile_stored(0, state->data, 64);
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
/* Initialize INT8 output state to zero */
|
|
283
|
+
NK_INTERNAL void nk_dots_i8_init_sapphireamx_(nk_dots_i8_state_sapphireamx_t *state) {
|
|
284
|
+
__m512i zero = _mm512_setzero_si512();
|
|
285
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) { _mm512_store_si512((__m512i *)state->data[row_idx], zero); }
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
/* Load A tile from row-major source with masking for edge tiles */
|
|
289
|
+
NK_INTERNAL void nk_dots_i8_load_a_sapphireamx_( //
|
|
290
|
+
nk_dots_i8_a16x64_sapphireamx_t *a_tile, //
|
|
291
|
+
nk_i8_t const *src, nk_size_t src_stride, //
|
|
292
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
293
|
+
|
|
294
|
+
__mmask64 column_mask = (valid_cols >= 64) ? 0xFFFFFFFFFFFFFFFFULL : ((__mmask64)1 << valid_cols) - 1;
|
|
295
|
+
__m512i zero = _mm512_setzero_si512();
|
|
296
|
+
|
|
297
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
298
|
+
if (row_idx < valid_rows) {
|
|
299
|
+
__m512i row = _mm512_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
|
|
300
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], row);
|
|
301
|
+
}
|
|
302
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
|
|
303
|
+
}
|
|
304
|
+
nk_compiler_barrier_sapphireamx_();
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
/* Store state to output matrix with masking for edge tiles */
|
|
308
|
+
NK_INTERNAL void nk_dots_i8_store_sapphireamx_( //
|
|
309
|
+
nk_dots_i8_state_sapphireamx_t const *state, //
|
|
310
|
+
nk_i32_t *dst, nk_size_t dst_stride_elements, //
|
|
311
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
312
|
+
|
|
313
|
+
__mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
|
|
314
|
+
|
|
315
|
+
for (nk_size_t row_idx = 0; row_idx < valid_rows; row_idx++) {
|
|
316
|
+
__m512i row = _mm512_load_si512((__m512i const *)state->data[row_idx]);
|
|
317
|
+
_mm512_mask_storeu_epi32(dst + row_idx * dst_stride_elements, column_mask, row);
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
/* Accumulate 3 A x B tile pairs into state using AMX TDPBSSD */
|
|
322
|
+
NK_INTERNAL void nk_dots_i8_update_sapphireamx_( //
|
|
323
|
+
nk_dots_i8_state_sapphireamx_t *state, //
|
|
324
|
+
nk_dots_i8_a16x64_sapphireamx_t const *a_tile_0, //
|
|
325
|
+
nk_dots_i8_a16x64_sapphireamx_t const *a_tile_1, //
|
|
326
|
+
nk_dots_i8_a16x64_sapphireamx_t const *a_tile_2, //
|
|
327
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_0, //
|
|
328
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_1, //
|
|
329
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_2) {
|
|
330
|
+
|
|
331
|
+
// Load all tiles into registers
|
|
332
|
+
_tile_loadd(0, state->data, 64); // C accumulator
|
|
333
|
+
_tile_loadd(1, a_tile_0->data, 64); // A0
|
|
334
|
+
_tile_loadd(2, a_tile_1->data, 64); // A1
|
|
335
|
+
_tile_loadd(3, a_tile_2->data, 64); // A2
|
|
336
|
+
_tile_loadd(4, b_tile_0->data, 64); // B0
|
|
337
|
+
_tile_loadd(5, b_tile_1->data, 64); // B1
|
|
338
|
+
_tile_loadd(6, b_tile_2->data, 64); // B2
|
|
339
|
+
|
|
340
|
+
// Accumulate: C += A0 × B0 + A1 × B1 + A2 × B2
|
|
341
|
+
_tile_dpbssd(0, 1, 4); // C += A0 × B0
|
|
342
|
+
_tile_dpbssd(0, 2, 5); // C += A1 × B1
|
|
343
|
+
_tile_dpbssd(0, 3, 6); // C += A2 × B2
|
|
344
|
+
|
|
345
|
+
// Store result
|
|
346
|
+
_tile_stored(0, state->data, 64);
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
/* Store BF16 2x2 state to output matrix with masking for edge tiles */
|
|
350
|
+
NK_INTERNAL void nk_dots_bf16_output2x2_sapphireamx_( //
|
|
351
|
+
nk_dots_bf16_state2x2_sapphireamx_t const *state, //
|
|
352
|
+
nk_f32_t *dst, nk_size_t dst_stride_elements, //
|
|
353
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
354
|
+
|
|
355
|
+
// Rows 0-15
|
|
356
|
+
nk_size_t const rows_upper = (valid_rows > 16) ? 16 : valid_rows;
|
|
357
|
+
nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
|
|
358
|
+
nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
|
|
359
|
+
|
|
360
|
+
if (rows_upper > 0 && cols_left > 0)
|
|
361
|
+
nk_dots_bf16_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_upper, cols_left);
|
|
362
|
+
if (rows_upper > 0 && cols_right > 0)
|
|
363
|
+
nk_dots_bf16_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_upper, cols_right);
|
|
364
|
+
|
|
365
|
+
// Rows 16-31
|
|
366
|
+
if (valid_rows > 16) {
|
|
367
|
+
nk_size_t const rows_lower = valid_rows - 16;
|
|
368
|
+
nk_f32_t *dst_lower = dst + 16 * dst_stride_elements;
|
|
369
|
+
if (cols_left > 0)
|
|
370
|
+
nk_dots_bf16_store_sapphireamx_(&state->c[1][0], dst_lower, dst_stride_elements, rows_lower, cols_left);
|
|
371
|
+
if (cols_right > 0)
|
|
372
|
+
nk_dots_bf16_store_sapphireamx_(&state->c[1][1], dst_lower + 16, dst_stride_elements, rows_lower,
|
|
373
|
+
cols_right);
|
|
374
|
+
}
|
|
375
|
+
}
|
|
376
|
+
|
|
377
|
+
/* Store INT8 2x2 state to output matrix with masking for edge tiles */
|
|
378
|
+
NK_INTERNAL void nk_dots_i8_output2x2_sapphireamx_( //
|
|
379
|
+
nk_dots_i8_state2x2_sapphireamx_t const *state, //
|
|
380
|
+
nk_i32_t *dst, nk_size_t dst_stride_elements, //
|
|
381
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
382
|
+
|
|
383
|
+
nk_size_t const rows_upper = (valid_rows > 16) ? 16 : valid_rows;
|
|
384
|
+
nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
|
|
385
|
+
nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
|
|
386
|
+
|
|
387
|
+
if (rows_upper > 0 && cols_left > 0)
|
|
388
|
+
nk_dots_i8_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_upper, cols_left);
|
|
389
|
+
if (rows_upper > 0 && cols_right > 0)
|
|
390
|
+
nk_dots_i8_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_upper, cols_right);
|
|
391
|
+
|
|
392
|
+
if (valid_rows > 16) {
|
|
393
|
+
nk_size_t const rows_lower = valid_rows - 16;
|
|
394
|
+
nk_i32_t *dst_lower = dst + 16 * dst_stride_elements;
|
|
395
|
+
if (cols_left > 0)
|
|
396
|
+
nk_dots_i8_store_sapphireamx_(&state->c[1][0], dst_lower, dst_stride_elements, rows_lower, cols_left);
|
|
397
|
+
if (cols_right > 0)
|
|
398
|
+
nk_dots_i8_store_sapphireamx_(&state->c[1][1], dst_lower + 16, dst_stride_elements, rows_lower, cols_right);
|
|
399
|
+
}
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
/* Initialize UINT8 output state to zero */
|
|
403
|
+
NK_INTERNAL void nk_dots_u8_init_sapphireamx_(nk_dots_u8_state_sapphireamx_t *state) {
|
|
404
|
+
nk_dots_i8_init_sapphireamx_((nk_dots_i8_state_sapphireamx_t *)state);
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
/* Load U8 A tile from row-major source with masking for edge tiles */
|
|
408
|
+
NK_INTERNAL void nk_dots_u8_load_a_sapphireamx_( //
|
|
409
|
+
nk_dots_u8_a16x64_sapphireamx_t *a_tile, //
|
|
410
|
+
nk_u8_t const *src, nk_size_t src_stride, //
|
|
411
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
412
|
+
nk_dots_i8_load_a_sapphireamx_( //
|
|
413
|
+
(nk_dots_i8_a16x64_sapphireamx_t *)a_tile, //
|
|
414
|
+
(nk_i8_t const *)src, src_stride, valid_rows, valid_cols);
|
|
415
|
+
}
|
|
416
|
+
|
|
417
|
+
/* Store U8 state to output matrix with masking for edge tiles */
|
|
418
|
+
NK_INTERNAL void nk_dots_u8_store_sapphireamx_( //
|
|
419
|
+
nk_dots_u8_state_sapphireamx_t const *state, //
|
|
420
|
+
nk_u32_t *dst, nk_size_t dst_stride_elements, //
|
|
421
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
422
|
+
nk_dots_i8_store_sapphireamx_( //
|
|
423
|
+
(nk_dots_i8_state_sapphireamx_t const *)state, //
|
|
424
|
+
(nk_i32_t *)dst, dst_stride_elements, valid_rows, valid_cols);
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
/* Store UINT8 2x2 state to output matrix with masking for edge tiles */
|
|
428
|
+
NK_INTERNAL void nk_dots_u8_output2x2_sapphireamx_( //
|
|
429
|
+
nk_dots_u8_state2x2_sapphireamx_t const *state, //
|
|
430
|
+
nk_u32_t *dst, nk_size_t dst_stride_elements, //
|
|
431
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
432
|
+
nk_dots_i8_output2x2_sapphireamx_( //
|
|
433
|
+
(nk_dots_i8_state2x2_sapphireamx_t const *)state, //
|
|
434
|
+
(nk_i32_t *)dst, dst_stride_elements, valid_rows, valid_cols);
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
/* Pack U8 A transposed into B format */
|
|
438
|
+
NK_INTERNAL void nk_dots_pack_u8_transposed_sapphireamx_( //
|
|
439
|
+
nk_dots_u8_a16x64_sapphireamx_t const *a_tile, //
|
|
440
|
+
nk_dots_u8_b64x16_sapphireamx_t *b_tile) {
|
|
441
|
+
|
|
442
|
+
// Load all 16 rows - each row is 64 UINT8 = 64 bytes = 1 ZMM
|
|
443
|
+
// Treat as 16 × 32-bit elements per row (each 32-bit = quad of UINT8)
|
|
444
|
+
__m512i row00 = _mm512_load_si512(&a_tile->data[0][0]);
|
|
445
|
+
__m512i row01 = _mm512_load_si512(&a_tile->data[1][0]);
|
|
446
|
+
__m512i row02 = _mm512_load_si512(&a_tile->data[2][0]);
|
|
447
|
+
__m512i row03 = _mm512_load_si512(&a_tile->data[3][0]);
|
|
448
|
+
__m512i row04 = _mm512_load_si512(&a_tile->data[4][0]);
|
|
449
|
+
__m512i row05 = _mm512_load_si512(&a_tile->data[5][0]);
|
|
450
|
+
__m512i row06 = _mm512_load_si512(&a_tile->data[6][0]);
|
|
451
|
+
__m512i row07 = _mm512_load_si512(&a_tile->data[7][0]);
|
|
452
|
+
__m512i row08 = _mm512_load_si512(&a_tile->data[8][0]);
|
|
453
|
+
__m512i row09 = _mm512_load_si512(&a_tile->data[9][0]);
|
|
454
|
+
__m512i row10 = _mm512_load_si512(&a_tile->data[10][0]);
|
|
455
|
+
__m512i row11 = _mm512_load_si512(&a_tile->data[11][0]);
|
|
456
|
+
__m512i row12 = _mm512_load_si512(&a_tile->data[12][0]);
|
|
457
|
+
__m512i row13 = _mm512_load_si512(&a_tile->data[13][0]);
|
|
458
|
+
__m512i row14 = _mm512_load_si512(&a_tile->data[14][0]);
|
|
459
|
+
__m512i row15 = _mm512_load_si512(&a_tile->data[15][0]);
|
|
460
|
+
|
|
461
|
+
// 16×16 transpose of 32-bit elements using hierarchical unpacks
|
|
462
|
+
// Stage 1: Unpack adjacent row pairs at 32-bit granularity
|
|
463
|
+
__m512i t01_lo = _mm512_unpacklo_epi32(row00, row01);
|
|
464
|
+
__m512i t01_hi = _mm512_unpackhi_epi32(row00, row01);
|
|
465
|
+
__m512i t23_lo = _mm512_unpacklo_epi32(row02, row03);
|
|
466
|
+
__m512i t23_hi = _mm512_unpackhi_epi32(row02, row03);
|
|
467
|
+
__m512i t45_lo = _mm512_unpacklo_epi32(row04, row05);
|
|
468
|
+
__m512i t45_hi = _mm512_unpackhi_epi32(row04, row05);
|
|
469
|
+
__m512i t67_lo = _mm512_unpacklo_epi32(row06, row07);
|
|
470
|
+
__m512i t67_hi = _mm512_unpackhi_epi32(row06, row07);
|
|
471
|
+
__m512i t89_lo = _mm512_unpacklo_epi32(row08, row09);
|
|
472
|
+
__m512i t89_hi = _mm512_unpackhi_epi32(row08, row09);
|
|
473
|
+
__m512i tab_lo = _mm512_unpacklo_epi32(row10, row11);
|
|
474
|
+
__m512i tab_hi = _mm512_unpackhi_epi32(row10, row11);
|
|
475
|
+
__m512i tcd_lo = _mm512_unpacklo_epi32(row12, row13);
|
|
476
|
+
__m512i tcd_hi = _mm512_unpackhi_epi32(row12, row13);
|
|
477
|
+
__m512i tef_lo = _mm512_unpacklo_epi32(row14, row15);
|
|
478
|
+
__m512i tef_hi = _mm512_unpackhi_epi32(row14, row15);
|
|
479
|
+
|
|
480
|
+
// Stage 2: Unpack at 64-bit granularity
|
|
481
|
+
__m512i u0123_ll = _mm512_unpacklo_epi64(t01_lo, t23_lo);
|
|
482
|
+
__m512i u0123_lh = _mm512_unpackhi_epi64(t01_lo, t23_lo);
|
|
483
|
+
__m512i u0123_hl = _mm512_unpacklo_epi64(t01_hi, t23_hi);
|
|
484
|
+
__m512i u0123_hh = _mm512_unpackhi_epi64(t01_hi, t23_hi);
|
|
485
|
+
__m512i u4567_ll = _mm512_unpacklo_epi64(t45_lo, t67_lo);
|
|
486
|
+
__m512i u4567_lh = _mm512_unpackhi_epi64(t45_lo, t67_lo);
|
|
487
|
+
__m512i u4567_hl = _mm512_unpacklo_epi64(t45_hi, t67_hi);
|
|
488
|
+
__m512i u4567_hh = _mm512_unpackhi_epi64(t45_hi, t67_hi);
|
|
489
|
+
__m512i u89ab_ll = _mm512_unpacklo_epi64(t89_lo, tab_lo);
|
|
490
|
+
__m512i u89ab_lh = _mm512_unpackhi_epi64(t89_lo, tab_lo);
|
|
491
|
+
__m512i u89ab_hl = _mm512_unpacklo_epi64(t89_hi, tab_hi);
|
|
492
|
+
__m512i u89ab_hh = _mm512_unpackhi_epi64(t89_hi, tab_hi);
|
|
493
|
+
__m512i ucdef_ll = _mm512_unpacklo_epi64(tcd_lo, tef_lo);
|
|
494
|
+
__m512i ucdef_lh = _mm512_unpackhi_epi64(tcd_lo, tef_lo);
|
|
495
|
+
__m512i ucdef_hl = _mm512_unpacklo_epi64(tcd_hi, tef_hi);
|
|
496
|
+
__m512i ucdef_hh = _mm512_unpackhi_epi64(tcd_hi, tef_hi);
|
|
497
|
+
|
|
498
|
+
// Stage 3: Shuffle 128-bit lanes
|
|
499
|
+
__m512i v0_a = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0x88);
|
|
500
|
+
__m512i v0_b = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0xDD);
|
|
501
|
+
__m512i v1_a = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0x88);
|
|
502
|
+
__m512i v1_b = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0xDD);
|
|
503
|
+
__m512i v2_a = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0x88);
|
|
504
|
+
__m512i v2_b = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0xDD);
|
|
505
|
+
__m512i v3_a = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0x88);
|
|
506
|
+
__m512i v3_b = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0xDD);
|
|
507
|
+
__m512i v4_a = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0x88);
|
|
508
|
+
__m512i v4_b = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0xDD);
|
|
509
|
+
__m512i v5_a = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0x88);
|
|
510
|
+
__m512i v5_b = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0xDD);
|
|
511
|
+
__m512i v6_a = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0x88);
|
|
512
|
+
__m512i v6_b = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0xDD);
|
|
513
|
+
__m512i v7_a = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0x88);
|
|
514
|
+
__m512i v7_b = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0xDD);
|
|
515
|
+
|
|
516
|
+
// Stage 4: Final 256-bit shuffle to complete transpose
|
|
517
|
+
__m512i out00 = _mm512_shuffle_i32x4(v0_a, v4_a, 0x88);
|
|
518
|
+
__m512i out01 = _mm512_shuffle_i32x4(v1_a, v5_a, 0x88);
|
|
519
|
+
__m512i out02 = _mm512_shuffle_i32x4(v2_a, v6_a, 0x88);
|
|
520
|
+
__m512i out03 = _mm512_shuffle_i32x4(v3_a, v7_a, 0x88);
|
|
521
|
+
__m512i out04 = _mm512_shuffle_i32x4(v0_a, v4_a, 0xDD);
|
|
522
|
+
__m512i out05 = _mm512_shuffle_i32x4(v1_a, v5_a, 0xDD);
|
|
523
|
+
__m512i out06 = _mm512_shuffle_i32x4(v2_a, v6_a, 0xDD);
|
|
524
|
+
__m512i out07 = _mm512_shuffle_i32x4(v3_a, v7_a, 0xDD);
|
|
525
|
+
__m512i out08 = _mm512_shuffle_i32x4(v0_b, v4_b, 0x88);
|
|
526
|
+
__m512i out09 = _mm512_shuffle_i32x4(v1_b, v5_b, 0x88);
|
|
527
|
+
__m512i out10 = _mm512_shuffle_i32x4(v2_b, v6_b, 0x88);
|
|
528
|
+
__m512i out11 = _mm512_shuffle_i32x4(v3_b, v7_b, 0x88);
|
|
529
|
+
__m512i out12 = _mm512_shuffle_i32x4(v0_b, v4_b, 0xDD);
|
|
530
|
+
__m512i out13 = _mm512_shuffle_i32x4(v1_b, v5_b, 0xDD);
|
|
531
|
+
__m512i out14 = _mm512_shuffle_i32x4(v2_b, v6_b, 0xDD);
|
|
532
|
+
__m512i out15 = _mm512_shuffle_i32x4(v3_b, v7_b, 0xDD);
|
|
533
|
+
|
|
534
|
+
// Store transposed results - each output row is one depth_group
|
|
535
|
+
// Output layout: B.data[depth_group][column][quad] = 16 columns × 4 UINT8 = 64 bytes
|
|
536
|
+
_mm512_store_si512(&b_tile->data[0][0][0], out00);
|
|
537
|
+
_mm512_store_si512(&b_tile->data[1][0][0], out01);
|
|
538
|
+
_mm512_store_si512(&b_tile->data[2][0][0], out02);
|
|
539
|
+
_mm512_store_si512(&b_tile->data[3][0][0], out03);
|
|
540
|
+
_mm512_store_si512(&b_tile->data[4][0][0], out08);
|
|
541
|
+
_mm512_store_si512(&b_tile->data[5][0][0], out09);
|
|
542
|
+
_mm512_store_si512(&b_tile->data[6][0][0], out10);
|
|
543
|
+
_mm512_store_si512(&b_tile->data[7][0][0], out11);
|
|
544
|
+
_mm512_store_si512(&b_tile->data[8][0][0], out04);
|
|
545
|
+
_mm512_store_si512(&b_tile->data[9][0][0], out05);
|
|
546
|
+
_mm512_store_si512(&b_tile->data[10][0][0], out06);
|
|
547
|
+
_mm512_store_si512(&b_tile->data[11][0][0], out07);
|
|
548
|
+
_mm512_store_si512(&b_tile->data[12][0][0], out12);
|
|
549
|
+
_mm512_store_si512(&b_tile->data[13][0][0], out13);
|
|
550
|
+
_mm512_store_si512(&b_tile->data[14][0][0], out14);
|
|
551
|
+
_mm512_store_si512(&b_tile->data[15][0][0], out15);
|
|
552
|
+
|
|
553
|
+
nk_compiler_barrier_sapphireamx_();
|
|
554
|
+
}
|
|
555
|
+
|
|
556
|
+
/* Accumulate 3 A x B tile pairs into state using AMX TDPBUUD */
|
|
557
|
+
NK_INTERNAL void nk_dots_u8_update_sapphireamx_( //
|
|
558
|
+
nk_dots_u8_state_sapphireamx_t *state, //
|
|
559
|
+
nk_dots_u8_a16x64_sapphireamx_t const *a_tile_0, //
|
|
560
|
+
nk_dots_u8_a16x64_sapphireamx_t const *a_tile_1, //
|
|
561
|
+
nk_dots_u8_a16x64_sapphireamx_t const *a_tile_2, //
|
|
562
|
+
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_0, //
|
|
563
|
+
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_1, //
|
|
564
|
+
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_2) {
|
|
565
|
+
|
|
566
|
+
// Load all tiles into registers
|
|
567
|
+
_tile_loadd(0, state->data, 64); // C accumulator
|
|
568
|
+
_tile_loadd(1, a_tile_0->data, 64); // A0
|
|
569
|
+
_tile_loadd(2, a_tile_1->data, 64); // A1
|
|
570
|
+
_tile_loadd(3, a_tile_2->data, 64); // A2
|
|
571
|
+
_tile_loadd(4, b_tile_0->data, 64); // B0
|
|
572
|
+
_tile_loadd(5, b_tile_1->data, 64); // B1
|
|
573
|
+
_tile_loadd(6, b_tile_2->data, 64); // B2
|
|
574
|
+
|
|
575
|
+
// Accumulate: C += A0 × B0 + A1 × B1 + A2 × B2
|
|
576
|
+
_tile_dpbuud(0, 1, 4); // C += A0 × B0
|
|
577
|
+
_tile_dpbuud(0, 2, 5); // C += A1 × B1
|
|
578
|
+
_tile_dpbuud(0, 3, 6); // C += A2 × B2
|
|
579
|
+
|
|
580
|
+
// Store result
|
|
581
|
+
_tile_stored(0, state->data, 64);
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
/* Load E4M3 A tile with FP8 to BF16 conversion */
|
|
585
|
+
NK_INTERNAL void nk_dots_e4m3_load_a_sapphireamx_( //
|
|
586
|
+
nk_dots_bf16_a16x32_sapphireamx_t *a_tile, //
|
|
587
|
+
nk_e4m3_t const *src, nk_size_t src_stride, //
|
|
588
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
589
|
+
|
|
590
|
+
__mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
|
|
591
|
+
__m512i zero = _mm512_setzero_si512();
|
|
592
|
+
|
|
593
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
594
|
+
if (row_idx < valid_rows) {
|
|
595
|
+
// Load 32 E4M3 bytes with masking
|
|
596
|
+
__m256i e4m3_row = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
|
|
597
|
+
// Convert to 32 BF16 values
|
|
598
|
+
__m512i bf16_row = nk_e4m3x32_to_bf16x32_icelake_(e4m3_row);
|
|
599
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row);
|
|
600
|
+
}
|
|
601
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
|
|
602
|
+
}
|
|
603
|
+
nk_compiler_barrier_sapphireamx_();
|
|
604
|
+
}
|
|
605
|
+
|
|
606
|
+
/* Load E5M2 A tile with FP8 to BF16 conversion */
|
|
607
|
+
NK_INTERNAL void nk_dots_e5m2_load_a_sapphireamx_( //
|
|
608
|
+
nk_dots_bf16_a16x32_sapphireamx_t *a_tile, //
|
|
609
|
+
nk_e5m2_t const *src, nk_size_t src_stride, //
|
|
610
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
611
|
+
|
|
612
|
+
__mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
|
|
613
|
+
__m512i zero = _mm512_setzero_si512();
|
|
614
|
+
|
|
615
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
616
|
+
if (row_idx < valid_rows) {
|
|
617
|
+
__m256i e5m2_row = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
|
|
618
|
+
__m512i bf16_row = nk_e5m2x32_to_bf16x32_icelake_(e5m2_row);
|
|
619
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row);
|
|
620
|
+
}
|
|
621
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
|
|
622
|
+
}
|
|
623
|
+
nk_compiler_barrier_sapphireamx_();
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
/* Pack A transposed into B format for BF16 */
|
|
627
|
+
NK_INTERNAL void nk_dots_pack_bf16_transposed_sapphireamx_( //
|
|
628
|
+
nk_dots_bf16_a16x32_sapphireamx_t const *a_tile, //
|
|
629
|
+
nk_dots_bf16_b32x16_sapphireamx_t *b_tile) {
|
|
630
|
+
|
|
631
|
+
// Load all 16 rows - each row is 32 BF16 = 64 bytes = 1 ZMM
|
|
632
|
+
// Treat as 16 × 32-bit elements per row (each 32-bit = pair of BF16)
|
|
633
|
+
__m512i row00 = _mm512_load_si512(&a_tile->data[0][0]);
|
|
634
|
+
__m512i row01 = _mm512_load_si512(&a_tile->data[1][0]);
|
|
635
|
+
__m512i row02 = _mm512_load_si512(&a_tile->data[2][0]);
|
|
636
|
+
__m512i row03 = _mm512_load_si512(&a_tile->data[3][0]);
|
|
637
|
+
__m512i row04 = _mm512_load_si512(&a_tile->data[4][0]);
|
|
638
|
+
__m512i row05 = _mm512_load_si512(&a_tile->data[5][0]);
|
|
639
|
+
__m512i row06 = _mm512_load_si512(&a_tile->data[6][0]);
|
|
640
|
+
__m512i row07 = _mm512_load_si512(&a_tile->data[7][0]);
|
|
641
|
+
__m512i row08 = _mm512_load_si512(&a_tile->data[8][0]);
|
|
642
|
+
__m512i row09 = _mm512_load_si512(&a_tile->data[9][0]);
|
|
643
|
+
__m512i row10 = _mm512_load_si512(&a_tile->data[10][0]);
|
|
644
|
+
__m512i row11 = _mm512_load_si512(&a_tile->data[11][0]);
|
|
645
|
+
__m512i row12 = _mm512_load_si512(&a_tile->data[12][0]);
|
|
646
|
+
__m512i row13 = _mm512_load_si512(&a_tile->data[13][0]);
|
|
647
|
+
__m512i row14 = _mm512_load_si512(&a_tile->data[14][0]);
|
|
648
|
+
__m512i row15 = _mm512_load_si512(&a_tile->data[15][0]);
|
|
649
|
+
|
|
650
|
+
// 16×16 transpose of 32-bit elements using hierarchical unpacks
|
|
651
|
+
// Stage 1: Unpack adjacent row pairs at 32-bit granularity
|
|
652
|
+
__m512i t01_lo = _mm512_unpacklo_epi32(row00, row01);
|
|
653
|
+
__m512i t01_hi = _mm512_unpackhi_epi32(row00, row01);
|
|
654
|
+
__m512i t23_lo = _mm512_unpacklo_epi32(row02, row03);
|
|
655
|
+
__m512i t23_hi = _mm512_unpackhi_epi32(row02, row03);
|
|
656
|
+
__m512i t45_lo = _mm512_unpacklo_epi32(row04, row05);
|
|
657
|
+
__m512i t45_hi = _mm512_unpackhi_epi32(row04, row05);
|
|
658
|
+
__m512i t67_lo = _mm512_unpacklo_epi32(row06, row07);
|
|
659
|
+
__m512i t67_hi = _mm512_unpackhi_epi32(row06, row07);
|
|
660
|
+
__m512i t89_lo = _mm512_unpacklo_epi32(row08, row09);
|
|
661
|
+
__m512i t89_hi = _mm512_unpackhi_epi32(row08, row09);
|
|
662
|
+
__m512i tab_lo = _mm512_unpacklo_epi32(row10, row11);
|
|
663
|
+
__m512i tab_hi = _mm512_unpackhi_epi32(row10, row11);
|
|
664
|
+
__m512i tcd_lo = _mm512_unpacklo_epi32(row12, row13);
|
|
665
|
+
__m512i tcd_hi = _mm512_unpackhi_epi32(row12, row13);
|
|
666
|
+
__m512i tef_lo = _mm512_unpacklo_epi32(row14, row15);
|
|
667
|
+
__m512i tef_hi = _mm512_unpackhi_epi32(row14, row15);
|
|
668
|
+
|
|
669
|
+
// Stage 2: Unpack at 64-bit granularity
|
|
670
|
+
__m512i u0123_ll = _mm512_unpacklo_epi64(t01_lo, t23_lo);
|
|
671
|
+
__m512i u0123_lh = _mm512_unpackhi_epi64(t01_lo, t23_lo);
|
|
672
|
+
__m512i u0123_hl = _mm512_unpacklo_epi64(t01_hi, t23_hi);
|
|
673
|
+
__m512i u0123_hh = _mm512_unpackhi_epi64(t01_hi, t23_hi);
|
|
674
|
+
__m512i u4567_ll = _mm512_unpacklo_epi64(t45_lo, t67_lo);
|
|
675
|
+
__m512i u4567_lh = _mm512_unpackhi_epi64(t45_lo, t67_lo);
|
|
676
|
+
__m512i u4567_hl = _mm512_unpacklo_epi64(t45_hi, t67_hi);
|
|
677
|
+
__m512i u4567_hh = _mm512_unpackhi_epi64(t45_hi, t67_hi);
|
|
678
|
+
__m512i u89ab_ll = _mm512_unpacklo_epi64(t89_lo, tab_lo);
|
|
679
|
+
__m512i u89ab_lh = _mm512_unpackhi_epi64(t89_lo, tab_lo);
|
|
680
|
+
__m512i u89ab_hl = _mm512_unpacklo_epi64(t89_hi, tab_hi);
|
|
681
|
+
__m512i u89ab_hh = _mm512_unpackhi_epi64(t89_hi, tab_hi);
|
|
682
|
+
__m512i ucdef_ll = _mm512_unpacklo_epi64(tcd_lo, tef_lo);
|
|
683
|
+
__m512i ucdef_lh = _mm512_unpackhi_epi64(tcd_lo, tef_lo);
|
|
684
|
+
__m512i ucdef_hl = _mm512_unpacklo_epi64(tcd_hi, tef_hi);
|
|
685
|
+
__m512i ucdef_hh = _mm512_unpackhi_epi64(tcd_hi, tef_hi);
|
|
686
|
+
|
|
687
|
+
// Stage 3: Shuffle 128-bit lanes using permute2x128 equivalent for 512-bit
|
|
688
|
+
// Use shuffle_i32x4 to move 128-bit chunks
|
|
689
|
+
__m512i v0_a = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0x88); // lanes 0,2 from each
|
|
690
|
+
__m512i v0_b = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0xDD); // lanes 1,3 from each
|
|
691
|
+
__m512i v1_a = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0x88);
|
|
692
|
+
__m512i v1_b = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0xDD);
|
|
693
|
+
__m512i v2_a = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0x88);
|
|
694
|
+
__m512i v2_b = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0xDD);
|
|
695
|
+
__m512i v3_a = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0x88);
|
|
696
|
+
__m512i v3_b = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0xDD);
|
|
697
|
+
__m512i v4_a = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0x88);
|
|
698
|
+
__m512i v4_b = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0xDD);
|
|
699
|
+
__m512i v5_a = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0x88);
|
|
700
|
+
__m512i v5_b = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0xDD);
|
|
701
|
+
__m512i v6_a = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0x88);
|
|
702
|
+
__m512i v6_b = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0xDD);
|
|
703
|
+
__m512i v7_a = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0x88);
|
|
704
|
+
__m512i v7_b = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0xDD);
|
|
705
|
+
|
|
706
|
+
// Stage 4: Final 256-bit shuffle to complete transpose
|
|
707
|
+
__m512i out00 = _mm512_shuffle_i32x4(v0_a, v4_a, 0x88);
|
|
708
|
+
__m512i out01 = _mm512_shuffle_i32x4(v1_a, v5_a, 0x88);
|
|
709
|
+
__m512i out02 = _mm512_shuffle_i32x4(v2_a, v6_a, 0x88);
|
|
710
|
+
__m512i out03 = _mm512_shuffle_i32x4(v3_a, v7_a, 0x88);
|
|
711
|
+
__m512i out04 = _mm512_shuffle_i32x4(v0_a, v4_a, 0xDD);
|
|
712
|
+
__m512i out05 = _mm512_shuffle_i32x4(v1_a, v5_a, 0xDD);
|
|
713
|
+
__m512i out06 = _mm512_shuffle_i32x4(v2_a, v6_a, 0xDD);
|
|
714
|
+
__m512i out07 = _mm512_shuffle_i32x4(v3_a, v7_a, 0xDD);
|
|
715
|
+
__m512i out08 = _mm512_shuffle_i32x4(v0_b, v4_b, 0x88);
|
|
716
|
+
__m512i out09 = _mm512_shuffle_i32x4(v1_b, v5_b, 0x88);
|
|
717
|
+
__m512i out10 = _mm512_shuffle_i32x4(v2_b, v6_b, 0x88);
|
|
718
|
+
__m512i out11 = _mm512_shuffle_i32x4(v3_b, v7_b, 0x88);
|
|
719
|
+
__m512i out12 = _mm512_shuffle_i32x4(v0_b, v4_b, 0xDD);
|
|
720
|
+
__m512i out13 = _mm512_shuffle_i32x4(v1_b, v5_b, 0xDD);
|
|
721
|
+
__m512i out14 = _mm512_shuffle_i32x4(v2_b, v6_b, 0xDD);
|
|
722
|
+
__m512i out15 = _mm512_shuffle_i32x4(v3_b, v7_b, 0xDD);
|
|
723
|
+
|
|
724
|
+
// Store transposed results - each output row is one depth_group
|
|
725
|
+
// Output layout: B.data[depth_group][column][pair] = 16 columns × 2 BF16 = 64 bytes
|
|
726
|
+
_mm512_store_si512(&b_tile->data[0][0][0], out00);
|
|
727
|
+
_mm512_store_si512(&b_tile->data[1][0][0], out01);
|
|
728
|
+
_mm512_store_si512(&b_tile->data[2][0][0], out02);
|
|
729
|
+
_mm512_store_si512(&b_tile->data[3][0][0], out03);
|
|
730
|
+
_mm512_store_si512(&b_tile->data[4][0][0], out08);
|
|
731
|
+
_mm512_store_si512(&b_tile->data[5][0][0], out09);
|
|
732
|
+
_mm512_store_si512(&b_tile->data[6][0][0], out10);
|
|
733
|
+
_mm512_store_si512(&b_tile->data[7][0][0], out11);
|
|
734
|
+
_mm512_store_si512(&b_tile->data[8][0][0], out04);
|
|
735
|
+
_mm512_store_si512(&b_tile->data[9][0][0], out05);
|
|
736
|
+
_mm512_store_si512(&b_tile->data[10][0][0], out06);
|
|
737
|
+
_mm512_store_si512(&b_tile->data[11][0][0], out07);
|
|
738
|
+
_mm512_store_si512(&b_tile->data[12][0][0], out12);
|
|
739
|
+
_mm512_store_si512(&b_tile->data[13][0][0], out13);
|
|
740
|
+
_mm512_store_si512(&b_tile->data[14][0][0], out14);
|
|
741
|
+
_mm512_store_si512(&b_tile->data[15][0][0], out15);
|
|
742
|
+
|
|
743
|
+
nk_compiler_barrier_sapphireamx_();
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
/* Pack A transposed into B format for INT8 */
|
|
747
|
+
NK_INTERNAL void nk_dots_pack_i8_transposed_sapphireamx_( //
|
|
748
|
+
nk_dots_i8_a16x64_sapphireamx_t const *a_tile, //
|
|
749
|
+
nk_dots_i8_b64x16_sapphireamx_t *b_tile) {
|
|
750
|
+
|
|
751
|
+
// Load all 16 rows - each row is 64 INT8 = 64 bytes = 1 ZMM
|
|
752
|
+
// Treat as 16 × 32-bit elements per row (each 32-bit = quad of INT8)
|
|
753
|
+
__m512i row00 = _mm512_load_si512(&a_tile->data[0][0]);
|
|
754
|
+
__m512i row01 = _mm512_load_si512(&a_tile->data[1][0]);
|
|
755
|
+
__m512i row02 = _mm512_load_si512(&a_tile->data[2][0]);
|
|
756
|
+
__m512i row03 = _mm512_load_si512(&a_tile->data[3][0]);
|
|
757
|
+
__m512i row04 = _mm512_load_si512(&a_tile->data[4][0]);
|
|
758
|
+
__m512i row05 = _mm512_load_si512(&a_tile->data[5][0]);
|
|
759
|
+
__m512i row06 = _mm512_load_si512(&a_tile->data[6][0]);
|
|
760
|
+
__m512i row07 = _mm512_load_si512(&a_tile->data[7][0]);
|
|
761
|
+
__m512i row08 = _mm512_load_si512(&a_tile->data[8][0]);
|
|
762
|
+
__m512i row09 = _mm512_load_si512(&a_tile->data[9][0]);
|
|
763
|
+
__m512i row10 = _mm512_load_si512(&a_tile->data[10][0]);
|
|
764
|
+
__m512i row11 = _mm512_load_si512(&a_tile->data[11][0]);
|
|
765
|
+
__m512i row12 = _mm512_load_si512(&a_tile->data[12][0]);
|
|
766
|
+
__m512i row13 = _mm512_load_si512(&a_tile->data[13][0]);
|
|
767
|
+
__m512i row14 = _mm512_load_si512(&a_tile->data[14][0]);
|
|
768
|
+
__m512i row15 = _mm512_load_si512(&a_tile->data[15][0]);
|
|
769
|
+
|
|
770
|
+
// 16×16 transpose of 32-bit elements using hierarchical unpacks
|
|
771
|
+
// Stage 1: Unpack adjacent row pairs at 32-bit granularity
|
|
772
|
+
__m512i t01_lo = _mm512_unpacklo_epi32(row00, row01);
|
|
773
|
+
__m512i t01_hi = _mm512_unpackhi_epi32(row00, row01);
|
|
774
|
+
__m512i t23_lo = _mm512_unpacklo_epi32(row02, row03);
|
|
775
|
+
__m512i t23_hi = _mm512_unpackhi_epi32(row02, row03);
|
|
776
|
+
__m512i t45_lo = _mm512_unpacklo_epi32(row04, row05);
|
|
777
|
+
__m512i t45_hi = _mm512_unpackhi_epi32(row04, row05);
|
|
778
|
+
__m512i t67_lo = _mm512_unpacklo_epi32(row06, row07);
|
|
779
|
+
__m512i t67_hi = _mm512_unpackhi_epi32(row06, row07);
|
|
780
|
+
__m512i t89_lo = _mm512_unpacklo_epi32(row08, row09);
|
|
781
|
+
__m512i t89_hi = _mm512_unpackhi_epi32(row08, row09);
|
|
782
|
+
__m512i tab_lo = _mm512_unpacklo_epi32(row10, row11);
|
|
783
|
+
__m512i tab_hi = _mm512_unpackhi_epi32(row10, row11);
|
|
784
|
+
__m512i tcd_lo = _mm512_unpacklo_epi32(row12, row13);
|
|
785
|
+
__m512i tcd_hi = _mm512_unpackhi_epi32(row12, row13);
|
|
786
|
+
__m512i tef_lo = _mm512_unpacklo_epi32(row14, row15);
|
|
787
|
+
__m512i tef_hi = _mm512_unpackhi_epi32(row14, row15);
|
|
788
|
+
|
|
789
|
+
// Stage 2: Unpack at 64-bit granularity
|
|
790
|
+
__m512i u0123_ll = _mm512_unpacklo_epi64(t01_lo, t23_lo);
|
|
791
|
+
__m512i u0123_lh = _mm512_unpackhi_epi64(t01_lo, t23_lo);
|
|
792
|
+
__m512i u0123_hl = _mm512_unpacklo_epi64(t01_hi, t23_hi);
|
|
793
|
+
__m512i u0123_hh = _mm512_unpackhi_epi64(t01_hi, t23_hi);
|
|
794
|
+
__m512i u4567_ll = _mm512_unpacklo_epi64(t45_lo, t67_lo);
|
|
795
|
+
__m512i u4567_lh = _mm512_unpackhi_epi64(t45_lo, t67_lo);
|
|
796
|
+
__m512i u4567_hl = _mm512_unpacklo_epi64(t45_hi, t67_hi);
|
|
797
|
+
__m512i u4567_hh = _mm512_unpackhi_epi64(t45_hi, t67_hi);
|
|
798
|
+
__m512i u89ab_ll = _mm512_unpacklo_epi64(t89_lo, tab_lo);
|
|
799
|
+
__m512i u89ab_lh = _mm512_unpackhi_epi64(t89_lo, tab_lo);
|
|
800
|
+
__m512i u89ab_hl = _mm512_unpacklo_epi64(t89_hi, tab_hi);
|
|
801
|
+
__m512i u89ab_hh = _mm512_unpackhi_epi64(t89_hi, tab_hi);
|
|
802
|
+
__m512i ucdef_ll = _mm512_unpacklo_epi64(tcd_lo, tef_lo);
|
|
803
|
+
__m512i ucdef_lh = _mm512_unpackhi_epi64(tcd_lo, tef_lo);
|
|
804
|
+
__m512i ucdef_hl = _mm512_unpacklo_epi64(tcd_hi, tef_hi);
|
|
805
|
+
__m512i ucdef_hh = _mm512_unpackhi_epi64(tcd_hi, tef_hi);
|
|
806
|
+
|
|
807
|
+
// Stage 3: Shuffle 128-bit lanes
|
|
808
|
+
__m512i v0_a = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0x88);
|
|
809
|
+
__m512i v0_b = _mm512_shuffle_i32x4(u0123_ll, u4567_ll, 0xDD);
|
|
810
|
+
__m512i v1_a = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0x88);
|
|
811
|
+
__m512i v1_b = _mm512_shuffle_i32x4(u0123_lh, u4567_lh, 0xDD);
|
|
812
|
+
__m512i v2_a = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0x88);
|
|
813
|
+
__m512i v2_b = _mm512_shuffle_i32x4(u0123_hl, u4567_hl, 0xDD);
|
|
814
|
+
__m512i v3_a = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0x88);
|
|
815
|
+
__m512i v3_b = _mm512_shuffle_i32x4(u0123_hh, u4567_hh, 0xDD);
|
|
816
|
+
__m512i v4_a = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0x88);
|
|
817
|
+
__m512i v4_b = _mm512_shuffle_i32x4(u89ab_ll, ucdef_ll, 0xDD);
|
|
818
|
+
__m512i v5_a = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0x88);
|
|
819
|
+
__m512i v5_b = _mm512_shuffle_i32x4(u89ab_lh, ucdef_lh, 0xDD);
|
|
820
|
+
__m512i v6_a = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0x88);
|
|
821
|
+
__m512i v6_b = _mm512_shuffle_i32x4(u89ab_hl, ucdef_hl, 0xDD);
|
|
822
|
+
__m512i v7_a = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0x88);
|
|
823
|
+
__m512i v7_b = _mm512_shuffle_i32x4(u89ab_hh, ucdef_hh, 0xDD);
|
|
824
|
+
|
|
825
|
+
// Stage 4: Final 256-bit shuffle to complete transpose
|
|
826
|
+
__m512i out00 = _mm512_shuffle_i32x4(v0_a, v4_a, 0x88);
|
|
827
|
+
__m512i out01 = _mm512_shuffle_i32x4(v1_a, v5_a, 0x88);
|
|
828
|
+
__m512i out02 = _mm512_shuffle_i32x4(v2_a, v6_a, 0x88);
|
|
829
|
+
__m512i out03 = _mm512_shuffle_i32x4(v3_a, v7_a, 0x88);
|
|
830
|
+
__m512i out04 = _mm512_shuffle_i32x4(v0_a, v4_a, 0xDD);
|
|
831
|
+
__m512i out05 = _mm512_shuffle_i32x4(v1_a, v5_a, 0xDD);
|
|
832
|
+
__m512i out06 = _mm512_shuffle_i32x4(v2_a, v6_a, 0xDD);
|
|
833
|
+
__m512i out07 = _mm512_shuffle_i32x4(v3_a, v7_a, 0xDD);
|
|
834
|
+
__m512i out08 = _mm512_shuffle_i32x4(v0_b, v4_b, 0x88);
|
|
835
|
+
__m512i out09 = _mm512_shuffle_i32x4(v1_b, v5_b, 0x88);
|
|
836
|
+
__m512i out10 = _mm512_shuffle_i32x4(v2_b, v6_b, 0x88);
|
|
837
|
+
__m512i out11 = _mm512_shuffle_i32x4(v3_b, v7_b, 0x88);
|
|
838
|
+
__m512i out12 = _mm512_shuffle_i32x4(v0_b, v4_b, 0xDD);
|
|
839
|
+
__m512i out13 = _mm512_shuffle_i32x4(v1_b, v5_b, 0xDD);
|
|
840
|
+
__m512i out14 = _mm512_shuffle_i32x4(v2_b, v6_b, 0xDD);
|
|
841
|
+
__m512i out15 = _mm512_shuffle_i32x4(v3_b, v7_b, 0xDD);
|
|
842
|
+
|
|
843
|
+
// Store transposed results - each output row is one depth_group
|
|
844
|
+
// Output layout: B.data[depth_group][column][quad] = 16 columns × 4 INT8 = 64 bytes
|
|
845
|
+
_mm512_store_si512(&b_tile->data[0][0][0], out00);
|
|
846
|
+
_mm512_store_si512(&b_tile->data[1][0][0], out01);
|
|
847
|
+
_mm512_store_si512(&b_tile->data[2][0][0], out02);
|
|
848
|
+
_mm512_store_si512(&b_tile->data[3][0][0], out03);
|
|
849
|
+
_mm512_store_si512(&b_tile->data[4][0][0], out08);
|
|
850
|
+
_mm512_store_si512(&b_tile->data[5][0][0], out09);
|
|
851
|
+
_mm512_store_si512(&b_tile->data[6][0][0], out10);
|
|
852
|
+
_mm512_store_si512(&b_tile->data[7][0][0], out11);
|
|
853
|
+
_mm512_store_si512(&b_tile->data[8][0][0], out04);
|
|
854
|
+
_mm512_store_si512(&b_tile->data[9][0][0], out05);
|
|
855
|
+
_mm512_store_si512(&b_tile->data[10][0][0], out06);
|
|
856
|
+
_mm512_store_si512(&b_tile->data[11][0][0], out07);
|
|
857
|
+
_mm512_store_si512(&b_tile->data[12][0][0], out12);
|
|
858
|
+
_mm512_store_si512(&b_tile->data[13][0][0], out13);
|
|
859
|
+
_mm512_store_si512(&b_tile->data[14][0][0], out14);
|
|
860
|
+
_mm512_store_si512(&b_tile->data[15][0][0], out15);
|
|
861
|
+
|
|
862
|
+
nk_compiler_barrier_sapphireamx_();
|
|
863
|
+
}
|
|
864
|
+
|
|
865
|
+
#pragma region Half Precision Floats
|
|
866
|
+
|
|
867
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_bf16_sapphireamx(nk_size_t column_count, nk_size_t depth) {
|
|
868
|
+
nk_size_t const tmm_rows = 16;
|
|
869
|
+
nk_size_t const tmm_cols = 32;
|
|
870
|
+
nk_size_t const tile_bytes = 512 * sizeof(nk_bf16_t); // 16 × 32 × 2 = 1KB
|
|
871
|
+
|
|
872
|
+
nk_size_t const full_column_tiles = column_count / tmm_rows;
|
|
873
|
+
nk_size_t const tiles_along_depth = nk_size_divide_round_up_(depth, tmm_cols);
|
|
874
|
+
nk_size_t const column_remainder_count = column_count - full_column_tiles * tmm_rows;
|
|
875
|
+
|
|
876
|
+
// Header (64 bytes aligned)
|
|
877
|
+
nk_size_t size = sizeof(nk_dots_amx_packed_header_t);
|
|
878
|
+
|
|
879
|
+
// All tiles for full column rows (Morton-ordered, pair-interleaved, depth remainder zero-padded)
|
|
880
|
+
size += full_column_tiles * tiles_along_depth * tile_bytes;
|
|
881
|
+
|
|
882
|
+
// Column edge: remaining rows for ALL depth columns, stored row-major
|
|
883
|
+
if (column_remainder_count > 0) size += column_remainder_count * depth * sizeof(nk_bf16_t);
|
|
884
|
+
|
|
885
|
+
// Per-column norms for angular/euclidean distance (4 bytes each: f32 or u32)
|
|
886
|
+
size += column_count * sizeof(nk_f32_t);
|
|
887
|
+
|
|
888
|
+
return size;
|
|
889
|
+
}
|
|
890
|
+
|
|
891
|
+
NK_PUBLIC void nk_dots_pack_bf16_sapphireamx( //
|
|
892
|
+
nk_bf16_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
893
|
+
nk_size_t b_stride, void *b_packed) {
|
|
894
|
+
|
|
895
|
+
// AMX BF16 tile dimensions: 16 rows × 32 columns (512 BF16 elements = 1KB)
|
|
896
|
+
nk_size_t const tmm_rows = 16;
|
|
897
|
+
nk_size_t const tmm_cols = 32;
|
|
898
|
+
nk_size_t const tile_elements = 512;
|
|
899
|
+
nk_size_t const tile_bytes = tile_elements * sizeof(nk_bf16_t);
|
|
900
|
+
nk_size_t const b_stride_elements = b_stride / sizeof(nk_bf16_t);
|
|
901
|
+
|
|
902
|
+
// Compute layout dimensions
|
|
903
|
+
nk_size_t const column_tiles_count = column_count / tmm_rows;
|
|
904
|
+
nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
|
|
905
|
+
nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
|
|
906
|
+
nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
|
|
907
|
+
|
|
908
|
+
// Write header with layout metadata
|
|
909
|
+
nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
|
|
910
|
+
header->full_column_tiles = (nk_u32_t)column_tiles_count;
|
|
911
|
+
header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
|
|
912
|
+
header->column_remainder_count = (nk_u32_t)column_remainder_count;
|
|
913
|
+
|
|
914
|
+
// Compute memory region offsets
|
|
915
|
+
nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
|
|
916
|
+
nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
|
|
917
|
+
header->column_edge_offset = (nk_u32_t)column_edge_offset;
|
|
918
|
+
|
|
919
|
+
// Pointers to packed data regions
|
|
920
|
+
nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
|
|
921
|
+
nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
|
|
922
|
+
|
|
923
|
+
// Zero-initialize all tiles (handles depth remainder padding)
|
|
924
|
+
for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
|
|
925
|
+
|
|
926
|
+
// Pack tiles using LINEAR ordering: tile_index = column_tile × depth_tiles_count + depth_tile
|
|
927
|
+
// This provides sequential memory access when streaming along depth dimension,
|
|
928
|
+
// which is critical for cache efficiency in the compute kernel.
|
|
929
|
+
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
930
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
931
|
+
|
|
932
|
+
// Linear tile index: all depth-tiles for one column-tile are contiguous
|
|
933
|
+
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
934
|
+
nk_bf16_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
935
|
+
|
|
936
|
+
// Source coordinates in original B matrix
|
|
937
|
+
nk_size_t const src_row_start = column_tile_idx * tmm_rows;
|
|
938
|
+
nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
|
|
939
|
+
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
940
|
+
: (depth - src_column_start);
|
|
941
|
+
|
|
942
|
+
// Pack with pair-interleaving as required by TDPBF16PS instruction.
|
|
943
|
+
// AMX expects: [col0_row0, col1_row0, col0_row1, col1_row1, col2_row0, col3_row0, ...]
|
|
944
|
+
// Formula: packed_idx = (column / 2) × 32 + row × 2 + (column % 2)
|
|
945
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
946
|
+
for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
|
|
947
|
+
nk_size_t const src_idx = (src_row_start + row_idx) * b_stride_elements + src_column_start +
|
|
948
|
+
column_idx;
|
|
949
|
+
nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
|
|
950
|
+
tile_output[dst_idx] = b[src_idx];
|
|
951
|
+
}
|
|
952
|
+
}
|
|
953
|
+
}
|
|
954
|
+
}
|
|
955
|
+
|
|
956
|
+
// Pack column-remainder rows in simple row-major format (for AVX-512 fallback)
|
|
957
|
+
if (column_remainder_count > 0) {
|
|
958
|
+
nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
|
|
959
|
+
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
960
|
+
for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
|
|
961
|
+
column_edge_ptr[row_idx * depth + column_idx] =
|
|
962
|
+
b[(remainder_start_row + row_idx) * b_stride_elements + column_idx];
|
|
963
|
+
}
|
|
964
|
+
}
|
|
965
|
+
}
|
|
966
|
+
|
|
967
|
+
// Compute and store per-column norms for angular/euclidean distance
|
|
968
|
+
nk_size_t norms_offset = column_edge_offset +
|
|
969
|
+
(column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_bf16_t) : 0);
|
|
970
|
+
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
971
|
+
nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
|
|
972
|
+
for (nk_size_t col = 0; col < column_count; col++)
|
|
973
|
+
norms[col] = nk_dots_reduce_sumsq_bf16_(b + col * b_stride_elements, depth);
|
|
974
|
+
}
|
|
975
|
+
|
|
976
|
+
NK_PUBLIC void nk_dots_packed_bf16_sapphireamx( //
|
|
977
|
+
nk_bf16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
978
|
+
nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
|
|
979
|
+
nk_unused_(cols_count);
|
|
980
|
+
|
|
981
|
+
// Parse packed B header
|
|
982
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
983
|
+
nk_size_t const column_tiles_count = header->full_column_tiles;
|
|
984
|
+
nk_size_t const depth_tiles_count = header->full_depth_tiles;
|
|
985
|
+
nk_size_t const column_remainder_count = header->column_remainder_count;
|
|
986
|
+
|
|
987
|
+
// Packed B data regions
|
|
988
|
+
nk_bf16_t const *b_tiles_base = (nk_bf16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
|
|
989
|
+
nk_bf16_t const *col_edge_ptr = (nk_bf16_t const *)((char const *)b_packed + header->column_edge_offset);
|
|
990
|
+
|
|
991
|
+
// Stride conversions
|
|
992
|
+
nk_size_t const a_stride_elements = a_stride_bytes / sizeof(nk_bf16_t);
|
|
993
|
+
nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
|
|
994
|
+
|
|
995
|
+
// Tile dimensions
|
|
996
|
+
nk_size_t const tile_depth = 32; // depth elements per BF16 tile
|
|
997
|
+
nk_size_t const tile_size = 512; // elements per packed tile
|
|
998
|
+
nk_size_t const full_cols = column_tiles_count * 16;
|
|
999
|
+
|
|
1000
|
+
// Block counts (32 × 32 output blocks = 2 × 2 tiles)
|
|
1001
|
+
nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
|
|
1002
|
+
nk_size_t const col_blocks_count = column_tiles_count / 2;
|
|
1003
|
+
|
|
1004
|
+
if (depth_tiles_count == 0) return;
|
|
1005
|
+
|
|
1006
|
+
// Tile buffers for A (only used for edge tiles)
|
|
1007
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tile_upper, a_tile_lower;
|
|
1008
|
+
nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
|
|
1009
|
+
|
|
1010
|
+
// Precompute: number of full depth-tiles (no masking needed)
|
|
1011
|
+
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
1012
|
+
nk_size_t const depth_remainder = depth % tile_depth;
|
|
1013
|
+
|
|
1014
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
1015
|
+
|
|
1016
|
+
// Loop order: row_blocks outer, col_blocks inner - maximizes A tile L2 cache reuse
|
|
1017
|
+
// A tiles stay in L2 while we sweep through all col_blocks for a given row_block
|
|
1018
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
1019
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
1020
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
1021
|
+
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
1022
|
+
|
|
1023
|
+
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
1024
|
+
nk_size_t const col_block_start = column_block_idx * 32;
|
|
1025
|
+
nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
|
|
1026
|
+
nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
|
|
1027
|
+
|
|
1028
|
+
// Zero accumulators (TMM4-7 stay resident across entire depth loop)
|
|
1029
|
+
_tile_zero(4);
|
|
1030
|
+
_tile_zero(5);
|
|
1031
|
+
_tile_zero(6);
|
|
1032
|
+
_tile_zero(7);
|
|
1033
|
+
|
|
1034
|
+
// Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
|
|
1035
|
+
if (is_full_row_block && full_depth_tiles_count > 0) {
|
|
1036
|
+
nk_bf16_t const *a_upper_base = a + row_block_start * a_stride_elements;
|
|
1037
|
+
nk_bf16_t const *a_lower_base = a + (row_block_start + 16) * a_stride_elements;
|
|
1038
|
+
|
|
1039
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
|
|
1040
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
1041
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
|
|
1042
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
1043
|
+
|
|
1044
|
+
// Prologue: load first depth tile
|
|
1045
|
+
_tile_loadd(0, a_upper_base, a_stride_bytes);
|
|
1046
|
+
_tile_loadd(1, a_lower_base, a_stride_bytes);
|
|
1047
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
1048
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
1049
|
+
|
|
1050
|
+
// Main loop: 2-deep software pipelining
|
|
1051
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < full_depth_tiles_count - 1; depth_tile_idx++) {
|
|
1052
|
+
nk_size_t const next_depth_offset = (depth_tile_idx + 1) * tile_depth;
|
|
1053
|
+
|
|
1054
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
1055
|
+
_tile_dpbf16ps(5, 0, 3);
|
|
1056
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
1057
|
+
_tile_dpbf16ps(7, 1, 3);
|
|
1058
|
+
|
|
1059
|
+
_tile_loadd(0, a_upper_base + next_depth_offset, a_stride_bytes);
|
|
1060
|
+
_tile_loadd(1, a_lower_base + next_depth_offset, a_stride_bytes);
|
|
1061
|
+
b_tile_left = (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
|
|
1062
|
+
depth_tile_idx + 1) *
|
|
1063
|
+
tile_size);
|
|
1064
|
+
b_tile_right = (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + (b_column_right_base +
|
|
1065
|
+
depth_tile_idx + 1) *
|
|
1066
|
+
tile_size);
|
|
1067
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
1068
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
1069
|
+
}
|
|
1070
|
+
|
|
1071
|
+
// Epilogue: final depth tile
|
|
1072
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
1073
|
+
_tile_dpbf16ps(5, 0, 3);
|
|
1074
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
1075
|
+
_tile_dpbf16ps(7, 1, 3);
|
|
1076
|
+
|
|
1077
|
+
// Handle partial depth-tile (if any)
|
|
1078
|
+
if (depth_remainder > 0) {
|
|
1079
|
+
nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
|
|
1080
|
+
|
|
1081
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper, a_upper_base + depth_offset, a_stride_elements, 16,
|
|
1082
|
+
depth_remainder);
|
|
1083
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower, a_lower_base + depth_offset, a_stride_elements, 16,
|
|
1084
|
+
depth_remainder);
|
|
1085
|
+
|
|
1086
|
+
b_tile_left = (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
|
|
1087
|
+
full_depth_tiles_count) *
|
|
1088
|
+
tile_size);
|
|
1089
|
+
b_tile_right = (nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + (b_column_right_base +
|
|
1090
|
+
full_depth_tiles_count) *
|
|
1091
|
+
tile_size);
|
|
1092
|
+
|
|
1093
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
1094
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
1095
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
1096
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
1097
|
+
|
|
1098
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
1099
|
+
_tile_dpbf16ps(5, 0, 3);
|
|
1100
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
1101
|
+
_tile_dpbf16ps(7, 1, 3);
|
|
1102
|
+
}
|
|
1103
|
+
}
|
|
1104
|
+
// Full row-block but only partial depth tile (depth < tile_depth)
|
|
1105
|
+
else if (is_full_row_block) {
|
|
1106
|
+
nk_bf16_t const *a_upper_base = a + row_block_start * a_stride_elements;
|
|
1107
|
+
nk_bf16_t const *a_lower_base = a + (row_block_start + 16) * a_stride_elements;
|
|
1108
|
+
|
|
1109
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper, a_upper_base, a_stride_elements, 16, depth_remainder);
|
|
1110
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower, a_lower_base, a_stride_elements, 16, depth_remainder);
|
|
1111
|
+
|
|
1112
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
|
|
1113
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
1114
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
|
|
1115
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
1116
|
+
|
|
1117
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
1118
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
1119
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
1120
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
1121
|
+
|
|
1122
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
1123
|
+
_tile_dpbf16ps(5, 0, 3);
|
|
1124
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
1125
|
+
_tile_dpbf16ps(7, 1, 3);
|
|
1126
|
+
}
|
|
1127
|
+
// Slow path: edge row-block → buffered load with masking
|
|
1128
|
+
else {
|
|
1129
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1130
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1131
|
+
|
|
1132
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
1133
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1134
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
|
|
1135
|
+
: depth_remainder;
|
|
1136
|
+
|
|
1137
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper,
|
|
1138
|
+
a + row_block_start * a_stride_elements + depth_offset,
|
|
1139
|
+
a_stride_elements, rows_in_upper_tile, valid_depth);
|
|
1140
|
+
if (rows_in_lower_tile > 0) {
|
|
1141
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower,
|
|
1142
|
+
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
1143
|
+
a_stride_elements, rows_in_lower_tile, valid_depth);
|
|
1144
|
+
}
|
|
1145
|
+
|
|
1146
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
|
|
1147
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
1148
|
+
(b_column_left_base + depth_tile_idx) * tile_size);
|
|
1149
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
|
|
1150
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
1151
|
+
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
1152
|
+
|
|
1153
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
1154
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
1155
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
1156
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
1157
|
+
|
|
1158
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
1159
|
+
_tile_dpbf16ps(5, 0, 3);
|
|
1160
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
1161
|
+
_tile_dpbf16ps(7, 1, 3);
|
|
1162
|
+
}
|
|
1163
|
+
}
|
|
1164
|
+
|
|
1165
|
+
// Store accumulators to output (once per output block)
|
|
1166
|
+
if (is_full_row_block) {
|
|
1167
|
+
nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
|
|
1168
|
+
_tile_stored(4, c_block, c_stride_bytes);
|
|
1169
|
+
_tile_stored(5, c_block + 16, c_stride_bytes);
|
|
1170
|
+
_tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
|
|
1171
|
+
_tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
|
|
1172
|
+
}
|
|
1173
|
+
else {
|
|
1174
|
+
_tile_stored(4, c_accum_buffer.c[0][0].data, 64);
|
|
1175
|
+
_tile_stored(5, c_accum_buffer.c[0][1].data, 64);
|
|
1176
|
+
_tile_stored(6, c_accum_buffer.c[1][0].data, 64);
|
|
1177
|
+
_tile_stored(7, c_accum_buffer.c[1][1].data, 64);
|
|
1178
|
+
nk_dots_bf16_output2x2_sapphireamx_(&c_accum_buffer,
|
|
1179
|
+
c + row_block_start * c_stride_elements + col_block_start,
|
|
1180
|
+
c_stride_elements, valid_rows_count, 32);
|
|
1181
|
+
}
|
|
1182
|
+
}
|
|
1183
|
+
}
|
|
1184
|
+
|
|
1185
|
+
// Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
|
|
1186
|
+
if (column_tiles_count % 2 == 1) {
|
|
1187
|
+
nk_size_t const column_tile_idx = column_tiles_count - 1;
|
|
1188
|
+
nk_size_t const col_start = column_tile_idx * 16;
|
|
1189
|
+
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
1190
|
+
|
|
1191
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
1192
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
1193
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
|
|
1194
|
+
: (rows_count - row_block_start);
|
|
1195
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1196
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1197
|
+
|
|
1198
|
+
nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
1199
|
+
|
|
1200
|
+
_tile_zero(4);
|
|
1201
|
+
_tile_zero(6);
|
|
1202
|
+
|
|
1203
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
1204
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1205
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
1206
|
+
|
|
1207
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_elements + depth_offset,
|
|
1208
|
+
a_stride_elements, rows_in_upper_tile, valid_depth);
|
|
1209
|
+
if (rows_in_lower_tile > 0) {
|
|
1210
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower,
|
|
1211
|
+
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
1212
|
+
a_stride_elements, rows_in_lower_tile, valid_depth);
|
|
1213
|
+
}
|
|
1214
|
+
|
|
1215
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
|
|
1216
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
1217
|
+
(b_column_base + depth_tile_idx) * tile_size);
|
|
1218
|
+
|
|
1219
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
1220
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
1221
|
+
_tile_loadd(2, b_tile->data, 64);
|
|
1222
|
+
|
|
1223
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
1224
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
1225
|
+
}
|
|
1226
|
+
|
|
1227
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
1228
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
1229
|
+
|
|
1230
|
+
nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
|
|
1231
|
+
c_stride_elements, rows_in_upper_tile, 16);
|
|
1232
|
+
if (rows_in_lower_tile > 0) {
|
|
1233
|
+
nk_dots_bf16_store_sapphireamx_(&c_lower_state,
|
|
1234
|
+
c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
1235
|
+
c_stride_elements, rows_in_lower_tile, 16);
|
|
1236
|
+
}
|
|
1237
|
+
}
|
|
1238
|
+
}
|
|
1239
|
+
|
|
1240
|
+
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
1241
|
+
if (column_remainder_count > 0) {
|
|
1242
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
1243
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
1244
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32
|
|
1245
|
+
: (rows_count - row_block_start);
|
|
1246
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1247
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1248
|
+
|
|
1249
|
+
nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
1250
|
+
nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
|
|
1251
|
+
nk_dots_bf16_b32x16_sapphireamx_t b_tile;
|
|
1252
|
+
|
|
1253
|
+
_tile_zero(4);
|
|
1254
|
+
_tile_zero(6);
|
|
1255
|
+
|
|
1256
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
1257
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1258
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
1259
|
+
|
|
1260
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_elements + depth_offset,
|
|
1261
|
+
a_stride_elements, rows_in_upper_tile, valid_depth);
|
|
1262
|
+
if (rows_in_lower_tile > 0) {
|
|
1263
|
+
nk_dots_bf16_load_a_sapphireamx_(&a_tile_lower,
|
|
1264
|
+
a + (row_block_start + 16) * a_stride_elements + depth_offset,
|
|
1265
|
+
a_stride_elements, rows_in_lower_tile, valid_depth);
|
|
1266
|
+
}
|
|
1267
|
+
|
|
1268
|
+
nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
|
|
1269
|
+
valid_depth);
|
|
1270
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
1271
|
+
|
|
1272
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
1273
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
1274
|
+
_tile_loadd(2, b_tile.data, 64);
|
|
1275
|
+
|
|
1276
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
1277
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
1278
|
+
}
|
|
1279
|
+
|
|
1280
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
1281
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
1282
|
+
|
|
1283
|
+
nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
|
|
1284
|
+
c_stride_elements, rows_in_upper_tile, column_remainder_count);
|
|
1285
|
+
if (rows_in_lower_tile > 0) {
|
|
1286
|
+
nk_dots_bf16_store_sapphireamx_(&c_lower_state,
|
|
1287
|
+
c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
1288
|
+
c_stride_elements, rows_in_lower_tile, column_remainder_count);
|
|
1289
|
+
}
|
|
1290
|
+
}
|
|
1291
|
+
}
|
|
1292
|
+
|
|
1293
|
+
_tile_release();
|
|
1294
|
+
}
|
|
1295
|
+
|
|
1296
|
+
NK_PUBLIC void nk_dots_compact_bf16_sapphireamx( //
|
|
1297
|
+
void *c, nk_size_t row_count, nk_size_t column_count, nk_size_t c_stride) {
|
|
1298
|
+
|
|
1299
|
+
nk_size_t const c_stride_f32 = c_stride / sizeof(nk_f32_t);
|
|
1300
|
+
nk_f32_t const *c_f32 = (nk_f32_t const *)c;
|
|
1301
|
+
nk_bf16_t *c_bf16 = (nk_bf16_t *)c;
|
|
1302
|
+
|
|
1303
|
+
for (nk_size_t row_idx = 0; row_idx < row_count; row_idx++) {
|
|
1304
|
+
nk_f32_t const *src_row = c_f32 + row_idx * c_stride_f32;
|
|
1305
|
+
nk_bf16_t *dst_row = c_bf16 + row_idx * column_count;
|
|
1306
|
+
nk_size_t column_idx = 0;
|
|
1307
|
+
|
|
1308
|
+
// Process 16 floats at a time using AVX512-BF16
|
|
1309
|
+
for (; column_idx + 16 <= column_count; column_idx += 16) {
|
|
1310
|
+
__m512 f32_vec = _mm512_loadu_ps(src_row + column_idx);
|
|
1311
|
+
__m256bh bf16_vec = _mm512_cvtneps_pbh(f32_vec);
|
|
1312
|
+
_mm256_storeu_si256((__m256i *)(dst_row + column_idx), nk_m256i_from_m256bh_(bf16_vec));
|
|
1313
|
+
}
|
|
1314
|
+
|
|
1315
|
+
// Handle remaining elements with masked operations
|
|
1316
|
+
if (column_idx < column_count) {
|
|
1317
|
+
__mmask16 tail_mask = (__mmask16)((1u << (column_count - column_idx)) - 1);
|
|
1318
|
+
__m512 f32_vec = _mm512_maskz_loadu_ps(tail_mask, src_row + column_idx);
|
|
1319
|
+
__m256bh bf16_vec = _mm512_cvtneps_pbh(f32_vec);
|
|
1320
|
+
_mm256_mask_storeu_epi16(dst_row + column_idx, tail_mask, nk_m256i_from_m256bh_(bf16_vec));
|
|
1321
|
+
}
|
|
1322
|
+
}
|
|
1323
|
+
}
|
|
1324
|
+
|
|
1325
|
+
NK_PUBLIC void nk_dots_symmetric_bf16_sapphireamx( //
|
|
1326
|
+
nk_bf16_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
|
|
1327
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
|
|
1328
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1329
|
+
|
|
1330
|
+
nk_size_t const stride_elements = stride / sizeof(nk_bf16_t);
|
|
1331
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
1332
|
+
|
|
1333
|
+
// Handle row slicing: compute rows [row_start, row_end)
|
|
1334
|
+
nk_size_t const row_end = (row_count == 0)
|
|
1335
|
+
? n_vectors
|
|
1336
|
+
: (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
|
|
1337
|
+
|
|
1338
|
+
// Round depth up to multiple of 96 (3 tiles × 32 elements)
|
|
1339
|
+
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
|
|
1340
|
+
nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
|
|
1341
|
+
|
|
1342
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tiles[3];
|
|
1343
|
+
nk_dots_bf16_a16x32_sapphireamx_t b_src_tiles[3];
|
|
1344
|
+
nk_dots_bf16_b32x16_sapphireamx_t b_tiles[3];
|
|
1345
|
+
nk_dots_bf16_state_sapphireamx_t state;
|
|
1346
|
+
|
|
1347
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
1348
|
+
|
|
1349
|
+
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
1350
|
+
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
1351
|
+
|
|
1352
|
+
for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
|
|
1353
|
+
nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
|
|
1354
|
+
|
|
1355
|
+
nk_dots_bf16_init_sapphireamx_(&state);
|
|
1356
|
+
|
|
1357
|
+
for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
|
|
1358
|
+
nk_size_t const depth_base = depth_group_idx * 96;
|
|
1359
|
+
|
|
1360
|
+
for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
|
|
1361
|
+
nk_size_t const depth_start = depth_base + tile_idx * 32;
|
|
1362
|
+
nk_size_t const valid_depth = (depth_start + 32 <= depth)
|
|
1363
|
+
? 32
|
|
1364
|
+
: (depth > depth_start ? depth - depth_start : 0);
|
|
1365
|
+
|
|
1366
|
+
nk_dots_bf16_load_a_sapphireamx_( //
|
|
1367
|
+
&a_tiles[tile_idx], //
|
|
1368
|
+
vectors + row_tile * stride_elements + depth_start, //
|
|
1369
|
+
stride_elements, valid_rows, valid_depth);
|
|
1370
|
+
|
|
1371
|
+
if (row_tile == col_tile) {
|
|
1372
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
1373
|
+
}
|
|
1374
|
+
else {
|
|
1375
|
+
nk_dots_bf16_load_a_sapphireamx_( //
|
|
1376
|
+
&b_src_tiles[tile_idx], //
|
|
1377
|
+
vectors + col_tile * stride_elements + depth_start, //
|
|
1378
|
+
stride_elements, valid_cols, valid_depth);
|
|
1379
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
1380
|
+
}
|
|
1381
|
+
}
|
|
1382
|
+
|
|
1383
|
+
nk_dots_bf16_update_sapphireamx_( //
|
|
1384
|
+
&state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
|
|
1385
|
+
}
|
|
1386
|
+
|
|
1387
|
+
nk_dots_bf16_store_sapphireamx_( //
|
|
1388
|
+
&state, result + row_tile * result_stride_elements + col_tile, //
|
|
1389
|
+
result_stride_elements, valid_rows, valid_cols);
|
|
1390
|
+
}
|
|
1391
|
+
}
|
|
1392
|
+
}
|
|
1393
|
+
|
|
1394
|
+
#pragma endregion // Half Precision Floats
|
|
1395
|
+
|
|
1396
|
+
#pragma region Signed Integers
|
|
1397
|
+
|
|
1398
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_i8_sapphireamx(nk_size_t column_count, nk_size_t depth) {
|
|
1399
|
+
nk_size_t const tmm_rows = 16;
|
|
1400
|
+
nk_size_t const tmm_cols = 64;
|
|
1401
|
+
nk_size_t const tile_bytes = 1024 * sizeof(nk_i8_t); // 16 × 64×1 = 1KB
|
|
1402
|
+
|
|
1403
|
+
nk_size_t const full_column_tiles = column_count / tmm_rows;
|
|
1404
|
+
nk_size_t const tiles_along_depth = nk_size_divide_round_up_(depth, tmm_cols);
|
|
1405
|
+
nk_size_t const column_remainder_count = column_count - full_column_tiles * tmm_rows;
|
|
1406
|
+
|
|
1407
|
+
// Header (64 bytes aligned)
|
|
1408
|
+
nk_size_t size = sizeof(nk_dots_amx_packed_header_t);
|
|
1409
|
+
|
|
1410
|
+
// All tiles for full column rows (Morton-ordered, quad-interleaved, depth remainder zero-padded)
|
|
1411
|
+
size += full_column_tiles * tiles_along_depth * tile_bytes;
|
|
1412
|
+
|
|
1413
|
+
// Column edge: remaining rows for ALL depth columns, stored row-major
|
|
1414
|
+
if (column_remainder_count > 0) size += column_remainder_count * depth * sizeof(nk_i8_t);
|
|
1415
|
+
|
|
1416
|
+
// Per-column norms for angular/euclidean distance (4 bytes each: f32 or u32)
|
|
1417
|
+
size += column_count * sizeof(nk_u32_t);
|
|
1418
|
+
|
|
1419
|
+
return size;
|
|
1420
|
+
}
|
|
1421
|
+
|
|
1422
|
+
NK_PUBLIC void nk_dots_pack_i8_sapphireamx( //
|
|
1423
|
+
nk_i8_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
1424
|
+
nk_size_t b_stride, void *b_packed) {
|
|
1425
|
+
|
|
1426
|
+
// AMX I8 tile dimensions: 16 rows × 64 columns (1024 I8 elements = 1KB)
|
|
1427
|
+
nk_size_t const tmm_rows = 16;
|
|
1428
|
+
nk_size_t const tmm_cols = 64;
|
|
1429
|
+
nk_size_t const tile_elements = 1024;
|
|
1430
|
+
nk_size_t const tile_bytes = tile_elements * sizeof(nk_i8_t);
|
|
1431
|
+
|
|
1432
|
+
// Compute layout dimensions
|
|
1433
|
+
nk_size_t const column_tiles_count = column_count / tmm_rows;
|
|
1434
|
+
nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
|
|
1435
|
+
nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
|
|
1436
|
+
nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
|
|
1437
|
+
|
|
1438
|
+
// Write header with layout metadata
|
|
1439
|
+
nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
|
|
1440
|
+
header->full_column_tiles = (nk_u32_t)column_tiles_count;
|
|
1441
|
+
header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
|
|
1442
|
+
header->column_remainder_count = (nk_u32_t)column_remainder_count;
|
|
1443
|
+
|
|
1444
|
+
// Compute memory region offsets
|
|
1445
|
+
nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
|
|
1446
|
+
nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
|
|
1447
|
+
header->column_edge_offset = (nk_u32_t)column_edge_offset;
|
|
1448
|
+
|
|
1449
|
+
// Pointers to packed data regions
|
|
1450
|
+
nk_i8_t *tiles_ptr = (nk_i8_t *)((char *)b_packed + tiles_offset);
|
|
1451
|
+
nk_i8_t *column_edge_ptr = (nk_i8_t *)((char *)b_packed + column_edge_offset);
|
|
1452
|
+
|
|
1453
|
+
// Zero-initialize all tiles (handles depth remainder padding)
|
|
1454
|
+
for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
|
|
1455
|
+
|
|
1456
|
+
// Pack tiles using LINEAR ordering: tile_index = column_tile × depth_tiles_count + depth_tile
|
|
1457
|
+
// This provides sequential memory access when streaming along depth dimension.
|
|
1458
|
+
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
1459
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
1460
|
+
|
|
1461
|
+
// Linear tile index: all depth-tiles for one column-tile are contiguous
|
|
1462
|
+
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
1463
|
+
nk_i8_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
1464
|
+
|
|
1465
|
+
// Source coordinates in original B matrix
|
|
1466
|
+
nk_size_t const src_row_start = column_tile_idx * tmm_rows;
|
|
1467
|
+
nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
|
|
1468
|
+
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
1469
|
+
: (depth - src_column_start);
|
|
1470
|
+
|
|
1471
|
+
// Pack with quad-interleaving as required by TDPBSSD instruction.
|
|
1472
|
+
// AMX expects: [col0_row0, col1_row0, col2_row0, col3_row0, col0_row1, ...]
|
|
1473
|
+
// Formula: packed_idx = (column / 4) × 64 + row × 4 + (column % 4)
|
|
1474
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
1475
|
+
for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
|
|
1476
|
+
nk_size_t const src_idx = (src_row_start + row_idx) * b_stride + src_column_start + column_idx;
|
|
1477
|
+
nk_size_t const dst_idx = (column_idx / 4) * 64 + row_idx * 4 + (column_idx % 4);
|
|
1478
|
+
tile_output[dst_idx] = b[src_idx];
|
|
1479
|
+
}
|
|
1480
|
+
}
|
|
1481
|
+
}
|
|
1482
|
+
}
|
|
1483
|
+
|
|
1484
|
+
// Pack column-remainder rows in simple row-major format (for AVX-512 fallback)
|
|
1485
|
+
if (column_remainder_count > 0) {
|
|
1486
|
+
nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
|
|
1487
|
+
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
1488
|
+
for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
|
|
1489
|
+
column_edge_ptr[row_idx * depth + column_idx] =
|
|
1490
|
+
b[(remainder_start_row + row_idx) * b_stride + column_idx];
|
|
1491
|
+
}
|
|
1492
|
+
}
|
|
1493
|
+
}
|
|
1494
|
+
|
|
1495
|
+
// Compute and store per-column norms for angular/euclidean distance
|
|
1496
|
+
nk_size_t norms_offset = column_edge_offset +
|
|
1497
|
+
(column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_i8_t) : 0);
|
|
1498
|
+
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
1499
|
+
nk_u32_t *norms = (nk_u32_t *)((char *)b_packed + norms_offset);
|
|
1500
|
+
for (nk_size_t col = 0; col < column_count; col++) norms[col] = nk_dots_reduce_sumsq_i8_(b + col * b_stride, depth);
|
|
1501
|
+
}
|
|
1502
|
+
|
|
1503
|
+
NK_PUBLIC void nk_dots_packed_i8_sapphireamx( //
|
|
1504
|
+
nk_i8_t const *a, void const *b_packed, nk_i32_t *c, //
|
|
1505
|
+
nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
|
|
1506
|
+
nk_unused_(cols_count);
|
|
1507
|
+
|
|
1508
|
+
// Parse packed B header
|
|
1509
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
1510
|
+
nk_size_t const column_tiles_count = header->full_column_tiles;
|
|
1511
|
+
nk_size_t const depth_tiles_count = header->full_depth_tiles;
|
|
1512
|
+
nk_size_t const column_remainder_count = header->column_remainder_count;
|
|
1513
|
+
|
|
1514
|
+
// Packed B data regions
|
|
1515
|
+
nk_i8_t const *b_tiles_base = (nk_i8_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
|
|
1516
|
+
nk_i8_t const *col_edge_ptr = (nk_i8_t const *)((char const *)b_packed + header->column_edge_offset);
|
|
1517
|
+
|
|
1518
|
+
// Stride conversions
|
|
1519
|
+
nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_i32_t);
|
|
1520
|
+
|
|
1521
|
+
// Tile dimensions
|
|
1522
|
+
nk_size_t const tile_depth = 64; // depth elements per INT8 tile
|
|
1523
|
+
nk_size_t const tile_size = 1024; // bytes per packed tile
|
|
1524
|
+
nk_size_t const full_cols = column_tiles_count * 16;
|
|
1525
|
+
|
|
1526
|
+
// Block counts (32 × 32 output blocks = 2 × 2 tiles)
|
|
1527
|
+
nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
|
|
1528
|
+
nk_size_t const col_blocks_count = column_tiles_count / 2;
|
|
1529
|
+
|
|
1530
|
+
if (depth_tiles_count == 0) return;
|
|
1531
|
+
|
|
1532
|
+
// Tile buffers for A (only used for edge tiles)
|
|
1533
|
+
nk_dots_i8_a16x64_sapphireamx_t a_tile_upper, a_tile_lower;
|
|
1534
|
+
nk_dots_i8_state2x2_sapphireamx_t c_accum_buffer;
|
|
1535
|
+
|
|
1536
|
+
// Precompute: number of full depth-tiles (no masking needed)
|
|
1537
|
+
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
1538
|
+
nk_size_t const depth_remainder = depth % tile_depth;
|
|
1539
|
+
|
|
1540
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
1541
|
+
|
|
1542
|
+
// Process all 32 × 32 row × column blocks (including partial edge blocks)
|
|
1543
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
1544
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
1545
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
1546
|
+
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
1547
|
+
|
|
1548
|
+
// Process full column-blocks (pairs of 16-column tiles = 32 columns)
|
|
1549
|
+
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
1550
|
+
nk_size_t const col_block_start = column_block_idx * 32;
|
|
1551
|
+
|
|
1552
|
+
// B tile base indices (linear layout: col_tile × depth_tiles_count + depth_tile)
|
|
1553
|
+
nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
|
|
1554
|
+
nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
|
|
1555
|
+
|
|
1556
|
+
// Zero accumulators (TMM4-7 stay resident across entire depth loop)
|
|
1557
|
+
_tile_zero(4); // C[upper, left]
|
|
1558
|
+
_tile_zero(5); // C[upper, right]
|
|
1559
|
+
_tile_zero(6); // C[lower, left]
|
|
1560
|
+
_tile_zero(7); // C[lower, right]
|
|
1561
|
+
|
|
1562
|
+
// Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
|
|
1563
|
+
if (is_full_row_block && full_depth_tiles_count > 0) {
|
|
1564
|
+
// A row pointers for direct load
|
|
1565
|
+
nk_i8_t const *a_upper_base = a + row_block_start * a_stride_bytes;
|
|
1566
|
+
nk_i8_t const *a_lower_base = a + (row_block_start + 16) * a_stride_bytes;
|
|
1567
|
+
|
|
1568
|
+
// B tile pointers
|
|
1569
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
|
|
1570
|
+
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
1571
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_right =
|
|
1572
|
+
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
1573
|
+
|
|
1574
|
+
// Prologue: load first depth tile into TMM0-3
|
|
1575
|
+
_tile_loadd(0, a_upper_base, a_stride_bytes);
|
|
1576
|
+
_tile_loadd(1, a_lower_base, a_stride_bytes);
|
|
1577
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
1578
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
1579
|
+
|
|
1580
|
+
// Main loop: 2-deep software pipelining (compute current while loading next)
|
|
1581
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < full_depth_tiles_count - 1; depth_tile_idx++) {
|
|
1582
|
+
nk_size_t const next_depth_offset = (depth_tile_idx + 1) * tile_depth;
|
|
1583
|
+
|
|
1584
|
+
_tile_dpbssd(4, 0, 2);
|
|
1585
|
+
_tile_dpbssd(5, 0, 3);
|
|
1586
|
+
_tile_dpbssd(6, 1, 2);
|
|
1587
|
+
_tile_dpbssd(7, 1, 3);
|
|
1588
|
+
|
|
1589
|
+
_tile_loadd(0, a_upper_base + next_depth_offset, a_stride_bytes);
|
|
1590
|
+
_tile_loadd(1, a_lower_base + next_depth_offset, a_stride_bytes);
|
|
1591
|
+
b_tile_left = (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
1592
|
+
(b_column_left_base + depth_tile_idx + 1) *
|
|
1593
|
+
tile_size);
|
|
1594
|
+
b_tile_right = (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_right_base +
|
|
1595
|
+
depth_tile_idx + 1) *
|
|
1596
|
+
tile_size);
|
|
1597
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
1598
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
1599
|
+
}
|
|
1600
|
+
|
|
1601
|
+
// Epilogue: final depth tile (no next to load)
|
|
1602
|
+
_tile_dpbssd(4, 0, 2);
|
|
1603
|
+
_tile_dpbssd(5, 0, 3);
|
|
1604
|
+
_tile_dpbssd(6, 1, 2);
|
|
1605
|
+
_tile_dpbssd(7, 1, 3);
|
|
1606
|
+
|
|
1607
|
+
// Handle partial depth-tile (if any) with buffered load
|
|
1608
|
+
if (depth_remainder > 0) {
|
|
1609
|
+
nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
|
|
1610
|
+
|
|
1611
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a_upper_base + depth_offset, a_stride_bytes, 16,
|
|
1612
|
+
depth_remainder);
|
|
1613
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_lower, a_lower_base + depth_offset, a_stride_bytes, 16,
|
|
1614
|
+
depth_remainder);
|
|
1615
|
+
|
|
1616
|
+
b_tile_left = (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
|
|
1617
|
+
full_depth_tiles_count) *
|
|
1618
|
+
tile_size);
|
|
1619
|
+
b_tile_right = (nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_right_base +
|
|
1620
|
+
full_depth_tiles_count) *
|
|
1621
|
+
tile_size);
|
|
1622
|
+
|
|
1623
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
1624
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
1625
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
1626
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
1627
|
+
|
|
1628
|
+
_tile_dpbssd(4, 0, 2);
|
|
1629
|
+
_tile_dpbssd(5, 0, 3);
|
|
1630
|
+
_tile_dpbssd(6, 1, 2);
|
|
1631
|
+
_tile_dpbssd(7, 1, 3);
|
|
1632
|
+
}
|
|
1633
|
+
}
|
|
1634
|
+
// Full row-block but only partial depth tile (depth < tile_depth)
|
|
1635
|
+
else if (is_full_row_block) {
|
|
1636
|
+
nk_i8_t const *a_upper_base = a + row_block_start * a_stride_bytes;
|
|
1637
|
+
nk_i8_t const *a_lower_base = a + (row_block_start + 16) * a_stride_bytes;
|
|
1638
|
+
|
|
1639
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a_upper_base, a_stride_bytes, 16, depth_remainder);
|
|
1640
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_lower, a_lower_base, a_stride_bytes, 16, depth_remainder);
|
|
1641
|
+
|
|
1642
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
|
|
1643
|
+
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
1644
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_right =
|
|
1645
|
+
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
1646
|
+
|
|
1647
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
1648
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
1649
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
1650
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
1651
|
+
|
|
1652
|
+
_tile_dpbssd(4, 0, 2);
|
|
1653
|
+
_tile_dpbssd(5, 0, 3);
|
|
1654
|
+
_tile_dpbssd(6, 1, 2);
|
|
1655
|
+
_tile_dpbssd(7, 1, 3);
|
|
1656
|
+
}
|
|
1657
|
+
// Slow path: edge row-block → always use buffered load with masking
|
|
1658
|
+
else {
|
|
1659
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1660
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1661
|
+
|
|
1662
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
1663
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1664
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
|
|
1665
|
+
: depth_remainder;
|
|
1666
|
+
|
|
1667
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
1668
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
1669
|
+
if (rows_in_lower_tile > 0) {
|
|
1670
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_lower,
|
|
1671
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
1672
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
1673
|
+
}
|
|
1674
|
+
|
|
1675
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
|
|
1676
|
+
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
1677
|
+
(b_column_left_base + depth_tile_idx) * tile_size);
|
|
1678
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_right =
|
|
1679
|
+
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
1680
|
+
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
1681
|
+
|
|
1682
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
1683
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
1684
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
1685
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
1686
|
+
|
|
1687
|
+
_tile_dpbssd(4, 0, 2);
|
|
1688
|
+
_tile_dpbssd(5, 0, 3);
|
|
1689
|
+
_tile_dpbssd(6, 1, 2);
|
|
1690
|
+
_tile_dpbssd(7, 1, 3);
|
|
1691
|
+
}
|
|
1692
|
+
}
|
|
1693
|
+
|
|
1694
|
+
// Store accumulators to output (once per output block, not per depth tile)
|
|
1695
|
+
if (is_full_row_block) {
|
|
1696
|
+
nk_i32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
|
|
1697
|
+
_tile_stored(4, c_block, c_stride_bytes);
|
|
1698
|
+
_tile_stored(5, c_block + 16, c_stride_bytes);
|
|
1699
|
+
_tile_stored(6, (nk_i32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
|
|
1700
|
+
_tile_stored(7, (nk_i32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
|
|
1701
|
+
}
|
|
1702
|
+
else {
|
|
1703
|
+
// Slow path: edge row-block needs masked output
|
|
1704
|
+
_tile_stored(4, c_accum_buffer.c[0][0].data, 64);
|
|
1705
|
+
_tile_stored(5, c_accum_buffer.c[0][1].data, 64);
|
|
1706
|
+
_tile_stored(6, c_accum_buffer.c[1][0].data, 64);
|
|
1707
|
+
_tile_stored(7, c_accum_buffer.c[1][1].data, 64);
|
|
1708
|
+
nk_dots_i8_output2x2_sapphireamx_(&c_accum_buffer,
|
|
1709
|
+
c + row_block_start * c_stride_elements + col_block_start,
|
|
1710
|
+
c_stride_elements, valid_rows_count, 32);
|
|
1711
|
+
}
|
|
1712
|
+
}
|
|
1713
|
+
|
|
1714
|
+
// Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
|
|
1715
|
+
if (column_tiles_count % 2 == 1) {
|
|
1716
|
+
nk_size_t const column_tile_idx = column_tiles_count - 1;
|
|
1717
|
+
nk_size_t const col_start = column_tile_idx * 16;
|
|
1718
|
+
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
1719
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1720
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1721
|
+
|
|
1722
|
+
// Use 1 × 2 blocking for single column-tile (2 row-tiles × 1 column-tile)
|
|
1723
|
+
nk_dots_i8_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
1724
|
+
|
|
1725
|
+
_tile_zero(4);
|
|
1726
|
+
_tile_zero(6);
|
|
1727
|
+
|
|
1728
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
1729
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1730
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
1731
|
+
|
|
1732
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
1733
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
1734
|
+
if (rows_in_lower_tile > 0) {
|
|
1735
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_lower,
|
|
1736
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
1737
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
1738
|
+
}
|
|
1739
|
+
|
|
1740
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile =
|
|
1741
|
+
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
1742
|
+
(b_column_base + depth_tile_idx) * tile_size);
|
|
1743
|
+
|
|
1744
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
1745
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
1746
|
+
_tile_loadd(2, b_tile->data, 64);
|
|
1747
|
+
|
|
1748
|
+
_tile_dpbssd(4, 0, 2);
|
|
1749
|
+
_tile_dpbssd(6, 1, 2);
|
|
1750
|
+
}
|
|
1751
|
+
|
|
1752
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
1753
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
1754
|
+
|
|
1755
|
+
nk_dots_i8_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
|
|
1756
|
+
c_stride_elements, rows_in_upper_tile, 16);
|
|
1757
|
+
if (rows_in_lower_tile > 0) {
|
|
1758
|
+
nk_dots_i8_store_sapphireamx_(&c_lower_state,
|
|
1759
|
+
c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
1760
|
+
c_stride_elements, rows_in_lower_tile, 16);
|
|
1761
|
+
}
|
|
1762
|
+
}
|
|
1763
|
+
|
|
1764
|
+
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
1765
|
+
if (column_remainder_count > 0) {
|
|
1766
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
1767
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
1768
|
+
|
|
1769
|
+
nk_dots_i8_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
1770
|
+
nk_dots_i8_a16x64_sapphireamx_t b_as_a;
|
|
1771
|
+
nk_dots_i8_b64x16_sapphireamx_t b_tile;
|
|
1772
|
+
|
|
1773
|
+
_tile_zero(4);
|
|
1774
|
+
_tile_zero(6);
|
|
1775
|
+
|
|
1776
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
1777
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
1778
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
1779
|
+
|
|
1780
|
+
// Load A tiles
|
|
1781
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
1782
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
1783
|
+
if (rows_in_lower_tile > 0) {
|
|
1784
|
+
nk_dots_i8_load_a_sapphireamx_(&a_tile_lower,
|
|
1785
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
1786
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
1787
|
+
}
|
|
1788
|
+
|
|
1789
|
+
// Load B edge data (row-major: b_edge[row × depth + column]) and pack into B tile
|
|
1790
|
+
// Each "row" in edge data corresponds to one output column
|
|
1791
|
+
nk_dots_i8_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
|
|
1792
|
+
valid_depth);
|
|
1793
|
+
nk_dots_pack_i8_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
1794
|
+
|
|
1795
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
1796
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
1797
|
+
_tile_loadd(2, b_tile.data, 64);
|
|
1798
|
+
|
|
1799
|
+
_tile_dpbssd(4, 0, 2);
|
|
1800
|
+
_tile_dpbssd(6, 1, 2);
|
|
1801
|
+
}
|
|
1802
|
+
|
|
1803
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
1804
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
1805
|
+
|
|
1806
|
+
nk_dots_i8_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
|
|
1807
|
+
c_stride_elements, rows_in_upper_tile, column_remainder_count);
|
|
1808
|
+
if (rows_in_lower_tile > 0) {
|
|
1809
|
+
nk_dots_i8_store_sapphireamx_(&c_lower_state,
|
|
1810
|
+
c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
1811
|
+
c_stride_elements, rows_in_lower_tile, column_remainder_count);
|
|
1812
|
+
}
|
|
1813
|
+
}
|
|
1814
|
+
}
|
|
1815
|
+
|
|
1816
|
+
_tile_release();
|
|
1817
|
+
}
|
|
1818
|
+
|
|
1819
|
+
NK_PUBLIC void nk_dots_compact_i8_sapphireamx( //
|
|
1820
|
+
void *c, nk_size_t row_count, nk_size_t column_count, nk_size_t c_stride, nk_i32_t const *a_squared_norms,
|
|
1821
|
+
nk_i32_t const *b_squared_norms) {
|
|
1822
|
+
|
|
1823
|
+
nk_size_t const c_stride_i32 = c_stride / sizeof(nk_i32_t);
|
|
1824
|
+
nk_i32_t const *c_i32 = (nk_i32_t const *)c;
|
|
1825
|
+
nk_i8_t *c_i8 = (nk_i8_t *)c;
|
|
1826
|
+
|
|
1827
|
+
// Use space after I8 output for precomputed b_rsqrt (I8 output is 4x smaller than I32 input)
|
|
1828
|
+
nk_f32_t *b_rsqrt = (nk_f32_t *)(c_i8 + row_count * column_count);
|
|
1829
|
+
|
|
1830
|
+
// Precompute rsqrt of all b_norms using AVX512 (16 at a time)
|
|
1831
|
+
__m512 half_vec = _mm512_set1_ps(0.5f);
|
|
1832
|
+
__m512 three_halves_vec = _mm512_set1_ps(1.5f);
|
|
1833
|
+
nk_size_t column_idx = 0;
|
|
1834
|
+
|
|
1835
|
+
for (; column_idx + 16 <= column_count; column_idx += 16) {
|
|
1836
|
+
__m512i b_norms_i32 = _mm512_loadu_si512(b_squared_norms + column_idx);
|
|
1837
|
+
__m512 b_norms_f32 = _mm512_cvtepi32_ps(b_norms_i32);
|
|
1838
|
+
__m512 rsqrt_vec = _mm512_rsqrt14_ps(b_norms_f32);
|
|
1839
|
+
// Newton-Raphson refinement
|
|
1840
|
+
rsqrt_vec = _mm512_mul_ps(
|
|
1841
|
+
rsqrt_vec,
|
|
1842
|
+
_mm512_sub_ps(three_halves_vec,
|
|
1843
|
+
_mm512_mul_ps(half_vec, _mm512_mul_ps(b_norms_f32, _mm512_mul_ps(rsqrt_vec, rsqrt_vec)))));
|
|
1844
|
+
// Zero out rsqrt where norm was zero
|
|
1845
|
+
__mmask16 nonzero_mask = _mm512_cmpneq_epi32_mask(b_norms_i32, _mm512_setzero_si512());
|
|
1846
|
+
rsqrt_vec = _mm512_maskz_mov_ps(nonzero_mask, rsqrt_vec);
|
|
1847
|
+
_mm512_storeu_ps(b_rsqrt + column_idx, rsqrt_vec);
|
|
1848
|
+
}
|
|
1849
|
+
|
|
1850
|
+
// Handle remaining b_norms with masked operations
|
|
1851
|
+
if (column_idx < column_count) {
|
|
1852
|
+
__mmask16 tail_mask = (__mmask16)((1u << (column_count - column_idx)) - 1);
|
|
1853
|
+
__m512i b_norms_i32 = _mm512_maskz_loadu_epi32(tail_mask, b_squared_norms + column_idx);
|
|
1854
|
+
__m512 b_norms_f32 = _mm512_cvtepi32_ps(b_norms_i32);
|
|
1855
|
+
__m512 rsqrt_vec = _mm512_rsqrt14_ps(b_norms_f32);
|
|
1856
|
+
rsqrt_vec = _mm512_mul_ps(
|
|
1857
|
+
rsqrt_vec,
|
|
1858
|
+
_mm512_sub_ps(three_halves_vec,
|
|
1859
|
+
_mm512_mul_ps(half_vec, _mm512_mul_ps(b_norms_f32, _mm512_mul_ps(rsqrt_vec, rsqrt_vec)))));
|
|
1860
|
+
__mmask16 nonzero_mask = _mm512_cmpneq_epi32_mask(b_norms_i32, _mm512_setzero_si512());
|
|
1861
|
+
rsqrt_vec = _mm512_maskz_mov_ps(nonzero_mask & tail_mask, rsqrt_vec);
|
|
1862
|
+
_mm512_mask_storeu_ps(b_rsqrt + column_idx, tail_mask, rsqrt_vec);
|
|
1863
|
+
}
|
|
1864
|
+
|
|
1865
|
+
__m512 scale_vec = _mm512_set1_ps(127.0f);
|
|
1866
|
+
|
|
1867
|
+
for (nk_size_t row_idx = 0; row_idx < row_count; row_idx++) {
|
|
1868
|
+
nk_i32_t const *src_row = c_i32 + row_idx * c_stride_i32;
|
|
1869
|
+
nk_i8_t *dst_row = c_i8 + row_idx * column_count;
|
|
1870
|
+
|
|
1871
|
+
// Compute rsqrt of a_norm for this row, broadcast to vector
|
|
1872
|
+
nk_f32_t a_norm_f32 = (nk_f32_t)a_squared_norms[row_idx];
|
|
1873
|
+
nk_f32_t a_rsqrt_val = 0.0f;
|
|
1874
|
+
if (a_norm_f32 > 0.0f) {
|
|
1875
|
+
__m128 a_vec = _mm_set_ss(a_norm_f32);
|
|
1876
|
+
__m128 rsqrt_s = _mm_rsqrt_ss(a_vec);
|
|
1877
|
+
rsqrt_s = _mm_mul_ss(
|
|
1878
|
+
rsqrt_s, _mm_sub_ss(_mm_set_ss(1.5f),
|
|
1879
|
+
_mm_mul_ss(_mm_set_ss(0.5f), _mm_mul_ss(a_vec, _mm_mul_ss(rsqrt_s, rsqrt_s)))));
|
|
1880
|
+
a_rsqrt_val = _mm_cvtss_f32(rsqrt_s);
|
|
1881
|
+
}
|
|
1882
|
+
__m512 a_rsqrt_vec = _mm512_set1_ps(a_rsqrt_val);
|
|
1883
|
+
__m512 row_scale = _mm512_mul_ps(a_rsqrt_vec, scale_vec);
|
|
1884
|
+
|
|
1885
|
+
column_idx = 0;
|
|
1886
|
+
|
|
1887
|
+
// Process 16 elements at a time
|
|
1888
|
+
for (; column_idx + 16 <= column_count; column_idx += 16) {
|
|
1889
|
+
__m512i c_vals = _mm512_loadu_si512(src_row + column_idx);
|
|
1890
|
+
__m512 c_f32 = _mm512_cvtepi32_ps(c_vals);
|
|
1891
|
+
__m512 b_rsqrt_vec = _mm512_loadu_ps(b_rsqrt + column_idx);
|
|
1892
|
+
__m512 normalized = _mm512_mul_ps(_mm512_mul_ps(c_f32, row_scale), b_rsqrt_vec);
|
|
1893
|
+
__m512i result_i32 = _mm512_cvtps_epi32(normalized);
|
|
1894
|
+
// Saturating pack I32 → I8 (16 values → 16 bytes in low 128 bits)
|
|
1895
|
+
__m128i result_i8 = _mm512_cvtsepi32_epi8(result_i32);
|
|
1896
|
+
_mm_storeu_si128((__m128i *)(dst_row + column_idx), result_i8);
|
|
1897
|
+
}
|
|
1898
|
+
|
|
1899
|
+
// Handle remaining elements with masked operations
|
|
1900
|
+
if (column_idx < column_count) {
|
|
1901
|
+
__mmask16 tail_mask = (__mmask16)((1u << (column_count - column_idx)) - 1);
|
|
1902
|
+
__m512i c_vals = _mm512_maskz_loadu_epi32(tail_mask, src_row + column_idx);
|
|
1903
|
+
__m512 c_f32 = _mm512_cvtepi32_ps(c_vals);
|
|
1904
|
+
__m512 b_rsqrt_vec = _mm512_maskz_loadu_ps(tail_mask, b_rsqrt + column_idx);
|
|
1905
|
+
__m512 normalized = _mm512_mul_ps(_mm512_mul_ps(c_f32, row_scale), b_rsqrt_vec);
|
|
1906
|
+
__m512i result_i32 = _mm512_cvtps_epi32(normalized);
|
|
1907
|
+
__m128i result_i8 = _mm512_cvtsepi32_epi8(result_i32);
|
|
1908
|
+
_mm_mask_storeu_epi8(dst_row + column_idx, tail_mask, result_i8);
|
|
1909
|
+
}
|
|
1910
|
+
}
|
|
1911
|
+
}
|
|
1912
|
+
|
|
1913
|
+
NK_PUBLIC void nk_dots_symmetric_i8_sapphireamx( //
|
|
1914
|
+
nk_i8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
|
|
1915
|
+
nk_size_t stride, nk_i32_t *result, nk_size_t result_stride, //
|
|
1916
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
1917
|
+
|
|
1918
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_i32_t);
|
|
1919
|
+
|
|
1920
|
+
// Handle row slicing: compute rows [row_start, row_end)
|
|
1921
|
+
nk_size_t const row_end = (row_count == 0)
|
|
1922
|
+
? n_vectors
|
|
1923
|
+
: (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
|
|
1924
|
+
|
|
1925
|
+
// Round depth up to multiple of 192 (3 tiles × 64 elements)
|
|
1926
|
+
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 64);
|
|
1927
|
+
nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
|
|
1928
|
+
|
|
1929
|
+
nk_dots_i8_a16x64_sapphireamx_t a_tiles[3];
|
|
1930
|
+
nk_dots_i8_a16x64_sapphireamx_t b_src_tiles[3];
|
|
1931
|
+
nk_dots_i8_b64x16_sapphireamx_t b_tiles[3];
|
|
1932
|
+
nk_dots_i8_state_sapphireamx_t state;
|
|
1933
|
+
|
|
1934
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
1935
|
+
|
|
1936
|
+
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
1937
|
+
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
1938
|
+
|
|
1939
|
+
for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
|
|
1940
|
+
nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
|
|
1941
|
+
|
|
1942
|
+
nk_dots_i8_init_sapphireamx_(&state);
|
|
1943
|
+
|
|
1944
|
+
for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
|
|
1945
|
+
nk_size_t const depth_base = depth_group_idx * 192;
|
|
1946
|
+
|
|
1947
|
+
for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
|
|
1948
|
+
nk_size_t const depth_start = depth_base + tile_idx * 64;
|
|
1949
|
+
nk_size_t const valid_depth = (depth_start + 64 <= depth)
|
|
1950
|
+
? 64
|
|
1951
|
+
: (depth > depth_start ? depth - depth_start : 0);
|
|
1952
|
+
|
|
1953
|
+
nk_dots_i8_load_a_sapphireamx_( //
|
|
1954
|
+
&a_tiles[tile_idx], //
|
|
1955
|
+
vectors + row_tile * stride + depth_start, //
|
|
1956
|
+
stride, valid_rows, valid_depth);
|
|
1957
|
+
|
|
1958
|
+
if (row_tile == col_tile) {
|
|
1959
|
+
nk_dots_pack_i8_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
1960
|
+
}
|
|
1961
|
+
else {
|
|
1962
|
+
nk_dots_i8_load_a_sapphireamx_( //
|
|
1963
|
+
&b_src_tiles[tile_idx], //
|
|
1964
|
+
vectors + col_tile * stride + depth_start, //
|
|
1965
|
+
stride, valid_cols, valid_depth);
|
|
1966
|
+
nk_dots_pack_i8_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
1967
|
+
}
|
|
1968
|
+
}
|
|
1969
|
+
|
|
1970
|
+
nk_dots_i8_update_sapphireamx_( //
|
|
1971
|
+
&state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
|
|
1972
|
+
}
|
|
1973
|
+
|
|
1974
|
+
nk_dots_i8_store_sapphireamx_( //
|
|
1975
|
+
&state, result + row_tile * result_stride_elements + col_tile, //
|
|
1976
|
+
result_stride_elements, valid_rows, valid_cols);
|
|
1977
|
+
}
|
|
1978
|
+
}
|
|
1979
|
+
}
|
|
1980
|
+
|
|
1981
|
+
#pragma endregion // Signed Integers
|
|
1982
|
+
|
|
1983
|
+
#pragma region Unsigned Integers
|
|
1984
|
+
|
|
1985
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_u8_sapphireamx(nk_size_t column_count, nk_size_t depth) {
|
|
1986
|
+
// Same layout as I8 - just different type interpretation
|
|
1987
|
+
return nk_dots_packed_size_i8_sapphireamx(column_count, depth);
|
|
1988
|
+
}
|
|
1989
|
+
|
|
1990
|
+
NK_PUBLIC void nk_dots_pack_u8_sapphireamx( //
|
|
1991
|
+
nk_u8_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
1992
|
+
nk_size_t b_stride, void *b_packed) {
|
|
1993
|
+
|
|
1994
|
+
nk_size_t const tmm_rows = 16;
|
|
1995
|
+
nk_size_t const tmm_cols = 64;
|
|
1996
|
+
nk_size_t const tile_elements = 1024;
|
|
1997
|
+
nk_size_t const tile_bytes = tile_elements * sizeof(nk_u8_t);
|
|
1998
|
+
|
|
1999
|
+
nk_size_t const column_tiles_count = column_count / tmm_rows;
|
|
2000
|
+
nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
|
|
2001
|
+
nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
|
|
2002
|
+
nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
|
|
2003
|
+
|
|
2004
|
+
nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
|
|
2005
|
+
header->full_column_tiles = (nk_u32_t)column_tiles_count;
|
|
2006
|
+
header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
|
|
2007
|
+
header->column_remainder_count = (nk_u32_t)column_remainder_count;
|
|
2008
|
+
|
|
2009
|
+
nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
|
|
2010
|
+
nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
|
|
2011
|
+
header->column_edge_offset = (nk_u32_t)column_edge_offset;
|
|
2012
|
+
|
|
2013
|
+
nk_u8_t *tiles_ptr = (nk_u8_t *)((char *)b_packed + tiles_offset);
|
|
2014
|
+
nk_u8_t *column_edge_ptr = (nk_u8_t *)((char *)b_packed + column_edge_offset);
|
|
2015
|
+
|
|
2016
|
+
for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
|
|
2017
|
+
|
|
2018
|
+
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
2019
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2020
|
+
|
|
2021
|
+
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
2022
|
+
nk_u8_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
2023
|
+
|
|
2024
|
+
nk_size_t const src_row_start = column_tile_idx * tmm_rows;
|
|
2025
|
+
nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
|
|
2026
|
+
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
2027
|
+
: (depth - src_column_start);
|
|
2028
|
+
|
|
2029
|
+
// Pack with quad-interleaving as required by TDPBUUD instruction.
|
|
2030
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
2031
|
+
for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
|
|
2032
|
+
nk_size_t const src_idx = (src_row_start + row_idx) * b_stride + src_column_start + column_idx;
|
|
2033
|
+
nk_size_t const dst_idx = (column_idx / 4) * 64 + row_idx * 4 + (column_idx % 4);
|
|
2034
|
+
tile_output[dst_idx] = b[src_idx];
|
|
2035
|
+
}
|
|
2036
|
+
}
|
|
2037
|
+
}
|
|
2038
|
+
}
|
|
2039
|
+
|
|
2040
|
+
if (column_remainder_count > 0) {
|
|
2041
|
+
nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
|
|
2042
|
+
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
2043
|
+
for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
|
|
2044
|
+
column_edge_ptr[row_idx * depth + column_idx] =
|
|
2045
|
+
b[(remainder_start_row + row_idx) * b_stride + column_idx];
|
|
2046
|
+
}
|
|
2047
|
+
}
|
|
2048
|
+
}
|
|
2049
|
+
|
|
2050
|
+
// Compute and store per-column norms for angular/euclidean distance
|
|
2051
|
+
nk_size_t norms_offset = column_edge_offset +
|
|
2052
|
+
(column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_u8_t) : 0);
|
|
2053
|
+
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
2054
|
+
nk_u32_t *norms = (nk_u32_t *)((char *)b_packed + norms_offset);
|
|
2055
|
+
for (nk_size_t col = 0; col < column_count; col++) norms[col] = nk_dots_reduce_sumsq_u8_(b + col * b_stride, depth);
|
|
2056
|
+
}
|
|
2057
|
+
|
|
2058
|
+
NK_PUBLIC void nk_dots_packed_u8_sapphireamx( //
|
|
2059
|
+
nk_u8_t const *a, void const *b_packed, nk_u32_t *c, //
|
|
2060
|
+
nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
|
|
2061
|
+
nk_unused_(cols_count);
|
|
2062
|
+
|
|
2063
|
+
// Parse packed B header
|
|
2064
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
2065
|
+
nk_size_t const column_tiles_count = header->full_column_tiles;
|
|
2066
|
+
nk_size_t const depth_tiles_count = header->full_depth_tiles;
|
|
2067
|
+
nk_size_t const column_remainder_count = header->column_remainder_count;
|
|
2068
|
+
|
|
2069
|
+
// Packed B data regions
|
|
2070
|
+
nk_u8_t const *b_tiles_base = (nk_u8_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
|
|
2071
|
+
nk_u8_t const *col_edge_ptr = (nk_u8_t const *)((char const *)b_packed + header->column_edge_offset);
|
|
2072
|
+
|
|
2073
|
+
// Stride conversions
|
|
2074
|
+
nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_u32_t);
|
|
2075
|
+
|
|
2076
|
+
// Tile dimensions
|
|
2077
|
+
nk_size_t const tile_depth = 64; // depth elements per U8 tile
|
|
2078
|
+
nk_size_t const tile_size = 1024; // bytes per packed tile
|
|
2079
|
+
nk_size_t const full_cols = column_tiles_count * 16;
|
|
2080
|
+
|
|
2081
|
+
// Block counts (32 × 32 output blocks = 2 × 2 tiles)
|
|
2082
|
+
nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
|
|
2083
|
+
nk_size_t const col_blocks_count = column_tiles_count / 2;
|
|
2084
|
+
|
|
2085
|
+
if (depth_tiles_count == 0) return;
|
|
2086
|
+
|
|
2087
|
+
// Tile buffers for A (only used for edge tiles)
|
|
2088
|
+
nk_dots_u8_a16x64_sapphireamx_t a_tile_upper, a_tile_lower;
|
|
2089
|
+
nk_dots_u8_state2x2_sapphireamx_t c_accum_buffer;
|
|
2090
|
+
|
|
2091
|
+
// Precompute: number of full depth-tiles
|
|
2092
|
+
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
2093
|
+
nk_size_t const depth_remainder = depth % tile_depth;
|
|
2094
|
+
|
|
2095
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
2096
|
+
|
|
2097
|
+
// Process all 32 × 32 row × column blocks
|
|
2098
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
2099
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
2100
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
2101
|
+
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
2102
|
+
|
|
2103
|
+
// Process full column-blocks (pairs of 16-column tiles = 32 columns)
|
|
2104
|
+
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
2105
|
+
nk_size_t const col_block_start = column_block_idx * 32;
|
|
2106
|
+
|
|
2107
|
+
// B tile base indices
|
|
2108
|
+
nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
|
|
2109
|
+
nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
|
|
2110
|
+
|
|
2111
|
+
// Zero accumulators (TMM4-7 stay resident across entire depth loop)
|
|
2112
|
+
_tile_zero(4);
|
|
2113
|
+
_tile_zero(5);
|
|
2114
|
+
_tile_zero(6);
|
|
2115
|
+
_tile_zero(7);
|
|
2116
|
+
|
|
2117
|
+
// Fast path: full row-block with full depth-tiles → direct A load with 2-deep pipelining
|
|
2118
|
+
if (is_full_row_block && full_depth_tiles_count > 0) {
|
|
2119
|
+
nk_u8_t const *a_upper_base = a + row_block_start * a_stride_bytes;
|
|
2120
|
+
nk_u8_t const *a_lower_base = a + (row_block_start + 16) * a_stride_bytes;
|
|
2121
|
+
|
|
2122
|
+
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_left =
|
|
2123
|
+
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
2124
|
+
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_right =
|
|
2125
|
+
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
2126
|
+
|
|
2127
|
+
// Prologue: load first depth tile into TMM0-3
|
|
2128
|
+
_tile_loadd(0, a_upper_base, a_stride_bytes);
|
|
2129
|
+
_tile_loadd(1, a_lower_base, a_stride_bytes);
|
|
2130
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
2131
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
2132
|
+
|
|
2133
|
+
// Main loop: 2-deep software pipelining (compute current while loading next)
|
|
2134
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < full_depth_tiles_count - 1; depth_tile_idx++) {
|
|
2135
|
+
nk_size_t const next_depth_offset = (depth_tile_idx + 1) * tile_depth;
|
|
2136
|
+
|
|
2137
|
+
_tile_dpbuud(4, 0, 2);
|
|
2138
|
+
_tile_dpbuud(5, 0, 3);
|
|
2139
|
+
_tile_dpbuud(6, 1, 2);
|
|
2140
|
+
_tile_dpbuud(7, 1, 3);
|
|
2141
|
+
|
|
2142
|
+
_tile_loadd(0, a_upper_base + next_depth_offset, a_stride_bytes);
|
|
2143
|
+
_tile_loadd(1, a_lower_base + next_depth_offset, a_stride_bytes);
|
|
2144
|
+
b_tile_left = (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
2145
|
+
(b_column_left_base + depth_tile_idx + 1) *
|
|
2146
|
+
tile_size);
|
|
2147
|
+
b_tile_right = (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_right_base +
|
|
2148
|
+
depth_tile_idx + 1) *
|
|
2149
|
+
tile_size);
|
|
2150
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
2151
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
2152
|
+
}
|
|
2153
|
+
|
|
2154
|
+
// Epilogue: final depth tile (no next to load)
|
|
2155
|
+
_tile_dpbuud(4, 0, 2);
|
|
2156
|
+
_tile_dpbuud(5, 0, 3);
|
|
2157
|
+
_tile_dpbuud(6, 1, 2);
|
|
2158
|
+
_tile_dpbuud(7, 1, 3);
|
|
2159
|
+
|
|
2160
|
+
// Handle partial depth-tile (if any) with buffered load
|
|
2161
|
+
if (depth_remainder > 0) {
|
|
2162
|
+
nk_size_t const depth_offset = full_depth_tiles_count * tile_depth;
|
|
2163
|
+
|
|
2164
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a_upper_base + depth_offset, a_stride_bytes, 16,
|
|
2165
|
+
depth_remainder);
|
|
2166
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_lower, a_lower_base + depth_offset, a_stride_bytes, 16,
|
|
2167
|
+
depth_remainder);
|
|
2168
|
+
|
|
2169
|
+
b_tile_left = (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_left_base +
|
|
2170
|
+
full_depth_tiles_count) *
|
|
2171
|
+
tile_size);
|
|
2172
|
+
b_tile_right = (nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + (b_column_right_base +
|
|
2173
|
+
full_depth_tiles_count) *
|
|
2174
|
+
tile_size);
|
|
2175
|
+
|
|
2176
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
2177
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
2178
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
2179
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
2180
|
+
|
|
2181
|
+
_tile_dpbuud(4, 0, 2);
|
|
2182
|
+
_tile_dpbuud(5, 0, 3);
|
|
2183
|
+
_tile_dpbuud(6, 1, 2);
|
|
2184
|
+
_tile_dpbuud(7, 1, 3);
|
|
2185
|
+
}
|
|
2186
|
+
}
|
|
2187
|
+
// Full row-block but only partial depth tile (depth < tile_depth)
|
|
2188
|
+
else if (is_full_row_block) {
|
|
2189
|
+
nk_u8_t const *a_upper_base = a + row_block_start * a_stride_bytes;
|
|
2190
|
+
nk_u8_t const *a_lower_base = a + (row_block_start + 16) * a_stride_bytes;
|
|
2191
|
+
|
|
2192
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a_upper_base, a_stride_bytes, 16, depth_remainder);
|
|
2193
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_lower, a_lower_base, a_stride_bytes, 16, depth_remainder);
|
|
2194
|
+
|
|
2195
|
+
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_left =
|
|
2196
|
+
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_left_base * tile_size);
|
|
2197
|
+
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_right =
|
|
2198
|
+
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base + b_column_right_base * tile_size);
|
|
2199
|
+
|
|
2200
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
2201
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
2202
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
2203
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
2204
|
+
|
|
2205
|
+
_tile_dpbuud(4, 0, 2);
|
|
2206
|
+
_tile_dpbuud(5, 0, 3);
|
|
2207
|
+
_tile_dpbuud(6, 1, 2);
|
|
2208
|
+
_tile_dpbuud(7, 1, 3);
|
|
2209
|
+
}
|
|
2210
|
+
// Slow path: edge row-block → always use buffered load
|
|
2211
|
+
else {
|
|
2212
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
2213
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
2214
|
+
|
|
2215
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2216
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2217
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth
|
|
2218
|
+
: depth_remainder;
|
|
2219
|
+
|
|
2220
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2221
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
2222
|
+
if (rows_in_lower_tile > 0) {
|
|
2223
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_lower,
|
|
2224
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2225
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
2226
|
+
}
|
|
2227
|
+
|
|
2228
|
+
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_left =
|
|
2229
|
+
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
2230
|
+
(b_column_left_base + depth_tile_idx) * tile_size);
|
|
2231
|
+
nk_dots_u8_b64x16_sapphireamx_t const *b_tile_right =
|
|
2232
|
+
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
2233
|
+
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
2234
|
+
|
|
2235
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
2236
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
2237
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
2238
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
2239
|
+
|
|
2240
|
+
_tile_dpbuud(4, 0, 2);
|
|
2241
|
+
_tile_dpbuud(5, 0, 3);
|
|
2242
|
+
_tile_dpbuud(6, 1, 2);
|
|
2243
|
+
_tile_dpbuud(7, 1, 3);
|
|
2244
|
+
}
|
|
2245
|
+
}
|
|
2246
|
+
|
|
2247
|
+
// Store accumulators to output (once per output block, not per depth tile)
|
|
2248
|
+
if (is_full_row_block) {
|
|
2249
|
+
nk_u32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
|
|
2250
|
+
_tile_stored(4, c_block, c_stride_bytes);
|
|
2251
|
+
_tile_stored(5, c_block + 16, c_stride_bytes);
|
|
2252
|
+
_tile_stored(6, (nk_u32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
|
|
2253
|
+
_tile_stored(7, (nk_u32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
|
|
2254
|
+
}
|
|
2255
|
+
else {
|
|
2256
|
+
_tile_stored(4, c_accum_buffer.c[0][0].data, 64);
|
|
2257
|
+
_tile_stored(5, c_accum_buffer.c[0][1].data, 64);
|
|
2258
|
+
_tile_stored(6, c_accum_buffer.c[1][0].data, 64);
|
|
2259
|
+
_tile_stored(7, c_accum_buffer.c[1][1].data, 64);
|
|
2260
|
+
nk_dots_u8_output2x2_sapphireamx_(&c_accum_buffer,
|
|
2261
|
+
c + row_block_start * c_stride_elements + col_block_start,
|
|
2262
|
+
c_stride_elements, valid_rows_count, 32);
|
|
2263
|
+
}
|
|
2264
|
+
}
|
|
2265
|
+
|
|
2266
|
+
// Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
|
|
2267
|
+
if (column_tiles_count % 2 == 1) {
|
|
2268
|
+
nk_size_t const column_tile_idx = column_tiles_count - 1;
|
|
2269
|
+
nk_size_t const col_start = column_tile_idx * 16;
|
|
2270
|
+
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
2271
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
2272
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
2273
|
+
|
|
2274
|
+
nk_dots_u8_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
2275
|
+
|
|
2276
|
+
_tile_zero(4);
|
|
2277
|
+
_tile_zero(6);
|
|
2278
|
+
|
|
2279
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2280
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2281
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2282
|
+
|
|
2283
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2284
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
2285
|
+
if (rows_in_lower_tile > 0) {
|
|
2286
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_lower,
|
|
2287
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2288
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
2289
|
+
}
|
|
2290
|
+
|
|
2291
|
+
nk_dots_u8_b64x16_sapphireamx_t const *b_tile =
|
|
2292
|
+
(nk_dots_u8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
2293
|
+
(b_column_base + depth_tile_idx) * tile_size);
|
|
2294
|
+
|
|
2295
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
2296
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
2297
|
+
_tile_loadd(2, b_tile->data, 64);
|
|
2298
|
+
|
|
2299
|
+
_tile_dpbuud(4, 0, 2);
|
|
2300
|
+
_tile_dpbuud(6, 1, 2);
|
|
2301
|
+
}
|
|
2302
|
+
|
|
2303
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
2304
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
2305
|
+
|
|
2306
|
+
nk_dots_u8_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
|
|
2307
|
+
c_stride_elements, rows_in_upper_tile, 16);
|
|
2308
|
+
if (rows_in_lower_tile > 0) {
|
|
2309
|
+
nk_dots_u8_store_sapphireamx_(&c_lower_state,
|
|
2310
|
+
c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
2311
|
+
c_stride_elements, rows_in_lower_tile, 16);
|
|
2312
|
+
}
|
|
2313
|
+
}
|
|
2314
|
+
|
|
2315
|
+
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
2316
|
+
if (column_remainder_count > 0) {
|
|
2317
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
2318
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
2319
|
+
|
|
2320
|
+
nk_dots_u8_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
2321
|
+
nk_dots_u8_a16x64_sapphireamx_t b_as_a;
|
|
2322
|
+
nk_dots_u8_b64x16_sapphireamx_t b_tile;
|
|
2323
|
+
|
|
2324
|
+
_tile_zero(4);
|
|
2325
|
+
_tile_zero(6);
|
|
2326
|
+
|
|
2327
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2328
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2329
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2330
|
+
|
|
2331
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2332
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
2333
|
+
if (rows_in_lower_tile > 0) {
|
|
2334
|
+
nk_dots_u8_load_a_sapphireamx_(&a_tile_lower,
|
|
2335
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2336
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
2337
|
+
}
|
|
2338
|
+
|
|
2339
|
+
nk_dots_u8_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
|
|
2340
|
+
valid_depth);
|
|
2341
|
+
nk_dots_pack_u8_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
2342
|
+
|
|
2343
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
2344
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
2345
|
+
_tile_loadd(2, b_tile.data, 64);
|
|
2346
|
+
|
|
2347
|
+
_tile_dpbuud(4, 0, 2);
|
|
2348
|
+
_tile_dpbuud(6, 1, 2);
|
|
2349
|
+
}
|
|
2350
|
+
|
|
2351
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
2352
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
2353
|
+
|
|
2354
|
+
nk_dots_u8_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
|
|
2355
|
+
c_stride_elements, rows_in_upper_tile, column_remainder_count);
|
|
2356
|
+
if (rows_in_lower_tile > 0) {
|
|
2357
|
+
nk_dots_u8_store_sapphireamx_(&c_lower_state,
|
|
2358
|
+
c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
2359
|
+
c_stride_elements, rows_in_lower_tile, column_remainder_count);
|
|
2360
|
+
}
|
|
2361
|
+
}
|
|
2362
|
+
}
|
|
2363
|
+
|
|
2364
|
+
_tile_release();
|
|
2365
|
+
}
|
|
2366
|
+
|
|
2367
|
+
NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx( //
|
|
2368
|
+
nk_u8_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
|
|
2369
|
+
nk_size_t stride, nk_u32_t *result, nk_size_t result_stride, //
|
|
2370
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
2371
|
+
|
|
2372
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_u32_t);
|
|
2373
|
+
|
|
2374
|
+
// Handle row slicing: compute rows [row_start, row_end)
|
|
2375
|
+
nk_size_t const row_end = (row_count == 0)
|
|
2376
|
+
? n_vectors
|
|
2377
|
+
: (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
|
|
2378
|
+
|
|
2379
|
+
// Round depth up to multiple of 192 (3 tiles × 64 elements)
|
|
2380
|
+
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 64);
|
|
2381
|
+
nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
|
|
2382
|
+
|
|
2383
|
+
nk_dots_u8_a16x64_sapphireamx_t a_tiles[3];
|
|
2384
|
+
nk_dots_u8_a16x64_sapphireamx_t b_src_tiles[3];
|
|
2385
|
+
nk_dots_u8_b64x16_sapphireamx_t b_tiles[3];
|
|
2386
|
+
nk_dots_u8_state_sapphireamx_t state;
|
|
2387
|
+
|
|
2388
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
2389
|
+
|
|
2390
|
+
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
2391
|
+
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
2392
|
+
|
|
2393
|
+
for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
|
|
2394
|
+
nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
|
|
2395
|
+
|
|
2396
|
+
nk_dots_u8_init_sapphireamx_(&state);
|
|
2397
|
+
|
|
2398
|
+
for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
|
|
2399
|
+
nk_size_t const depth_base = depth_group_idx * 192;
|
|
2400
|
+
|
|
2401
|
+
for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
|
|
2402
|
+
nk_size_t const depth_start = depth_base + tile_idx * 64;
|
|
2403
|
+
nk_size_t const valid_depth = (depth_start + 64 <= depth)
|
|
2404
|
+
? 64
|
|
2405
|
+
: (depth > depth_start ? depth - depth_start : 0);
|
|
2406
|
+
|
|
2407
|
+
nk_dots_u8_load_a_sapphireamx_( //
|
|
2408
|
+
&a_tiles[tile_idx], //
|
|
2409
|
+
vectors + row_tile * stride + depth_start, //
|
|
2410
|
+
stride, valid_rows, valid_depth);
|
|
2411
|
+
|
|
2412
|
+
if (row_tile == col_tile) {
|
|
2413
|
+
nk_dots_pack_u8_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
2414
|
+
}
|
|
2415
|
+
else {
|
|
2416
|
+
nk_dots_u8_load_a_sapphireamx_( //
|
|
2417
|
+
&b_src_tiles[tile_idx], //
|
|
2418
|
+
vectors + col_tile * stride + depth_start, //
|
|
2419
|
+
stride, valid_cols, valid_depth);
|
|
2420
|
+
nk_dots_pack_u8_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
2421
|
+
}
|
|
2422
|
+
}
|
|
2423
|
+
|
|
2424
|
+
nk_dots_u8_update_sapphireamx_( //
|
|
2425
|
+
&state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
|
|
2426
|
+
}
|
|
2427
|
+
|
|
2428
|
+
nk_dots_u8_store_sapphireamx_( //
|
|
2429
|
+
&state, result + row_tile * result_stride_elements + col_tile, //
|
|
2430
|
+
result_stride_elements, valid_rows, valid_cols);
|
|
2431
|
+
}
|
|
2432
|
+
}
|
|
2433
|
+
}
|
|
2434
|
+
|
|
2435
|
+
#pragma endregion // Unsigned Integers
|
|
2436
|
+
|
|
2437
|
+
#pragma region Quarter Precision E4M3
|
|
2438
|
+
|
|
2439
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e4m3_sapphireamx(nk_size_t column_count, nk_size_t depth) {
|
|
2440
|
+
// FP8 uses BF16 tile layout after conversion (same element count: 32 per row)
|
|
2441
|
+
return nk_dots_packed_size_bf16_sapphireamx(column_count, depth);
|
|
2442
|
+
}
|
|
2443
|
+
|
|
2444
|
+
NK_PUBLIC void nk_dots_pack_e4m3_sapphireamx( //
|
|
2445
|
+
nk_e4m3_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
2446
|
+
nk_size_t b_stride, void *b_packed) {
|
|
2447
|
+
|
|
2448
|
+
nk_size_t const tmm_rows = 16;
|
|
2449
|
+
nk_size_t const tmm_cols = 32; // Same depth granularity as BF16
|
|
2450
|
+
nk_size_t const tile_elements = 512;
|
|
2451
|
+
nk_size_t const tile_bytes = tile_elements * sizeof(nk_bf16_t);
|
|
2452
|
+
|
|
2453
|
+
nk_size_t const column_tiles_count = column_count / tmm_rows;
|
|
2454
|
+
nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
|
|
2455
|
+
nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
|
|
2456
|
+
nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
|
|
2457
|
+
|
|
2458
|
+
nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
|
|
2459
|
+
header->full_column_tiles = (nk_u32_t)column_tiles_count;
|
|
2460
|
+
header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
|
|
2461
|
+
header->column_remainder_count = (nk_u32_t)column_remainder_count;
|
|
2462
|
+
|
|
2463
|
+
nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
|
|
2464
|
+
nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
|
|
2465
|
+
header->column_edge_offset = (nk_u32_t)column_edge_offset;
|
|
2466
|
+
|
|
2467
|
+
nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
|
|
2468
|
+
nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
|
|
2469
|
+
|
|
2470
|
+
for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
|
|
2471
|
+
|
|
2472
|
+
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
2473
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2474
|
+
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
2475
|
+
nk_bf16_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
2476
|
+
|
|
2477
|
+
nk_size_t const src_row_start = column_tile_idx * tmm_rows;
|
|
2478
|
+
nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
|
|
2479
|
+
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
2480
|
+
: (depth - src_column_start);
|
|
2481
|
+
|
|
2482
|
+
// Convert E4M3 to BF16 and pack with pair-interleaving
|
|
2483
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
2484
|
+
nk_size_t src_row = src_row_start + row_idx;
|
|
2485
|
+
// Load 32 E4M3 bytes and convert to BF16
|
|
2486
|
+
__mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
|
|
2487
|
+
__m256i e4m3_row = _mm256_maskz_loadu_epi8(column_mask, b + src_row * b_stride + src_column_start);
|
|
2488
|
+
__m512i bf16_row = nk_e4m3x32_to_bf16x32_icelake_(e4m3_row);
|
|
2489
|
+
// Store with pair-interleaving
|
|
2490
|
+
nk_bf16_t bf16_buf[32];
|
|
2491
|
+
_mm512_storeu_si512((__m512i *)bf16_buf, bf16_row);
|
|
2492
|
+
for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
|
|
2493
|
+
nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
|
|
2494
|
+
tile_output[dst_idx] = bf16_buf[column_idx];
|
|
2495
|
+
}
|
|
2496
|
+
}
|
|
2497
|
+
}
|
|
2498
|
+
}
|
|
2499
|
+
|
|
2500
|
+
// Pack column-remainder rows (convert E4M3 to BF16)
|
|
2501
|
+
if (column_remainder_count > 0) {
|
|
2502
|
+
nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
|
|
2503
|
+
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
2504
|
+
for (nk_size_t column_idx = 0; column_idx < depth; column_idx += 32) {
|
|
2505
|
+
nk_size_t columns = (column_idx + 32 <= depth) ? 32 : (depth - column_idx);
|
|
2506
|
+
__mmask32 column_mask = (columns >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns) - 1;
|
|
2507
|
+
__m256i e4m3_chunk = _mm256_maskz_loadu_epi8(
|
|
2508
|
+
column_mask, b + (remainder_start_row + row_idx) * b_stride + column_idx);
|
|
2509
|
+
__m512i bf16_chunk = nk_e4m3x32_to_bf16x32_icelake_(e4m3_chunk);
|
|
2510
|
+
_mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask, bf16_chunk);
|
|
2511
|
+
}
|
|
2512
|
+
}
|
|
2513
|
+
}
|
|
2514
|
+
|
|
2515
|
+
// Compute and store per-column norms for angular/euclidean distance
|
|
2516
|
+
nk_size_t norms_offset = column_edge_offset +
|
|
2517
|
+
(column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_bf16_t) : 0);
|
|
2518
|
+
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
2519
|
+
nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
|
|
2520
|
+
for (nk_size_t col = 0; col < column_count; col++)
|
|
2521
|
+
norms[col] = nk_dots_reduce_sumsq_e4m3_(b + col * b_stride, depth);
|
|
2522
|
+
}
|
|
2523
|
+
|
|
2524
|
+
NK_PUBLIC void nk_dots_packed_e4m3_sapphireamx( //
|
|
2525
|
+
nk_e4m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
2526
|
+
nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
|
|
2527
|
+
nk_unused_(cols_count);
|
|
2528
|
+
|
|
2529
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
2530
|
+
nk_size_t const column_tiles_count = header->full_column_tiles;
|
|
2531
|
+
nk_size_t const depth_tiles_count = header->full_depth_tiles;
|
|
2532
|
+
nk_size_t const column_remainder_count = header->column_remainder_count;
|
|
2533
|
+
|
|
2534
|
+
// B tiles are already in BF16 format
|
|
2535
|
+
nk_bf16_t const *b_tiles_base = (nk_bf16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
|
|
2536
|
+
nk_bf16_t const *col_edge_ptr = (nk_bf16_t const *)((char const *)b_packed + header->column_edge_offset);
|
|
2537
|
+
|
|
2538
|
+
nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
|
|
2539
|
+
nk_size_t const tile_depth = 32;
|
|
2540
|
+
nk_size_t const tile_size = 512;
|
|
2541
|
+
nk_size_t const full_cols = column_tiles_count * 16;
|
|
2542
|
+
|
|
2543
|
+
nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
|
|
2544
|
+
nk_size_t const col_blocks_count = column_tiles_count / 2;
|
|
2545
|
+
|
|
2546
|
+
if (depth_tiles_count == 0) return;
|
|
2547
|
+
|
|
2548
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tile_upper, a_tile_lower;
|
|
2549
|
+
nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
|
|
2550
|
+
|
|
2551
|
+
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
2552
|
+
nk_size_t const depth_remainder = depth % tile_depth;
|
|
2553
|
+
|
|
2554
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
2555
|
+
|
|
2556
|
+
// Loop order: row_blocks outer, col_blocks inner
|
|
2557
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
2558
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
2559
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
2560
|
+
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
2561
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
2562
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
2563
|
+
|
|
2564
|
+
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
2565
|
+
nk_size_t const col_block_start = column_block_idx * 32;
|
|
2566
|
+
nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
|
|
2567
|
+
nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
|
|
2568
|
+
|
|
2569
|
+
// Zero accumulators (TMM4-7 stay resident across entire depth loop)
|
|
2570
|
+
_tile_zero(4);
|
|
2571
|
+
_tile_zero(5);
|
|
2572
|
+
_tile_zero(6);
|
|
2573
|
+
_tile_zero(7);
|
|
2574
|
+
|
|
2575
|
+
// FP8 always uses buffered load for E4M3 → BF16 conversion
|
|
2576
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2577
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2578
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2579
|
+
|
|
2580
|
+
// Load A with FP8 → BF16 conversion
|
|
2581
|
+
nk_dots_e4m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2582
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
2583
|
+
if (rows_in_lower_tile > 0) {
|
|
2584
|
+
nk_dots_e4m3_load_a_sapphireamx_(&a_tile_lower,
|
|
2585
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2586
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
2587
|
+
}
|
|
2588
|
+
|
|
2589
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
|
|
2590
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
2591
|
+
(b_column_left_base + depth_tile_idx) * tile_size);
|
|
2592
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
|
|
2593
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
2594
|
+
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
2595
|
+
|
|
2596
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
2597
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
2598
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
2599
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
2600
|
+
|
|
2601
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
2602
|
+
_tile_dpbf16ps(5, 0, 3);
|
|
2603
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
2604
|
+
_tile_dpbf16ps(7, 1, 3);
|
|
2605
|
+
}
|
|
2606
|
+
|
|
2607
|
+
// Store accumulators to output (once per output block)
|
|
2608
|
+
if (is_full_row_block) {
|
|
2609
|
+
nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
|
|
2610
|
+
_tile_stored(4, c_block, c_stride_bytes);
|
|
2611
|
+
_tile_stored(5, c_block + 16, c_stride_bytes);
|
|
2612
|
+
_tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
|
|
2613
|
+
_tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
|
|
2614
|
+
}
|
|
2615
|
+
else {
|
|
2616
|
+
_tile_stored(4, c_accum_buffer.c[0][0].data, 64);
|
|
2617
|
+
_tile_stored(5, c_accum_buffer.c[0][1].data, 64);
|
|
2618
|
+
_tile_stored(6, c_accum_buffer.c[1][0].data, 64);
|
|
2619
|
+
_tile_stored(7, c_accum_buffer.c[1][1].data, 64);
|
|
2620
|
+
nk_dots_bf16_output2x2_sapphireamx_(&c_accum_buffer,
|
|
2621
|
+
c + row_block_start * c_stride_elements + col_block_start,
|
|
2622
|
+
c_stride_elements, valid_rows_count, 32);
|
|
2623
|
+
}
|
|
2624
|
+
}
|
|
2625
|
+
|
|
2626
|
+
// Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
|
|
2627
|
+
if (column_tiles_count % 2 == 1) {
|
|
2628
|
+
nk_size_t const column_tile_idx = column_tiles_count - 1;
|
|
2629
|
+
nk_size_t const col_start = column_tile_idx * 16;
|
|
2630
|
+
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
2631
|
+
|
|
2632
|
+
nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
2633
|
+
_tile_zero(4);
|
|
2634
|
+
_tile_zero(6);
|
|
2635
|
+
|
|
2636
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2637
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2638
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2639
|
+
|
|
2640
|
+
nk_dots_e4m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2641
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
2642
|
+
if (rows_in_lower_tile > 0) {
|
|
2643
|
+
nk_dots_e4m3_load_a_sapphireamx_(&a_tile_lower,
|
|
2644
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2645
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
2646
|
+
}
|
|
2647
|
+
|
|
2648
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
|
|
2649
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
2650
|
+
(b_column_base + depth_tile_idx) * tile_size);
|
|
2651
|
+
|
|
2652
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
2653
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
2654
|
+
_tile_loadd(2, b_tile->data, 64);
|
|
2655
|
+
|
|
2656
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
2657
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
2658
|
+
}
|
|
2659
|
+
|
|
2660
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
2661
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
2662
|
+
|
|
2663
|
+
nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
|
|
2664
|
+
c_stride_elements, rows_in_upper_tile, 16);
|
|
2665
|
+
if (rows_in_lower_tile > 0) {
|
|
2666
|
+
nk_dots_bf16_store_sapphireamx_(&c_lower_state,
|
|
2667
|
+
c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
2668
|
+
c_stride_elements, rows_in_lower_tile, 16);
|
|
2669
|
+
}
|
|
2670
|
+
}
|
|
2671
|
+
|
|
2672
|
+
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
2673
|
+
if (column_remainder_count > 0) {
|
|
2674
|
+
nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
2675
|
+
nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
|
|
2676
|
+
nk_dots_bf16_b32x16_sapphireamx_t b_tile;
|
|
2677
|
+
|
|
2678
|
+
_tile_zero(4);
|
|
2679
|
+
_tile_zero(6);
|
|
2680
|
+
|
|
2681
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2682
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2683
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2684
|
+
|
|
2685
|
+
nk_dots_e4m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2686
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
2687
|
+
if (rows_in_lower_tile > 0) {
|
|
2688
|
+
nk_dots_e4m3_load_a_sapphireamx_(&a_tile_lower,
|
|
2689
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2690
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
2691
|
+
}
|
|
2692
|
+
|
|
2693
|
+
// B edge data is already in BF16 format
|
|
2694
|
+
nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
|
|
2695
|
+
valid_depth);
|
|
2696
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
2697
|
+
|
|
2698
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
2699
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
2700
|
+
_tile_loadd(2, b_tile.data, 64);
|
|
2701
|
+
|
|
2702
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
2703
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
2704
|
+
}
|
|
2705
|
+
|
|
2706
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
2707
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
2708
|
+
|
|
2709
|
+
nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
|
|
2710
|
+
c_stride_elements, rows_in_upper_tile, column_remainder_count);
|
|
2711
|
+
if (rows_in_lower_tile > 0) {
|
|
2712
|
+
nk_dots_bf16_store_sapphireamx_(&c_lower_state,
|
|
2713
|
+
c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
2714
|
+
c_stride_elements, rows_in_lower_tile, column_remainder_count);
|
|
2715
|
+
}
|
|
2716
|
+
}
|
|
2717
|
+
}
|
|
2718
|
+
|
|
2719
|
+
_tile_release();
|
|
2720
|
+
}
|
|
2721
|
+
|
|
2722
|
+
#pragma endregion // Quarter Precision E4M3
|
|
2723
|
+
|
|
2724
|
+
#pragma region Quarter Precision E5M2
|
|
2725
|
+
|
|
2726
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_sapphireamx(nk_size_t column_count, nk_size_t depth) {
|
|
2727
|
+
return nk_dots_packed_size_bf16_sapphireamx(column_count, depth);
|
|
2728
|
+
}
|
|
2729
|
+
|
|
2730
|
+
NK_PUBLIC void nk_dots_pack_e5m2_sapphireamx( //
|
|
2731
|
+
nk_e5m2_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
2732
|
+
nk_size_t b_stride, void *b_packed) {
|
|
2733
|
+
|
|
2734
|
+
nk_size_t const tmm_rows = 16;
|
|
2735
|
+
nk_size_t const tmm_cols = 32;
|
|
2736
|
+
nk_size_t const tile_elements = 512;
|
|
2737
|
+
nk_size_t const tile_bytes = tile_elements * sizeof(nk_bf16_t);
|
|
2738
|
+
|
|
2739
|
+
nk_size_t const column_tiles_count = column_count / tmm_rows;
|
|
2740
|
+
nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
|
|
2741
|
+
nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
|
|
2742
|
+
nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
|
|
2743
|
+
|
|
2744
|
+
nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
|
|
2745
|
+
header->full_column_tiles = (nk_u32_t)column_tiles_count;
|
|
2746
|
+
header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
|
|
2747
|
+
header->column_remainder_count = (nk_u32_t)column_remainder_count;
|
|
2748
|
+
|
|
2749
|
+
nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
|
|
2750
|
+
nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
|
|
2751
|
+
header->column_edge_offset = (nk_u32_t)column_edge_offset;
|
|
2752
|
+
|
|
2753
|
+
nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
|
|
2754
|
+
nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
|
|
2755
|
+
|
|
2756
|
+
for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
|
|
2757
|
+
|
|
2758
|
+
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
2759
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2760
|
+
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
2761
|
+
nk_bf16_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
2762
|
+
|
|
2763
|
+
nk_size_t const src_row_start = column_tile_idx * tmm_rows;
|
|
2764
|
+
nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
|
|
2765
|
+
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
2766
|
+
: (depth - src_column_start);
|
|
2767
|
+
|
|
2768
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
2769
|
+
nk_size_t src_row = src_row_start + row_idx;
|
|
2770
|
+
__mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
|
|
2771
|
+
__m256i e5m2_row = _mm256_maskz_loadu_epi8(column_mask, b + src_row * b_stride + src_column_start);
|
|
2772
|
+
__m512i bf16_row = nk_e5m2x32_to_bf16x32_icelake_(e5m2_row);
|
|
2773
|
+
nk_bf16_t bf16_buf[32];
|
|
2774
|
+
_mm512_storeu_si512((__m512i *)bf16_buf, bf16_row);
|
|
2775
|
+
for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
|
|
2776
|
+
nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
|
|
2777
|
+
tile_output[dst_idx] = bf16_buf[column_idx];
|
|
2778
|
+
}
|
|
2779
|
+
}
|
|
2780
|
+
}
|
|
2781
|
+
}
|
|
2782
|
+
|
|
2783
|
+
if (column_remainder_count > 0) {
|
|
2784
|
+
nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
|
|
2785
|
+
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
2786
|
+
for (nk_size_t column_idx = 0; column_idx < depth; column_idx += 32) {
|
|
2787
|
+
nk_size_t columns = (column_idx + 32 <= depth) ? 32 : (depth - column_idx);
|
|
2788
|
+
__mmask32 column_mask = (columns >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns) - 1;
|
|
2789
|
+
__m256i e5m2_chunk = _mm256_maskz_loadu_epi8(
|
|
2790
|
+
column_mask, b + (remainder_start_row + row_idx) * b_stride + column_idx);
|
|
2791
|
+
__m512i bf16_chunk = nk_e5m2x32_to_bf16x32_icelake_(e5m2_chunk);
|
|
2792
|
+
_mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask, bf16_chunk);
|
|
2793
|
+
}
|
|
2794
|
+
}
|
|
2795
|
+
}
|
|
2796
|
+
|
|
2797
|
+
// Compute and store per-column norms for angular/euclidean distance
|
|
2798
|
+
nk_size_t norms_offset = column_edge_offset +
|
|
2799
|
+
(column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_bf16_t) : 0);
|
|
2800
|
+
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
2801
|
+
nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
|
|
2802
|
+
for (nk_size_t col = 0; col < column_count; col++)
|
|
2803
|
+
norms[col] = nk_dots_reduce_sumsq_e5m2_(b + col * b_stride, depth);
|
|
2804
|
+
}
|
|
2805
|
+
|
|
2806
|
+
NK_PUBLIC void nk_dots_packed_e5m2_sapphireamx( //
|
|
2807
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
2808
|
+
nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
|
|
2809
|
+
nk_unused_(cols_count);
|
|
2810
|
+
|
|
2811
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
2812
|
+
nk_size_t const column_tiles_count = header->full_column_tiles;
|
|
2813
|
+
nk_size_t const depth_tiles_count = header->full_depth_tiles;
|
|
2814
|
+
nk_size_t const column_remainder_count = header->column_remainder_count;
|
|
2815
|
+
|
|
2816
|
+
nk_bf16_t const *b_tiles_base = (nk_bf16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
|
|
2817
|
+
nk_bf16_t const *col_edge_ptr = (nk_bf16_t const *)((char const *)b_packed + header->column_edge_offset);
|
|
2818
|
+
|
|
2819
|
+
nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
|
|
2820
|
+
nk_size_t const tile_depth = 32;
|
|
2821
|
+
nk_size_t const tile_size = 512;
|
|
2822
|
+
nk_size_t const full_cols = column_tiles_count * 16;
|
|
2823
|
+
|
|
2824
|
+
nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
|
|
2825
|
+
nk_size_t const col_blocks_count = column_tiles_count / 2;
|
|
2826
|
+
|
|
2827
|
+
if (depth_tiles_count == 0) return;
|
|
2828
|
+
|
|
2829
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tile_upper, a_tile_lower;
|
|
2830
|
+
nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
|
|
2831
|
+
|
|
2832
|
+
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
2833
|
+
nk_size_t const depth_remainder = depth % tile_depth;
|
|
2834
|
+
|
|
2835
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
2836
|
+
|
|
2837
|
+
// Loop order: row_blocks outer, col_blocks inner
|
|
2838
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
2839
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
2840
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
2841
|
+
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
2842
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
2843
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
2844
|
+
|
|
2845
|
+
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
2846
|
+
nk_size_t const col_block_start = column_block_idx * 32;
|
|
2847
|
+
nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
|
|
2848
|
+
nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
|
|
2849
|
+
|
|
2850
|
+
// Zero accumulators (TMM4-7 stay resident across entire depth loop)
|
|
2851
|
+
_tile_zero(4);
|
|
2852
|
+
_tile_zero(5);
|
|
2853
|
+
_tile_zero(6);
|
|
2854
|
+
_tile_zero(7);
|
|
2855
|
+
|
|
2856
|
+
// FP8 always uses buffered load for E5M2 → BF16 conversion
|
|
2857
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2858
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2859
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2860
|
+
|
|
2861
|
+
// Load A with FP8 → BF16 conversion
|
|
2862
|
+
nk_dots_e5m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2863
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
2864
|
+
if (rows_in_lower_tile > 0) {
|
|
2865
|
+
nk_dots_e5m2_load_a_sapphireamx_(&a_tile_lower,
|
|
2866
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2867
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
2868
|
+
}
|
|
2869
|
+
|
|
2870
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
|
|
2871
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
2872
|
+
(b_column_left_base + depth_tile_idx) * tile_size);
|
|
2873
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
|
|
2874
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
2875
|
+
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
2876
|
+
|
|
2877
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
2878
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
2879
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
2880
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
2881
|
+
|
|
2882
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
2883
|
+
_tile_dpbf16ps(5, 0, 3);
|
|
2884
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
2885
|
+
_tile_dpbf16ps(7, 1, 3);
|
|
2886
|
+
}
|
|
2887
|
+
|
|
2888
|
+
// Store accumulators to output (once per output block)
|
|
2889
|
+
if (is_full_row_block) {
|
|
2890
|
+
nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
|
|
2891
|
+
_tile_stored(4, c_block, c_stride_bytes);
|
|
2892
|
+
_tile_stored(5, c_block + 16, c_stride_bytes);
|
|
2893
|
+
_tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
|
|
2894
|
+
_tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
|
|
2895
|
+
}
|
|
2896
|
+
else {
|
|
2897
|
+
_tile_stored(4, c_accum_buffer.c[0][0].data, 64);
|
|
2898
|
+
_tile_stored(5, c_accum_buffer.c[0][1].data, 64);
|
|
2899
|
+
_tile_stored(6, c_accum_buffer.c[1][0].data, 64);
|
|
2900
|
+
_tile_stored(7, c_accum_buffer.c[1][1].data, 64);
|
|
2901
|
+
nk_dots_bf16_output2x2_sapphireamx_(&c_accum_buffer,
|
|
2902
|
+
c + row_block_start * c_stride_elements + col_block_start,
|
|
2903
|
+
c_stride_elements, valid_rows_count, 32);
|
|
2904
|
+
}
|
|
2905
|
+
}
|
|
2906
|
+
|
|
2907
|
+
// Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
|
|
2908
|
+
if (column_tiles_count % 2 == 1) {
|
|
2909
|
+
nk_size_t const column_tile_idx = column_tiles_count - 1;
|
|
2910
|
+
nk_size_t const col_start = column_tile_idx * 16;
|
|
2911
|
+
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
2912
|
+
|
|
2913
|
+
nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
2914
|
+
_tile_zero(4);
|
|
2915
|
+
_tile_zero(6);
|
|
2916
|
+
|
|
2917
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2918
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2919
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2920
|
+
|
|
2921
|
+
nk_dots_e5m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2922
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
2923
|
+
if (rows_in_lower_tile > 0) {
|
|
2924
|
+
nk_dots_e5m2_load_a_sapphireamx_(&a_tile_lower,
|
|
2925
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2926
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
2927
|
+
}
|
|
2928
|
+
|
|
2929
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
|
|
2930
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
2931
|
+
(b_column_base + depth_tile_idx) * tile_size);
|
|
2932
|
+
|
|
2933
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
2934
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
2935
|
+
_tile_loadd(2, b_tile->data, 64);
|
|
2936
|
+
|
|
2937
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
2938
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
2939
|
+
}
|
|
2940
|
+
|
|
2941
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
2942
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
2943
|
+
|
|
2944
|
+
nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
|
|
2945
|
+
c_stride_elements, rows_in_upper_tile, 16);
|
|
2946
|
+
if (rows_in_lower_tile > 0) {
|
|
2947
|
+
nk_dots_bf16_store_sapphireamx_(&c_lower_state,
|
|
2948
|
+
c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
2949
|
+
c_stride_elements, rows_in_lower_tile, 16);
|
|
2950
|
+
}
|
|
2951
|
+
}
|
|
2952
|
+
|
|
2953
|
+
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
2954
|
+
if (column_remainder_count > 0) {
|
|
2955
|
+
nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
2956
|
+
nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
|
|
2957
|
+
nk_dots_bf16_b32x16_sapphireamx_t b_tile;
|
|
2958
|
+
|
|
2959
|
+
_tile_zero(4);
|
|
2960
|
+
_tile_zero(6);
|
|
2961
|
+
|
|
2962
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
2963
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
2964
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
2965
|
+
|
|
2966
|
+
nk_dots_e5m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
2967
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
2968
|
+
if (rows_in_lower_tile > 0) {
|
|
2969
|
+
nk_dots_e5m2_load_a_sapphireamx_(&a_tile_lower,
|
|
2970
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
2971
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
2972
|
+
}
|
|
2973
|
+
|
|
2974
|
+
nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
|
|
2975
|
+
valid_depth);
|
|
2976
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
2977
|
+
|
|
2978
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
2979
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
2980
|
+
_tile_loadd(2, b_tile.data, 64);
|
|
2981
|
+
|
|
2982
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
2983
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
2984
|
+
}
|
|
2985
|
+
|
|
2986
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
2987
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
2988
|
+
|
|
2989
|
+
nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
|
|
2990
|
+
c_stride_elements, rows_in_upper_tile, column_remainder_count);
|
|
2991
|
+
if (rows_in_lower_tile > 0) {
|
|
2992
|
+
nk_dots_bf16_store_sapphireamx_(&c_lower_state,
|
|
2993
|
+
c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
2994
|
+
c_stride_elements, rows_in_lower_tile, column_remainder_count);
|
|
2995
|
+
}
|
|
2996
|
+
}
|
|
2997
|
+
}
|
|
2998
|
+
|
|
2999
|
+
_tile_release();
|
|
3000
|
+
}
|
|
3001
|
+
|
|
3002
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_sapphireamx( //
|
|
3003
|
+
nk_e5m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
|
|
3004
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
|
|
3005
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
3006
|
+
|
|
3007
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
3008
|
+
|
|
3009
|
+
// Handle row slicing: compute rows [row_start, row_end)
|
|
3010
|
+
nk_size_t const row_end = (row_count == 0)
|
|
3011
|
+
? n_vectors
|
|
3012
|
+
: (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
|
|
3013
|
+
|
|
3014
|
+
// Round depth up to multiple of 96 (3 tiles × 32 elements)
|
|
3015
|
+
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
|
|
3016
|
+
nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
|
|
3017
|
+
|
|
3018
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tiles[3];
|
|
3019
|
+
nk_dots_bf16_a16x32_sapphireamx_t b_src_tiles[3];
|
|
3020
|
+
nk_dots_bf16_b32x16_sapphireamx_t b_tiles[3];
|
|
3021
|
+
nk_dots_bf16_state_sapphireamx_t state;
|
|
3022
|
+
|
|
3023
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
3024
|
+
|
|
3025
|
+
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
3026
|
+
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
3027
|
+
|
|
3028
|
+
for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
|
|
3029
|
+
nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
|
|
3030
|
+
|
|
3031
|
+
nk_dots_bf16_init_sapphireamx_(&state);
|
|
3032
|
+
|
|
3033
|
+
for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
|
|
3034
|
+
nk_size_t const depth_base = depth_group_idx * 96;
|
|
3035
|
+
|
|
3036
|
+
for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
|
|
3037
|
+
nk_size_t const depth_start = depth_base + tile_idx * 32;
|
|
3038
|
+
nk_size_t const valid_depth = (depth_start + 32 <= depth)
|
|
3039
|
+
? 32
|
|
3040
|
+
: (depth > depth_start ? depth - depth_start : 0);
|
|
3041
|
+
|
|
3042
|
+
nk_dots_e5m2_load_a_sapphireamx_( //
|
|
3043
|
+
&a_tiles[tile_idx], //
|
|
3044
|
+
vectors + row_tile * stride + depth_start, //
|
|
3045
|
+
stride, valid_rows, valid_depth);
|
|
3046
|
+
|
|
3047
|
+
if (row_tile == col_tile) {
|
|
3048
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3049
|
+
}
|
|
3050
|
+
else {
|
|
3051
|
+
nk_dots_e5m2_load_a_sapphireamx_( //
|
|
3052
|
+
&b_src_tiles[tile_idx], //
|
|
3053
|
+
vectors + col_tile * stride + depth_start, //
|
|
3054
|
+
stride, valid_cols, valid_depth);
|
|
3055
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3056
|
+
}
|
|
3057
|
+
}
|
|
3058
|
+
|
|
3059
|
+
nk_dots_bf16_update_sapphireamx_( //
|
|
3060
|
+
&state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
|
|
3061
|
+
}
|
|
3062
|
+
|
|
3063
|
+
nk_dots_bf16_store_sapphireamx_( //
|
|
3064
|
+
&state, result + row_tile * result_stride_elements + col_tile, //
|
|
3065
|
+
result_stride_elements, valid_rows, valid_cols);
|
|
3066
|
+
}
|
|
3067
|
+
}
|
|
3068
|
+
}
|
|
3069
|
+
|
|
3070
|
+
NK_PUBLIC void nk_dots_symmetric_e4m3_sapphireamx( //
|
|
3071
|
+
nk_e4m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
|
|
3072
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
|
|
3073
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
3074
|
+
|
|
3075
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
3076
|
+
|
|
3077
|
+
// Handle row slicing: compute rows [row_start, row_end)
|
|
3078
|
+
nk_size_t const row_end = (row_count == 0)
|
|
3079
|
+
? n_vectors
|
|
3080
|
+
: (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
|
|
3081
|
+
|
|
3082
|
+
// Round depth up to multiple of 96 (3 tiles × 32 elements)
|
|
3083
|
+
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
|
|
3084
|
+
nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
|
|
3085
|
+
|
|
3086
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tiles[3];
|
|
3087
|
+
nk_dots_bf16_a16x32_sapphireamx_t b_src_tiles[3];
|
|
3088
|
+
nk_dots_bf16_b32x16_sapphireamx_t b_tiles[3];
|
|
3089
|
+
nk_dots_bf16_state_sapphireamx_t state;
|
|
3090
|
+
|
|
3091
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
3092
|
+
|
|
3093
|
+
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
3094
|
+
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
3095
|
+
|
|
3096
|
+
for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
|
|
3097
|
+
nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
|
|
3098
|
+
|
|
3099
|
+
nk_dots_bf16_init_sapphireamx_(&state);
|
|
3100
|
+
|
|
3101
|
+
for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
|
|
3102
|
+
nk_size_t const depth_base = depth_group_idx * 96;
|
|
3103
|
+
|
|
3104
|
+
for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
|
|
3105
|
+
nk_size_t const depth_start = depth_base + tile_idx * 32;
|
|
3106
|
+
nk_size_t const valid_depth = (depth_start + 32 <= depth)
|
|
3107
|
+
? 32
|
|
3108
|
+
: (depth > depth_start ? depth - depth_start : 0);
|
|
3109
|
+
|
|
3110
|
+
nk_dots_e4m3_load_a_sapphireamx_( //
|
|
3111
|
+
&a_tiles[tile_idx], //
|
|
3112
|
+
vectors + row_tile * stride + depth_start, //
|
|
3113
|
+
stride, valid_rows, valid_depth);
|
|
3114
|
+
|
|
3115
|
+
if (row_tile == col_tile) {
|
|
3116
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3117
|
+
}
|
|
3118
|
+
else {
|
|
3119
|
+
nk_dots_e4m3_load_a_sapphireamx_( //
|
|
3120
|
+
&b_src_tiles[tile_idx], //
|
|
3121
|
+
vectors + col_tile * stride + depth_start, //
|
|
3122
|
+
stride, valid_cols, valid_depth);
|
|
3123
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3124
|
+
}
|
|
3125
|
+
}
|
|
3126
|
+
|
|
3127
|
+
nk_dots_bf16_update_sapphireamx_( //
|
|
3128
|
+
&state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
|
|
3129
|
+
}
|
|
3130
|
+
|
|
3131
|
+
nk_dots_bf16_store_sapphireamx_( //
|
|
3132
|
+
&state, result + row_tile * result_stride_elements + col_tile, //
|
|
3133
|
+
result_stride_elements, valid_rows, valid_cols);
|
|
3134
|
+
}
|
|
3135
|
+
}
|
|
3136
|
+
}
|
|
3137
|
+
|
|
3138
|
+
#pragma endregion // Quarter Precision E5M2
|
|
3139
|
+
|
|
3140
|
+
#pragma region Micro Precision E2M3
|
|
3141
|
+
|
|
3142
|
+
/* Load E2M3 A tile with E2M3 to signed I8 conversion via VPERMB LUT.
|
|
3143
|
+
* Each E2M3 byte encodes: bit 5 = sign, bits 4:0 = magnitude (5-bit index).
|
|
3144
|
+
* The LUT maps 5-bit magnitude to value * 16, then sign is applied via conditional negation.
|
|
3145
|
+
* Result is stored in INT8 tile for use with _tile_dpbssd.
|
|
3146
|
+
*/
|
|
3147
|
+
NK_INTERNAL void nk_dots_e2m3_load_a_sapphireamx_( //
|
|
3148
|
+
nk_dots_i8_a16x64_sapphireamx_t *a_tile, //
|
|
3149
|
+
nk_e2m3_t const *src, nk_size_t src_stride, //
|
|
3150
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
3151
|
+
|
|
3152
|
+
// Build 64-byte LUT for VPERMB: 32 entries replicated to fill both halves.
|
|
3153
|
+
// magnitude → value×16:
|
|
3154
|
+
// e=0 (step 2): {0,2,4,6,8,10,12,14},
|
|
3155
|
+
// e=1 (step 2): {16,18,20,22,24,26,28,30},
|
|
3156
|
+
// e=2 (step 4): {32,36,40,44,48,52,56,60},
|
|
3157
|
+
// e=3 (step 8): {64,72,80,88,96,104,112,120}
|
|
3158
|
+
NK_ALIGN64 static nk_u8_t const lut_bytes[64] = {
|
|
3159
|
+
0, 2, 4, 6, 8, 10, 12, 14, //
|
|
3160
|
+
16, 18, 20, 22, 24, 26, 28, 30, //
|
|
3161
|
+
32, 36, 40, 44, 48, 52, 56, 60, //
|
|
3162
|
+
64, 72, 80, 88, 96, 104, 112, 120, //
|
|
3163
|
+
0, 2, 4, 6, 8, 10, 12, 14, //
|
|
3164
|
+
16, 18, 20, 22, 24, 26, 28, 30, //
|
|
3165
|
+
32, 36, 40, 44, 48, 52, 56, 60, //
|
|
3166
|
+
64, 72, 80, 88, 96, 104, 112, 120, //
|
|
3167
|
+
};
|
|
3168
|
+
__m512i magnitude_lut_u8x64 = _mm512_load_si512((__m512i const *)lut_bytes);
|
|
3169
|
+
__m512i sign_mask_u8x64 = _mm512_set1_epi8(0x20);
|
|
3170
|
+
__m512i magnitude_mask_u8x64 = _mm512_set1_epi8(0x1F);
|
|
3171
|
+
__m512i zero_i8x64 = _mm512_setzero_si512();
|
|
3172
|
+
|
|
3173
|
+
__mmask64 column_mask = (valid_cols >= 64) ? 0xFFFFFFFFFFFFFFFFULL : ((__mmask64)1 << valid_cols) - 1;
|
|
3174
|
+
|
|
3175
|
+
for (nk_size_t row = 0; row < 16; row++) {
|
|
3176
|
+
if (row < valid_rows) {
|
|
3177
|
+
__m512i raw_u8x64 = _mm512_maskz_loadu_epi8(column_mask, src + row * src_stride);
|
|
3178
|
+
__m512i magnitude_u8x64 = _mm512_and_si512(raw_u8x64, magnitude_mask_u8x64);
|
|
3179
|
+
__m512i unsigned_value_u8x64 = _mm512_permutexvar_epi8(magnitude_u8x64, magnitude_lut_u8x64);
|
|
3180
|
+
__mmask64 negate_mask = _mm512_test_epi8_mask(raw_u8x64, sign_mask_u8x64);
|
|
3181
|
+
__m512i signed_value_i8x64 = _mm512_mask_sub_epi8(unsigned_value_u8x64, negate_mask, zero_i8x64,
|
|
3182
|
+
unsigned_value_u8x64);
|
|
3183
|
+
_mm512_store_si512(a_tile->data[row], signed_value_i8x64);
|
|
3184
|
+
}
|
|
3185
|
+
else { _mm512_store_si512(a_tile->data[row], zero_i8x64); }
|
|
3186
|
+
}
|
|
3187
|
+
nk_compiler_barrier_sapphireamx_();
|
|
3188
|
+
}
|
|
3189
|
+
|
|
3190
|
+
/* Store E2M3 accumulator: read I32 state, convert to F32, multiply by 1/256, store as F32. */
|
|
3191
|
+
NK_INTERNAL void nk_dots_e2m3_store_sapphireamx_( //
|
|
3192
|
+
nk_dots_i8_state_sapphireamx_t const *state, //
|
|
3193
|
+
nk_f32_t *dst, nk_size_t dst_stride_elements, //
|
|
3194
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
3195
|
+
|
|
3196
|
+
__mmask16 column_mask = (valid_cols >= 16) ? 0xFFFF : ((__mmask16)1 << valid_cols) - 1;
|
|
3197
|
+
__m512 scale = _mm512_set1_ps(1.0f / 256.0f);
|
|
3198
|
+
|
|
3199
|
+
for (nk_size_t row = 0; row < valid_rows; row++) {
|
|
3200
|
+
__m512i i32_row = _mm512_load_si512(state->data[row]);
|
|
3201
|
+
__m512 f32_row = _mm512_mul_ps(_mm512_cvtepi32_ps(i32_row), scale);
|
|
3202
|
+
_mm512_mask_storeu_ps(dst + row * dst_stride_elements, column_mask, f32_row);
|
|
3203
|
+
}
|
|
3204
|
+
}
|
|
3205
|
+
|
|
3206
|
+
/* Store E2M3 2x2 accumulator state to F32 output matrix with masking for edge tiles. */
|
|
3207
|
+
NK_INTERNAL void nk_dots_e2m3_output2x2_sapphireamx_( //
|
|
3208
|
+
nk_dots_i8_state2x2_sapphireamx_t const *state, //
|
|
3209
|
+
nk_f32_t *dst, nk_size_t dst_stride_elements, //
|
|
3210
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
3211
|
+
|
|
3212
|
+
nk_size_t const rows_upper = (valid_rows > 16) ? 16 : valid_rows;
|
|
3213
|
+
nk_size_t const cols_left = (valid_cols > 16) ? 16 : valid_cols;
|
|
3214
|
+
nk_size_t const cols_right = (valid_cols > 16) ? valid_cols - 16 : 0;
|
|
3215
|
+
|
|
3216
|
+
if (rows_upper > 0 && cols_left > 0)
|
|
3217
|
+
nk_dots_e2m3_store_sapphireamx_(&state->c[0][0], dst, dst_stride_elements, rows_upper, cols_left);
|
|
3218
|
+
if (rows_upper > 0 && cols_right > 0)
|
|
3219
|
+
nk_dots_e2m3_store_sapphireamx_(&state->c[0][1], dst + 16, dst_stride_elements, rows_upper, cols_right);
|
|
3220
|
+
|
|
3221
|
+
if (valid_rows > 16) {
|
|
3222
|
+
nk_size_t const rows_lower = valid_rows - 16;
|
|
3223
|
+
nk_f32_t *dst_lower = dst + 16 * dst_stride_elements;
|
|
3224
|
+
if (cols_left > 0)
|
|
3225
|
+
nk_dots_e2m3_store_sapphireamx_(&state->c[1][0], dst_lower, dst_stride_elements, rows_lower, cols_left);
|
|
3226
|
+
if (cols_right > 0)
|
|
3227
|
+
nk_dots_e2m3_store_sapphireamx_(&state->c[1][1], dst_lower + 16, dst_stride_elements, rows_lower,
|
|
3228
|
+
cols_right);
|
|
3229
|
+
}
|
|
3230
|
+
}
|
|
3231
|
+
|
|
3232
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e2m3_sapphireamx(nk_size_t column_count, nk_size_t depth) {
|
|
3233
|
+
// E2M3 uses INT8 tile layout after conversion (same element count: 64 per row)
|
|
3234
|
+
return nk_dots_packed_size_i8_sapphireamx(column_count, depth);
|
|
3235
|
+
}
|
|
3236
|
+
|
|
3237
|
+
NK_PUBLIC void nk_dots_pack_e2m3_sapphireamx( //
|
|
3238
|
+
nk_e2m3_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
3239
|
+
nk_size_t b_stride, void *b_packed) {
|
|
3240
|
+
|
|
3241
|
+
// AMX I8 tile dimensions: 16 rows x 64 columns (1024 I8 elements = 1KB)
|
|
3242
|
+
nk_size_t const tmm_rows = 16;
|
|
3243
|
+
nk_size_t const tmm_cols = 64;
|
|
3244
|
+
nk_size_t const tile_elements = 1024;
|
|
3245
|
+
nk_size_t const tile_bytes = tile_elements * sizeof(nk_i8_t);
|
|
3246
|
+
|
|
3247
|
+
nk_size_t const column_tiles_count = column_count / tmm_rows;
|
|
3248
|
+
nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
|
|
3249
|
+
nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
|
|
3250
|
+
nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
|
|
3251
|
+
|
|
3252
|
+
nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
|
|
3253
|
+
header->full_column_tiles = (nk_u32_t)column_tiles_count;
|
|
3254
|
+
header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
|
|
3255
|
+
header->column_remainder_count = (nk_u32_t)column_remainder_count;
|
|
3256
|
+
|
|
3257
|
+
nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
|
|
3258
|
+
nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
|
|
3259
|
+
header->column_edge_offset = (nk_u32_t)column_edge_offset;
|
|
3260
|
+
|
|
3261
|
+
nk_i8_t *tiles_ptr = (nk_i8_t *)((char *)b_packed + tiles_offset);
|
|
3262
|
+
nk_i8_t *column_edge_ptr = (nk_i8_t *)((char *)b_packed + column_edge_offset);
|
|
3263
|
+
|
|
3264
|
+
// Zero-initialize all tiles (handles depth remainder padding)
|
|
3265
|
+
for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
|
|
3266
|
+
|
|
3267
|
+
// E2M3 magnitude-to-value LUT (value * 16)
|
|
3268
|
+
static nk_u8_t const lut_magnitude[32] = {
|
|
3269
|
+
0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, //
|
|
3270
|
+
32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120, //
|
|
3271
|
+
};
|
|
3272
|
+
|
|
3273
|
+
// Pack tiles with E2M3 -> I8 conversion and quad-interleaving
|
|
3274
|
+
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
3275
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
3276
|
+
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
3277
|
+
nk_i8_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
3278
|
+
|
|
3279
|
+
nk_size_t const src_row_start = column_tile_idx * tmm_rows;
|
|
3280
|
+
nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
|
|
3281
|
+
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
3282
|
+
: (depth - src_column_start);
|
|
3283
|
+
|
|
3284
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
3285
|
+
for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
|
|
3286
|
+
nk_size_t const src_idx = (src_row_start + row_idx) * b_stride + src_column_start + column_idx;
|
|
3287
|
+
nk_size_t const dst_idx = (column_idx / 4) * 64 + row_idx * 4 + (column_idx % 4);
|
|
3288
|
+
nk_u8_t raw = b[src_idx];
|
|
3289
|
+
nk_u8_t magnitude = raw & 0x1F;
|
|
3290
|
+
nk_i8_t val = (nk_i8_t)lut_magnitude[magnitude];
|
|
3291
|
+
if (raw & 0x20) val = -val;
|
|
3292
|
+
tile_output[dst_idx] = val;
|
|
3293
|
+
}
|
|
3294
|
+
}
|
|
3295
|
+
}
|
|
3296
|
+
}
|
|
3297
|
+
|
|
3298
|
+
// Pack column-remainder rows (convert E2M3 to I8)
|
|
3299
|
+
if (column_remainder_count > 0) {
|
|
3300
|
+
nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
|
|
3301
|
+
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
3302
|
+
for (nk_size_t column_idx = 0; column_idx < depth; column_idx++) {
|
|
3303
|
+
nk_u8_t raw = b[(remainder_start_row + row_idx) * b_stride + column_idx];
|
|
3304
|
+
nk_u8_t magnitude = raw & 0x1F;
|
|
3305
|
+
nk_i8_t val = (nk_i8_t)lut_magnitude[magnitude];
|
|
3306
|
+
if (raw & 0x20) val = -val;
|
|
3307
|
+
column_edge_ptr[row_idx * depth + column_idx] = val;
|
|
3308
|
+
}
|
|
3309
|
+
}
|
|
3310
|
+
}
|
|
3311
|
+
|
|
3312
|
+
// Compute and store per-column norms for angular/euclidean distance
|
|
3313
|
+
nk_size_t norms_offset = column_edge_offset +
|
|
3314
|
+
(column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_i8_t) : 0);
|
|
3315
|
+
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
3316
|
+
nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
|
|
3317
|
+
for (nk_size_t col = 0; col < column_count; col++)
|
|
3318
|
+
norms[col] = nk_dots_reduce_sumsq_e2m3_(b + col * b_stride, depth);
|
|
3319
|
+
}
|
|
3320
|
+
|
|
3321
|
+
NK_PUBLIC void nk_dots_packed_e2m3_sapphireamx( //
|
|
3322
|
+
nk_e2m3_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
3323
|
+
nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
|
|
3324
|
+
nk_unused_(cols_count);
|
|
3325
|
+
|
|
3326
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
3327
|
+
nk_size_t const column_tiles_count = header->full_column_tiles;
|
|
3328
|
+
nk_size_t const depth_tiles_count = header->full_depth_tiles;
|
|
3329
|
+
nk_size_t const column_remainder_count = header->column_remainder_count;
|
|
3330
|
+
|
|
3331
|
+
// B tiles are already in I8 format
|
|
3332
|
+
nk_i8_t const *b_tiles_base = (nk_i8_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
|
|
3333
|
+
nk_i8_t const *col_edge_ptr = (nk_i8_t const *)((char const *)b_packed + header->column_edge_offset);
|
|
3334
|
+
|
|
3335
|
+
nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
|
|
3336
|
+
nk_size_t const tile_depth = 64;
|
|
3337
|
+
nk_size_t const tile_size = 1024;
|
|
3338
|
+
nk_size_t const full_cols = column_tiles_count * 16;
|
|
3339
|
+
|
|
3340
|
+
nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
|
|
3341
|
+
nk_size_t const col_blocks_count = column_tiles_count / 2;
|
|
3342
|
+
|
|
3343
|
+
if (depth_tiles_count == 0) return;
|
|
3344
|
+
|
|
3345
|
+
nk_dots_i8_a16x64_sapphireamx_t a_tile_upper, a_tile_lower;
|
|
3346
|
+
nk_dots_i8_state2x2_sapphireamx_t c_accum_buffer;
|
|
3347
|
+
|
|
3348
|
+
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
3349
|
+
nk_size_t const depth_remainder = depth % tile_depth;
|
|
3350
|
+
|
|
3351
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
3352
|
+
|
|
3353
|
+
// Loop order: row_blocks outer, col_blocks inner
|
|
3354
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
3355
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
3356
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
3357
|
+
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
3358
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
3359
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
3360
|
+
|
|
3361
|
+
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
3362
|
+
nk_size_t const col_block_start = column_block_idx * 32;
|
|
3363
|
+
nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
|
|
3364
|
+
nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
|
|
3365
|
+
|
|
3366
|
+
// Zero accumulators (TMM4-7 stay resident across entire depth loop)
|
|
3367
|
+
_tile_zero(4);
|
|
3368
|
+
_tile_zero(5);
|
|
3369
|
+
_tile_zero(6);
|
|
3370
|
+
_tile_zero(7);
|
|
3371
|
+
|
|
3372
|
+
// E2M3 always uses buffered load for E2M3 -> I8 conversion
|
|
3373
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
3374
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
3375
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
3376
|
+
|
|
3377
|
+
// Load A with E2M3 -> I8 conversion
|
|
3378
|
+
nk_dots_e2m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3379
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
3380
|
+
if (rows_in_lower_tile > 0) {
|
|
3381
|
+
nk_dots_e2m3_load_a_sapphireamx_(&a_tile_lower,
|
|
3382
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
3383
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
3384
|
+
}
|
|
3385
|
+
|
|
3386
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_left =
|
|
3387
|
+
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
3388
|
+
(b_column_left_base + depth_tile_idx) * tile_size);
|
|
3389
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile_right =
|
|
3390
|
+
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
3391
|
+
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
3392
|
+
|
|
3393
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
3394
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
3395
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
3396
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
3397
|
+
|
|
3398
|
+
_tile_dpbssd(4, 0, 2);
|
|
3399
|
+
_tile_dpbssd(5, 0, 3);
|
|
3400
|
+
_tile_dpbssd(6, 1, 2);
|
|
3401
|
+
_tile_dpbssd(7, 1, 3);
|
|
3402
|
+
}
|
|
3403
|
+
|
|
3404
|
+
// Store accumulators to output (once per output block)
|
|
3405
|
+
// Can't directly store I32 tiles to F32 output, must buffer + convert
|
|
3406
|
+
if (is_full_row_block) {
|
|
3407
|
+
nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
|
|
3408
|
+
nk_dots_i8_state2x2_sapphireamx_t c_accum_buffer;
|
|
3409
|
+
_tile_stored(4, c_accum_buffer.c[0][0].data, 64);
|
|
3410
|
+
_tile_stored(5, c_accum_buffer.c[0][1].data, 64);
|
|
3411
|
+
_tile_stored(6, c_accum_buffer.c[1][0].data, 64);
|
|
3412
|
+
_tile_stored(7, c_accum_buffer.c[1][1].data, 64);
|
|
3413
|
+
nk_dots_e2m3_output2x2_sapphireamx_(&c_accum_buffer, c_block, c_stride_elements, valid_rows_count, 32);
|
|
3414
|
+
}
|
|
3415
|
+
else {
|
|
3416
|
+
_tile_stored(4, c_accum_buffer.c[0][0].data, 64);
|
|
3417
|
+
_tile_stored(5, c_accum_buffer.c[0][1].data, 64);
|
|
3418
|
+
_tile_stored(6, c_accum_buffer.c[1][0].data, 64);
|
|
3419
|
+
_tile_stored(7, c_accum_buffer.c[1][1].data, 64);
|
|
3420
|
+
nk_dots_e2m3_output2x2_sapphireamx_(&c_accum_buffer,
|
|
3421
|
+
c + row_block_start * c_stride_elements + col_block_start,
|
|
3422
|
+
c_stride_elements, valid_rows_count, 32);
|
|
3423
|
+
}
|
|
3424
|
+
}
|
|
3425
|
+
|
|
3426
|
+
// Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
|
|
3427
|
+
if (column_tiles_count % 2 == 1) {
|
|
3428
|
+
nk_size_t const column_tile_idx = column_tiles_count - 1;
|
|
3429
|
+
nk_size_t const col_start = column_tile_idx * 16;
|
|
3430
|
+
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
3431
|
+
|
|
3432
|
+
nk_dots_i8_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
3433
|
+
_tile_zero(4);
|
|
3434
|
+
_tile_zero(6);
|
|
3435
|
+
|
|
3436
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
3437
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
3438
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
3439
|
+
|
|
3440
|
+
nk_dots_e2m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3441
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
3442
|
+
if (rows_in_lower_tile > 0) {
|
|
3443
|
+
nk_dots_e2m3_load_a_sapphireamx_(&a_tile_lower,
|
|
3444
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
3445
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
3446
|
+
}
|
|
3447
|
+
|
|
3448
|
+
nk_dots_i8_b64x16_sapphireamx_t const *b_tile =
|
|
3449
|
+
(nk_dots_i8_b64x16_sapphireamx_t const *)(b_tiles_base +
|
|
3450
|
+
(b_column_base + depth_tile_idx) * tile_size);
|
|
3451
|
+
|
|
3452
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
3453
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
3454
|
+
_tile_loadd(2, b_tile->data, 64);
|
|
3455
|
+
|
|
3456
|
+
_tile_dpbssd(4, 0, 2);
|
|
3457
|
+
_tile_dpbssd(6, 1, 2);
|
|
3458
|
+
}
|
|
3459
|
+
|
|
3460
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
3461
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
3462
|
+
|
|
3463
|
+
nk_dots_e2m3_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
|
|
3464
|
+
c_stride_elements, rows_in_upper_tile, 16);
|
|
3465
|
+
if (rows_in_lower_tile > 0) {
|
|
3466
|
+
nk_dots_e2m3_store_sapphireamx_(&c_lower_state,
|
|
3467
|
+
c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
3468
|
+
c_stride_elements, rows_in_lower_tile, 16);
|
|
3469
|
+
}
|
|
3470
|
+
}
|
|
3471
|
+
|
|
3472
|
+
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
3473
|
+
if (column_remainder_count > 0) {
|
|
3474
|
+
nk_dots_i8_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
3475
|
+
nk_dots_i8_a16x64_sapphireamx_t b_as_a;
|
|
3476
|
+
nk_dots_i8_b64x16_sapphireamx_t b_tile;
|
|
3477
|
+
|
|
3478
|
+
_tile_zero(4);
|
|
3479
|
+
_tile_zero(6);
|
|
3480
|
+
|
|
3481
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
3482
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
3483
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
3484
|
+
|
|
3485
|
+
nk_dots_e2m3_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3486
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
3487
|
+
if (rows_in_lower_tile > 0) {
|
|
3488
|
+
nk_dots_e2m3_load_a_sapphireamx_(&a_tile_lower,
|
|
3489
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
3490
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
3491
|
+
}
|
|
3492
|
+
|
|
3493
|
+
// B edge data is already in I8 format
|
|
3494
|
+
nk_dots_i8_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
|
|
3495
|
+
valid_depth);
|
|
3496
|
+
nk_dots_pack_i8_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
3497
|
+
|
|
3498
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
3499
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
3500
|
+
_tile_loadd(2, b_tile.data, 64);
|
|
3501
|
+
|
|
3502
|
+
_tile_dpbssd(4, 0, 2);
|
|
3503
|
+
_tile_dpbssd(6, 1, 2);
|
|
3504
|
+
}
|
|
3505
|
+
|
|
3506
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
3507
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
3508
|
+
|
|
3509
|
+
nk_dots_e2m3_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
|
|
3510
|
+
c_stride_elements, rows_in_upper_tile, column_remainder_count);
|
|
3511
|
+
if (rows_in_lower_tile > 0) {
|
|
3512
|
+
nk_dots_e2m3_store_sapphireamx_(&c_lower_state,
|
|
3513
|
+
c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
3514
|
+
c_stride_elements, rows_in_lower_tile, column_remainder_count);
|
|
3515
|
+
}
|
|
3516
|
+
}
|
|
3517
|
+
}
|
|
3518
|
+
|
|
3519
|
+
_tile_release();
|
|
3520
|
+
}
|
|
3521
|
+
|
|
3522
|
+
NK_PUBLIC void nk_dots_symmetric_e2m3_sapphireamx( //
|
|
3523
|
+
nk_e2m3_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
|
|
3524
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
|
|
3525
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
3526
|
+
|
|
3527
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
3528
|
+
|
|
3529
|
+
// Handle row slicing: compute rows [row_start, row_end)
|
|
3530
|
+
nk_size_t const row_end = (row_count == 0)
|
|
3531
|
+
? n_vectors
|
|
3532
|
+
: (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
|
|
3533
|
+
|
|
3534
|
+
// Round depth up to multiple of 192 (3 tiles x 64 elements)
|
|
3535
|
+
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 64);
|
|
3536
|
+
nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
|
|
3537
|
+
|
|
3538
|
+
nk_dots_i8_a16x64_sapphireamx_t a_tiles[3];
|
|
3539
|
+
nk_dots_i8_a16x64_sapphireamx_t b_src_tiles[3];
|
|
3540
|
+
nk_dots_i8_b64x16_sapphireamx_t b_tiles[3];
|
|
3541
|
+
nk_dots_i8_state_sapphireamx_t state;
|
|
3542
|
+
|
|
3543
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
3544
|
+
|
|
3545
|
+
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
3546
|
+
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
3547
|
+
|
|
3548
|
+
for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
|
|
3549
|
+
nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
|
|
3550
|
+
|
|
3551
|
+
nk_dots_i8_init_sapphireamx_(&state);
|
|
3552
|
+
|
|
3553
|
+
for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
|
|
3554
|
+
nk_size_t const depth_base = depth_group_idx * 192;
|
|
3555
|
+
|
|
3556
|
+
for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
|
|
3557
|
+
nk_size_t const depth_start = depth_base + tile_idx * 64;
|
|
3558
|
+
nk_size_t const valid_depth = (depth_start + 64 <= depth)
|
|
3559
|
+
? 64
|
|
3560
|
+
: (depth > depth_start ? depth - depth_start : 0);
|
|
3561
|
+
|
|
3562
|
+
nk_dots_e2m3_load_a_sapphireamx_( //
|
|
3563
|
+
&a_tiles[tile_idx], //
|
|
3564
|
+
vectors + row_tile * stride + depth_start, //
|
|
3565
|
+
stride, valid_rows, valid_depth);
|
|
3566
|
+
|
|
3567
|
+
if (row_tile == col_tile) {
|
|
3568
|
+
nk_dots_pack_i8_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3569
|
+
}
|
|
3570
|
+
else {
|
|
3571
|
+
nk_dots_e2m3_load_a_sapphireamx_( //
|
|
3572
|
+
&b_src_tiles[tile_idx], //
|
|
3573
|
+
vectors + col_tile * stride + depth_start, //
|
|
3574
|
+
stride, valid_cols, valid_depth);
|
|
3575
|
+
nk_dots_pack_i8_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3576
|
+
}
|
|
3577
|
+
}
|
|
3578
|
+
|
|
3579
|
+
nk_dots_i8_update_sapphireamx_( //
|
|
3580
|
+
&state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
|
|
3581
|
+
}
|
|
3582
|
+
|
|
3583
|
+
nk_dots_e2m3_store_sapphireamx_( //
|
|
3584
|
+
&state, result + row_tile * result_stride_elements + col_tile, //
|
|
3585
|
+
result_stride_elements, valid_rows, valid_cols);
|
|
3586
|
+
}
|
|
3587
|
+
}
|
|
3588
|
+
}
|
|
3589
|
+
|
|
3590
|
+
#pragma endregion // Micro Precision E2M3
|
|
3591
|
+
|
|
3592
|
+
#pragma region Micro Precision E3M2
|
|
3593
|
+
|
|
3594
|
+
/* Load E3M2 A tile with FP8 to BF16 conversion */
|
|
3595
|
+
NK_INTERNAL void nk_dots_e3m2_load_a_sapphireamx_( //
|
|
3596
|
+
nk_dots_bf16_a16x32_sapphireamx_t *a_tile, //
|
|
3597
|
+
nk_e3m2_t const *src, nk_size_t src_stride, //
|
|
3598
|
+
nk_size_t valid_rows, nk_size_t valid_cols) {
|
|
3599
|
+
|
|
3600
|
+
__mmask32 column_mask = (valid_cols >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << valid_cols) - 1;
|
|
3601
|
+
__m512i zero = _mm512_setzero_si512();
|
|
3602
|
+
|
|
3603
|
+
for (nk_size_t row_idx = 0; row_idx < 16; row_idx++) {
|
|
3604
|
+
if (row_idx < valid_rows) {
|
|
3605
|
+
__m256i e3m2_row = _mm256_maskz_loadu_epi8(column_mask, src + row_idx * src_stride);
|
|
3606
|
+
__m512i bf16_row = nk_e3m2x32_to_bf16x32_icelake_(e3m2_row);
|
|
3607
|
+
_mm512_store_si512((__m512i *)a_tile->data[row_idx], bf16_row);
|
|
3608
|
+
}
|
|
3609
|
+
else { _mm512_store_si512((__m512i *)a_tile->data[row_idx], zero); }
|
|
3610
|
+
}
|
|
3611
|
+
nk_compiler_barrier_sapphireamx_();
|
|
3612
|
+
}
|
|
3613
|
+
|
|
3614
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e3m2_sapphireamx(nk_size_t column_count, nk_size_t depth) {
|
|
3615
|
+
return nk_dots_packed_size_bf16_sapphireamx(column_count, depth);
|
|
3616
|
+
}
|
|
3617
|
+
|
|
3618
|
+
NK_PUBLIC void nk_dots_pack_e3m2_sapphireamx( //
|
|
3619
|
+
nk_e3m2_t const *b, nk_size_t column_count, nk_size_t depth, //
|
|
3620
|
+
nk_size_t b_stride, void *b_packed) {
|
|
3621
|
+
|
|
3622
|
+
nk_size_t const tmm_rows = 16;
|
|
3623
|
+
nk_size_t const tmm_cols = 32;
|
|
3624
|
+
nk_size_t const tile_elements = 512;
|
|
3625
|
+
nk_size_t const tile_bytes = tile_elements * sizeof(nk_bf16_t);
|
|
3626
|
+
|
|
3627
|
+
nk_size_t const column_tiles_count = column_count / tmm_rows;
|
|
3628
|
+
nk_size_t const depth_tiles_count = nk_size_divide_round_up_(depth, tmm_cols);
|
|
3629
|
+
nk_size_t const column_remainder_count = column_count - column_tiles_count * tmm_rows;
|
|
3630
|
+
nk_size_t const total_tiles = column_tiles_count * depth_tiles_count;
|
|
3631
|
+
|
|
3632
|
+
nk_dots_amx_packed_header_t *header = (nk_dots_amx_packed_header_t *)b_packed;
|
|
3633
|
+
header->full_column_tiles = (nk_u32_t)column_tiles_count;
|
|
3634
|
+
header->full_depth_tiles = (nk_u32_t)depth_tiles_count;
|
|
3635
|
+
header->column_remainder_count = (nk_u32_t)column_remainder_count;
|
|
3636
|
+
|
|
3637
|
+
nk_size_t const tiles_offset = sizeof(nk_dots_amx_packed_header_t);
|
|
3638
|
+
nk_size_t const column_edge_offset = tiles_offset + total_tiles * tile_bytes;
|
|
3639
|
+
header->column_edge_offset = (nk_u32_t)column_edge_offset;
|
|
3640
|
+
|
|
3641
|
+
nk_bf16_t *tiles_ptr = (nk_bf16_t *)((char *)b_packed + tiles_offset);
|
|
3642
|
+
nk_bf16_t *column_edge_ptr = (nk_bf16_t *)((char *)b_packed + column_edge_offset);
|
|
3643
|
+
|
|
3644
|
+
for (nk_size_t idx = 0; idx < total_tiles * tile_elements; idx++) tiles_ptr[idx] = 0;
|
|
3645
|
+
|
|
3646
|
+
for (nk_size_t column_tile_idx = 0; column_tile_idx < column_tiles_count; column_tile_idx++) {
|
|
3647
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
3648
|
+
nk_size_t const tile_index = column_tile_idx * depth_tiles_count + depth_tile_idx;
|
|
3649
|
+
nk_bf16_t *tile_output = tiles_ptr + tile_index * tile_elements;
|
|
3650
|
+
|
|
3651
|
+
nk_size_t const src_row_start = column_tile_idx * tmm_rows;
|
|
3652
|
+
nk_size_t const src_column_start = depth_tile_idx * tmm_cols;
|
|
3653
|
+
nk_size_t const columns_to_pack = (src_column_start + tmm_cols <= depth) ? tmm_cols
|
|
3654
|
+
: (depth - src_column_start);
|
|
3655
|
+
|
|
3656
|
+
for (nk_size_t row_idx = 0; row_idx < tmm_rows; row_idx++) {
|
|
3657
|
+
nk_size_t src_row = src_row_start + row_idx;
|
|
3658
|
+
__mmask32 column_mask = (columns_to_pack >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns_to_pack) - 1;
|
|
3659
|
+
__m256i e3m2_row = _mm256_maskz_loadu_epi8(column_mask, b + src_row * b_stride + src_column_start);
|
|
3660
|
+
__m512i bf16_row = nk_e3m2x32_to_bf16x32_icelake_(e3m2_row);
|
|
3661
|
+
nk_bf16_t bf16_buf[32];
|
|
3662
|
+
_mm512_storeu_si512((__m512i *)bf16_buf, bf16_row);
|
|
3663
|
+
for (nk_size_t column_idx = 0; column_idx < columns_to_pack; column_idx++) {
|
|
3664
|
+
nk_size_t const dst_idx = (column_idx / 2) * 32 + row_idx * 2 + (column_idx % 2);
|
|
3665
|
+
tile_output[dst_idx] = bf16_buf[column_idx];
|
|
3666
|
+
}
|
|
3667
|
+
}
|
|
3668
|
+
}
|
|
3669
|
+
}
|
|
3670
|
+
|
|
3671
|
+
if (column_remainder_count > 0) {
|
|
3672
|
+
nk_size_t const remainder_start_row = column_tiles_count * tmm_rows;
|
|
3673
|
+
for (nk_size_t row_idx = 0; row_idx < column_remainder_count; row_idx++) {
|
|
3674
|
+
for (nk_size_t column_idx = 0; column_idx < depth; column_idx += 32) {
|
|
3675
|
+
nk_size_t columns = (column_idx + 32 <= depth) ? 32 : (depth - column_idx);
|
|
3676
|
+
__mmask32 column_mask = (columns >= 32) ? 0xFFFFFFFF : ((__mmask32)1 << columns) - 1;
|
|
3677
|
+
__m256i e3m2_chunk = _mm256_maskz_loadu_epi8(
|
|
3678
|
+
column_mask, b + (remainder_start_row + row_idx) * b_stride + column_idx);
|
|
3679
|
+
__m512i bf16_chunk = nk_e3m2x32_to_bf16x32_icelake_(e3m2_chunk);
|
|
3680
|
+
_mm512_mask_storeu_epi16(column_edge_ptr + row_idx * depth + column_idx, column_mask, bf16_chunk);
|
|
3681
|
+
}
|
|
3682
|
+
}
|
|
3683
|
+
}
|
|
3684
|
+
|
|
3685
|
+
// Compute and store per-column norms for angular/euclidean distance
|
|
3686
|
+
nk_size_t norms_offset = column_edge_offset +
|
|
3687
|
+
(column_remainder_count > 0 ? column_remainder_count * depth * sizeof(nk_bf16_t) : 0);
|
|
3688
|
+
header->norms_byte_offset = (nk_u32_t)norms_offset;
|
|
3689
|
+
nk_f32_t *norms = (nk_f32_t *)((char *)b_packed + norms_offset);
|
|
3690
|
+
for (nk_size_t col = 0; col < column_count; col++)
|
|
3691
|
+
norms[col] = nk_dots_reduce_sumsq_e3m2_(b + col * b_stride, depth);
|
|
3692
|
+
}
|
|
3693
|
+
|
|
3694
|
+
NK_PUBLIC void nk_dots_packed_e3m2_sapphireamx( //
|
|
3695
|
+
nk_e3m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
3696
|
+
nk_size_t rows_count, nk_size_t cols_count, nk_size_t depth, nk_size_t a_stride_bytes, nk_size_t c_stride_bytes) {
|
|
3697
|
+
nk_unused_(cols_count);
|
|
3698
|
+
|
|
3699
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
3700
|
+
nk_size_t const column_tiles_count = header->full_column_tiles;
|
|
3701
|
+
nk_size_t const depth_tiles_count = header->full_depth_tiles;
|
|
3702
|
+
nk_size_t const column_remainder_count = header->column_remainder_count;
|
|
3703
|
+
|
|
3704
|
+
nk_bf16_t const *b_tiles_base = (nk_bf16_t const *)((char const *)b_packed + sizeof(nk_dots_amx_packed_header_t));
|
|
3705
|
+
nk_bf16_t const *col_edge_ptr = (nk_bf16_t const *)((char const *)b_packed + header->column_edge_offset);
|
|
3706
|
+
|
|
3707
|
+
nk_size_t const c_stride_elements = c_stride_bytes / sizeof(nk_f32_t);
|
|
3708
|
+
nk_size_t const tile_depth = 32;
|
|
3709
|
+
nk_size_t const tile_size = 512;
|
|
3710
|
+
nk_size_t const full_cols = column_tiles_count * 16;
|
|
3711
|
+
|
|
3712
|
+
nk_size_t const row_blocks_count = nk_size_divide_round_up_(rows_count, 32);
|
|
3713
|
+
nk_size_t const col_blocks_count = column_tiles_count / 2;
|
|
3714
|
+
|
|
3715
|
+
if (depth_tiles_count == 0) return;
|
|
3716
|
+
|
|
3717
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tile_upper, a_tile_lower;
|
|
3718
|
+
nk_dots_bf16_state2x2_sapphireamx_t c_accum_buffer;
|
|
3719
|
+
|
|
3720
|
+
nk_size_t const full_depth_tiles_count = depth / tile_depth;
|
|
3721
|
+
nk_size_t const depth_remainder = depth % tile_depth;
|
|
3722
|
+
|
|
3723
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
3724
|
+
|
|
3725
|
+
// Loop order: row_blocks outer, col_blocks inner
|
|
3726
|
+
for (nk_size_t row_block_idx = 0; row_block_idx < row_blocks_count; row_block_idx++) {
|
|
3727
|
+
nk_size_t const row_block_start = row_block_idx * 32;
|
|
3728
|
+
nk_size_t const valid_rows_count = (row_block_start + 32 <= rows_count) ? 32 : (rows_count - row_block_start);
|
|
3729
|
+
nk_size_t const is_full_row_block = (valid_rows_count == 32);
|
|
3730
|
+
nk_size_t const rows_in_upper_tile = (valid_rows_count > 16) ? 16 : valid_rows_count;
|
|
3731
|
+
nk_size_t const rows_in_lower_tile = (valid_rows_count > 16) ? valid_rows_count - 16 : 0;
|
|
3732
|
+
|
|
3733
|
+
for (nk_size_t column_block_idx = 0; column_block_idx < col_blocks_count; column_block_idx++) {
|
|
3734
|
+
nk_size_t const col_block_start = column_block_idx * 32;
|
|
3735
|
+
nk_size_t const b_column_left_base = (column_block_idx * 2) * depth_tiles_count;
|
|
3736
|
+
nk_size_t const b_column_right_base = (column_block_idx * 2 + 1) * depth_tiles_count;
|
|
3737
|
+
|
|
3738
|
+
// Zero accumulators (TMM4-7 stay resident across entire depth loop)
|
|
3739
|
+
_tile_zero(4);
|
|
3740
|
+
_tile_zero(5);
|
|
3741
|
+
_tile_zero(6);
|
|
3742
|
+
_tile_zero(7);
|
|
3743
|
+
|
|
3744
|
+
// FP8 always uses buffered load for E3M2 -> BF16 conversion
|
|
3745
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
3746
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
3747
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
3748
|
+
|
|
3749
|
+
// Load A with FP8 -> BF16 conversion
|
|
3750
|
+
nk_dots_e3m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3751
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
3752
|
+
if (rows_in_lower_tile > 0) {
|
|
3753
|
+
nk_dots_e3m2_load_a_sapphireamx_(&a_tile_lower,
|
|
3754
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
3755
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
3756
|
+
}
|
|
3757
|
+
|
|
3758
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_left =
|
|
3759
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
3760
|
+
(b_column_left_base + depth_tile_idx) * tile_size);
|
|
3761
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile_right =
|
|
3762
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
3763
|
+
(b_column_right_base + depth_tile_idx) * tile_size);
|
|
3764
|
+
|
|
3765
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
3766
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
3767
|
+
_tile_loadd(2, b_tile_left->data, 64);
|
|
3768
|
+
_tile_loadd(3, b_tile_right->data, 64);
|
|
3769
|
+
|
|
3770
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
3771
|
+
_tile_dpbf16ps(5, 0, 3);
|
|
3772
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
3773
|
+
_tile_dpbf16ps(7, 1, 3);
|
|
3774
|
+
}
|
|
3775
|
+
|
|
3776
|
+
// Store accumulators to output (once per output block)
|
|
3777
|
+
if (is_full_row_block) {
|
|
3778
|
+
nk_f32_t *c_block = c + row_block_start * c_stride_elements + col_block_start;
|
|
3779
|
+
_tile_stored(4, c_block, c_stride_bytes);
|
|
3780
|
+
_tile_stored(5, c_block + 16, c_stride_bytes);
|
|
3781
|
+
_tile_stored(6, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes), c_stride_bytes);
|
|
3782
|
+
_tile_stored(7, (nk_f32_t *)((char *)c_block + 16 * c_stride_bytes) + 16, c_stride_bytes);
|
|
3783
|
+
}
|
|
3784
|
+
else {
|
|
3785
|
+
_tile_stored(4, c_accum_buffer.c[0][0].data, 64);
|
|
3786
|
+
_tile_stored(5, c_accum_buffer.c[0][1].data, 64);
|
|
3787
|
+
_tile_stored(6, c_accum_buffer.c[1][0].data, 64);
|
|
3788
|
+
_tile_stored(7, c_accum_buffer.c[1][1].data, 64);
|
|
3789
|
+
nk_dots_bf16_output2x2_sapphireamx_(&c_accum_buffer,
|
|
3790
|
+
c + row_block_start * c_stride_elements + col_block_start,
|
|
3791
|
+
c_stride_elements, valid_rows_count, 32);
|
|
3792
|
+
}
|
|
3793
|
+
}
|
|
3794
|
+
|
|
3795
|
+
// Handle odd column-tile (single 16-column tile if column_tiles_count is odd)
|
|
3796
|
+
if (column_tiles_count % 2 == 1) {
|
|
3797
|
+
nk_size_t const column_tile_idx = column_tiles_count - 1;
|
|
3798
|
+
nk_size_t const col_start = column_tile_idx * 16;
|
|
3799
|
+
nk_size_t const b_column_base = column_tile_idx * depth_tiles_count;
|
|
3800
|
+
|
|
3801
|
+
nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
3802
|
+
_tile_zero(4);
|
|
3803
|
+
_tile_zero(6);
|
|
3804
|
+
|
|
3805
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
3806
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
3807
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
3808
|
+
|
|
3809
|
+
nk_dots_e3m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3810
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
3811
|
+
if (rows_in_lower_tile > 0) {
|
|
3812
|
+
nk_dots_e3m2_load_a_sapphireamx_(&a_tile_lower,
|
|
3813
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
3814
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
3815
|
+
}
|
|
3816
|
+
|
|
3817
|
+
nk_dots_bf16_b32x16_sapphireamx_t const *b_tile =
|
|
3818
|
+
(nk_dots_bf16_b32x16_sapphireamx_t const *)(b_tiles_base +
|
|
3819
|
+
(b_column_base + depth_tile_idx) * tile_size);
|
|
3820
|
+
|
|
3821
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
3822
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
3823
|
+
_tile_loadd(2, b_tile->data, 64);
|
|
3824
|
+
|
|
3825
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
3826
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
3827
|
+
}
|
|
3828
|
+
|
|
3829
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
3830
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
3831
|
+
|
|
3832
|
+
nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + col_start,
|
|
3833
|
+
c_stride_elements, rows_in_upper_tile, 16);
|
|
3834
|
+
if (rows_in_lower_tile > 0) {
|
|
3835
|
+
nk_dots_bf16_store_sapphireamx_(&c_lower_state,
|
|
3836
|
+
c + (row_block_start + 16) * c_stride_elements + col_start,
|
|
3837
|
+
c_stride_elements, rows_in_lower_tile, 16);
|
|
3838
|
+
}
|
|
3839
|
+
}
|
|
3840
|
+
|
|
3841
|
+
// Handle column-edge (remaining columns < 16) using AMX with partial tiles
|
|
3842
|
+
if (column_remainder_count > 0) {
|
|
3843
|
+
nk_dots_bf16_state_sapphireamx_t c_upper_state, c_lower_state;
|
|
3844
|
+
nk_dots_bf16_a16x32_sapphireamx_t b_as_a;
|
|
3845
|
+
nk_dots_bf16_b32x16_sapphireamx_t b_tile;
|
|
3846
|
+
|
|
3847
|
+
_tile_zero(4);
|
|
3848
|
+
_tile_zero(6);
|
|
3849
|
+
|
|
3850
|
+
for (nk_size_t depth_tile_idx = 0; depth_tile_idx < depth_tiles_count; depth_tile_idx++) {
|
|
3851
|
+
nk_size_t const depth_offset = depth_tile_idx * tile_depth;
|
|
3852
|
+
nk_size_t const valid_depth = (depth_tile_idx < full_depth_tiles_count) ? tile_depth : depth_remainder;
|
|
3853
|
+
|
|
3854
|
+
nk_dots_e3m2_load_a_sapphireamx_(&a_tile_upper, a + row_block_start * a_stride_bytes + depth_offset,
|
|
3855
|
+
a_stride_bytes, rows_in_upper_tile, valid_depth);
|
|
3856
|
+
if (rows_in_lower_tile > 0) {
|
|
3857
|
+
nk_dots_e3m2_load_a_sapphireamx_(&a_tile_lower,
|
|
3858
|
+
a + (row_block_start + 16) * a_stride_bytes + depth_offset,
|
|
3859
|
+
a_stride_bytes, rows_in_lower_tile, valid_depth);
|
|
3860
|
+
}
|
|
3861
|
+
|
|
3862
|
+
nk_dots_bf16_load_a_sapphireamx_(&b_as_a, col_edge_ptr + depth_offset, depth, column_remainder_count,
|
|
3863
|
+
valid_depth);
|
|
3864
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&b_as_a, &b_tile);
|
|
3865
|
+
|
|
3866
|
+
_tile_loadd(0, a_tile_upper.data, 64);
|
|
3867
|
+
_tile_loadd(1, a_tile_lower.data, 64);
|
|
3868
|
+
_tile_loadd(2, b_tile.data, 64);
|
|
3869
|
+
|
|
3870
|
+
_tile_dpbf16ps(4, 0, 2);
|
|
3871
|
+
_tile_dpbf16ps(6, 1, 2);
|
|
3872
|
+
}
|
|
3873
|
+
|
|
3874
|
+
_tile_stored(4, c_upper_state.data, 64);
|
|
3875
|
+
_tile_stored(6, c_lower_state.data, 64);
|
|
3876
|
+
|
|
3877
|
+
nk_dots_bf16_store_sapphireamx_(&c_upper_state, c + row_block_start * c_stride_elements + full_cols,
|
|
3878
|
+
c_stride_elements, rows_in_upper_tile, column_remainder_count);
|
|
3879
|
+
if (rows_in_lower_tile > 0) {
|
|
3880
|
+
nk_dots_bf16_store_sapphireamx_(&c_lower_state,
|
|
3881
|
+
c + (row_block_start + 16) * c_stride_elements + full_cols,
|
|
3882
|
+
c_stride_elements, rows_in_lower_tile, column_remainder_count);
|
|
3883
|
+
}
|
|
3884
|
+
}
|
|
3885
|
+
}
|
|
3886
|
+
|
|
3887
|
+
_tile_release();
|
|
3888
|
+
}
|
|
3889
|
+
|
|
3890
|
+
NK_PUBLIC void nk_dots_symmetric_e3m2_sapphireamx( //
|
|
3891
|
+
nk_e3m2_t const *vectors, nk_size_t n_vectors, nk_size_t depth, //
|
|
3892
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride, //
|
|
3893
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
3894
|
+
|
|
3895
|
+
nk_size_t const stride_elements = stride; // sizeof(nk_e3m2_t) == 1, so bytes == elements
|
|
3896
|
+
nk_size_t const result_stride_elements = result_stride / sizeof(nk_f32_t);
|
|
3897
|
+
|
|
3898
|
+
// Handle row slicing: compute rows [row_start, row_end)
|
|
3899
|
+
nk_size_t const row_end = (row_count == 0)
|
|
3900
|
+
? n_vectors
|
|
3901
|
+
: (row_start + row_count < n_vectors ? row_start + row_count : n_vectors);
|
|
3902
|
+
|
|
3903
|
+
// Round depth up to multiple of 96 (3 tiles x 32 bf16 elements)
|
|
3904
|
+
nk_size_t const depth_tiles = nk_size_divide_round_up_(depth, 32);
|
|
3905
|
+
nk_size_t const depth_tile_groups = nk_size_divide_round_up_(depth_tiles, 3);
|
|
3906
|
+
|
|
3907
|
+
nk_dots_bf16_a16x32_sapphireamx_t a_tiles[3];
|
|
3908
|
+
nk_dots_bf16_a16x32_sapphireamx_t b_src_tiles[3];
|
|
3909
|
+
nk_dots_bf16_b32x16_sapphireamx_t b_tiles[3];
|
|
3910
|
+
nk_dots_bf16_state_sapphireamx_t state;
|
|
3911
|
+
|
|
3912
|
+
nk_amx_tile_configure_sapphireamx_();
|
|
3913
|
+
|
|
3914
|
+
for (nk_size_t row_tile = row_start; row_tile < row_end; row_tile += 16) {
|
|
3915
|
+
nk_size_t const valid_rows = (row_tile + 16 <= row_end) ? 16 : (row_end - row_tile);
|
|
3916
|
+
|
|
3917
|
+
for (nk_size_t col_tile = 0; col_tile < n_vectors; col_tile += 16) {
|
|
3918
|
+
nk_size_t const valid_cols = (col_tile + 16 <= n_vectors) ? 16 : (n_vectors - col_tile);
|
|
3919
|
+
|
|
3920
|
+
nk_dots_bf16_init_sapphireamx_(&state);
|
|
3921
|
+
|
|
3922
|
+
for (nk_size_t depth_group_idx = 0; depth_group_idx < depth_tile_groups; depth_group_idx++) {
|
|
3923
|
+
nk_size_t const depth_base = depth_group_idx * 96;
|
|
3924
|
+
|
|
3925
|
+
for (int tile_idx = 0; tile_idx < 3; tile_idx++) {
|
|
3926
|
+
nk_size_t const depth_start = depth_base + tile_idx * 32;
|
|
3927
|
+
nk_size_t const valid_depth = (depth_start + 32 <= depth)
|
|
3928
|
+
? 32
|
|
3929
|
+
: (depth > depth_start ? depth - depth_start : 0);
|
|
3930
|
+
|
|
3931
|
+
nk_dots_e3m2_load_a_sapphireamx_( //
|
|
3932
|
+
&a_tiles[tile_idx], //
|
|
3933
|
+
vectors + row_tile * stride_elements + depth_start, //
|
|
3934
|
+
stride_elements, valid_rows, valid_depth);
|
|
3935
|
+
|
|
3936
|
+
if (row_tile == col_tile) {
|
|
3937
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&a_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3938
|
+
}
|
|
3939
|
+
else {
|
|
3940
|
+
nk_dots_e3m2_load_a_sapphireamx_( //
|
|
3941
|
+
&b_src_tiles[tile_idx], //
|
|
3942
|
+
vectors + col_tile * stride_elements + depth_start, //
|
|
3943
|
+
stride_elements, valid_cols, valid_depth);
|
|
3944
|
+
nk_dots_pack_bf16_transposed_sapphireamx_(&b_src_tiles[tile_idx], &b_tiles[tile_idx]);
|
|
3945
|
+
}
|
|
3946
|
+
}
|
|
3947
|
+
|
|
3948
|
+
nk_dots_bf16_update_sapphireamx_( //
|
|
3949
|
+
&state, &a_tiles[0], &a_tiles[1], &a_tiles[2], &b_tiles[0], &b_tiles[1], &b_tiles[2]);
|
|
3950
|
+
}
|
|
3951
|
+
|
|
3952
|
+
nk_dots_bf16_store_sapphireamx_( //
|
|
3953
|
+
&state, result + row_tile * result_stride_elements + col_tile, //
|
|
3954
|
+
result_stride_elements, valid_rows, valid_cols);
|
|
3955
|
+
}
|
|
3956
|
+
}
|
|
3957
|
+
}
|
|
3958
|
+
|
|
3959
|
+
#pragma endregion // Micro Precision E3M2
|
|
3960
|
+
|
|
3961
|
+
#if defined(__clang__)
|
|
3962
|
+
#pragma clang attribute pop
|
|
3963
|
+
#elif defined(__GNUC__)
|
|
3964
|
+
#pragma GCC pop_options
|
|
3965
|
+
#endif
|
|
3966
|
+
|
|
3967
|
+
#if defined(__cplusplus)
|
|
3968
|
+
} // extern "C"
|
|
3969
|
+
#endif
|
|
3970
|
+
|
|
3971
|
+
#endif // NK_TARGET_SAPPHIREAMX
|
|
3972
|
+
#endif // NK_TARGET_X86_
|
|
3973
|
+
#endif // NK_DOTS_SAPPHIREAMX_H
|