numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,2844 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief SWAR-accelerated Batched Dot Products for SIMD-free CPUs.
|
|
3
|
+
* @file include/numkong/dots/serial.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date December 27, 2025
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/dots.h for API overview and use cases
|
|
8
|
+
*
|
|
9
|
+
* This file provides two macro families for generating GEMM kernels:
|
|
10
|
+
*
|
|
11
|
+
* - nk_define_dots_packed_: vectorized inner-products between rows of A and Bᵀ
|
|
12
|
+
* - nk_define_dots_symmetric_: vectorized inner-products between rows and columns of A
|
|
13
|
+
*
|
|
14
|
+
* Both use the same B packing format (see below), enabling pack-once-use-anywhere.
|
|
15
|
+
*
|
|
16
|
+
* @section packing B Matrix Packing Format
|
|
17
|
+
*
|
|
18
|
+
* Computing C = A × Bᵀ where:
|
|
19
|
+
*
|
|
20
|
+
* - A[row_count, depth] row-major: A[i, k] at address A + i × lda + k
|
|
21
|
+
* - B[column_count, depth] row-major (pre-transposed): B[j, k] at address B + j × ldb + k
|
|
22
|
+
* - C[row_count, column_count] row-major: C[i, j] at address C + i × ldc + j
|
|
23
|
+
*
|
|
24
|
+
* The API convention stores B as Bᵀ for efficient SIMD access:
|
|
25
|
+
*
|
|
26
|
+
* - A[i, k:k+4] is contiguous in row-major (good)
|
|
27
|
+
* - B[j, k:k+4] is contiguous in row-major (good - already transposed)
|
|
28
|
+
*
|
|
29
|
+
* Packing adds row grouping (group_size = 16) for:
|
|
30
|
+
*
|
|
31
|
+
* - Zero-padding on edges (avoids boundary checks in inner loop)
|
|
32
|
+
* - Cache-friendly blocking in outer loops
|
|
33
|
+
*
|
|
34
|
+
* Memory layout example - B[8, 8] with 8 output columns (j), 8 depth (k):
|
|
35
|
+
*
|
|
36
|
+
* k=0 k=1 k=2 k=3 k=4 k=5 k=6 k=7
|
|
37
|
+
* ┌─────────────────────────────────────────────────┐
|
|
38
|
+
* j=0 │ a0 a1 a2 a3 a4 a5 a6 a7 │
|
|
39
|
+
* j=1 │ b0 b1 b2 b3 b4 b5 b6 b7 │
|
|
40
|
+
* j=2 │ c0 c1 c2 c3 c4 c5 c6 c7 │
|
|
41
|
+
* j=3 │ d0 d1 d2 d3 d4 d5 d6 d7 │
|
|
42
|
+
* j=4 │ e0 e1 e2 e3 e4 e5 e6 e7 │
|
|
43
|
+
* j=5 │ f0 f1 f2 f3 f4 f5 f6 f7 │
|
|
44
|
+
* j=6 │ g0 g1 g2 g3 g4 g5 g6 g7 │
|
|
45
|
+
* j=7 │ h0 h1 h2 h3 h4 h5 h6 h7 │
|
|
46
|
+
* └─────────────────────────────────────────────────┘
|
|
47
|
+
*
|
|
48
|
+
* Packed as B_packed[column_count_padded, depth] (grouped for alignment):
|
|
49
|
+
*
|
|
50
|
+
* Group 0 (j=0..7, padded to 16):
|
|
51
|
+
* ┌───────────────────────────────────┐
|
|
52
|
+
* │ a0 a1 a2 a3 a4 a5 a6 a7 │ j=0 │ ← row 0 copied as-is
|
|
53
|
+
* │ b0 b1 b2 b3 b4 b5 b6 b7 │ j=1 │
|
|
54
|
+
* │ c0 c1 c2 c3 c4 c5 c6 c7 │ j=2 │
|
|
55
|
+
* │ d0 d1 d2 d3 d4 d5 d6 d7 │ j=3 │
|
|
56
|
+
* │ e0 e1 e2 e3 e4 e5 e6 e7 │ j=4 │
|
|
57
|
+
* │ f0 f1 f2 f3 f4 f5 f6 f7 │ j=5 │
|
|
58
|
+
* │ g0 g1 g2 g3 g4 g5 g6 g7 │ j=6 │
|
|
59
|
+
* │ h0 h1 h2 h3 h4 h5 h6 h7 │ j=7 │
|
|
60
|
+
* │ 00 00 00 00 00 00 00 00 │ padding │
|
|
61
|
+
* │ ... │ ... │
|
|
62
|
+
* └───────────────────────────────────┘
|
|
63
|
+
*
|
|
64
|
+
* Addressing formula for B_packed[j, k]:
|
|
65
|
+
*
|
|
66
|
+
* group = j / group_size
|
|
67
|
+
* j_in_group = j % group_size
|
|
68
|
+
* B_packed[j, k] = packed[group * group_size * depth + j_in_group * depth + k]
|
|
69
|
+
*
|
|
70
|
+
* Inner loop accesses B_packed[j, k:k+simd] which is contiguous - just ptr + k.
|
|
71
|
+
*/
|
|
72
|
+
|
|
73
|
+
#ifndef NK_DOTS_SERIAL_H
|
|
74
|
+
#define NK_DOTS_SERIAL_H
|
|
75
|
+
|
|
76
|
+
#include "numkong/types.h"
|
|
77
|
+
#include "numkong/cast/serial.h" // `nk_partial_load_b32x4_serial_`
|
|
78
|
+
#include "numkong/dot/serial.h" // `nk_dot_f32x4_state_serial_t`
|
|
79
|
+
#include "numkong/spatial/serial.h" // `nk_f32_sqrt_serial`
|
|
80
|
+
#include "numkong/reduce.h" // `nk_reduce_moments_*`
|
|
81
|
+
|
|
82
|
+
#if defined(__cplusplus)
|
|
83
|
+
extern "C" {
|
|
84
|
+
#endif
|
|
85
|
+
|
|
86
|
+
/* Packed buffer header (64-byte aligned).
|
|
87
|
+
* Used by all packed matmul backends (serial, NEON, AVX-512, SVE).
|
|
88
|
+
*
|
|
89
|
+
* Important units clarification:
|
|
90
|
+
* - For types where dimensions_per_value = 1 (f32, i8, u8, etc.): dimensions == values
|
|
91
|
+
* - For sub-byte types (i4x2, u4x2): dimensions ≠ values
|
|
92
|
+
* - dimensions = individual 4-bit nibbles (e.g., 128 nibbles)
|
|
93
|
+
* - values = storage bytes containing nibbles (e.g., 64 bytes for 128 nibbles)
|
|
94
|
+
* - dimensions_per_value = 2 (2 nibbles per byte)
|
|
95
|
+
*/
|
|
96
|
+
typedef struct {
|
|
97
|
+
nk_u32_t column_count; // Actual number of columns (not padded)
|
|
98
|
+
nk_u32_t depth_dimensions; // Logical depth in dimensions (nibbles for i4/u4, values for i8/f32)
|
|
99
|
+
nk_u32_t depth_padded_values; // Padded depth in storage values (bytes for i4/u4, values for i8/f32)
|
|
100
|
+
nk_u32_t reserved[13]; // Padding to 64 bytes
|
|
101
|
+
} nk_cross_packed_buffer_header_t;
|
|
102
|
+
|
|
103
|
+
/* Norm compute helpers for packing.
|
|
104
|
+
* Each computes the norm (sum-of-squares or popcount) of a contiguous row.
|
|
105
|
+
* Used by `nk_define_cross_pack_` to append per-column norms to packed buffers.
|
|
106
|
+
*/
|
|
107
|
+
NK_INTERNAL nk_f64_t nk_dots_reduce_sumsq_f64_(nk_f64_t const *data, nk_size_t count) {
|
|
108
|
+
nk_f64_t sum, sumsq;
|
|
109
|
+
nk_reduce_moments_f64(data, count, sizeof(nk_f64_t), &sum, &sumsq);
|
|
110
|
+
return sumsq;
|
|
111
|
+
}
|
|
112
|
+
NK_INTERNAL nk_f64_t nk_dots_reduce_sumsq_f32_(nk_f32_t const *data, nk_size_t count) {
|
|
113
|
+
nk_f64_t sum, sumsq;
|
|
114
|
+
nk_reduce_moments_f32(data, count, sizeof(nk_f32_t), &sum, &sumsq);
|
|
115
|
+
return sumsq;
|
|
116
|
+
}
|
|
117
|
+
NK_INTERNAL nk_f32_t nk_dots_reduce_sumsq_f16_(nk_f16_t const *data, nk_size_t count) {
|
|
118
|
+
nk_f32_t sum, sumsq;
|
|
119
|
+
nk_reduce_moments_f16(data, count, sizeof(nk_f16_t), &sum, &sumsq);
|
|
120
|
+
return sumsq;
|
|
121
|
+
}
|
|
122
|
+
NK_INTERNAL nk_f32_t nk_dots_reduce_sumsq_bf16_(nk_bf16_t const *data, nk_size_t count) {
|
|
123
|
+
nk_f32_t sum, sumsq;
|
|
124
|
+
nk_reduce_moments_bf16(data, count, sizeof(nk_bf16_t), &sum, &sumsq);
|
|
125
|
+
return sumsq;
|
|
126
|
+
}
|
|
127
|
+
NK_INTERNAL nk_f32_t nk_dots_reduce_sumsq_e4m3_(nk_e4m3_t const *data, nk_size_t count) {
|
|
128
|
+
nk_f32_t sum, sumsq;
|
|
129
|
+
nk_reduce_moments_e4m3(data, count, sizeof(nk_e4m3_t), &sum, &sumsq);
|
|
130
|
+
return sumsq;
|
|
131
|
+
}
|
|
132
|
+
NK_INTERNAL nk_f32_t nk_dots_reduce_sumsq_e5m2_(nk_e5m2_t const *data, nk_size_t count) {
|
|
133
|
+
nk_f32_t sum, sumsq;
|
|
134
|
+
nk_reduce_moments_e5m2(data, count, sizeof(nk_e5m2_t), &sum, &sumsq);
|
|
135
|
+
return sumsq;
|
|
136
|
+
}
|
|
137
|
+
NK_INTERNAL nk_f32_t nk_dots_reduce_sumsq_e2m3_(nk_e2m3_t const *data, nk_size_t count) {
|
|
138
|
+
nk_f32_t sum, sumsq;
|
|
139
|
+
nk_reduce_moments_e2m3(data, count, sizeof(nk_e2m3_t), &sum, &sumsq);
|
|
140
|
+
return sumsq;
|
|
141
|
+
}
|
|
142
|
+
NK_INTERNAL nk_f32_t nk_dots_reduce_sumsq_e3m2_(nk_e3m2_t const *data, nk_size_t count) {
|
|
143
|
+
nk_f32_t sum, sumsq;
|
|
144
|
+
nk_reduce_moments_e3m2(data, count, sizeof(nk_e3m2_t), &sum, &sumsq);
|
|
145
|
+
return sumsq;
|
|
146
|
+
}
|
|
147
|
+
NK_INTERNAL nk_u32_t nk_dots_reduce_sumsq_i8_(nk_i8_t const *data, nk_size_t count) {
|
|
148
|
+
nk_i64_t sum;
|
|
149
|
+
nk_u64_t sumsq;
|
|
150
|
+
nk_reduce_moments_i8(data, count, sizeof(nk_i8_t), &sum, &sumsq);
|
|
151
|
+
return (nk_u32_t)sumsq;
|
|
152
|
+
}
|
|
153
|
+
NK_INTERNAL nk_u32_t nk_dots_reduce_sumsq_u8_(nk_u8_t const *data, nk_size_t count) {
|
|
154
|
+
nk_u64_t sum, sumsq;
|
|
155
|
+
nk_reduce_moments_u8(data, count, sizeof(nk_u8_t), &sum, &sumsq);
|
|
156
|
+
return (nk_u32_t)sumsq;
|
|
157
|
+
}
|
|
158
|
+
NK_INTERNAL nk_u32_t nk_dots_reduce_sumsq_i4_(nk_i4x2_t const *data, nk_size_t count) {
|
|
159
|
+
nk_i64_t sum;
|
|
160
|
+
nk_u64_t sumsq;
|
|
161
|
+
nk_reduce_moments_i4(data, count, sizeof(nk_i4x2_t), &sum, &sumsq);
|
|
162
|
+
return (nk_u32_t)sumsq;
|
|
163
|
+
}
|
|
164
|
+
NK_INTERNAL nk_u32_t nk_dots_reduce_sumsq_u4_(nk_u4x2_t const *data, nk_size_t count) {
|
|
165
|
+
nk_u64_t sum, sumsq;
|
|
166
|
+
nk_reduce_moments_u4(data, count, sizeof(nk_u4x2_t), &sum, &sumsq);
|
|
167
|
+
return (nk_u32_t)sumsq;
|
|
168
|
+
}
|
|
169
|
+
NK_INTERNAL nk_u32_t nk_dots_reduce_sum_u1_(nk_u1x8_t const *data, nk_size_t count_bits) {
|
|
170
|
+
nk_u64_t sum, sumsq;
|
|
171
|
+
nk_reduce_moments_u1(data, count_bits, sizeof(nk_u1x8_t), &sum, &sumsq);
|
|
172
|
+
return (nk_u32_t)sum;
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
/* Combined moment trampolines for compensated GEMM.
|
|
176
|
+
* Each computes BOTH sum and norm (sum-of-squares) in a single nk_reduce_moments call.
|
|
177
|
+
* Used by nk_define_cross_compensated_pack_ to store both in the packed buffer.
|
|
178
|
+
*/
|
|
179
|
+
NK_INTERNAL void nk_dots_reduce_moments_i8_(nk_i8_t const *data, nk_size_t count, nk_i32_t *sum, nk_u32_t *norm) {
|
|
180
|
+
nk_i64_t s;
|
|
181
|
+
nk_u64_t sq;
|
|
182
|
+
nk_reduce_moments_i8(data, count, sizeof(nk_i8_t), &s, &sq);
|
|
183
|
+
*sum = (nk_i32_t)s;
|
|
184
|
+
*norm = (nk_u32_t)sq;
|
|
185
|
+
}
|
|
186
|
+
NK_INTERNAL void nk_dots_reduce_moments_u8_(nk_u8_t const *data, nk_size_t count, nk_u32_t *sum, nk_u32_t *norm) {
|
|
187
|
+
nk_u64_t s, sq;
|
|
188
|
+
nk_reduce_moments_u8(data, count, sizeof(nk_u8_t), &s, &sq);
|
|
189
|
+
*sum = (nk_u32_t)s;
|
|
190
|
+
*norm = (nk_u32_t)sq;
|
|
191
|
+
}
|
|
192
|
+
NK_INTERNAL void nk_dots_reduce_moments_i4_(nk_i4x2_t const *data, nk_size_t count, nk_i32_t *sum, nk_u32_t *norm) {
|
|
193
|
+
nk_i64_t s;
|
|
194
|
+
nk_u64_t sq;
|
|
195
|
+
nk_reduce_moments_i4(data, count, sizeof(nk_i4x2_t), &s, &sq);
|
|
196
|
+
*sum = (nk_i32_t)s;
|
|
197
|
+
*norm = (nk_u32_t)sq;
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
/* A-row sum helpers for compensated GEMM finalization.
|
|
201
|
+
* i8/u8: no A-side correction needed, stubs return 0.
|
|
202
|
+
* i4: needs A-side sum for correction term.
|
|
203
|
+
*/
|
|
204
|
+
NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i8_stub_(nk_i8_t const *d, nk_size_t c) {
|
|
205
|
+
nk_unused_(d);
|
|
206
|
+
nk_unused_(c);
|
|
207
|
+
return 0;
|
|
208
|
+
}
|
|
209
|
+
NK_INTERNAL nk_i32_t nk_dots_reduce_sum_u8_stub_(nk_u8_t const *d, nk_size_t c) {
|
|
210
|
+
nk_unused_(d);
|
|
211
|
+
nk_unused_(c);
|
|
212
|
+
return 0;
|
|
213
|
+
}
|
|
214
|
+
NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t count) {
|
|
215
|
+
nk_i64_t sum;
|
|
216
|
+
nk_u64_t sumsq;
|
|
217
|
+
nk_reduce_moments_i4(data, count, sizeof(nk_i4x2_t), &sum, &sumsq);
|
|
218
|
+
return (nk_i32_t)sum;
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
/**
|
|
222
|
+
* @brief Generates function to calculate packed B matrix buffer size for GEMM micro-kernels.
|
|
223
|
+
*
|
|
224
|
+
* Memory layout: B_packed[column_count, depth_padded] with header storing metadata.
|
|
225
|
+
* Buffer size: sizeof(header) + column_count × depth_padded × sizeof(intermediate_type) + column_count × sizeof(norm)
|
|
226
|
+
* Depth padding logic: Round up to `depth_simd_dimensions` multiple, then add `depth_simd_dimensions`
|
|
227
|
+
* if stride is power-of-2.
|
|
228
|
+
*
|
|
229
|
+
* @param api_name Operation name (hammings, dots)
|
|
230
|
+
* @param input_type_name Original type's name of B matrix values (i4, f16, bf16, e4m3, e5m2, f32, etc.)
|
|
231
|
+
* @param isa_suffix Platform Instruct Set Architecture suffix (serial, haswell, icelake, etc.)
|
|
232
|
+
* @param input_type Original type of B matrix values (i4x2, f16, bf16, e4m3, e5m2, f32, etc.)
|
|
233
|
+
* @param intermediate_type Internal storage type in packed buffer (often bf16 or f32 for mixed precision)
|
|
234
|
+
* @param norm_value_type Type of per-column norm values (f32, f64, u32) appended after packed data
|
|
235
|
+
* @param depth_simd_dimensions SIMD vector width in values for this platform/type combination
|
|
236
|
+
* @param dimensions_per_value Number of logical dimensions in a single value of input_type_name.
|
|
237
|
+
*/
|
|
238
|
+
#define nk_define_cross_pack_size_(api_name, input_type_name, isa_suffix, input_value_type, packed_value_type, \
|
|
239
|
+
norm_value_type, depth_simd_dimensions, dimensions_per_value) \
|
|
240
|
+
NK_PUBLIC nk_size_t nk_##api_name##_packed_size_##input_type_name##_##isa_suffix(nk_size_t column_count, \
|
|
241
|
+
nk_size_t depth) { \
|
|
242
|
+
/* depth is always in logical dimensions (nibbles for i4, bytes for i8, etc.) */ \
|
|
243
|
+
/* depth_simd_dimensions is also in logical dimensions */ \
|
|
244
|
+
\
|
|
245
|
+
/* Step 1: Pad depth in dimensions */ \
|
|
246
|
+
nk_size_t depth_dimensions_padded = nk_size_round_up_to_multiple_(depth, depth_simd_dimensions); \
|
|
247
|
+
\
|
|
248
|
+
/* Step 2: Convert dimensions to storage values */ \
|
|
249
|
+
nk_size_t depth_values_padded = nk_size_divide_round_up_(depth_dimensions_padded, dimensions_per_value); \
|
|
250
|
+
\
|
|
251
|
+
/* Step 3: Calculate stride in bytes for power-of-2 check */ \
|
|
252
|
+
nk_size_t const stride_bytes = depth_values_padded * sizeof(nk_##packed_value_type##_t); \
|
|
253
|
+
\
|
|
254
|
+
/* Step 4: Break power-of-2 strides for cache associativity */ \
|
|
255
|
+
if ((stride_bytes & (stride_bytes - 1)) == 0 && stride_bytes > 0) { \
|
|
256
|
+
/* Add one SIMD step worth of storage values */ \
|
|
257
|
+
depth_values_padded += nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
258
|
+
} \
|
|
259
|
+
\
|
|
260
|
+
/* Step 5: Return total buffer size (packed data + per-column norms) */ \
|
|
261
|
+
return sizeof(nk_cross_packed_buffer_header_t) + \
|
|
262
|
+
column_count * depth_values_padded * sizeof(nk_##packed_value_type##_t) + \
|
|
263
|
+
column_count * sizeof(nk_##norm_value_type##_t); \
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
/**
|
|
267
|
+
* @brief Generates function to pack and optionally convert B matrix for efficient GEMM inner loops.
|
|
268
|
+
*
|
|
269
|
+
* Packing serves two performance-critical purposes:
|
|
270
|
+
*
|
|
271
|
+
* 1. Type conversion (input_type → intermediate_type): For mixed-precision GEMM, convert B values
|
|
272
|
+
* once during packing rather than repeatedly in tight inner loops. Example: F16 → F32 conversion
|
|
273
|
+
* happens once per value instead of once per (row of A × value of B) access. This amortizes
|
|
274
|
+
* conversion cost across all rows of A.
|
|
275
|
+
*
|
|
276
|
+
* 2. Cache optimization: Pad depth to break power-of-2 byte strides that cause cache associativity
|
|
277
|
+
* conflicts. Example: depth = 8192, F32 → stride = 32,768 bytes (power-of-2) maps to same cache sets,
|
|
278
|
+
* causing conflict misses. Padding to 8200 → stride = 32,800 bytes (non-power-of-2) distributes
|
|
279
|
+
* accesses across more cache sets.
|
|
280
|
+
*
|
|
281
|
+
* Input layout: B[column_count, depth] stored row-major with b_stride_in_bytes between rows
|
|
282
|
+
* Output layout: B_packed[column_count, depth_padded] - simple column-major, no grouping
|
|
283
|
+
* Addressing: B_packed[j, k] = packed_data[j × depth_padded + k]
|
|
284
|
+
*
|
|
285
|
+
* Depth padding: Round up to `depth_simd_dimensions` multiple, then add `depth_simd_dimensions`
|
|
286
|
+
* if stride is power-of-2. Zero-initializes entire buffer before copying to handle padding safely.
|
|
287
|
+
*
|
|
288
|
+
* @param api_name Operation name (hammings, dots)
|
|
289
|
+
* @param input_type_name Original type's name of B matrix values (i4, f16, bf16, e4m3, e5m2, f32, etc.)
|
|
290
|
+
* @param isa_suffix Platform Instruct Set Architecture suffix (serial, haswell, icelake, etc.)
|
|
291
|
+
* @param input_type Original type of B matrix values (i4x2, f16, bf16, e4m3, e5m2, f32, etc.)
|
|
292
|
+
* @param intermediate_type Internal storage type in packed buffer (often bf16 or f32 for mixed precision)
|
|
293
|
+
* @param convert_value_fn Element conversion function: void fn(input_type const*, intermediate_type*)
|
|
294
|
+
* @param norm_value_type Type of per-column norm values (f32, f64, u32) appended after packed data
|
|
295
|
+
* @param compute_norm_fn Function: norm_value_type fn(input_value_type const*, nk_size_t count)
|
|
296
|
+
* @param depth_simd_dimensions SIMD vector width in values for depth padding alignment
|
|
297
|
+
* @param dimensions_per_value Number of logical dimensions in a single value of input_type.
|
|
298
|
+
*/
|
|
299
|
+
#define nk_define_cross_pack_(api_name, input_type_name, isa_suffix, input_value_type, packed_value_type, \
|
|
300
|
+
convert_value_fn, norm_value_type, compute_norm_fn, depth_simd_dimensions, \
|
|
301
|
+
dimensions_per_value) \
|
|
302
|
+
NK_PUBLIC void nk_##api_name##_pack_##input_type_name##_##isa_suffix( \
|
|
303
|
+
nk_##input_value_type##_t const *b, nk_size_t column_count, nk_size_t depth, nk_size_t b_stride_in_bytes, \
|
|
304
|
+
void *b_packed) { \
|
|
305
|
+
/* Use identical padding calculation as pack_size */ \
|
|
306
|
+
nk_size_t depth_dimensions_padded = nk_size_round_up_to_multiple_(depth, depth_simd_dimensions); \
|
|
307
|
+
nk_size_t depth_values_padded = nk_size_divide_round_up_(depth_dimensions_padded, dimensions_per_value); \
|
|
308
|
+
\
|
|
309
|
+
/* Power-of-2 breaking (same as pack_size) */ \
|
|
310
|
+
nk_size_t const stride_bytes = depth_values_padded * sizeof(nk_##packed_value_type##_t); \
|
|
311
|
+
if ((stride_bytes & (stride_bytes - 1)) == 0 && stride_bytes > 0) { \
|
|
312
|
+
depth_values_padded += nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
313
|
+
} \
|
|
314
|
+
\
|
|
315
|
+
/* Calculate input depth in values */ \
|
|
316
|
+
nk_size_t const depth_in_values = nk_size_divide_round_up_(depth, dimensions_per_value); \
|
|
317
|
+
\
|
|
318
|
+
/* Store dimensions in header */ \
|
|
319
|
+
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed; \
|
|
320
|
+
header->column_count = (nk_u32_t)column_count; \
|
|
321
|
+
header->depth_dimensions = (nk_u32_t)depth; /* depth in dimensions (nibbles for i4/u4) */ \
|
|
322
|
+
header->depth_padded_values = (nk_u32_t)depth_values_padded; /* padded depth in VALUES (bytes for i4/u4) */ \
|
|
323
|
+
\
|
|
324
|
+
nk_##packed_value_type##_t *packed = (nk_##packed_value_type##_t *)((char *)b_packed + \
|
|
325
|
+
sizeof(nk_cross_packed_buffer_header_t)); \
|
|
326
|
+
\
|
|
327
|
+
/* Zero entire buffer for depth padding */ \
|
|
328
|
+
nk_size_t const total_values = column_count * depth_values_padded; \
|
|
329
|
+
for (nk_size_t i = 0; i < total_values; ++i) packed[i] = 0; \
|
|
330
|
+
\
|
|
331
|
+
/* Copy/convert B[column_count, depth] to packed[column_count, depth_padded] - simple column-major */ \
|
|
332
|
+
for (nk_size_t column_index = 0; column_index < column_count; ++column_index) { \
|
|
333
|
+
nk_##packed_value_type##_t *destination_row = packed + column_index * depth_values_padded; \
|
|
334
|
+
nk_##input_value_type##_t const *source_row = \
|
|
335
|
+
(nk_##input_value_type##_t const *)((char const *)b + column_index * b_stride_in_bytes); \
|
|
336
|
+
for (nk_size_t depth_index = 0; depth_index < depth_in_values; ++depth_index) { \
|
|
337
|
+
convert_value_fn(&source_row[depth_index], &destination_row[depth_index]); \
|
|
338
|
+
} \
|
|
339
|
+
/* Padding values already zeroed above */ \
|
|
340
|
+
} \
|
|
341
|
+
\
|
|
342
|
+
/* Append per-column norms after packed data */ \
|
|
343
|
+
nk_##norm_value_type##_t *norms = (nk_##norm_value_type##_t *)(packed + total_values); \
|
|
344
|
+
for (nk_size_t column_index = 0; column_index < column_count; ++column_index) { \
|
|
345
|
+
nk_##input_value_type##_t const *source_row = \
|
|
346
|
+
(nk_##input_value_type##_t const *)((char const *)b + column_index * b_stride_in_bytes); \
|
|
347
|
+
norms[column_index] = compute_norm_fn(source_row, depth); \
|
|
348
|
+
} \
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
/**
|
|
352
|
+
* @brief Generates function to calculate packed B matrix buffer size for compensated GEMM.
|
|
353
|
+
*
|
|
354
|
+
* Like nk_define_cross_pack_size_ but the buffer stores BOTH norms AND column sums.
|
|
355
|
+
* Layout: [ Header 64B ] [ Packed data ] [ Norms (norm_type) ] [ Column sums (sum_type) ]
|
|
356
|
+
* Norms first → existing nk_define_cross_normalized_packed_ reads norms at the same offset.
|
|
357
|
+
*/
|
|
358
|
+
#define nk_define_cross_compensated_pack_size_(api_name, input_type_name, isa_suffix, input_value_type, \
|
|
359
|
+
packed_value_type, sum_value_type, norm_value_type, \
|
|
360
|
+
depth_simd_dimensions, dimensions_per_value) \
|
|
361
|
+
NK_PUBLIC nk_size_t nk_##api_name##_packed_size_##input_type_name##_##isa_suffix(nk_size_t column_count, \
|
|
362
|
+
nk_size_t depth) { \
|
|
363
|
+
nk_size_t depth_dimensions_padded = nk_size_round_up_to_multiple_(depth, depth_simd_dimensions); \
|
|
364
|
+
nk_size_t depth_values_padded = nk_size_divide_round_up_(depth_dimensions_padded, dimensions_per_value); \
|
|
365
|
+
nk_size_t const stride_bytes = depth_values_padded * sizeof(nk_##packed_value_type##_t); \
|
|
366
|
+
if ((stride_bytes & (stride_bytes - 1)) == 0 && stride_bytes > 0) { \
|
|
367
|
+
depth_values_padded += nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
368
|
+
} \
|
|
369
|
+
return sizeof(nk_cross_packed_buffer_header_t) + \
|
|
370
|
+
column_count * depth_values_padded * sizeof(nk_##packed_value_type##_t) + \
|
|
371
|
+
column_count * sizeof(nk_##norm_value_type##_t) + column_count * sizeof(nk_##sum_value_type##_t); \
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
/**
|
|
375
|
+
* @brief Generates function to pack B matrix with BOTH norms and column sums for compensated GEMM.
|
|
376
|
+
*
|
|
377
|
+
* Like nk_define_cross_pack_ but uses compute_moments_fn(data, count, &sum, &norm) to compute
|
|
378
|
+
* both sum and norm in a single pass, storing both after the packed data.
|
|
379
|
+
* Layout: [ Header ] [ Packed data ] [ Norms ] [ Column sums ]
|
|
380
|
+
*/
|
|
381
|
+
#define nk_define_cross_compensated_pack_(api_name, input_type_name, isa_suffix, input_value_type, packed_value_type, \
|
|
382
|
+
convert_value_fn, sum_value_type, norm_value_type, compute_moments_fn, \
|
|
383
|
+
depth_simd_dimensions, dimensions_per_value) \
|
|
384
|
+
NK_PUBLIC void nk_##api_name##_pack_##input_type_name##_##isa_suffix( \
|
|
385
|
+
nk_##input_value_type##_t const *b, nk_size_t column_count, nk_size_t depth, nk_size_t b_stride_in_bytes, \
|
|
386
|
+
void *b_packed) { \
|
|
387
|
+
nk_size_t depth_dimensions_padded = nk_size_round_up_to_multiple_(depth, depth_simd_dimensions); \
|
|
388
|
+
nk_size_t depth_values_padded = nk_size_divide_round_up_(depth_dimensions_padded, dimensions_per_value); \
|
|
389
|
+
nk_size_t const stride_bytes = depth_values_padded * sizeof(nk_##packed_value_type##_t); \
|
|
390
|
+
if ((stride_bytes & (stride_bytes - 1)) == 0 && stride_bytes > 0) { \
|
|
391
|
+
depth_values_padded += nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
392
|
+
} \
|
|
393
|
+
nk_size_t const depth_in_values = nk_size_divide_round_up_(depth, dimensions_per_value); \
|
|
394
|
+
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed; \
|
|
395
|
+
header->column_count = (nk_u32_t)column_count; \
|
|
396
|
+
header->depth_dimensions = (nk_u32_t)depth; \
|
|
397
|
+
header->depth_padded_values = (nk_u32_t)depth_values_padded; \
|
|
398
|
+
nk_##packed_value_type##_t *packed = (nk_##packed_value_type##_t *)((char *)b_packed + \
|
|
399
|
+
sizeof(nk_cross_packed_buffer_header_t)); \
|
|
400
|
+
nk_size_t const total_values = column_count * depth_values_padded; \
|
|
401
|
+
for (nk_size_t i = 0; i < total_values; ++i) packed[i] = 0; \
|
|
402
|
+
for (nk_size_t column_index = 0; column_index < column_count; ++column_index) { \
|
|
403
|
+
nk_##packed_value_type##_t *destination_row = packed + column_index * depth_values_padded; \
|
|
404
|
+
nk_##input_value_type##_t const *source_row = \
|
|
405
|
+
(nk_##input_value_type##_t const *)((char const *)b + column_index * b_stride_in_bytes); \
|
|
406
|
+
for (nk_size_t depth_index = 0; depth_index < depth_in_values; ++depth_index) { \
|
|
407
|
+
convert_value_fn(&source_row[depth_index], &destination_row[depth_index]); \
|
|
408
|
+
} \
|
|
409
|
+
} \
|
|
410
|
+
/* Norms first (same offset as non-compensated pack), then column sums */ \
|
|
411
|
+
nk_##norm_value_type##_t *norms = (nk_##norm_value_type##_t *)(packed + total_values); \
|
|
412
|
+
nk_##sum_value_type##_t *col_sums = (nk_##sum_value_type##_t *)(norms + column_count); \
|
|
413
|
+
for (nk_size_t column_index = 0; column_index < column_count; ++column_index) { \
|
|
414
|
+
nk_##input_value_type##_t const *source_row = \
|
|
415
|
+
(nk_##input_value_type##_t const *)((char const *)b + column_index * b_stride_in_bytes); \
|
|
416
|
+
compute_moments_fn(source_row, depth, &col_sums[column_index], &norms[column_index]); \
|
|
417
|
+
} \
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
/**
|
|
421
|
+
* @brief Generates optimized GEMM implementation: C = A × Bᵀ with pre-packed B matrix.
|
|
422
|
+
*
|
|
423
|
+
* This macro creates a complete batched matrix multiplication kernel with THREE specialized
|
|
424
|
+
* code paths that are automatically selected based on the remaining work at each blocking level.
|
|
425
|
+
* The kernel requires B to be pre-packed using nk_define_cross_pack_ before invocation.
|
|
426
|
+
*
|
|
427
|
+
* @par Mathematical Operation
|
|
428
|
+
* C[row_count, column_count] = A[row_count, depth] × Bᵀ[column_count, depth]
|
|
429
|
+
* where operation can be dot product, Hamming distance, Jaccard similarity, etc.
|
|
430
|
+
*
|
|
431
|
+
* @par Three Kernel Variants for Adaptive Performance
|
|
432
|
+
*
|
|
433
|
+
* 1. @b 4×4 @b register @b tile @b kernel (primary path, ~80% of work):
|
|
434
|
+
* - Processes 4 rows of A × 4 columns of B simultaneously
|
|
435
|
+
* - Maintains 16 independent accumulators in registers (state_type[4][4])
|
|
436
|
+
* - Achieves maximum instruction-level parallelism (16 FMAs per depth iteration)
|
|
437
|
+
* - Used when: row_count ≥ 4 AND column_count ≥ 4
|
|
438
|
+
* - Performance: Peak throughput, optimal register utilization
|
|
439
|
+
*
|
|
440
|
+
* 2. @b 1×8 @b register @b tile @b kernel (edge case, ~15% of work):
|
|
441
|
+
* - Processes 1 row of A × 8 columns of B when remaining rows < 4
|
|
442
|
+
* - Maintains 8 independent accumulators (state_type[1][8])
|
|
443
|
+
* - Balances vectorization with low row count
|
|
444
|
+
* - Used when: row_count < 4 AND column_count ≥ 8
|
|
445
|
+
* - Performance: Better throughput than generic fallback for wide matrices
|
|
446
|
+
*
|
|
447
|
+
* 3. @b Generic @b fallback @b kernel (edge cases, ~5% of work):
|
|
448
|
+
* - Handles all irregular cases (row_count < 4 AND column_count < 8)
|
|
449
|
+
* - Single accumulator, minimal unrolling
|
|
450
|
+
* - Used for: Small tiles, remainder handling
|
|
451
|
+
* - Performance: Lower throughput but handles all edge cases correctly
|
|
452
|
+
*
|
|
453
|
+
* @par Cache Blocking Strategy (No Depth Blocking)
|
|
454
|
+
*
|
|
455
|
+
* Unlike traditional GEMM which blocks all three dimensions (M, N, K), this implementation
|
|
456
|
+
* deliberately omits depth (K) blocking for several reasons:
|
|
457
|
+
*
|
|
458
|
+
* 1. @b Streaming @b access @b pattern: A and B are read sequentially along depth dimension
|
|
459
|
+
* - Prefetcher-friendly access (hardware prefetch works well)
|
|
460
|
+
* - No cache reuse along depth within a single C[i,j] computation
|
|
461
|
+
*
|
|
462
|
+
* 2. @b Depth @b is @b typically @b small: For ML inference, depth is often 128-4096 values
|
|
463
|
+
* - Fits in L2/L3 cache for single row of A
|
|
464
|
+
* - B is pre-packed for optimal spatial locality
|
|
465
|
+
*
|
|
466
|
+
* 3. @b Simplicity @b and @b instruction @b cache @b efficiency:
|
|
467
|
+
* - Fewer nested loops = better instruction cache utilization
|
|
468
|
+
* - Simpler control flow = easier for compiler to optimize
|
|
469
|
+
*
|
|
470
|
+
* @par Pre-Packing Benefits
|
|
471
|
+
*
|
|
472
|
+
* B matrix is pre-packed using nk_define_cross_pack_ before kernel invocation:
|
|
473
|
+
* - @b Type @b conversion @b amortization: Convert B values once (e.g., bf16→f32) rather than
|
|
474
|
+
* per A row access. Saves (row_count - 1) × column_count conversions.
|
|
475
|
+
* - @b Cache @b line @b optimization: Pad depth to break power-of-2 strides that cause cache
|
|
476
|
+
* associativity conflicts (e.g., 8192 → 8200 values).
|
|
477
|
+
* - @b Spatial @b locality: Transpose B so columns are contiguous, enabling efficient SIMD loads.
|
|
478
|
+
*
|
|
479
|
+
* @par Loop Structure
|
|
480
|
+
*
|
|
481
|
+
* for column_block in columns (step: varies based on available columns):
|
|
482
|
+
* for row_block in rows (step: varies based on available rows):
|
|
483
|
+
* for row_tile in row_block (step: 4 or 1 depending on variant):
|
|
484
|
+
* for column_tile in column_block (step: 4 or 8 depending on variant):
|
|
485
|
+
* accumulator_tiles[row_tile][column_tile] = init_accumulator_fn()
|
|
486
|
+
* for depth_index in depth (step: depth_simd_dimensions):
|
|
487
|
+
* a_vectors = load_a_vec_fn(A[row_tile, depth_index])
|
|
488
|
+
* b_vectors = load_b_vec_fn(B_packed[column_tile, depth_index])
|
|
489
|
+
* accumulator_tiles = inner_product_fn(accumulator_tiles, a_vectors, b_vectors)
|
|
490
|
+
* results = reduce_accumulators_fn(accumulator_tiles)
|
|
491
|
+
* partial_store_fn(results, C[row_tile, column_tile])
|
|
492
|
+
*
|
|
493
|
+
* @par Generated Function
|
|
494
|
+
*
|
|
495
|
+
* nk_##api_name##_packed_##input_type_name##_##isa_suffix##_aligned_(
|
|
496
|
+
* A_matrix, B_packed_buffer, C_matrix, row_count, column_count, depth,
|
|
497
|
+
* A_stride_bytes, C_stride_bytes)
|
|
498
|
+
*
|
|
499
|
+
* @param api_name Operation family (dots, hammings, jaccards) for codegen namespace
|
|
500
|
+
* @param input_type_name Type identifier for codegen (f32, bf16, i8, u1, etc.)
|
|
501
|
+
* @param isa_suffix ISA backend identifier (serial, haswell, neon, sve, icelake, etc.)
|
|
502
|
+
* @param input_type C type of input matrix values (f32, bf16, i8, u1x8, etc.)
|
|
503
|
+
* @param intermediate_type Storage type in packed B buffer (often bf16 or f32 for mixed precision)
|
|
504
|
+
* @param output_type C type of output matrix C values (f32, u32, f64, etc.)
|
|
505
|
+
* @param vec_type SIMD vector type for depth dimension (e.g., __m256, nk_f32x8_t)
|
|
506
|
+
* @param state_type Accumulator state type (often vec_type or wider, e.g., __m256 or __m512)
|
|
507
|
+
* @param result_vec_type SIMD vector type for reduction results (e.g., __m128 for 4 f32 results)
|
|
508
|
+
* @param init_accumulator_fn Initialize accumulator: void fn(state_type*)
|
|
509
|
+
* @param load_a_vec_fn Full A vector load: vec_type fn(input_type const*, nk_size_t offset)
|
|
510
|
+
* @param partial_load_a_vec_fn Partial A load for remainder
|
|
511
|
+
* @param load_b_vec_fn Full B vector load: vec_type fn(intermediate_type const*, nk_size_t offset)
|
|
512
|
+
* @param partial_load_b_vec_fn Partial B load for remainder
|
|
513
|
+
* @param inner_product_fn Inner product accumulate
|
|
514
|
+
* @param reduce_accumulators_fn Reduce 4 accumulators
|
|
515
|
+
* @param store_fn Full-width store for results
|
|
516
|
+
* @param store_fn Full-width store for results
|
|
517
|
+
* @param partial_store_fn Partial store for results
|
|
518
|
+
* @param depth_simd_dimensions SIMD vector width in logical dimensions (e.g., 8 for f32 on AVX2, 128 for u1 on serial)
|
|
519
|
+
* @param dimensions_per_value Packing ratio: dimensions per storage value (1 for f32, 2 for i4x2, 8 for u1x8)
|
|
520
|
+
*
|
|
521
|
+
* @sa nk_define_cross_symmetric_ for symmetric C = A × Aᵀ computation (upper triangle only)
|
|
522
|
+
* @sa nk_define_cross_pack_size_ for calculating B_packed buffer size
|
|
523
|
+
* @sa nk_define_cross_pack_ for packing B matrix into optimized layout
|
|
524
|
+
* @sa include/numkong/set/serial.h for state type definitions
|
|
525
|
+
* @sa include/numkong/cast/serial.h for load/store function implementations
|
|
526
|
+
*/
|
|
527
|
+
#define nk_define_cross_packed_(api_name, input_type_name, isa_suffix, input_value_type, packed_value_type, \
|
|
528
|
+
result_value_type, vec_type, state_type, result_vec_type, init_accumulator_fn, \
|
|
529
|
+
load_a_vec_fn, partial_load_a_vec_fn, load_b_vec_fn, partial_load_b_vec_fn, \
|
|
530
|
+
inner_product_fn, reduce_accumulators_fn, store_fn, partial_store_fn, \
|
|
531
|
+
depth_simd_dimensions, dimensions_per_value) \
|
|
532
|
+
NK_PUBLIC void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_aligned_( \
|
|
533
|
+
nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
|
|
534
|
+
nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
|
|
535
|
+
nk_size_t c_stride_in_bytes) { \
|
|
536
|
+
/* Read padded depth from header for correct stride calculation */ \
|
|
537
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer; \
|
|
538
|
+
nk_size_t const depth_padded = header->depth_padded_values; \
|
|
539
|
+
\
|
|
540
|
+
nk_##packed_value_type##_t const *packed_data = \
|
|
541
|
+
(nk_##packed_value_type##_t const *)((char const *)b_packed_buffer + \
|
|
542
|
+
sizeof(nk_cross_packed_buffer_header_t)); \
|
|
543
|
+
\
|
|
544
|
+
/* Cache blocking parameters (no depth_block blocking - full depth accumulated per tile) */ \
|
|
545
|
+
nk_size_t const row_block_size = 128; /* L2 cache blocking over rows */ \
|
|
546
|
+
nk_size_t const column_block_size = 2048; /* L3 cache blocking over columns */ \
|
|
547
|
+
nk_size_t const register_row_count = 4; /* Rows per register tile */ \
|
|
548
|
+
nk_size_t const register_column_count = 4; /* Columns per register tile */ \
|
|
549
|
+
/* Correct aligned_depth calculation for sub-byte types */ \
|
|
550
|
+
nk_size_t const depth_dimensions_aligned = (depth / depth_simd_dimensions) * depth_simd_dimensions; \
|
|
551
|
+
nk_size_t const aligned_depth = nk_size_divide_round_up_(depth_dimensions_aligned, dimensions_per_value); \
|
|
552
|
+
/* Calculate step size in storage values for loop increment */ \
|
|
553
|
+
nk_size_t const depth_step_values = nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
554
|
+
\
|
|
555
|
+
/* Zero output matrix */ \
|
|
556
|
+
for (nk_size_t row_index = 0; row_index < row_count; ++row_index) { \
|
|
557
|
+
nk_##result_value_type##_t *c_row = (nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
558
|
+
row_index * c_stride_in_bytes); \
|
|
559
|
+
for (nk_size_t column_index = 0; column_index < column_count; ++column_index) c_row[column_index] = 0; \
|
|
560
|
+
} \
|
|
561
|
+
\
|
|
562
|
+
/* Loop 1: L3 cache blocking over columns */ \
|
|
563
|
+
for (nk_size_t column_block_start_index = 0; column_block_start_index < column_count; \
|
|
564
|
+
column_block_start_index += column_block_size) { \
|
|
565
|
+
nk_size_t column_block_end_index = column_block_start_index + column_block_size; \
|
|
566
|
+
if (column_block_end_index > column_count) column_block_end_index = column_count; \
|
|
567
|
+
\
|
|
568
|
+
/* Loop 2: L2 cache blocking over rows */ \
|
|
569
|
+
for (nk_size_t row_block_start_index = 0; row_block_start_index < row_count; \
|
|
570
|
+
row_block_start_index += row_block_size) { \
|
|
571
|
+
nk_size_t row_block_end_index = row_block_start_index + row_block_size; \
|
|
572
|
+
if (row_block_end_index > row_count) row_block_end_index = row_count; \
|
|
573
|
+
\
|
|
574
|
+
/* Loop 3: Register tiling over columns (register_column_count columns per batch) */ \
|
|
575
|
+
for (nk_size_t tile_column_start_index = column_block_start_index; \
|
|
576
|
+
tile_column_start_index < column_block_end_index; \
|
|
577
|
+
tile_column_start_index += register_column_count) { \
|
|
578
|
+
\
|
|
579
|
+
/* Compute B pointers once per column tile - direct column-major addressing */ \
|
|
580
|
+
nk_##packed_value_type##_t const *b_depth_ptr_0 = packed_data + \
|
|
581
|
+
(tile_column_start_index + 0) * depth_padded; \
|
|
582
|
+
nk_##packed_value_type##_t const *b_depth_ptr_1 = packed_data + \
|
|
583
|
+
(tile_column_start_index + 1) * depth_padded; \
|
|
584
|
+
nk_##packed_value_type##_t const *b_depth_ptr_2 = packed_data + \
|
|
585
|
+
(tile_column_start_index + 2) * depth_padded; \
|
|
586
|
+
nk_##packed_value_type##_t const *b_depth_ptr_3 = packed_data + \
|
|
587
|
+
(tile_column_start_index + 3) * depth_padded; \
|
|
588
|
+
\
|
|
589
|
+
/* Loop 4: Register tiling over rows (register_row_count rows per tile) */ \
|
|
590
|
+
for (nk_size_t tile_row_start_index = row_block_start_index; \
|
|
591
|
+
tile_row_start_index < row_block_end_index; tile_row_start_index += register_row_count) { \
|
|
592
|
+
\
|
|
593
|
+
/* Initialize register_row_count × register_column_count accumulator states */ \
|
|
594
|
+
state_type accumulator_tiles[4][4]; \
|
|
595
|
+
init_accumulator_fn(&accumulator_tiles[0][0]), init_accumulator_fn(&accumulator_tiles[0][1]), \
|
|
596
|
+
init_accumulator_fn(&accumulator_tiles[0][2]), \
|
|
597
|
+
init_accumulator_fn(&accumulator_tiles[0][3]); \
|
|
598
|
+
init_accumulator_fn(&accumulator_tiles[1][0]), init_accumulator_fn(&accumulator_tiles[1][1]), \
|
|
599
|
+
init_accumulator_fn(&accumulator_tiles[1][2]), \
|
|
600
|
+
init_accumulator_fn(&accumulator_tiles[1][3]); \
|
|
601
|
+
init_accumulator_fn(&accumulator_tiles[2][0]), init_accumulator_fn(&accumulator_tiles[2][1]), \
|
|
602
|
+
init_accumulator_fn(&accumulator_tiles[2][2]), \
|
|
603
|
+
init_accumulator_fn(&accumulator_tiles[2][3]); \
|
|
604
|
+
init_accumulator_fn(&accumulator_tiles[3][0]), init_accumulator_fn(&accumulator_tiles[3][1]), \
|
|
605
|
+
init_accumulator_fn(&accumulator_tiles[3][2]), \
|
|
606
|
+
init_accumulator_fn(&accumulator_tiles[3][3]); \
|
|
607
|
+
\
|
|
608
|
+
/* A row pointers */ \
|
|
609
|
+
nk_##input_value_type##_t const *a_row_ptr_0 = \
|
|
610
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
611
|
+
(tile_row_start_index + 0) * a_stride_in_bytes); \
|
|
612
|
+
nk_##input_value_type##_t const *a_row_ptr_1 = \
|
|
613
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
614
|
+
(tile_row_start_index + 1) * a_stride_in_bytes); \
|
|
615
|
+
nk_##input_value_type##_t const *a_row_ptr_2 = \
|
|
616
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
617
|
+
(tile_row_start_index + 2) * a_stride_in_bytes); \
|
|
618
|
+
nk_##input_value_type##_t const *a_row_ptr_3 = \
|
|
619
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
620
|
+
(tile_row_start_index + 3) * a_stride_in_bytes); \
|
|
621
|
+
\
|
|
622
|
+
/* Tight inner loop: full depth with simple depth_index addressing */ \
|
|
623
|
+
vec_type a_vector_0, a_vector_1, a_vector_2, a_vector_3; \
|
|
624
|
+
vec_type b_vector_0, b_vector_1, b_vector_2, b_vector_3; \
|
|
625
|
+
for (nk_size_t depth_index = 0; depth_index < aligned_depth; \
|
|
626
|
+
depth_index += depth_step_values) { \
|
|
627
|
+
/* Load next few values from 4 rows from A (unpacked, may upcast) */ \
|
|
628
|
+
load_a_vec_fn(a_row_ptr_0 + depth_index, &a_vector_0); \
|
|
629
|
+
load_a_vec_fn(a_row_ptr_1 + depth_index, &a_vector_1); \
|
|
630
|
+
load_a_vec_fn(a_row_ptr_2 + depth_index, &a_vector_2); \
|
|
631
|
+
load_a_vec_fn(a_row_ptr_3 + depth_index, &a_vector_3); \
|
|
632
|
+
\
|
|
633
|
+
/* Load next few values from 4 rows from B (packed, already upcasted) */ \
|
|
634
|
+
load_b_vec_fn(b_depth_ptr_0 + depth_index, &b_vector_0); \
|
|
635
|
+
load_b_vec_fn(b_depth_ptr_1 + depth_index, &b_vector_1); \
|
|
636
|
+
load_b_vec_fn(b_depth_ptr_2 + depth_index, &b_vector_2); \
|
|
637
|
+
load_b_vec_fn(b_depth_ptr_3 + depth_index, &b_vector_3); \
|
|
638
|
+
\
|
|
639
|
+
/* 16 FMAs: 4 A rows × 4 B columns */ \
|
|
640
|
+
inner_product_fn(&accumulator_tiles[0][0], a_vector_0, b_vector_0, \
|
|
641
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
642
|
+
inner_product_fn(&accumulator_tiles[0][1], a_vector_0, b_vector_1, \
|
|
643
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
644
|
+
inner_product_fn(&accumulator_tiles[0][2], a_vector_0, b_vector_2, \
|
|
645
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
646
|
+
inner_product_fn(&accumulator_tiles[0][3], a_vector_0, b_vector_3, \
|
|
647
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
648
|
+
inner_product_fn(&accumulator_tiles[1][0], a_vector_1, b_vector_0, \
|
|
649
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
650
|
+
inner_product_fn(&accumulator_tiles[1][1], a_vector_1, b_vector_1, \
|
|
651
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
652
|
+
inner_product_fn(&accumulator_tiles[1][2], a_vector_1, b_vector_2, \
|
|
653
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
654
|
+
inner_product_fn(&accumulator_tiles[1][3], a_vector_1, b_vector_3, \
|
|
655
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
656
|
+
inner_product_fn(&accumulator_tiles[2][0], a_vector_2, b_vector_0, \
|
|
657
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
658
|
+
inner_product_fn(&accumulator_tiles[2][1], a_vector_2, b_vector_1, \
|
|
659
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
660
|
+
inner_product_fn(&accumulator_tiles[2][2], a_vector_2, b_vector_2, \
|
|
661
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
662
|
+
inner_product_fn(&accumulator_tiles[2][3], a_vector_2, b_vector_3, \
|
|
663
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
664
|
+
inner_product_fn(&accumulator_tiles[3][0], a_vector_3, b_vector_0, \
|
|
665
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
666
|
+
inner_product_fn(&accumulator_tiles[3][1], a_vector_3, b_vector_1, \
|
|
667
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
668
|
+
inner_product_fn(&accumulator_tiles[3][2], a_vector_3, b_vector_2, \
|
|
669
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
670
|
+
inner_product_fn(&accumulator_tiles[3][3], a_vector_3, b_vector_3, \
|
|
671
|
+
depth_index * dimensions_per_value, depth_simd_dimensions); \
|
|
672
|
+
} \
|
|
673
|
+
/* Finalize and store register_rows x register_cols results using batched 4-way reduction */ \
|
|
674
|
+
result_vec_type result_vector; \
|
|
675
|
+
nk_##result_value_type##_t *c_row_ptr_0 = \
|
|
676
|
+
(nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
677
|
+
(tile_row_start_index + 0) * c_stride_in_bytes); \
|
|
678
|
+
reduce_accumulators_fn(&accumulator_tiles[0][0], &accumulator_tiles[0][1], \
|
|
679
|
+
&accumulator_tiles[0][2], &accumulator_tiles[0][3], depth, \
|
|
680
|
+
&result_vector); \
|
|
681
|
+
store_fn(&result_vector, c_row_ptr_0 + tile_column_start_index); \
|
|
682
|
+
nk_##result_value_type##_t *c_row_ptr_1 = \
|
|
683
|
+
(nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
684
|
+
(tile_row_start_index + 1) * c_stride_in_bytes); \
|
|
685
|
+
reduce_accumulators_fn(&accumulator_tiles[1][0], &accumulator_tiles[1][1], \
|
|
686
|
+
&accumulator_tiles[1][2], &accumulator_tiles[1][3], depth, \
|
|
687
|
+
&result_vector); \
|
|
688
|
+
store_fn(&result_vector, c_row_ptr_1 + tile_column_start_index); \
|
|
689
|
+
nk_##result_value_type##_t *c_row_ptr_2 = \
|
|
690
|
+
(nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
691
|
+
(tile_row_start_index + 2) * c_stride_in_bytes); \
|
|
692
|
+
reduce_accumulators_fn(&accumulator_tiles[2][0], &accumulator_tiles[2][1], \
|
|
693
|
+
&accumulator_tiles[2][2], &accumulator_tiles[2][3], depth, \
|
|
694
|
+
&result_vector); \
|
|
695
|
+
store_fn(&result_vector, c_row_ptr_2 + tile_column_start_index); \
|
|
696
|
+
nk_##result_value_type##_t *c_row_ptr_3 = \
|
|
697
|
+
(nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
698
|
+
(tile_row_start_index + 3) * c_stride_in_bytes); \
|
|
699
|
+
reduce_accumulators_fn(&accumulator_tiles[3][0], &accumulator_tiles[3][1], \
|
|
700
|
+
&accumulator_tiles[3][2], &accumulator_tiles[3][3], depth, \
|
|
701
|
+
&result_vector); \
|
|
702
|
+
store_fn(&result_vector, c_row_ptr_3 + tile_column_start_index); \
|
|
703
|
+
} \
|
|
704
|
+
} \
|
|
705
|
+
} \
|
|
706
|
+
} \
|
|
707
|
+
} \
|
|
708
|
+
NK_PUBLIC void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_1x8_aligned_( \
|
|
709
|
+
nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
|
|
710
|
+
nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
|
|
711
|
+
nk_size_t c_stride_in_bytes) { \
|
|
712
|
+
/* Read padded depth from header for correct stride calculation */ \
|
|
713
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer; \
|
|
714
|
+
nk_size_t const depth_padded = header->depth_padded_values; /* in storage values */ \
|
|
715
|
+
\
|
|
716
|
+
nk_##packed_value_type##_t const *packed_data = \
|
|
717
|
+
(nk_##packed_value_type##_t const *)((char const *)b_packed_buffer + \
|
|
718
|
+
sizeof(nk_cross_packed_buffer_header_t)); \
|
|
719
|
+
\
|
|
720
|
+
/* Cache blocking parameters (no depth_block blocking - full depth accumulated per tile) */ \
|
|
721
|
+
nk_size_t const row_block_size = 128; /* L2 cache blocking over rows */ \
|
|
722
|
+
nk_size_t const column_block_size = 2048; /* L3 cache blocking over columns */ \
|
|
723
|
+
nk_size_t const register_row_count = 1; /* Rows per register tile */ \
|
|
724
|
+
nk_size_t const register_column_count = 8; /* Columns per register tile (2 × 4) */ \
|
|
725
|
+
/* Correct aligned_depth calculation for sub-byte types */ \
|
|
726
|
+
nk_size_t const depth_dimensions_aligned = (depth / depth_simd_dimensions) * depth_simd_dimensions; \
|
|
727
|
+
nk_size_t const aligned_depth = nk_size_divide_round_up_(depth_dimensions_aligned, dimensions_per_value); \
|
|
728
|
+
/* Calculate step size in storage values for loop increment */ \
|
|
729
|
+
nk_size_t const depth_step_values = nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
730
|
+
nk_unused_(register_row_count); /* Used in comments, loop uses 1 directly */ \
|
|
731
|
+
\
|
|
732
|
+
/* Zero output matrix */ \
|
|
733
|
+
for (nk_size_t row_index = 0; row_index < row_count; ++row_index) { \
|
|
734
|
+
nk_##result_value_type##_t *c_row = (nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
735
|
+
row_index * c_stride_in_bytes); \
|
|
736
|
+
for (nk_size_t column_index = 0; column_index < column_count; ++column_index) c_row[column_index] = 0; \
|
|
737
|
+
} \
|
|
738
|
+
\
|
|
739
|
+
/* Loop 1: L3 cache blocking over columns */ \
|
|
740
|
+
for (nk_size_t column_block_start_index = 0; column_block_start_index < column_count; \
|
|
741
|
+
column_block_start_index += column_block_size) { \
|
|
742
|
+
nk_size_t column_block_end_index = column_block_start_index + column_block_size; \
|
|
743
|
+
if (column_block_end_index > column_count) column_block_end_index = column_count; \
|
|
744
|
+
\
|
|
745
|
+
/* Loop 2: L2 cache blocking over rows */ \
|
|
746
|
+
for (nk_size_t row_block_start_index = 0; row_block_start_index < row_count; \
|
|
747
|
+
row_block_start_index += row_block_size) { \
|
|
748
|
+
nk_size_t const row_block_end_index = row_block_start_index + row_block_size < row_count \
|
|
749
|
+
? row_block_start_index + row_block_size \
|
|
750
|
+
: row_count; \
|
|
751
|
+
\
|
|
752
|
+
/* Loop 3: Register tiling over columns (register_column_count columns per batch) */ \
|
|
753
|
+
for (nk_size_t tile_column_start_index = column_block_start_index; \
|
|
754
|
+
tile_column_start_index < column_block_end_index; \
|
|
755
|
+
tile_column_start_index += register_column_count) { \
|
|
756
|
+
\
|
|
757
|
+
/* Compute B pointers once per column tile - direct column-major addressing */ \
|
|
758
|
+
nk_##packed_value_type##_t const *b_depth_ptr_0 = packed_data + \
|
|
759
|
+
(tile_column_start_index + 0) * depth_padded; \
|
|
760
|
+
nk_##packed_value_type##_t const *b_depth_ptr_1 = packed_data + \
|
|
761
|
+
(tile_column_start_index + 1) * depth_padded; \
|
|
762
|
+
nk_##packed_value_type##_t const *b_depth_ptr_2 = packed_data + \
|
|
763
|
+
(tile_column_start_index + 2) * depth_padded; \
|
|
764
|
+
nk_##packed_value_type##_t const *b_depth_ptr_3 = packed_data + \
|
|
765
|
+
(tile_column_start_index + 3) * depth_padded; \
|
|
766
|
+
nk_##packed_value_type##_t const *b_depth_ptr_4 = packed_data + \
|
|
767
|
+
(tile_column_start_index + 4) * depth_padded; \
|
|
768
|
+
nk_##packed_value_type##_t const *b_depth_ptr_5 = packed_data + \
|
|
769
|
+
(tile_column_start_index + 5) * depth_padded; \
|
|
770
|
+
nk_##packed_value_type##_t const *b_depth_ptr_6 = packed_data + \
|
|
771
|
+
(tile_column_start_index + 6) * depth_padded; \
|
|
772
|
+
nk_##packed_value_type##_t const *b_depth_ptr_7 = packed_data + \
|
|
773
|
+
(tile_column_start_index + 7) * depth_padded; \
|
|
774
|
+
\
|
|
775
|
+
/* Loop 4: Process 1 row at a time */ \
|
|
776
|
+
for (nk_size_t row_index = row_block_start_index; row_index < row_block_end_index; ++row_index) { \
|
|
777
|
+
\
|
|
778
|
+
/* Initialize 1 × 8 accumulator states */ \
|
|
779
|
+
state_type accumulator_0, accumulator_1, accumulator_2, accumulator_3, accumulator_4, \
|
|
780
|
+
accumulator_5, accumulator_6, accumulator_7; \
|
|
781
|
+
init_accumulator_fn(&accumulator_0), init_accumulator_fn(&accumulator_1), \
|
|
782
|
+
init_accumulator_fn(&accumulator_2), init_accumulator_fn(&accumulator_3), \
|
|
783
|
+
init_accumulator_fn(&accumulator_4), init_accumulator_fn(&accumulator_5), \
|
|
784
|
+
init_accumulator_fn(&accumulator_6), init_accumulator_fn(&accumulator_7); \
|
|
785
|
+
\
|
|
786
|
+
/* A row pointer */ \
|
|
787
|
+
nk_##input_value_type##_t const *a_row_ptr = \
|
|
788
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
789
|
+
row_index * a_stride_in_bytes); \
|
|
790
|
+
\
|
|
791
|
+
/* Tight inner loop: full depth with simple depth_index addressing */ \
|
|
792
|
+
vec_type a_vector; \
|
|
793
|
+
vec_type b_vector_0, b_vector_1, b_vector_2, b_vector_3, b_vector_4, b_vector_5, b_vector_6, \
|
|
794
|
+
b_vector_7; \
|
|
795
|
+
for (nk_size_t depth_index = 0; depth_index < aligned_depth; \
|
|
796
|
+
depth_index += depth_step_values) { \
|
|
797
|
+
/* Load A vector (1 row) */ \
|
|
798
|
+
load_a_vec_fn(a_row_ptr + depth_index, &a_vector); \
|
|
799
|
+
\
|
|
800
|
+
/* Load B vectors (8 columns) */ \
|
|
801
|
+
load_b_vec_fn(b_depth_ptr_0 + depth_index, &b_vector_0); \
|
|
802
|
+
load_b_vec_fn(b_depth_ptr_1 + depth_index, &b_vector_1); \
|
|
803
|
+
load_b_vec_fn(b_depth_ptr_2 + depth_index, &b_vector_2); \
|
|
804
|
+
load_b_vec_fn(b_depth_ptr_3 + depth_index, &b_vector_3); \
|
|
805
|
+
load_b_vec_fn(b_depth_ptr_4 + depth_index, &b_vector_4); \
|
|
806
|
+
load_b_vec_fn(b_depth_ptr_5 + depth_index, &b_vector_5); \
|
|
807
|
+
load_b_vec_fn(b_depth_ptr_6 + depth_index, &b_vector_6); \
|
|
808
|
+
load_b_vec_fn(b_depth_ptr_7 + depth_index, &b_vector_7); \
|
|
809
|
+
\
|
|
810
|
+
/* 8 FMAs: 1 A row × 8 B columns */ \
|
|
811
|
+
inner_product_fn(&accumulator_0, a_vector, b_vector_0, depth_index * dimensions_per_value, \
|
|
812
|
+
depth_simd_dimensions); \
|
|
813
|
+
inner_product_fn(&accumulator_1, a_vector, b_vector_1, depth_index * dimensions_per_value, \
|
|
814
|
+
depth_simd_dimensions); \
|
|
815
|
+
inner_product_fn(&accumulator_2, a_vector, b_vector_2, depth_index * dimensions_per_value, \
|
|
816
|
+
depth_simd_dimensions); \
|
|
817
|
+
inner_product_fn(&accumulator_3, a_vector, b_vector_3, depth_index * dimensions_per_value, \
|
|
818
|
+
depth_simd_dimensions); \
|
|
819
|
+
inner_product_fn(&accumulator_4, a_vector, b_vector_4, depth_index * dimensions_per_value, \
|
|
820
|
+
depth_simd_dimensions); \
|
|
821
|
+
inner_product_fn(&accumulator_5, a_vector, b_vector_5, depth_index * dimensions_per_value, \
|
|
822
|
+
depth_simd_dimensions); \
|
|
823
|
+
inner_product_fn(&accumulator_6, a_vector, b_vector_6, depth_index * dimensions_per_value, \
|
|
824
|
+
depth_simd_dimensions); \
|
|
825
|
+
inner_product_fn(&accumulator_7, a_vector, b_vector_7, depth_index * dimensions_per_value, \
|
|
826
|
+
depth_simd_dimensions); \
|
|
827
|
+
} \
|
|
828
|
+
\
|
|
829
|
+
/* Finalize and store 1 × 8 results using two 4-way reductions */ \
|
|
830
|
+
result_vec_type result_vector; \
|
|
831
|
+
nk_##result_value_type##_t *c_row_ptr = \
|
|
832
|
+
(nk_##result_value_type##_t *)((char *)c_matrix + row_index * c_stride_in_bytes); \
|
|
833
|
+
/* First 4 columns */ \
|
|
834
|
+
reduce_accumulators_fn(&accumulator_0, &accumulator_1, &accumulator_2, &accumulator_3, depth, \
|
|
835
|
+
&result_vector); \
|
|
836
|
+
store_fn(&result_vector, c_row_ptr + tile_column_start_index); \
|
|
837
|
+
/* Second 4 columns */ \
|
|
838
|
+
reduce_accumulators_fn(&accumulator_4, &accumulator_5, &accumulator_6, &accumulator_7, depth, \
|
|
839
|
+
&result_vector); \
|
|
840
|
+
store_fn(&result_vector, c_row_ptr + tile_column_start_index + 4); \
|
|
841
|
+
} \
|
|
842
|
+
} \
|
|
843
|
+
} \
|
|
844
|
+
} \
|
|
845
|
+
} \
|
|
846
|
+
NK_PUBLIC void nk_##api_name##_packed_##input_type_name##_##isa_suffix( \
|
|
847
|
+
nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
|
|
848
|
+
nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
|
|
849
|
+
nk_size_t c_stride_in_bytes) { \
|
|
850
|
+
/* Read padded depth from header for correct stride calculation */ \
|
|
851
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer; \
|
|
852
|
+
nk_size_t const depth_padded = header->depth_padded_values; \
|
|
853
|
+
\
|
|
854
|
+
/* Cache blocking parameters (hardcoded for optimal L1/L2/L3 utilization) */ \
|
|
855
|
+
nk_size_t const row_block_size = 128; /* L2 cache blocking over rows */ \
|
|
856
|
+
nk_size_t const column_block_size = 2048; /* L3 cache blocking over columns */ \
|
|
857
|
+
nk_size_t const register_row_count = 4; /* Rows per register tile */ \
|
|
858
|
+
nk_size_t const register_column_count = 4; /* Columns per register tile */ \
|
|
859
|
+
nk_unused_(register_column_count); /* Suppress unused warnings */ \
|
|
860
|
+
/* Use 1 × 8 kernel when columns are aligned to 8 and many columns relative to rows */ \
|
|
861
|
+
if (column_count % 8 == 0 && column_count >= row_count * 2 && depth % depth_simd_dimensions == 0) { \
|
|
862
|
+
nk_##api_name##_packed_##input_type_name##_##isa_suffix##_1x8_aligned_( \
|
|
863
|
+
a_matrix, b_packed_buffer, c_matrix, row_count, column_count, depth, a_stride_in_bytes, \
|
|
864
|
+
c_stride_in_bytes); \
|
|
865
|
+
return; \
|
|
866
|
+
} \
|
|
867
|
+
/* Use 4 × 4 kernel when dimensions are 4-aligned */ \
|
|
868
|
+
if (row_count % 4 == 0 && column_count % 4 == 0 && depth % depth_simd_dimensions == 0) { \
|
|
869
|
+
nk_##api_name##_packed_##input_type_name##_##isa_suffix##_aligned_(a_matrix, b_packed_buffer, c_matrix, \
|
|
870
|
+
row_count, column_count, depth, \
|
|
871
|
+
a_stride_in_bytes, c_stride_in_bytes); \
|
|
872
|
+
return; \
|
|
873
|
+
} \
|
|
874
|
+
\
|
|
875
|
+
/* Zero output matrix */ \
|
|
876
|
+
for (nk_size_t row_index = 0; row_index < row_count; ++row_index) { \
|
|
877
|
+
nk_##result_value_type##_t *c_row = (nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
878
|
+
row_index * c_stride_in_bytes); \
|
|
879
|
+
for (nk_size_t column_index = 0; column_index < column_count; ++column_index) c_row[column_index] = 0; \
|
|
880
|
+
} \
|
|
881
|
+
\
|
|
882
|
+
/* Compute aligned/remainder depth for partial loads (correct for sub-byte types) */ \
|
|
883
|
+
nk_size_t const depth_dimensions_aligned = (depth / depth_simd_dimensions) * depth_simd_dimensions; \
|
|
884
|
+
nk_size_t const aligned_depth = nk_size_divide_round_up_(depth_dimensions_aligned, dimensions_per_value); \
|
|
885
|
+
nk_size_t const depth_in_values = nk_size_divide_round_up_(depth, dimensions_per_value); \
|
|
886
|
+
nk_size_t const remainder_depth = depth_in_values - aligned_depth; \
|
|
887
|
+
nk_size_t const remainder_dimensions = depth - depth_dimensions_aligned; \
|
|
888
|
+
/* Calculate step size in storage values for loop increment */ \
|
|
889
|
+
nk_size_t const depth_step_values = nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
890
|
+
\
|
|
891
|
+
/* Loop 1: L3 cache blocking over columns */ \
|
|
892
|
+
nk_##packed_value_type##_t const *packed_data = \
|
|
893
|
+
(nk_##packed_value_type##_t const *)((char const *)b_packed_buffer + \
|
|
894
|
+
sizeof(nk_cross_packed_buffer_header_t)); \
|
|
895
|
+
for (nk_size_t column_block_start_index = 0; column_block_start_index < column_count; \
|
|
896
|
+
column_block_start_index += column_block_size) { \
|
|
897
|
+
nk_size_t column_block_end_index = column_block_start_index + column_block_size; \
|
|
898
|
+
if (column_block_end_index > column_count) column_block_end_index = column_count; \
|
|
899
|
+
\
|
|
900
|
+
/* Loop 2: L2 cache blocking over rows */ \
|
|
901
|
+
for (nk_size_t row_block_start_index = 0; row_block_start_index < row_count; \
|
|
902
|
+
row_block_start_index += row_block_size) { \
|
|
903
|
+
nk_size_t row_block_end_index = row_block_start_index + row_block_size; \
|
|
904
|
+
if (row_block_end_index > row_count) row_block_end_index = row_count; \
|
|
905
|
+
\
|
|
906
|
+
/* Loop 4: Register tiling over columns (register_column_count columns per batch) */ \
|
|
907
|
+
for (nk_size_t tile_column_start_index = column_block_start_index; \
|
|
908
|
+
tile_column_start_index < column_block_end_index; \
|
|
909
|
+
tile_column_start_index += register_column_count) { \
|
|
910
|
+
nk_size_t tile_column_count = register_column_count; \
|
|
911
|
+
if (tile_column_start_index + tile_column_count > column_block_end_index) \
|
|
912
|
+
tile_column_count = column_block_end_index - tile_column_start_index; \
|
|
913
|
+
\
|
|
914
|
+
/* Compute B pointers once per column tile - direct column-major addressing */ \
|
|
915
|
+
nk_##packed_value_type##_t const *b_depth_ptr_0 = packed_data + \
|
|
916
|
+
(tile_column_start_index + 0) * depth_padded; \
|
|
917
|
+
nk_##packed_value_type##_t const *b_depth_ptr_1 = \
|
|
918
|
+
(tile_column_count > 1) ? packed_data + (tile_column_start_index + 1) * depth_padded \
|
|
919
|
+
: b_depth_ptr_0; \
|
|
920
|
+
nk_##packed_value_type##_t const *b_depth_ptr_2 = \
|
|
921
|
+
(tile_column_count > 2) ? packed_data + (tile_column_start_index + 2) * depth_padded \
|
|
922
|
+
: b_depth_ptr_0; \
|
|
923
|
+
nk_##packed_value_type##_t const *b_depth_ptr_3 = \
|
|
924
|
+
(tile_column_count > 3) ? packed_data + (tile_column_start_index + 3) * depth_padded \
|
|
925
|
+
: b_depth_ptr_0; \
|
|
926
|
+
\
|
|
927
|
+
/* Loop 5: Register tiling over rows (register_rows rows per tile) */ \
|
|
928
|
+
for (nk_size_t tile_row_start_index = row_block_start_index; \
|
|
929
|
+
tile_row_start_index < row_block_end_index; tile_row_start_index += register_row_count) { \
|
|
930
|
+
nk_size_t tile_row_count = register_row_count; \
|
|
931
|
+
if (tile_row_start_index + tile_row_count > row_block_end_index) \
|
|
932
|
+
tile_row_count = row_block_end_index - tile_row_start_index; \
|
|
933
|
+
\
|
|
934
|
+
/* Initialize register_rows x register_cols accumulator states */ \
|
|
935
|
+
state_type accumulator_tiles[4][4]; \
|
|
936
|
+
for (nk_size_t r = 0; r < tile_row_count; ++r) { \
|
|
937
|
+
init_accumulator_fn(&accumulator_tiles[r][0]); \
|
|
938
|
+
init_accumulator_fn(&accumulator_tiles[r][1]); \
|
|
939
|
+
init_accumulator_fn(&accumulator_tiles[r][2]); \
|
|
940
|
+
init_accumulator_fn(&accumulator_tiles[r][3]); \
|
|
941
|
+
} \
|
|
942
|
+
\
|
|
943
|
+
/* A row pointers */ \
|
|
944
|
+
nk_##input_value_type##_t const *a_row_ptr_0 = \
|
|
945
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
946
|
+
(tile_row_start_index + 0) * a_stride_in_bytes); \
|
|
947
|
+
nk_##input_value_type##_t const *a_row_ptr_1 = \
|
|
948
|
+
(tile_row_count > 1) \
|
|
949
|
+
? (nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
950
|
+
(tile_row_start_index + 1) * a_stride_in_bytes) \
|
|
951
|
+
: a_row_ptr_0; \
|
|
952
|
+
nk_##input_value_type##_t const *a_row_ptr_2 = \
|
|
953
|
+
(tile_row_count > 2) \
|
|
954
|
+
? (nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
955
|
+
(tile_row_start_index + 2) * a_stride_in_bytes) \
|
|
956
|
+
: a_row_ptr_0; \
|
|
957
|
+
nk_##input_value_type##_t const *a_row_ptr_3 = \
|
|
958
|
+
(tile_row_count > 3) \
|
|
959
|
+
? (nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
960
|
+
(tile_row_start_index + 3) * a_stride_in_bytes) \
|
|
961
|
+
: a_row_ptr_0; \
|
|
962
|
+
\
|
|
963
|
+
/* Tight inner loop: k values with simple ptr+k addressing */ \
|
|
964
|
+
vec_type a_first_vec, a_second_vec, a_third_vec, a_fourth_vec; \
|
|
965
|
+
vec_type b_first_vec, b_second_vec, b_third_vec, b_fourth_vec; \
|
|
966
|
+
for (nk_size_t k = 0; k < aligned_depth; k += depth_step_values) { \
|
|
967
|
+
/* Load next few values from 4 rows from A */ \
|
|
968
|
+
load_a_vec_fn(a_row_ptr_0 + k, &a_first_vec); \
|
|
969
|
+
load_a_vec_fn(a_row_ptr_1 + k, &a_second_vec); \
|
|
970
|
+
load_a_vec_fn(a_row_ptr_2 + k, &a_third_vec); \
|
|
971
|
+
load_a_vec_fn(a_row_ptr_3 + k, &a_fourth_vec); \
|
|
972
|
+
\
|
|
973
|
+
/* Load next few values from 4 rows from B */ \
|
|
974
|
+
load_b_vec_fn(b_depth_ptr_0 + k, &b_first_vec); \
|
|
975
|
+
load_b_vec_fn(b_depth_ptr_1 + k, &b_second_vec); \
|
|
976
|
+
load_b_vec_fn(b_depth_ptr_2 + k, &b_third_vec); \
|
|
977
|
+
load_b_vec_fn(b_depth_ptr_3 + k, &b_fourth_vec); \
|
|
978
|
+
\
|
|
979
|
+
/* 16 FMAs: 4 A rows × 4 B columns */ \
|
|
980
|
+
inner_product_fn(&accumulator_tiles[0][0], a_first_vec, b_first_vec, \
|
|
981
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
982
|
+
inner_product_fn(&accumulator_tiles[0][1], a_first_vec, b_second_vec, \
|
|
983
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
984
|
+
inner_product_fn(&accumulator_tiles[0][2], a_first_vec, b_third_vec, \
|
|
985
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
986
|
+
inner_product_fn(&accumulator_tiles[0][3], a_first_vec, b_fourth_vec, \
|
|
987
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
988
|
+
inner_product_fn(&accumulator_tiles[1][0], a_second_vec, b_first_vec, \
|
|
989
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
990
|
+
inner_product_fn(&accumulator_tiles[1][1], a_second_vec, b_second_vec, \
|
|
991
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
992
|
+
inner_product_fn(&accumulator_tiles[1][2], a_second_vec, b_third_vec, \
|
|
993
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
994
|
+
inner_product_fn(&accumulator_tiles[1][3], a_second_vec, b_fourth_vec, \
|
|
995
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
996
|
+
inner_product_fn(&accumulator_tiles[2][0], a_third_vec, b_first_vec, \
|
|
997
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
998
|
+
inner_product_fn(&accumulator_tiles[2][1], a_third_vec, b_second_vec, \
|
|
999
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
1000
|
+
inner_product_fn(&accumulator_tiles[2][2], a_third_vec, b_third_vec, \
|
|
1001
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
1002
|
+
inner_product_fn(&accumulator_tiles[2][3], a_third_vec, b_fourth_vec, \
|
|
1003
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
1004
|
+
inner_product_fn(&accumulator_tiles[3][0], a_fourth_vec, b_first_vec, \
|
|
1005
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
1006
|
+
inner_product_fn(&accumulator_tiles[3][1], a_fourth_vec, b_second_vec, \
|
|
1007
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
1008
|
+
inner_product_fn(&accumulator_tiles[3][2], a_fourth_vec, b_third_vec, \
|
|
1009
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
1010
|
+
inner_product_fn(&accumulator_tiles[3][3], a_fourth_vec, b_fourth_vec, \
|
|
1011
|
+
k * dimensions_per_value, depth_simd_dimensions); \
|
|
1012
|
+
} \
|
|
1013
|
+
\
|
|
1014
|
+
/* Handle remainder k positions with partial loads */ \
|
|
1015
|
+
if (remainder_depth > 0) { \
|
|
1016
|
+
/* Load next few values from 4 rows from A */ \
|
|
1017
|
+
partial_load_a_vec_fn(a_row_ptr_0 + aligned_depth, &a_first_vec, remainder_dimensions); \
|
|
1018
|
+
partial_load_a_vec_fn(a_row_ptr_1 + aligned_depth, &a_second_vec, remainder_dimensions); \
|
|
1019
|
+
partial_load_a_vec_fn(a_row_ptr_2 + aligned_depth, &a_third_vec, remainder_dimensions); \
|
|
1020
|
+
partial_load_a_vec_fn(a_row_ptr_3 + aligned_depth, &a_fourth_vec, remainder_dimensions); \
|
|
1021
|
+
\
|
|
1022
|
+
/* Load next few values from 4 rows from B */ \
|
|
1023
|
+
partial_load_b_vec_fn(b_depth_ptr_0 + aligned_depth, &b_first_vec, remainder_dimensions); \
|
|
1024
|
+
partial_load_b_vec_fn(b_depth_ptr_1 + aligned_depth, &b_second_vec, remainder_dimensions); \
|
|
1025
|
+
partial_load_b_vec_fn(b_depth_ptr_2 + aligned_depth, &b_third_vec, remainder_dimensions); \
|
|
1026
|
+
partial_load_b_vec_fn(b_depth_ptr_3 + aligned_depth, &b_fourth_vec, remainder_dimensions); \
|
|
1027
|
+
\
|
|
1028
|
+
/* 16 FMAs: 4 A rows × 4 B columns */ \
|
|
1029
|
+
inner_product_fn(&accumulator_tiles[0][0], a_first_vec, b_first_vec, \
|
|
1030
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1031
|
+
inner_product_fn(&accumulator_tiles[0][1], a_first_vec, b_second_vec, \
|
|
1032
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1033
|
+
inner_product_fn(&accumulator_tiles[0][2], a_first_vec, b_third_vec, \
|
|
1034
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1035
|
+
inner_product_fn(&accumulator_tiles[0][3], a_first_vec, b_fourth_vec, \
|
|
1036
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1037
|
+
inner_product_fn(&accumulator_tiles[1][0], a_second_vec, b_first_vec, \
|
|
1038
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1039
|
+
inner_product_fn(&accumulator_tiles[1][1], a_second_vec, b_second_vec, \
|
|
1040
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1041
|
+
inner_product_fn(&accumulator_tiles[1][2], a_second_vec, b_third_vec, \
|
|
1042
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1043
|
+
inner_product_fn(&accumulator_tiles[1][3], a_second_vec, b_fourth_vec, \
|
|
1044
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1045
|
+
inner_product_fn(&accumulator_tiles[2][0], a_third_vec, b_first_vec, \
|
|
1046
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1047
|
+
inner_product_fn(&accumulator_tiles[2][1], a_third_vec, b_second_vec, \
|
|
1048
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1049
|
+
inner_product_fn(&accumulator_tiles[2][2], a_third_vec, b_third_vec, \
|
|
1050
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1051
|
+
inner_product_fn(&accumulator_tiles[2][3], a_third_vec, b_fourth_vec, \
|
|
1052
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1053
|
+
inner_product_fn(&accumulator_tiles[3][0], a_fourth_vec, b_first_vec, \
|
|
1054
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1055
|
+
inner_product_fn(&accumulator_tiles[3][1], a_fourth_vec, b_second_vec, \
|
|
1056
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1057
|
+
inner_product_fn(&accumulator_tiles[3][2], a_fourth_vec, b_third_vec, \
|
|
1058
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1059
|
+
inner_product_fn(&accumulator_tiles[3][3], a_fourth_vec, b_fourth_vec, \
|
|
1060
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1061
|
+
} \
|
|
1062
|
+
\
|
|
1063
|
+
/* Finalize and store register_rows x register_cols results using batched 4-way reduction */ \
|
|
1064
|
+
for (nk_size_t r = 0; r < tile_row_count; ++r) { \
|
|
1065
|
+
result_vec_type result_vector; \
|
|
1066
|
+
reduce_accumulators_fn(&accumulator_tiles[r][0], &accumulator_tiles[r][1], \
|
|
1067
|
+
&accumulator_tiles[r][2], &accumulator_tiles[r][3], depth, \
|
|
1068
|
+
&result_vector); \
|
|
1069
|
+
\
|
|
1070
|
+
nk_##result_value_type##_t *c_row = \
|
|
1071
|
+
(nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
1072
|
+
(tile_row_start_index + r) * c_stride_in_bytes); \
|
|
1073
|
+
partial_store_fn(&result_vector, c_row + tile_column_start_index, tile_column_count); \
|
|
1074
|
+
} \
|
|
1075
|
+
} \
|
|
1076
|
+
} \
|
|
1077
|
+
} \
|
|
1078
|
+
} \
|
|
1079
|
+
}
|
|
1080
|
+
|
|
1081
|
+
/**
|
|
1082
|
+
* @brief Generates compensated GEMM: C = A × Bᵀ with precomputed B column sums.
|
|
1083
|
+
*
|
|
1084
|
+
* Like nk_define_cross_packed_ but the finalize function receives precomputed B column sums
|
|
1085
|
+
* and per-row A sums to apply algebraic correction inline. This eliminates correction
|
|
1086
|
+
* accumulators from the inner loop state, halving register pressure for integer dot products.
|
|
1087
|
+
*
|
|
1088
|
+
* The compensated_finalize_fn signature differs from the standard reduce_accumulators_fn:
|
|
1089
|
+
* compensated_finalize_fn(state_a, state_b, state_c, state_d, depth, a_sum, b_sums_vec, result)
|
|
1090
|
+
* where a_sum is a scalar A row sum and b_sums_vec contains 4 B column sums as SIMD vector.
|
|
1091
|
+
*
|
|
1092
|
+
* Buffer layout: [ Header ] [ Packed data ] [ Norms ] [ Column sums ]
|
|
1093
|
+
* The norms occupy the same position as in non-compensated packs, so spatial functions work.
|
|
1094
|
+
*/
|
|
1095
|
+
#define nk_define_cross_compensated_packed_( \
|
|
1096
|
+
api_name, input_type_name, isa_suffix, input_value_type, packed_value_type, result_value_type, sum_value_type, \
|
|
1097
|
+
norm_value_type, vec_type, state_type, result_vec_type, init_accumulator_fn, load_a_vec_fn, partial_load_a_vec_fn, \
|
|
1098
|
+
load_b_vec_fn, partial_load_b_vec_fn, inner_product_fn, compensated_finalize_fn, store_fn, partial_store_fn, \
|
|
1099
|
+
load_sum_fn, partial_load_sum_fn, compute_a_sum_fn, depth_simd_dimensions, dimensions_per_value) \
|
|
1100
|
+
NK_PUBLIC void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_aligned_( \
|
|
1101
|
+
nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
|
|
1102
|
+
nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
|
|
1103
|
+
nk_size_t c_stride_in_bytes) { \
|
|
1104
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer; \
|
|
1105
|
+
nk_size_t const depth_padded = header->depth_padded_values; \
|
|
1106
|
+
nk_##packed_value_type##_t const *packed_data = \
|
|
1107
|
+
(nk_##packed_value_type##_t const *)((char const *)b_packed_buffer + \
|
|
1108
|
+
sizeof(nk_cross_packed_buffer_header_t)); \
|
|
1109
|
+
/* Locate column sums: after packed data + norms */ \
|
|
1110
|
+
nk_size_t const total_packed_values = column_count * depth_padded; \
|
|
1111
|
+
nk_##norm_value_type##_t const *b_norms = (nk_##norm_value_type##_t const *)(packed_data + \
|
|
1112
|
+
total_packed_values); \
|
|
1113
|
+
nk_##sum_value_type##_t const *b_sums = (nk_##sum_value_type##_t const *)(b_norms + column_count); \
|
|
1114
|
+
nk_unused_(b_norms); \
|
|
1115
|
+
nk_size_t const row_block_size = 128; \
|
|
1116
|
+
nk_size_t const column_block_size = 2048; \
|
|
1117
|
+
nk_size_t const register_row_count = 4; \
|
|
1118
|
+
nk_size_t const register_column_count = 4; \
|
|
1119
|
+
nk_size_t const depth_dimensions_aligned = (depth / depth_simd_dimensions) * depth_simd_dimensions; \
|
|
1120
|
+
nk_size_t const aligned_depth = nk_size_divide_round_up_(depth_dimensions_aligned, dimensions_per_value); \
|
|
1121
|
+
nk_size_t const depth_step_values = nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
1122
|
+
for (nk_size_t row_index = 0; row_index < row_count; ++row_index) { \
|
|
1123
|
+
nk_##result_value_type##_t *c_row = (nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
1124
|
+
row_index * c_stride_in_bytes); \
|
|
1125
|
+
for (nk_size_t ci = 0; ci < column_count; ++ci) c_row[ci] = 0; \
|
|
1126
|
+
} \
|
|
1127
|
+
for (nk_size_t cb = 0; cb < column_count; cb += column_block_size) { \
|
|
1128
|
+
nk_size_t ce = cb + column_block_size; \
|
|
1129
|
+
if (ce > column_count) ce = column_count; \
|
|
1130
|
+
for (nk_size_t rb = 0; rb < row_count; rb += row_block_size) { \
|
|
1131
|
+
nk_size_t re = rb + row_block_size; \
|
|
1132
|
+
if (re > row_count) re = row_count; \
|
|
1133
|
+
for (nk_size_t tc = cb; tc < ce; tc += register_column_count) { \
|
|
1134
|
+
nk_##packed_value_type##_t const *b_depth_ptr_0 = packed_data + (tc + 0) * depth_padded; \
|
|
1135
|
+
nk_##packed_value_type##_t const *b_depth_ptr_1 = packed_data + (tc + 1) * depth_padded; \
|
|
1136
|
+
nk_##packed_value_type##_t const *b_depth_ptr_2 = packed_data + (tc + 2) * depth_padded; \
|
|
1137
|
+
nk_##packed_value_type##_t const *b_depth_ptr_3 = packed_data + (tc + 3) * depth_padded; \
|
|
1138
|
+
/* Load 4 B column sums as SIMD vector */ \
|
|
1139
|
+
result_vec_type b_sum_vec; \
|
|
1140
|
+
load_sum_fn(b_sums + tc, &b_sum_vec); \
|
|
1141
|
+
for (nk_size_t tr = rb; tr < re; tr += register_row_count) { \
|
|
1142
|
+
state_type acc[4][4]; \
|
|
1143
|
+
init_accumulator_fn(&acc[0][0]), init_accumulator_fn(&acc[0][1]), \
|
|
1144
|
+
init_accumulator_fn(&acc[0][2]), init_accumulator_fn(&acc[0][3]); \
|
|
1145
|
+
init_accumulator_fn(&acc[1][0]), init_accumulator_fn(&acc[1][1]), \
|
|
1146
|
+
init_accumulator_fn(&acc[1][2]), init_accumulator_fn(&acc[1][3]); \
|
|
1147
|
+
init_accumulator_fn(&acc[2][0]), init_accumulator_fn(&acc[2][1]), \
|
|
1148
|
+
init_accumulator_fn(&acc[2][2]), init_accumulator_fn(&acc[2][3]); \
|
|
1149
|
+
init_accumulator_fn(&acc[3][0]), init_accumulator_fn(&acc[3][1]), \
|
|
1150
|
+
init_accumulator_fn(&acc[3][2]), init_accumulator_fn(&acc[3][3]); \
|
|
1151
|
+
nk_##input_value_type##_t const *a_row_ptr_0 = \
|
|
1152
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
1153
|
+
(tr + 0) * a_stride_in_bytes); \
|
|
1154
|
+
nk_##input_value_type##_t const *a_row_ptr_1 = \
|
|
1155
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
1156
|
+
(tr + 1) * a_stride_in_bytes); \
|
|
1157
|
+
nk_##input_value_type##_t const *a_row_ptr_2 = \
|
|
1158
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
1159
|
+
(tr + 2) * a_stride_in_bytes); \
|
|
1160
|
+
nk_##input_value_type##_t const *a_row_ptr_3 = \
|
|
1161
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
1162
|
+
(tr + 3) * a_stride_in_bytes); \
|
|
1163
|
+
/* Precompute A row sums (no-op for i8/u8, real for i4) */ \
|
|
1164
|
+
nk_##sum_value_type##_t a_sums[4]; \
|
|
1165
|
+
a_sums[0] = compute_a_sum_fn(a_row_ptr_0, depth); \
|
|
1166
|
+
a_sums[1] = compute_a_sum_fn(a_row_ptr_1, depth); \
|
|
1167
|
+
a_sums[2] = compute_a_sum_fn(a_row_ptr_2, depth); \
|
|
1168
|
+
a_sums[3] = compute_a_sum_fn(a_row_ptr_3, depth); \
|
|
1169
|
+
vec_type av0, av1, av2, av3, bv0, bv1, bv2, bv3; \
|
|
1170
|
+
for (nk_size_t di = 0; di < aligned_depth; di += depth_step_values) { \
|
|
1171
|
+
load_a_vec_fn(a_row_ptr_0 + di, &av0); \
|
|
1172
|
+
load_a_vec_fn(a_row_ptr_1 + di, &av1); \
|
|
1173
|
+
load_a_vec_fn(a_row_ptr_2 + di, &av2); \
|
|
1174
|
+
load_a_vec_fn(a_row_ptr_3 + di, &av3); \
|
|
1175
|
+
load_b_vec_fn(b_depth_ptr_0 + di, &bv0); \
|
|
1176
|
+
load_b_vec_fn(b_depth_ptr_1 + di, &bv1); \
|
|
1177
|
+
load_b_vec_fn(b_depth_ptr_2 + di, &bv2); \
|
|
1178
|
+
load_b_vec_fn(b_depth_ptr_3 + di, &bv3); \
|
|
1179
|
+
inner_product_fn(&acc[0][0], av0, bv0, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1180
|
+
inner_product_fn(&acc[0][1], av0, bv1, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1181
|
+
inner_product_fn(&acc[0][2], av0, bv2, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1182
|
+
inner_product_fn(&acc[0][3], av0, bv3, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1183
|
+
inner_product_fn(&acc[1][0], av1, bv0, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1184
|
+
inner_product_fn(&acc[1][1], av1, bv1, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1185
|
+
inner_product_fn(&acc[1][2], av1, bv2, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1186
|
+
inner_product_fn(&acc[1][3], av1, bv3, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1187
|
+
inner_product_fn(&acc[2][0], av2, bv0, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1188
|
+
inner_product_fn(&acc[2][1], av2, bv1, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1189
|
+
inner_product_fn(&acc[2][2], av2, bv2, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1190
|
+
inner_product_fn(&acc[2][3], av2, bv3, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1191
|
+
inner_product_fn(&acc[3][0], av3, bv0, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1192
|
+
inner_product_fn(&acc[3][1], av3, bv1, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1193
|
+
inner_product_fn(&acc[3][2], av3, bv2, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1194
|
+
inner_product_fn(&acc[3][3], av3, bv3, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1195
|
+
} \
|
|
1196
|
+
/* Compensated finalize: apply correction with precomputed sums */ \
|
|
1197
|
+
result_vec_type result_vector; \
|
|
1198
|
+
for (nk_size_t r = 0; r < register_row_count; ++r) { \
|
|
1199
|
+
compensated_finalize_fn(&acc[r][0], &acc[r][1], &acc[r][2], &acc[r][3], depth, a_sums[r], \
|
|
1200
|
+
b_sum_vec, &result_vector); \
|
|
1201
|
+
nk_##result_value_type##_t *c_row = \
|
|
1202
|
+
(nk_##result_value_type##_t *)((char *)c_matrix + (tr + r) * c_stride_in_bytes); \
|
|
1203
|
+
store_fn(&result_vector, c_row + tc); \
|
|
1204
|
+
} \
|
|
1205
|
+
} \
|
|
1206
|
+
} \
|
|
1207
|
+
} \
|
|
1208
|
+
} \
|
|
1209
|
+
} \
|
|
1210
|
+
NK_PUBLIC void nk_##api_name##_packed_##input_type_name##_##isa_suffix##_1x8_aligned_( \
|
|
1211
|
+
nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
|
|
1212
|
+
nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
|
|
1213
|
+
nk_size_t c_stride_in_bytes) { \
|
|
1214
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer; \
|
|
1215
|
+
nk_size_t const depth_padded = header->depth_padded_values; \
|
|
1216
|
+
nk_##packed_value_type##_t const *packed_data = \
|
|
1217
|
+
(nk_##packed_value_type##_t const *)((char const *)b_packed_buffer + \
|
|
1218
|
+
sizeof(nk_cross_packed_buffer_header_t)); \
|
|
1219
|
+
nk_size_t const total_packed_values = column_count * depth_padded; \
|
|
1220
|
+
nk_##norm_value_type##_t const *b_norms = (nk_##norm_value_type##_t const *)(packed_data + \
|
|
1221
|
+
total_packed_values); \
|
|
1222
|
+
nk_##sum_value_type##_t const *b_sums = (nk_##sum_value_type##_t const *)(b_norms + column_count); \
|
|
1223
|
+
nk_unused_(b_norms); \
|
|
1224
|
+
nk_size_t const row_block_size = 128; \
|
|
1225
|
+
nk_size_t const column_block_size = 2048; \
|
|
1226
|
+
nk_size_t const register_column_count = 8; \
|
|
1227
|
+
nk_size_t const depth_dimensions_aligned = (depth / depth_simd_dimensions) * depth_simd_dimensions; \
|
|
1228
|
+
nk_size_t const aligned_depth = nk_size_divide_round_up_(depth_dimensions_aligned, dimensions_per_value); \
|
|
1229
|
+
nk_size_t const depth_step_values = nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
1230
|
+
for (nk_size_t row_index = 0; row_index < row_count; ++row_index) { \
|
|
1231
|
+
nk_##result_value_type##_t *c_row = (nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
1232
|
+
row_index * c_stride_in_bytes); \
|
|
1233
|
+
for (nk_size_t ci = 0; ci < column_count; ++ci) c_row[ci] = 0; \
|
|
1234
|
+
} \
|
|
1235
|
+
for (nk_size_t cb = 0; cb < column_count; cb += column_block_size) { \
|
|
1236
|
+
nk_size_t ce = cb + column_block_size; \
|
|
1237
|
+
if (ce > column_count) ce = column_count; \
|
|
1238
|
+
for (nk_size_t rb2 = 0; rb2 < row_count; rb2 += row_block_size) { \
|
|
1239
|
+
nk_size_t re2 = rb2 + row_block_size < row_count ? rb2 + row_block_size : row_count; \
|
|
1240
|
+
for (nk_size_t tc = cb; tc < ce; tc += register_column_count) { \
|
|
1241
|
+
nk_##packed_value_type##_t const *bp0 = packed_data + (tc + 0) * depth_padded; \
|
|
1242
|
+
nk_##packed_value_type##_t const *bp1 = packed_data + (tc + 1) * depth_padded; \
|
|
1243
|
+
nk_##packed_value_type##_t const *bp2 = packed_data + (tc + 2) * depth_padded; \
|
|
1244
|
+
nk_##packed_value_type##_t const *bp3 = packed_data + (tc + 3) * depth_padded; \
|
|
1245
|
+
nk_##packed_value_type##_t const *bp4 = packed_data + (tc + 4) * depth_padded; \
|
|
1246
|
+
nk_##packed_value_type##_t const *bp5 = packed_data + (tc + 5) * depth_padded; \
|
|
1247
|
+
nk_##packed_value_type##_t const *bp6 = packed_data + (tc + 6) * depth_padded; \
|
|
1248
|
+
nk_##packed_value_type##_t const *bp7 = packed_data + (tc + 7) * depth_padded; \
|
|
1249
|
+
result_vec_type b_sum_lo, b_sum_hi; \
|
|
1250
|
+
load_sum_fn(b_sums + tc, &b_sum_lo); \
|
|
1251
|
+
load_sum_fn(b_sums + tc + 4, &b_sum_hi); \
|
|
1252
|
+
for (nk_size_t ri = rb2; ri < re2; ++ri) { \
|
|
1253
|
+
state_type s0, s1, s2, s3, s4, s5, s6, s7; \
|
|
1254
|
+
init_accumulator_fn(&s0), init_accumulator_fn(&s1), init_accumulator_fn(&s2), \
|
|
1255
|
+
init_accumulator_fn(&s3), init_accumulator_fn(&s4), init_accumulator_fn(&s5), \
|
|
1256
|
+
init_accumulator_fn(&s6), init_accumulator_fn(&s7); \
|
|
1257
|
+
nk_##input_value_type##_t const *a_row = \
|
|
1258
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + ri * a_stride_in_bytes); \
|
|
1259
|
+
nk_##sum_value_type##_t a_sum_val = compute_a_sum_fn(a_row, depth); \
|
|
1260
|
+
vec_type av; \
|
|
1261
|
+
vec_type bv0, bv1, bv2, bv3, bv4, bv5, bv6, bv7; \
|
|
1262
|
+
for (nk_size_t di = 0; di < aligned_depth; di += depth_step_values) { \
|
|
1263
|
+
load_a_vec_fn(a_row + di, &av); \
|
|
1264
|
+
load_b_vec_fn(bp0 + di, &bv0), load_b_vec_fn(bp1 + di, &bv1); \
|
|
1265
|
+
load_b_vec_fn(bp2 + di, &bv2), load_b_vec_fn(bp3 + di, &bv3); \
|
|
1266
|
+
load_b_vec_fn(bp4 + di, &bv4), load_b_vec_fn(bp5 + di, &bv5); \
|
|
1267
|
+
load_b_vec_fn(bp6 + di, &bv6), load_b_vec_fn(bp7 + di, &bv7); \
|
|
1268
|
+
inner_product_fn(&s0, av, bv0, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1269
|
+
inner_product_fn(&s1, av, bv1, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1270
|
+
inner_product_fn(&s2, av, bv2, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1271
|
+
inner_product_fn(&s3, av, bv3, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1272
|
+
inner_product_fn(&s4, av, bv4, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1273
|
+
inner_product_fn(&s5, av, bv5, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1274
|
+
inner_product_fn(&s6, av, bv6, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1275
|
+
inner_product_fn(&s7, av, bv7, di * dimensions_per_value, depth_simd_dimensions); \
|
|
1276
|
+
} \
|
|
1277
|
+
result_vec_type rv; \
|
|
1278
|
+
nk_##result_value_type##_t *c_row = (nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
1279
|
+
ri * c_stride_in_bytes); \
|
|
1280
|
+
compensated_finalize_fn(&s0, &s1, &s2, &s3, depth, a_sum_val, b_sum_lo, &rv); \
|
|
1281
|
+
store_fn(&rv, c_row + tc); \
|
|
1282
|
+
compensated_finalize_fn(&s4, &s5, &s6, &s7, depth, a_sum_val, b_sum_hi, &rv); \
|
|
1283
|
+
store_fn(&rv, c_row + tc + 4); \
|
|
1284
|
+
} \
|
|
1285
|
+
} \
|
|
1286
|
+
} \
|
|
1287
|
+
} \
|
|
1288
|
+
} \
|
|
1289
|
+
NK_PUBLIC void nk_##api_name##_packed_##input_type_name##_##isa_suffix( \
|
|
1290
|
+
nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##result_value_type##_t *c_matrix, \
|
|
1291
|
+
nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
|
|
1292
|
+
nk_size_t c_stride_in_bytes) { \
|
|
1293
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer; \
|
|
1294
|
+
nk_size_t const depth_padded = header->depth_padded_values; \
|
|
1295
|
+
nk_size_t const row_block_size = 128; \
|
|
1296
|
+
nk_size_t const column_block_size = 2048; \
|
|
1297
|
+
nk_size_t const register_row_count = 4; \
|
|
1298
|
+
nk_size_t const register_column_count = 4; \
|
|
1299
|
+
nk_unused_(register_column_count); \
|
|
1300
|
+
if (column_count % 8 == 0 && column_count >= row_count * 2 && depth % depth_simd_dimensions == 0) { \
|
|
1301
|
+
nk_##api_name##_packed_##input_type_name##_##isa_suffix##_1x8_aligned_( \
|
|
1302
|
+
a_matrix, b_packed_buffer, c_matrix, row_count, column_count, depth, a_stride_in_bytes, \
|
|
1303
|
+
c_stride_in_bytes); \
|
|
1304
|
+
return; \
|
|
1305
|
+
} \
|
|
1306
|
+
if (row_count % 4 == 0 && column_count % 4 == 0 && depth % depth_simd_dimensions == 0) { \
|
|
1307
|
+
nk_##api_name##_packed_##input_type_name##_##isa_suffix##_aligned_(a_matrix, b_packed_buffer, c_matrix, \
|
|
1308
|
+
row_count, column_count, depth, \
|
|
1309
|
+
a_stride_in_bytes, c_stride_in_bytes); \
|
|
1310
|
+
return; \
|
|
1311
|
+
} \
|
|
1312
|
+
/* Generic fallback with partial loads and compensated finalize */ \
|
|
1313
|
+
nk_##packed_value_type##_t const *packed_data = \
|
|
1314
|
+
(nk_##packed_value_type##_t const *)((char const *)b_packed_buffer + \
|
|
1315
|
+
sizeof(nk_cross_packed_buffer_header_t)); \
|
|
1316
|
+
nk_size_t const total_packed_values = column_count * depth_padded; \
|
|
1317
|
+
nk_##norm_value_type##_t const *b_norms = (nk_##norm_value_type##_t const *)(packed_data + \
|
|
1318
|
+
total_packed_values); \
|
|
1319
|
+
nk_##sum_value_type##_t const *b_sums = (nk_##sum_value_type##_t const *)(b_norms + column_count); \
|
|
1320
|
+
nk_unused_(b_norms); \
|
|
1321
|
+
nk_size_t const depth_dimensions_aligned = (depth / depth_simd_dimensions) * depth_simd_dimensions; \
|
|
1322
|
+
nk_size_t const aligned_depth = nk_size_divide_round_up_(depth_dimensions_aligned, dimensions_per_value); \
|
|
1323
|
+
nk_size_t const depth_in_values = nk_size_divide_round_up_(depth, dimensions_per_value); \
|
|
1324
|
+
nk_size_t const remainder_depth = depth_in_values - aligned_depth; \
|
|
1325
|
+
nk_size_t const remainder_dimensions = depth - depth_dimensions_aligned; \
|
|
1326
|
+
nk_size_t const depth_step_values = nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
1327
|
+
for (nk_size_t row_index = 0; row_index < row_count; ++row_index) { \
|
|
1328
|
+
nk_##result_value_type##_t *c_row = (nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
1329
|
+
row_index * c_stride_in_bytes); \
|
|
1330
|
+
for (nk_size_t ci = 0; ci < column_count; ++ci) c_row[ci] = 0; \
|
|
1331
|
+
} \
|
|
1332
|
+
for (nk_size_t cb = 0; cb < column_count; cb += column_block_size) { \
|
|
1333
|
+
nk_size_t ce = cb + column_block_size; \
|
|
1334
|
+
if (ce > column_count) ce = column_count; \
|
|
1335
|
+
for (nk_size_t rb = 0; rb < row_count; rb += row_block_size) { \
|
|
1336
|
+
nk_size_t re = rb + row_block_size; \
|
|
1337
|
+
if (re > row_count) re = row_count; \
|
|
1338
|
+
for (nk_size_t tc = cb; tc < ce; tc += register_column_count) { \
|
|
1339
|
+
nk_size_t tile_col_count = register_column_count; \
|
|
1340
|
+
if (tc + tile_col_count > ce) tile_col_count = ce - tc; \
|
|
1341
|
+
nk_##packed_value_type##_t const *bdp0 = packed_data + (tc + 0) * depth_padded; \
|
|
1342
|
+
nk_##packed_value_type##_t const *bdp1 = (tile_col_count > 1) \
|
|
1343
|
+
? packed_data + (tc + 1) * depth_padded \
|
|
1344
|
+
: bdp0; \
|
|
1345
|
+
nk_##packed_value_type##_t const *bdp2 = (tile_col_count > 2) \
|
|
1346
|
+
? packed_data + (tc + 2) * depth_padded \
|
|
1347
|
+
: bdp0; \
|
|
1348
|
+
nk_##packed_value_type##_t const *bdp3 = (tile_col_count > 3) \
|
|
1349
|
+
? packed_data + (tc + 3) * depth_padded \
|
|
1350
|
+
: bdp0; \
|
|
1351
|
+
result_vec_type b_sum_vec; \
|
|
1352
|
+
partial_load_sum_fn(b_sums + tc, &b_sum_vec, tile_col_count); \
|
|
1353
|
+
for (nk_size_t tr = rb; tr < re; tr += register_row_count) { \
|
|
1354
|
+
nk_size_t tile_row_count = register_row_count; \
|
|
1355
|
+
if (tr + tile_row_count > re) tile_row_count = re - tr; \
|
|
1356
|
+
state_type acc[4][4]; \
|
|
1357
|
+
for (nk_size_t rr = 0; rr < tile_row_count; ++rr) { \
|
|
1358
|
+
init_accumulator_fn(&acc[rr][0]); \
|
|
1359
|
+
init_accumulator_fn(&acc[rr][1]); \
|
|
1360
|
+
init_accumulator_fn(&acc[rr][2]); \
|
|
1361
|
+
init_accumulator_fn(&acc[rr][3]); \
|
|
1362
|
+
} \
|
|
1363
|
+
nk_##input_value_type##_t const *arp0 = \
|
|
1364
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
1365
|
+
(tr + 0) * a_stride_in_bytes); \
|
|
1366
|
+
nk_##input_value_type##_t const *arp1 = \
|
|
1367
|
+
(tile_row_count > 1) ? (nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
1368
|
+
(tr + 1) * a_stride_in_bytes) \
|
|
1369
|
+
: arp0; \
|
|
1370
|
+
nk_##input_value_type##_t const *arp2 = \
|
|
1371
|
+
(tile_row_count > 2) ? (nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
1372
|
+
(tr + 2) * a_stride_in_bytes) \
|
|
1373
|
+
: arp0; \
|
|
1374
|
+
nk_##input_value_type##_t const *arp3 = \
|
|
1375
|
+
(tile_row_count > 3) ? (nk_##input_value_type##_t const *)((char const *)a_matrix + \
|
|
1376
|
+
(tr + 3) * a_stride_in_bytes) \
|
|
1377
|
+
: arp0; \
|
|
1378
|
+
nk_##sum_value_type##_t a_sums[4]; \
|
|
1379
|
+
a_sums[0] = compute_a_sum_fn(arp0, depth); \
|
|
1380
|
+
a_sums[1] = (tile_row_count > 1) ? compute_a_sum_fn(arp1, depth) : 0; \
|
|
1381
|
+
a_sums[2] = (tile_row_count > 2) ? compute_a_sum_fn(arp2, depth) : 0; \
|
|
1382
|
+
a_sums[3] = (tile_row_count > 3) ? compute_a_sum_fn(arp3, depth) : 0; \
|
|
1383
|
+
vec_type av0, av1, av2, av3, bv0, bv1, bv2, bv3; \
|
|
1384
|
+
for (nk_size_t k = 0; k < aligned_depth; k += depth_step_values) { \
|
|
1385
|
+
load_a_vec_fn(arp0 + k, &av0); \
|
|
1386
|
+
load_a_vec_fn(arp1 + k, &av1); \
|
|
1387
|
+
load_a_vec_fn(arp2 + k, &av2); \
|
|
1388
|
+
load_a_vec_fn(arp3 + k, &av3); \
|
|
1389
|
+
load_b_vec_fn(bdp0 + k, &bv0); \
|
|
1390
|
+
load_b_vec_fn(bdp1 + k, &bv1); \
|
|
1391
|
+
load_b_vec_fn(bdp2 + k, &bv2); \
|
|
1392
|
+
load_b_vec_fn(bdp3 + k, &bv3); \
|
|
1393
|
+
inner_product_fn(&acc[0][0], av0, bv0, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1394
|
+
inner_product_fn(&acc[0][1], av0, bv1, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1395
|
+
inner_product_fn(&acc[0][2], av0, bv2, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1396
|
+
inner_product_fn(&acc[0][3], av0, bv3, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1397
|
+
inner_product_fn(&acc[1][0], av1, bv0, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1398
|
+
inner_product_fn(&acc[1][1], av1, bv1, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1399
|
+
inner_product_fn(&acc[1][2], av1, bv2, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1400
|
+
inner_product_fn(&acc[1][3], av1, bv3, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1401
|
+
inner_product_fn(&acc[2][0], av2, bv0, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1402
|
+
inner_product_fn(&acc[2][1], av2, bv1, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1403
|
+
inner_product_fn(&acc[2][2], av2, bv2, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1404
|
+
inner_product_fn(&acc[2][3], av2, bv3, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1405
|
+
inner_product_fn(&acc[3][0], av3, bv0, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1406
|
+
inner_product_fn(&acc[3][1], av3, bv1, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1407
|
+
inner_product_fn(&acc[3][2], av3, bv2, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1408
|
+
inner_product_fn(&acc[3][3], av3, bv3, k * dimensions_per_value, depth_simd_dimensions); \
|
|
1409
|
+
} \
|
|
1410
|
+
if (remainder_depth > 0) { \
|
|
1411
|
+
partial_load_a_vec_fn(arp0 + aligned_depth, &av0, remainder_dimensions); \
|
|
1412
|
+
partial_load_a_vec_fn(arp1 + aligned_depth, &av1, remainder_dimensions); \
|
|
1413
|
+
partial_load_a_vec_fn(arp2 + aligned_depth, &av2, remainder_dimensions); \
|
|
1414
|
+
partial_load_a_vec_fn(arp3 + aligned_depth, &av3, remainder_dimensions); \
|
|
1415
|
+
partial_load_b_vec_fn(bdp0 + aligned_depth, &bv0, remainder_dimensions); \
|
|
1416
|
+
partial_load_b_vec_fn(bdp1 + aligned_depth, &bv1, remainder_dimensions); \
|
|
1417
|
+
partial_load_b_vec_fn(bdp2 + aligned_depth, &bv2, remainder_dimensions); \
|
|
1418
|
+
partial_load_b_vec_fn(bdp3 + aligned_depth, &bv3, remainder_dimensions); \
|
|
1419
|
+
inner_product_fn(&acc[0][0], av0, bv0, aligned_depth * dimensions_per_value, \
|
|
1420
|
+
remainder_dimensions); \
|
|
1421
|
+
inner_product_fn(&acc[0][1], av0, bv1, aligned_depth * dimensions_per_value, \
|
|
1422
|
+
remainder_dimensions); \
|
|
1423
|
+
inner_product_fn(&acc[0][2], av0, bv2, aligned_depth * dimensions_per_value, \
|
|
1424
|
+
remainder_dimensions); \
|
|
1425
|
+
inner_product_fn(&acc[0][3], av0, bv3, aligned_depth * dimensions_per_value, \
|
|
1426
|
+
remainder_dimensions); \
|
|
1427
|
+
inner_product_fn(&acc[1][0], av1, bv0, aligned_depth * dimensions_per_value, \
|
|
1428
|
+
remainder_dimensions); \
|
|
1429
|
+
inner_product_fn(&acc[1][1], av1, bv1, aligned_depth * dimensions_per_value, \
|
|
1430
|
+
remainder_dimensions); \
|
|
1431
|
+
inner_product_fn(&acc[1][2], av1, bv2, aligned_depth * dimensions_per_value, \
|
|
1432
|
+
remainder_dimensions); \
|
|
1433
|
+
inner_product_fn(&acc[1][3], av1, bv3, aligned_depth * dimensions_per_value, \
|
|
1434
|
+
remainder_dimensions); \
|
|
1435
|
+
inner_product_fn(&acc[2][0], av2, bv0, aligned_depth * dimensions_per_value, \
|
|
1436
|
+
remainder_dimensions); \
|
|
1437
|
+
inner_product_fn(&acc[2][1], av2, bv1, aligned_depth * dimensions_per_value, \
|
|
1438
|
+
remainder_dimensions); \
|
|
1439
|
+
inner_product_fn(&acc[2][2], av2, bv2, aligned_depth * dimensions_per_value, \
|
|
1440
|
+
remainder_dimensions); \
|
|
1441
|
+
inner_product_fn(&acc[2][3], av2, bv3, aligned_depth * dimensions_per_value, \
|
|
1442
|
+
remainder_dimensions); \
|
|
1443
|
+
inner_product_fn(&acc[3][0], av3, bv0, aligned_depth * dimensions_per_value, \
|
|
1444
|
+
remainder_dimensions); \
|
|
1445
|
+
inner_product_fn(&acc[3][1], av3, bv1, aligned_depth * dimensions_per_value, \
|
|
1446
|
+
remainder_dimensions); \
|
|
1447
|
+
inner_product_fn(&acc[3][2], av3, bv2, aligned_depth * dimensions_per_value, \
|
|
1448
|
+
remainder_dimensions); \
|
|
1449
|
+
inner_product_fn(&acc[3][3], av3, bv3, aligned_depth * dimensions_per_value, \
|
|
1450
|
+
remainder_dimensions); \
|
|
1451
|
+
} \
|
|
1452
|
+
for (nk_size_t rr = 0; rr < tile_row_count; ++rr) { \
|
|
1453
|
+
result_vec_type rv; \
|
|
1454
|
+
compensated_finalize_fn(&acc[rr][0], &acc[rr][1], &acc[rr][2], &acc[rr][3], depth, \
|
|
1455
|
+
a_sums[rr], b_sum_vec, &rv); \
|
|
1456
|
+
nk_##result_value_type##_t *c_row = \
|
|
1457
|
+
(nk_##result_value_type##_t *)((char *)c_matrix + (tr + rr) * c_stride_in_bytes); \
|
|
1458
|
+
partial_store_fn(&rv, c_row + tc, tile_col_count); \
|
|
1459
|
+
} \
|
|
1460
|
+
} \
|
|
1461
|
+
} \
|
|
1462
|
+
} \
|
|
1463
|
+
} \
|
|
1464
|
+
}
|
|
1465
|
+
|
|
1466
|
+
/**
|
|
1467
|
+
* @brief Generates compensated symmetric Gram matrix: C = A × Aᵀ with inline correction.
|
|
1468
|
+
*
|
|
1469
|
+
* Like nk_define_cross_symmetric_ but the finalize function receives precomputed sums.
|
|
1470
|
+
* For symmetric computation, both row and column vectors come from the same matrix A,
|
|
1471
|
+
* so A sums serve as both row and column sums.
|
|
1472
|
+
*
|
|
1473
|
+
* The off-diagonal helper uses 4×4 tiling (matching nk_define_cross_symmetric_) with
|
|
1474
|
+
* progressive sum accumulation: SAD runs on port 5 alongside DPBUSD on ports 0+1 for
|
|
1475
|
+
* zero throughput overhead on Alder Lake and Ice Lake.
|
|
1476
|
+
*/
|
|
1477
|
+
#define nk_define_cross_compensated_symmetric_( \
|
|
1478
|
+
api_name, input_type_name, isa_suffix, input_value_type, result_value_type, sum_value_type, norm_value_type, \
|
|
1479
|
+
vec_type, state_type, result_vec_type, init_accumulator_fn, load_vec_fn, partial_load_vec_fn, inner_product_fn, \
|
|
1480
|
+
compensated_finalize_fn, store_fn, partial_store_fn, load_sum_fn, partial_load_sum_fn, sum_state_type, \
|
|
1481
|
+
init_sum_fn, update_sum_fn, finalize_sum_fn, depth_simd_dimensions, dimensions_per_value) \
|
|
1482
|
+
NK_INTERNAL void nk_##api_name##_symmetric_diagonal_##input_type_name##_##isa_suffix##_( \
|
|
1483
|
+
nk_##input_value_type##_t const **vector_base_ptrs, nk_size_t i_macro, nk_size_t macro_size, \
|
|
1484
|
+
nk_size_t aligned_depth, nk_size_t remainder_depth, nk_size_t remainder_dimensions, \
|
|
1485
|
+
nk_size_t depth_step_values, nk_size_t dimensions_per_value_runtime, nk_##result_value_type##_t *result, \
|
|
1486
|
+
nk_size_t result_stride_values, nk_size_t finalizer_batch_size, nk_size_t depth) { \
|
|
1487
|
+
nk_unused_(finalizer_batch_size); \
|
|
1488
|
+
nk_unused_(dimensions_per_value_runtime); \
|
|
1489
|
+
/* Compute sums via stateful helpers — separate loop is fine since diagonal is ~1.6% of work */ \
|
|
1490
|
+
nk_size_t padded_depth_dimensions = aligned_depth * dimensions_per_value + \
|
|
1491
|
+
(remainder_depth > 0 ? depth_simd_dimensions : 0); \
|
|
1492
|
+
nk_##sum_value_type##_t precomputed_sums[32]; \
|
|
1493
|
+
for (nk_size_t s = 0; s < macro_size; s++) { \
|
|
1494
|
+
sum_state_type ss; \
|
|
1495
|
+
init_sum_fn(&ss); \
|
|
1496
|
+
for (nk_size_t di = 0; di < aligned_depth; di += depth_step_values) { \
|
|
1497
|
+
vec_type v; \
|
|
1498
|
+
load_vec_fn(vector_base_ptrs[s] + di, &v); \
|
|
1499
|
+
update_sum_fn(&ss, v); \
|
|
1500
|
+
} \
|
|
1501
|
+
if (remainder_depth > 0) { \
|
|
1502
|
+
vec_type v; \
|
|
1503
|
+
partial_load_vec_fn(vector_base_ptrs[s] + aligned_depth, &v, remainder_dimensions); \
|
|
1504
|
+
update_sum_fn(&ss, v); \
|
|
1505
|
+
} \
|
|
1506
|
+
precomputed_sums[s] = finalize_sum_fn(&ss, padded_depth_dimensions); \
|
|
1507
|
+
} \
|
|
1508
|
+
for (nk_size_t tile_row_start = 0; tile_row_start < macro_size; tile_row_start += 4) { \
|
|
1509
|
+
for (nk_size_t tile_col_start = tile_row_start; tile_col_start < macro_size; tile_col_start += 4) { \
|
|
1510
|
+
nk_size_t tile_rows = (tile_row_start + 4 <= macro_size) ? 4 : (macro_size - tile_row_start); \
|
|
1511
|
+
nk_size_t tile_cols = (tile_col_start + 4 <= macro_size) ? 4 : (macro_size - tile_col_start); \
|
|
1512
|
+
int is_diag = (tile_row_start == tile_col_start); \
|
|
1513
|
+
NK_ALIGN64 state_type accumulators[4][7]; \
|
|
1514
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
1515
|
+
nk_size_t init_start = is_diag ? row : 0; \
|
|
1516
|
+
nk_size_t init_end = is_diag ? (row + 4) : tile_cols; \
|
|
1517
|
+
for (nk_size_t col = init_start; col < init_end; col++) { \
|
|
1518
|
+
init_accumulator_fn(&accumulators[row][col]); \
|
|
1519
|
+
} \
|
|
1520
|
+
} \
|
|
1521
|
+
nk_##input_value_type##_t const *row_ptrs[4], *col_ptrs[4]; \
|
|
1522
|
+
row_ptrs[0] = vector_base_ptrs[tile_row_start + 0]; \
|
|
1523
|
+
row_ptrs[1] = (tile_rows > 1) ? vector_base_ptrs[tile_row_start + 1] : row_ptrs[0]; \
|
|
1524
|
+
row_ptrs[2] = (tile_rows > 2) ? vector_base_ptrs[tile_row_start + 2] : row_ptrs[0]; \
|
|
1525
|
+
row_ptrs[3] = (tile_rows > 3) ? vector_base_ptrs[tile_row_start + 3] : row_ptrs[0]; \
|
|
1526
|
+
if (is_diag) { \
|
|
1527
|
+
col_ptrs[0] = row_ptrs[0]; \
|
|
1528
|
+
col_ptrs[1] = row_ptrs[1]; \
|
|
1529
|
+
col_ptrs[2] = row_ptrs[2]; \
|
|
1530
|
+
col_ptrs[3] = row_ptrs[3]; \
|
|
1531
|
+
} \
|
|
1532
|
+
else { \
|
|
1533
|
+
col_ptrs[0] = vector_base_ptrs[tile_col_start + 0]; \
|
|
1534
|
+
col_ptrs[1] = (tile_cols > 1) ? vector_base_ptrs[tile_col_start + 1] : col_ptrs[0]; \
|
|
1535
|
+
col_ptrs[2] = (tile_cols > 2) ? vector_base_ptrs[tile_col_start + 2] : col_ptrs[0]; \
|
|
1536
|
+
col_ptrs[3] = (tile_cols > 3) ? vector_base_ptrs[tile_col_start + 3] : col_ptrs[0]; \
|
|
1537
|
+
} \
|
|
1538
|
+
vec_type row_vecs[4], col_vecs[4]; \
|
|
1539
|
+
for (nk_size_t di = 0; di < aligned_depth; di += depth_step_values) { \
|
|
1540
|
+
load_vec_fn(row_ptrs[0] + di, &row_vecs[0]); \
|
|
1541
|
+
load_vec_fn(row_ptrs[1] + di, &row_vecs[1]); \
|
|
1542
|
+
load_vec_fn(row_ptrs[2] + di, &row_vecs[2]); \
|
|
1543
|
+
load_vec_fn(row_ptrs[3] + di, &row_vecs[3]); \
|
|
1544
|
+
if (!is_diag) { \
|
|
1545
|
+
load_vec_fn(col_ptrs[0] + di, &col_vecs[0]); \
|
|
1546
|
+
load_vec_fn(col_ptrs[1] + di, &col_vecs[1]); \
|
|
1547
|
+
load_vec_fn(col_ptrs[2] + di, &col_vecs[2]); \
|
|
1548
|
+
load_vec_fn(col_ptrs[3] + di, &col_vecs[3]); \
|
|
1549
|
+
} \
|
|
1550
|
+
else { \
|
|
1551
|
+
col_vecs[0] = row_vecs[0]; \
|
|
1552
|
+
col_vecs[1] = row_vecs[1]; \
|
|
1553
|
+
col_vecs[2] = row_vecs[2]; \
|
|
1554
|
+
col_vecs[3] = row_vecs[3]; \
|
|
1555
|
+
} \
|
|
1556
|
+
if (tile_rows == 4 && tile_cols == 4 && is_diag) { \
|
|
1557
|
+
/* Upper triangle: 10 FMAs */ \
|
|
1558
|
+
inner_product_fn(&accumulators[0][0], row_vecs[0], col_vecs[0], di * dimensions_per_value, \
|
|
1559
|
+
depth_simd_dimensions); \
|
|
1560
|
+
inner_product_fn(&accumulators[0][1], row_vecs[0], col_vecs[1], di * dimensions_per_value, \
|
|
1561
|
+
depth_simd_dimensions); \
|
|
1562
|
+
inner_product_fn(&accumulators[0][2], row_vecs[0], col_vecs[2], di * dimensions_per_value, \
|
|
1563
|
+
depth_simd_dimensions); \
|
|
1564
|
+
inner_product_fn(&accumulators[0][3], row_vecs[0], col_vecs[3], di * dimensions_per_value, \
|
|
1565
|
+
depth_simd_dimensions); \
|
|
1566
|
+
inner_product_fn(&accumulators[1][1], row_vecs[1], col_vecs[1], di * dimensions_per_value, \
|
|
1567
|
+
depth_simd_dimensions); \
|
|
1568
|
+
inner_product_fn(&accumulators[1][2], row_vecs[1], col_vecs[2], di * dimensions_per_value, \
|
|
1569
|
+
depth_simd_dimensions); \
|
|
1570
|
+
inner_product_fn(&accumulators[1][3], row_vecs[1], col_vecs[3], di * dimensions_per_value, \
|
|
1571
|
+
depth_simd_dimensions); \
|
|
1572
|
+
inner_product_fn(&accumulators[2][2], row_vecs[2], col_vecs[2], di * dimensions_per_value, \
|
|
1573
|
+
depth_simd_dimensions); \
|
|
1574
|
+
inner_product_fn(&accumulators[2][3], row_vecs[2], col_vecs[3], di * dimensions_per_value, \
|
|
1575
|
+
depth_simd_dimensions); \
|
|
1576
|
+
inner_product_fn(&accumulators[3][3], row_vecs[3], col_vecs[3], di * dimensions_per_value, \
|
|
1577
|
+
depth_simd_dimensions); \
|
|
1578
|
+
} \
|
|
1579
|
+
else if (tile_rows == 4 && tile_cols == 4) { \
|
|
1580
|
+
/* Full 4×4 rectangle: 16 FMAs */ \
|
|
1581
|
+
inner_product_fn(&accumulators[0][0], row_vecs[0], col_vecs[0], di * dimensions_per_value, \
|
|
1582
|
+
depth_simd_dimensions); \
|
|
1583
|
+
inner_product_fn(&accumulators[0][1], row_vecs[0], col_vecs[1], di * dimensions_per_value, \
|
|
1584
|
+
depth_simd_dimensions); \
|
|
1585
|
+
inner_product_fn(&accumulators[0][2], row_vecs[0], col_vecs[2], di * dimensions_per_value, \
|
|
1586
|
+
depth_simd_dimensions); \
|
|
1587
|
+
inner_product_fn(&accumulators[0][3], row_vecs[0], col_vecs[3], di * dimensions_per_value, \
|
|
1588
|
+
depth_simd_dimensions); \
|
|
1589
|
+
inner_product_fn(&accumulators[1][0], row_vecs[1], col_vecs[0], di * dimensions_per_value, \
|
|
1590
|
+
depth_simd_dimensions); \
|
|
1591
|
+
inner_product_fn(&accumulators[1][1], row_vecs[1], col_vecs[1], di * dimensions_per_value, \
|
|
1592
|
+
depth_simd_dimensions); \
|
|
1593
|
+
inner_product_fn(&accumulators[1][2], row_vecs[1], col_vecs[2], di * dimensions_per_value, \
|
|
1594
|
+
depth_simd_dimensions); \
|
|
1595
|
+
inner_product_fn(&accumulators[1][3], row_vecs[1], col_vecs[3], di * dimensions_per_value, \
|
|
1596
|
+
depth_simd_dimensions); \
|
|
1597
|
+
inner_product_fn(&accumulators[2][0], row_vecs[2], col_vecs[0], di * dimensions_per_value, \
|
|
1598
|
+
depth_simd_dimensions); \
|
|
1599
|
+
inner_product_fn(&accumulators[2][1], row_vecs[2], col_vecs[1], di * dimensions_per_value, \
|
|
1600
|
+
depth_simd_dimensions); \
|
|
1601
|
+
inner_product_fn(&accumulators[2][2], row_vecs[2], col_vecs[2], di * dimensions_per_value, \
|
|
1602
|
+
depth_simd_dimensions); \
|
|
1603
|
+
inner_product_fn(&accumulators[2][3], row_vecs[2], col_vecs[3], di * dimensions_per_value, \
|
|
1604
|
+
depth_simd_dimensions); \
|
|
1605
|
+
inner_product_fn(&accumulators[3][0], row_vecs[3], col_vecs[0], di * dimensions_per_value, \
|
|
1606
|
+
depth_simd_dimensions); \
|
|
1607
|
+
inner_product_fn(&accumulators[3][1], row_vecs[3], col_vecs[1], di * dimensions_per_value, \
|
|
1608
|
+
depth_simd_dimensions); \
|
|
1609
|
+
inner_product_fn(&accumulators[3][2], row_vecs[3], col_vecs[2], di * dimensions_per_value, \
|
|
1610
|
+
depth_simd_dimensions); \
|
|
1611
|
+
inner_product_fn(&accumulators[3][3], row_vecs[3], col_vecs[3], di * dimensions_per_value, \
|
|
1612
|
+
depth_simd_dimensions); \
|
|
1613
|
+
} \
|
|
1614
|
+
else { \
|
|
1615
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
1616
|
+
nk_size_t col_start = is_diag ? row : 0; \
|
|
1617
|
+
nk_size_t col_end = is_diag ? (row < 4 ? 4 : tile_cols) : tile_cols; \
|
|
1618
|
+
for (nk_size_t col = col_start; col < col_end; col++) \
|
|
1619
|
+
inner_product_fn(&accumulators[row][col], row_vecs[row], col_vecs[col], \
|
|
1620
|
+
di * dimensions_per_value, depth_simd_dimensions); \
|
|
1621
|
+
} \
|
|
1622
|
+
} \
|
|
1623
|
+
} \
|
|
1624
|
+
if (remainder_depth > 0) { \
|
|
1625
|
+
partial_load_vec_fn(row_ptrs[0] + aligned_depth, &row_vecs[0], remainder_dimensions); \
|
|
1626
|
+
partial_load_vec_fn(row_ptrs[1] + aligned_depth, &row_vecs[1], remainder_dimensions); \
|
|
1627
|
+
partial_load_vec_fn(row_ptrs[2] + aligned_depth, &row_vecs[2], remainder_dimensions); \
|
|
1628
|
+
partial_load_vec_fn(row_ptrs[3] + aligned_depth, &row_vecs[3], remainder_dimensions); \
|
|
1629
|
+
if (!is_diag) { \
|
|
1630
|
+
partial_load_vec_fn(col_ptrs[0] + aligned_depth, &col_vecs[0], remainder_dimensions); \
|
|
1631
|
+
partial_load_vec_fn(col_ptrs[1] + aligned_depth, &col_vecs[1], remainder_dimensions); \
|
|
1632
|
+
partial_load_vec_fn(col_ptrs[2] + aligned_depth, &col_vecs[2], remainder_dimensions); \
|
|
1633
|
+
partial_load_vec_fn(col_ptrs[3] + aligned_depth, &col_vecs[3], remainder_dimensions); \
|
|
1634
|
+
} \
|
|
1635
|
+
else { \
|
|
1636
|
+
col_vecs[0] = row_vecs[0]; \
|
|
1637
|
+
col_vecs[1] = row_vecs[1]; \
|
|
1638
|
+
col_vecs[2] = row_vecs[2]; \
|
|
1639
|
+
col_vecs[3] = row_vecs[3]; \
|
|
1640
|
+
} \
|
|
1641
|
+
if (tile_rows == 4 && tile_cols == 4 && is_diag) { \
|
|
1642
|
+
inner_product_fn(&accumulators[0][0], row_vecs[0], col_vecs[0], \
|
|
1643
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1644
|
+
inner_product_fn(&accumulators[0][1], row_vecs[0], col_vecs[1], \
|
|
1645
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1646
|
+
inner_product_fn(&accumulators[0][2], row_vecs[0], col_vecs[2], \
|
|
1647
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1648
|
+
inner_product_fn(&accumulators[0][3], row_vecs[0], col_vecs[3], \
|
|
1649
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1650
|
+
inner_product_fn(&accumulators[1][1], row_vecs[1], col_vecs[1], \
|
|
1651
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1652
|
+
inner_product_fn(&accumulators[1][2], row_vecs[1], col_vecs[2], \
|
|
1653
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1654
|
+
inner_product_fn(&accumulators[1][3], row_vecs[1], col_vecs[3], \
|
|
1655
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1656
|
+
inner_product_fn(&accumulators[2][2], row_vecs[2], col_vecs[2], \
|
|
1657
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1658
|
+
inner_product_fn(&accumulators[2][3], row_vecs[2], col_vecs[3], \
|
|
1659
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1660
|
+
inner_product_fn(&accumulators[3][3], row_vecs[3], col_vecs[3], \
|
|
1661
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1662
|
+
} \
|
|
1663
|
+
else if (tile_rows == 4 && tile_cols == 4) { \
|
|
1664
|
+
inner_product_fn(&accumulators[0][0], row_vecs[0], col_vecs[0], \
|
|
1665
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1666
|
+
inner_product_fn(&accumulators[0][1], row_vecs[0], col_vecs[1], \
|
|
1667
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1668
|
+
inner_product_fn(&accumulators[0][2], row_vecs[0], col_vecs[2], \
|
|
1669
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1670
|
+
inner_product_fn(&accumulators[0][3], row_vecs[0], col_vecs[3], \
|
|
1671
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1672
|
+
inner_product_fn(&accumulators[1][0], row_vecs[1], col_vecs[0], \
|
|
1673
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1674
|
+
inner_product_fn(&accumulators[1][1], row_vecs[1], col_vecs[1], \
|
|
1675
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1676
|
+
inner_product_fn(&accumulators[1][2], row_vecs[1], col_vecs[2], \
|
|
1677
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1678
|
+
inner_product_fn(&accumulators[1][3], row_vecs[1], col_vecs[3], \
|
|
1679
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1680
|
+
inner_product_fn(&accumulators[2][0], row_vecs[2], col_vecs[0], \
|
|
1681
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1682
|
+
inner_product_fn(&accumulators[2][1], row_vecs[2], col_vecs[1], \
|
|
1683
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1684
|
+
inner_product_fn(&accumulators[2][2], row_vecs[2], col_vecs[2], \
|
|
1685
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1686
|
+
inner_product_fn(&accumulators[2][3], row_vecs[2], col_vecs[3], \
|
|
1687
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1688
|
+
inner_product_fn(&accumulators[3][0], row_vecs[3], col_vecs[0], \
|
|
1689
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1690
|
+
inner_product_fn(&accumulators[3][1], row_vecs[3], col_vecs[1], \
|
|
1691
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1692
|
+
inner_product_fn(&accumulators[3][2], row_vecs[3], col_vecs[2], \
|
|
1693
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1694
|
+
inner_product_fn(&accumulators[3][3], row_vecs[3], col_vecs[3], \
|
|
1695
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1696
|
+
} \
|
|
1697
|
+
else { \
|
|
1698
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
1699
|
+
nk_size_t col_start = is_diag ? row : 0; \
|
|
1700
|
+
nk_size_t col_end = is_diag ? (row < 4 ? 4 : tile_cols) : tile_cols; \
|
|
1701
|
+
for (nk_size_t col = col_start; col < col_end; col++) \
|
|
1702
|
+
inner_product_fn(&accumulators[row][col], row_vecs[row], col_vecs[col], \
|
|
1703
|
+
aligned_depth * dimensions_per_value, remainder_dimensions); \
|
|
1704
|
+
} \
|
|
1705
|
+
} \
|
|
1706
|
+
} \
|
|
1707
|
+
nk_##sum_value_type##_t row_sums[4] = {0}, col_sums_arr[4] = {0}; \
|
|
1708
|
+
for (nk_size_t r = 0; r < tile_rows; r++) row_sums[r] = precomputed_sums[tile_row_start + r]; \
|
|
1709
|
+
for (nk_size_t c = 0; c < tile_cols; c++) \
|
|
1710
|
+
col_sums_arr[c] = is_diag ? row_sums[c] : precomputed_sums[tile_col_start + c]; \
|
|
1711
|
+
/* Build column sums as SIMD vector — for diagonal tiles, shift per row */ \
|
|
1712
|
+
result_vec_type col_sum_vec; \
|
|
1713
|
+
if (!is_diag) partial_load_sum_fn(col_sums_arr, &col_sum_vec, tile_cols); \
|
|
1714
|
+
/* Finalize with compensation */ \
|
|
1715
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
1716
|
+
if (is_diag) { \
|
|
1717
|
+
nk_##sum_value_type##_t shifted[4] = {0}; \
|
|
1718
|
+
for (nk_size_t c = 0; c < 4 && (row + c) < tile_cols; c++) shifted[c] = col_sums_arr[row + c]; \
|
|
1719
|
+
partial_load_sum_fn(shifted, &col_sum_vec, 4); \
|
|
1720
|
+
} \
|
|
1721
|
+
result_vec_type rv; \
|
|
1722
|
+
compensated_finalize_fn( \
|
|
1723
|
+
&accumulators[row][is_diag ? row : 0], &accumulators[row][(is_diag ? row : 0) + 1], \
|
|
1724
|
+
&accumulators[row][(is_diag ? row : 0) + 2], &accumulators[row][(is_diag ? row : 0) + 3], \
|
|
1725
|
+
depth, row_sums[row], col_sum_vec, &rv); \
|
|
1726
|
+
nk_size_t global_row = i_macro + tile_row_start + row; \
|
|
1727
|
+
nk_size_t global_col_start = i_macro + tile_col_start + (is_diag ? row : 0); \
|
|
1728
|
+
nk_size_t store_count = is_diag ? (tile_cols - row) : tile_cols; \
|
|
1729
|
+
nk_##result_value_type##_t *dest = result + global_row * result_stride_values + global_col_start; \
|
|
1730
|
+
partial_store_fn(&rv, dest, store_count); \
|
|
1731
|
+
} \
|
|
1732
|
+
} \
|
|
1733
|
+
} \
|
|
1734
|
+
} \
|
|
1735
|
+
/* Off-diagonal helper: 4×4 tiling with inline sum accumulation (16 FMAs + up to 8 SADs per depth step) */ \
|
|
1736
|
+
NK_INTERNAL void nk_##api_name##_symmetric_offdiagonal_##input_type_name##_##isa_suffix##_( \
|
|
1737
|
+
nk_##input_value_type##_t const **row_ptrs_macro, nk_##input_value_type##_t const **col_ptrs_macro, \
|
|
1738
|
+
nk_size_t i_macro, nk_size_t j_macro, nk_size_t macro_i_size, nk_size_t macro_j_size, nk_size_t aligned_depth, \
|
|
1739
|
+
nk_size_t remainder_depth, nk_size_t remainder_dimensions, nk_size_t depth_step_values, \
|
|
1740
|
+
nk_size_t dimensions_per_value_runtime, nk_##result_value_type##_t *result, nk_size_t result_stride_values, \
|
|
1741
|
+
nk_size_t finalizer_batch_size, nk_size_t depth) { \
|
|
1742
|
+
nk_unused_(finalizer_batch_size); \
|
|
1743
|
+
nk_unused_(dimensions_per_value_runtime); \
|
|
1744
|
+
nk_size_t padded_depth_dimensions = aligned_depth * dimensions_per_value + \
|
|
1745
|
+
(remainder_depth > 0 ? depth_simd_dimensions : 0); \
|
|
1746
|
+
/* Sum caches for this macro-tile pair — computed once, reused across tiles */ \
|
|
1747
|
+
nk_##sum_value_type##_t row_sums[32], col_sums[32]; \
|
|
1748
|
+
for (nk_size_t tile_row_start = 0; tile_row_start < macro_i_size; tile_row_start += 4) { \
|
|
1749
|
+
for (nk_size_t tile_col_start = 0; tile_col_start < macro_j_size; tile_col_start += 4) { \
|
|
1750
|
+
nk_size_t tile_rows = (tile_row_start + 4 <= macro_i_size) ? 4 : (macro_i_size - tile_row_start); \
|
|
1751
|
+
nk_size_t tile_cols = (tile_col_start + 4 <= macro_j_size) ? 4 : (macro_j_size - tile_col_start); \
|
|
1752
|
+
/* Determine if this tile should compute sums — predictable branches */ \
|
|
1753
|
+
int compute_row_sums_flag = (tile_col_start == 0); \
|
|
1754
|
+
int compute_col_sums_flag = (tile_row_start == 0); \
|
|
1755
|
+
/* Initialize 4×4 dot accumulators */ \
|
|
1756
|
+
NK_ALIGN64 state_type accumulators[4][4]; \
|
|
1757
|
+
for (nk_size_t row = 0; row < tile_rows; row++) \
|
|
1758
|
+
for (nk_size_t col = 0; col < tile_cols; col++) init_accumulator_fn(&accumulators[row][col]); \
|
|
1759
|
+
/* Initialize sum accumulators (only when needed) */ \
|
|
1760
|
+
sum_state_type rsum[4], csum[4]; \
|
|
1761
|
+
if (compute_row_sums_flag) \
|
|
1762
|
+
for (nk_size_t r = 0; r < tile_rows; r++) init_sum_fn(&rsum[r]); \
|
|
1763
|
+
if (compute_col_sums_flag) \
|
|
1764
|
+
for (nk_size_t c = 0; c < tile_cols; c++) init_sum_fn(&csum[c]); \
|
|
1765
|
+
/* Setup pointers (hoist outside depth loop) */ \
|
|
1766
|
+
nk_##input_value_type##_t const *row_ptrs[4], *col_ptrs[4]; \
|
|
1767
|
+
row_ptrs[0] = row_ptrs_macro[tile_row_start + 0]; \
|
|
1768
|
+
row_ptrs[1] = (tile_rows > 1) ? row_ptrs_macro[tile_row_start + 1] : row_ptrs[0]; \
|
|
1769
|
+
row_ptrs[2] = (tile_rows > 2) ? row_ptrs_macro[tile_row_start + 2] : row_ptrs[0]; \
|
|
1770
|
+
row_ptrs[3] = (tile_rows > 3) ? row_ptrs_macro[tile_row_start + 3] : row_ptrs[0]; \
|
|
1771
|
+
col_ptrs[0] = col_ptrs_macro[tile_col_start + 0]; \
|
|
1772
|
+
col_ptrs[1] = (tile_cols > 1) ? col_ptrs_macro[tile_col_start + 1] : col_ptrs[0]; \
|
|
1773
|
+
col_ptrs[2] = (tile_cols > 2) ? col_ptrs_macro[tile_col_start + 2] : col_ptrs[0]; \
|
|
1774
|
+
col_ptrs[3] = (tile_cols > 3) ? col_ptrs_macro[tile_col_start + 3] : col_ptrs[0]; \
|
|
1775
|
+
/* Depth loop — innermost, 16 FMAs + up to 8 SADs per iteration */ \
|
|
1776
|
+
vec_type row_vecs[4], col_vecs[4]; \
|
|
1777
|
+
for (nk_size_t di = 0; di < aligned_depth; di += depth_step_values) { \
|
|
1778
|
+
load_vec_fn(row_ptrs[0] + di, &row_vecs[0]); \
|
|
1779
|
+
load_vec_fn(row_ptrs[1] + di, &row_vecs[1]); \
|
|
1780
|
+
load_vec_fn(row_ptrs[2] + di, &row_vecs[2]); \
|
|
1781
|
+
load_vec_fn(row_ptrs[3] + di, &row_vecs[3]); \
|
|
1782
|
+
load_vec_fn(col_ptrs[0] + di, &col_vecs[0]); \
|
|
1783
|
+
load_vec_fn(col_ptrs[1] + di, &col_vecs[1]); \
|
|
1784
|
+
load_vec_fn(col_ptrs[2] + di, &col_vecs[2]); \
|
|
1785
|
+
load_vec_fn(col_ptrs[3] + di, &col_vecs[3]); \
|
|
1786
|
+
nk_size_t vector_offset = di * dimensions_per_value; \
|
|
1787
|
+
if (tile_rows == 4 && tile_cols == 4) { \
|
|
1788
|
+
inner_product_fn(&accumulators[0][0], row_vecs[0], col_vecs[0], vector_offset, \
|
|
1789
|
+
depth_simd_dimensions); \
|
|
1790
|
+
inner_product_fn(&accumulators[0][1], row_vecs[0], col_vecs[1], vector_offset, \
|
|
1791
|
+
depth_simd_dimensions); \
|
|
1792
|
+
inner_product_fn(&accumulators[0][2], row_vecs[0], col_vecs[2], vector_offset, \
|
|
1793
|
+
depth_simd_dimensions); \
|
|
1794
|
+
inner_product_fn(&accumulators[0][3], row_vecs[0], col_vecs[3], vector_offset, \
|
|
1795
|
+
depth_simd_dimensions); \
|
|
1796
|
+
inner_product_fn(&accumulators[1][0], row_vecs[1], col_vecs[0], vector_offset, \
|
|
1797
|
+
depth_simd_dimensions); \
|
|
1798
|
+
inner_product_fn(&accumulators[1][1], row_vecs[1], col_vecs[1], vector_offset, \
|
|
1799
|
+
depth_simd_dimensions); \
|
|
1800
|
+
inner_product_fn(&accumulators[1][2], row_vecs[1], col_vecs[2], vector_offset, \
|
|
1801
|
+
depth_simd_dimensions); \
|
|
1802
|
+
inner_product_fn(&accumulators[1][3], row_vecs[1], col_vecs[3], vector_offset, \
|
|
1803
|
+
depth_simd_dimensions); \
|
|
1804
|
+
inner_product_fn(&accumulators[2][0], row_vecs[2], col_vecs[0], vector_offset, \
|
|
1805
|
+
depth_simd_dimensions); \
|
|
1806
|
+
inner_product_fn(&accumulators[2][1], row_vecs[2], col_vecs[1], vector_offset, \
|
|
1807
|
+
depth_simd_dimensions); \
|
|
1808
|
+
inner_product_fn(&accumulators[2][2], row_vecs[2], col_vecs[2], vector_offset, \
|
|
1809
|
+
depth_simd_dimensions); \
|
|
1810
|
+
inner_product_fn(&accumulators[2][3], row_vecs[2], col_vecs[3], vector_offset, \
|
|
1811
|
+
depth_simd_dimensions); \
|
|
1812
|
+
inner_product_fn(&accumulators[3][0], row_vecs[3], col_vecs[0], vector_offset, \
|
|
1813
|
+
depth_simd_dimensions); \
|
|
1814
|
+
inner_product_fn(&accumulators[3][1], row_vecs[3], col_vecs[1], vector_offset, \
|
|
1815
|
+
depth_simd_dimensions); \
|
|
1816
|
+
inner_product_fn(&accumulators[3][2], row_vecs[3], col_vecs[2], vector_offset, \
|
|
1817
|
+
depth_simd_dimensions); \
|
|
1818
|
+
inner_product_fn(&accumulators[3][3], row_vecs[3], col_vecs[3], vector_offset, \
|
|
1819
|
+
depth_simd_dimensions); \
|
|
1820
|
+
} \
|
|
1821
|
+
else { \
|
|
1822
|
+
for (nk_size_t row = 0; row < tile_rows; row++) \
|
|
1823
|
+
for (nk_size_t col = 0; col < tile_cols; col++) \
|
|
1824
|
+
inner_product_fn(&accumulators[row][col], row_vecs[row], col_vecs[col], vector_offset, \
|
|
1825
|
+
depth_simd_dimensions); \
|
|
1826
|
+
} \
|
|
1827
|
+
/* Progressive sum accumulation (SADs on port 5, parallel with DPBUSD on ports 0+1) */ \
|
|
1828
|
+
if (compute_row_sums_flag) { \
|
|
1829
|
+
update_sum_fn(&rsum[0], row_vecs[0]); \
|
|
1830
|
+
if (tile_rows > 1) update_sum_fn(&rsum[1], row_vecs[1]); \
|
|
1831
|
+
if (tile_rows > 2) update_sum_fn(&rsum[2], row_vecs[2]); \
|
|
1832
|
+
if (tile_rows > 3) update_sum_fn(&rsum[3], row_vecs[3]); \
|
|
1833
|
+
} \
|
|
1834
|
+
if (compute_col_sums_flag) { \
|
|
1835
|
+
update_sum_fn(&csum[0], col_vecs[0]); \
|
|
1836
|
+
if (tile_cols > 1) update_sum_fn(&csum[1], col_vecs[1]); \
|
|
1837
|
+
if (tile_cols > 2) update_sum_fn(&csum[2], col_vecs[2]); \
|
|
1838
|
+
if (tile_cols > 3) update_sum_fn(&csum[3], col_vecs[3]); \
|
|
1839
|
+
} \
|
|
1840
|
+
} \
|
|
1841
|
+
/* Handle remainder depth */ \
|
|
1842
|
+
if (remainder_depth > 0) { \
|
|
1843
|
+
partial_load_vec_fn(row_ptrs[0] + aligned_depth, &row_vecs[0], remainder_dimensions); \
|
|
1844
|
+
partial_load_vec_fn(row_ptrs[1] + aligned_depth, &row_vecs[1], remainder_dimensions); \
|
|
1845
|
+
partial_load_vec_fn(row_ptrs[2] + aligned_depth, &row_vecs[2], remainder_dimensions); \
|
|
1846
|
+
partial_load_vec_fn(row_ptrs[3] + aligned_depth, &row_vecs[3], remainder_dimensions); \
|
|
1847
|
+
partial_load_vec_fn(col_ptrs[0] + aligned_depth, &col_vecs[0], remainder_dimensions); \
|
|
1848
|
+
partial_load_vec_fn(col_ptrs[1] + aligned_depth, &col_vecs[1], remainder_dimensions); \
|
|
1849
|
+
partial_load_vec_fn(col_ptrs[2] + aligned_depth, &col_vecs[2], remainder_dimensions); \
|
|
1850
|
+
partial_load_vec_fn(col_ptrs[3] + aligned_depth, &col_vecs[3], remainder_dimensions); \
|
|
1851
|
+
nk_size_t vector_offset = aligned_depth * dimensions_per_value; \
|
|
1852
|
+
for (nk_size_t row = 0; row < tile_rows; row++) \
|
|
1853
|
+
for (nk_size_t col = 0; col < tile_cols; col++) \
|
|
1854
|
+
inner_product_fn(&accumulators[row][col], row_vecs[row], col_vecs[col], vector_offset, \
|
|
1855
|
+
remainder_dimensions); \
|
|
1856
|
+
if (compute_row_sums_flag) { \
|
|
1857
|
+
update_sum_fn(&rsum[0], row_vecs[0]); \
|
|
1858
|
+
if (tile_rows > 1) update_sum_fn(&rsum[1], row_vecs[1]); \
|
|
1859
|
+
if (tile_rows > 2) update_sum_fn(&rsum[2], row_vecs[2]); \
|
|
1860
|
+
if (tile_rows > 3) update_sum_fn(&rsum[3], row_vecs[3]); \
|
|
1861
|
+
} \
|
|
1862
|
+
if (compute_col_sums_flag) { \
|
|
1863
|
+
update_sum_fn(&csum[0], col_vecs[0]); \
|
|
1864
|
+
if (tile_cols > 1) update_sum_fn(&csum[1], col_vecs[1]); \
|
|
1865
|
+
if (tile_cols > 2) update_sum_fn(&csum[2], col_vecs[2]); \
|
|
1866
|
+
if (tile_cols > 3) update_sum_fn(&csum[3], col_vecs[3]); \
|
|
1867
|
+
} \
|
|
1868
|
+
} \
|
|
1869
|
+
/* Finalize and cache sums */ \
|
|
1870
|
+
if (compute_row_sums_flag) \
|
|
1871
|
+
for (nk_size_t r = 0; r < tile_rows; r++) \
|
|
1872
|
+
row_sums[tile_row_start + r] = finalize_sum_fn(&rsum[r], padded_depth_dimensions); \
|
|
1873
|
+
if (compute_col_sums_flag) \
|
|
1874
|
+
for (nk_size_t c = 0; c < tile_cols; c++) \
|
|
1875
|
+
col_sums[tile_col_start + c] = finalize_sum_fn(&csum[c], padded_depth_dimensions); \
|
|
1876
|
+
/* Build col_sum SIMD vector once (constant across rows) */ \
|
|
1877
|
+
nk_##sum_value_type##_t cs_arr[4] = {0}; \
|
|
1878
|
+
for (nk_size_t c = 0; c < tile_cols; c++) cs_arr[c] = col_sums[tile_col_start + c]; \
|
|
1879
|
+
result_vec_type cs_vec; \
|
|
1880
|
+
partial_load_sum_fn(cs_arr, &cs_vec, tile_cols); \
|
|
1881
|
+
/* Compensated finalize + store */ \
|
|
1882
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
1883
|
+
result_vec_type rv; \
|
|
1884
|
+
compensated_finalize_fn(&accumulators[row][0], &accumulators[row][1], &accumulators[row][2], \
|
|
1885
|
+
&accumulators[row][3], depth, row_sums[tile_row_start + row], cs_vec, \
|
|
1886
|
+
&rv); \
|
|
1887
|
+
nk_##result_value_type##_t *dest = result + \
|
|
1888
|
+
(i_macro + tile_row_start + row) * result_stride_values + \
|
|
1889
|
+
(j_macro + tile_col_start); \
|
|
1890
|
+
partial_store_fn(&rv, dest, tile_cols); \
|
|
1891
|
+
} \
|
|
1892
|
+
} \
|
|
1893
|
+
} \
|
|
1894
|
+
} \
|
|
1895
|
+
NK_PUBLIC void nk_##api_name##_symmetric_##input_type_name##_##isa_suffix( \
|
|
1896
|
+
nk_##input_value_type##_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, \
|
|
1897
|
+
nk_##result_value_type##_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) { \
|
|
1898
|
+
nk_size_t const macro_tile_size = 32; \
|
|
1899
|
+
nk_size_t const row_block_size = 128; /* L2 cache blocking */ \
|
|
1900
|
+
nk_size_t const column_block_size = 2048; /* L3 cache blocking */ \
|
|
1901
|
+
nk_size_t const depth_dimensions_aligned = (depth / depth_simd_dimensions) * depth_simd_dimensions; \
|
|
1902
|
+
nk_size_t const aligned_depth = nk_size_divide_round_up_(depth_dimensions_aligned, dimensions_per_value); \
|
|
1903
|
+
nk_size_t const depth_in_values = nk_size_divide_round_up_(depth, dimensions_per_value); \
|
|
1904
|
+
nk_size_t const remainder_depth = depth_in_values - aligned_depth; \
|
|
1905
|
+
nk_size_t const remainder_dimensions = depth - depth_dimensions_aligned; \
|
|
1906
|
+
nk_size_t const depth_step = nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
1907
|
+
nk_size_t const result_stride_values = result_stride / sizeof(nk_##result_value_type##_t); \
|
|
1908
|
+
nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors; \
|
|
1909
|
+
\
|
|
1910
|
+
/* Process upper triangle with L3/L2/L1 blocking (column blocks → row blocks → 32×32 macro-tiles) */ \
|
|
1911
|
+
for (nk_size_t j_block = 0; j_block < n_vectors; j_block += column_block_size) { \
|
|
1912
|
+
nk_size_t j_block_end = (j_block + column_block_size < n_vectors) ? j_block + column_block_size \
|
|
1913
|
+
: n_vectors; \
|
|
1914
|
+
\
|
|
1915
|
+
for (nk_size_t i_block = row_start; i_block < row_end; i_block += row_block_size) { \
|
|
1916
|
+
nk_size_t i_block_end = (i_block + row_block_size < row_end) ? i_block + row_block_size : row_end; \
|
|
1917
|
+
\
|
|
1918
|
+
/* Skip blocks entirely below diagonal. Blocks fully above the diagonal are still part of the upper \
|
|
1919
|
+
* triangle and must be computed. */ \
|
|
1920
|
+
if (i_block >= j_block_end) continue; \
|
|
1921
|
+
\
|
|
1922
|
+
for (nk_size_t i_macro = i_block; i_macro < i_block_end; i_macro += macro_tile_size) { \
|
|
1923
|
+
/* Upper triangle: j_macro starts at max(i_macro, j_block) */ \
|
|
1924
|
+
nk_size_t j_start = (i_macro > j_block) ? i_macro : j_block; \
|
|
1925
|
+
for (nk_size_t j_macro = j_start; j_macro < j_block_end; j_macro += macro_tile_size) { \
|
|
1926
|
+
nk_size_t macro_i_size = (i_macro + macro_tile_size <= i_block_end) ? macro_tile_size \
|
|
1927
|
+
: (i_block_end - i_macro); \
|
|
1928
|
+
nk_size_t macro_j_size = (j_macro + macro_tile_size <= j_block_end) ? macro_tile_size \
|
|
1929
|
+
: (j_block_end - j_macro); \
|
|
1930
|
+
\
|
|
1931
|
+
/* Build pointer arrays */ \
|
|
1932
|
+
nk_##input_value_type##_t const *vec_ptrs_i[32]; \
|
|
1933
|
+
nk_##input_value_type##_t const *vec_ptrs_j[32]; \
|
|
1934
|
+
for (nk_size_t k = 0; k < macro_i_size; k++) \
|
|
1935
|
+
vec_ptrs_i[k] = (nk_##input_value_type##_t const *)((char const *)vectors + \
|
|
1936
|
+
(i_macro + k) * stride); \
|
|
1937
|
+
for (nk_size_t k = macro_i_size; k < 32; k++) vec_ptrs_i[k] = vec_ptrs_i[0]; \
|
|
1938
|
+
\
|
|
1939
|
+
if (i_macro == j_macro && macro_i_size == macro_j_size) { \
|
|
1940
|
+
/* Diagonal macro-tile */ \
|
|
1941
|
+
nk_##api_name##_symmetric_diagonal_##input_type_name##_##isa_suffix##_( \
|
|
1942
|
+
vec_ptrs_i, i_macro, macro_i_size, aligned_depth, remainder_depth, \
|
|
1943
|
+
remainder_dimensions, depth_step, dimensions_per_value, result, result_stride_values, \
|
|
1944
|
+
4, depth); \
|
|
1945
|
+
} \
|
|
1946
|
+
else { \
|
|
1947
|
+
/* Off-diagonal macro-tile */ \
|
|
1948
|
+
for (nk_size_t k = 0; k < macro_j_size; k++) \
|
|
1949
|
+
vec_ptrs_j[k] = (nk_##input_value_type##_t const *)((char const *)vectors + \
|
|
1950
|
+
(j_macro + k) * stride); \
|
|
1951
|
+
for (nk_size_t k = macro_j_size; k < 32; k++) vec_ptrs_j[k] = vec_ptrs_j[0]; \
|
|
1952
|
+
nk_##api_name##_symmetric_offdiagonal_##input_type_name##_##isa_suffix##_( \
|
|
1953
|
+
vec_ptrs_i, vec_ptrs_j, i_macro, j_macro, macro_i_size, macro_j_size, aligned_depth, \
|
|
1954
|
+
remainder_depth, remainder_dimensions, depth_step, dimensions_per_value, result, \
|
|
1955
|
+
result_stride_values, 4, depth); \
|
|
1956
|
+
} \
|
|
1957
|
+
} \
|
|
1958
|
+
} \
|
|
1959
|
+
} \
|
|
1960
|
+
} \
|
|
1961
|
+
}
|
|
1962
|
+
|
|
1963
|
+
/**
|
|
1964
|
+
* @brief Generates optimized symmetric Gram matrix computation: C = A × Aᵀ (upper triangle only).
|
|
1965
|
+
*
|
|
1966
|
+
* This macro creates a complete symmetric cross-product implementation with two specialized
|
|
1967
|
+
* internal helper functions (diagonal and off-diagonal) that are called by a public wrapper.
|
|
1968
|
+
* Symmetric computation exploits the property that C[i,j] = C[j,i], computing only the upper
|
|
1969
|
+
* triangle and avoiding redundant computation and storage.
|
|
1970
|
+
*
|
|
1971
|
+
* @par Mathematical Operation For each pair (i,j) where i ≤ j:
|
|
1972
|
+
* C[i,j] = operation(A[i,:], A[j,:])
|
|
1973
|
+
* where operation can be dot product, Hamming distance, Jaccard similarity, etc.
|
|
1974
|
+
*
|
|
1975
|
+
* @par Architecture - Three-Level Tiling Hierarchy
|
|
1976
|
+
*
|
|
1977
|
+
* 1. @b 32×32 @b macro-tiles (outermost): Divides the upper triangle into 32×32 blocks
|
|
1978
|
+
* - Rationale: Fits well in L1 cache (32 vectors × depth × value_size)
|
|
1979
|
+
* - Enables diagonal vs off-diagonal specialization
|
|
1980
|
+
* - Amortizes vector loads across all depth iterations
|
|
1981
|
+
* - Pre-loads and upcasts ALL 32 vectors ONCE per depth iteration (not per FMA)
|
|
1982
|
+
*
|
|
1983
|
+
* 2. @b 4×4 @b register @b tiles (middle): Within each macro-tile, process 4×4 sub-blocks
|
|
1984
|
+
* - Rationale: Maximizes register reuse (4 A vectors × 4 A vectors = 16 accumulators)
|
|
1985
|
+
* - Enables full FMA unrolling (16 FMAs for off-diagonal, 10 for diagonal)
|
|
1986
|
+
* - Balances register pressure with instruction-level parallelism
|
|
1987
|
+
*
|
|
1988
|
+
* 3. @b Depth @b loop (innermost): For each depth chunk, accumulate outer products
|
|
1989
|
+
* - Depth loop is INSIDE macro-tile, OUTSIDE register tiles
|
|
1990
|
+
* - Type conversion (e.g., bf16→f32) happens at macro-tile level (once per vector)
|
|
1991
|
+
*
|
|
1992
|
+
* @par Diagonal vs Off-Diagonal Optimization
|
|
1993
|
+
*
|
|
1994
|
+
* - @b Diagonal @b macro-tiles (i_macro == j_macro): Computes C[i:i+32, i:i+32]
|
|
1995
|
+
* - Loads 32 vectors ONCE (50% load reduction vs off-diagonal)
|
|
1996
|
+
* - Computes upper triangle only within the tile (10 FMAs per 4×4 block)
|
|
1997
|
+
* - Uses nk_##api_name##_symmetric_diagonal_##input_type_name##_##isa_suffix##_ helper
|
|
1998
|
+
*
|
|
1999
|
+
* - @b Off-diagonal @b macro-tiles (i_macro < j_macro): Computes C[i:i+32, j:j+32]
|
|
2000
|
+
* - Loads vec_i[32] + vec_j[32] (full 64 vectors for two sets)
|
|
2001
|
+
* - Computes full 32×32 block (16 FMAs per 4×4 block)
|
|
2002
|
+
* - Uses nk_##api_name##_symmetric_offdiagonal_##input_type_name##_##isa_suffix##_ helper
|
|
2003
|
+
*
|
|
2004
|
+
* @par When to Use Symmetric vs Packed Variant
|
|
2005
|
+
*
|
|
2006
|
+
* - Use symmetric (this macro) when: A is the SAME matrix for both sides (C = A × Aᵀ)
|
|
2007
|
+
* - Saves 50% computation and storage (upper triangle only)
|
|
2008
|
+
* - Automatic diagonal optimization (50% fewer loads on diagonal tiles)
|
|
2009
|
+
* - Ideal for: distance matrices, correlation matrices, Gram matrices
|
|
2010
|
+
*
|
|
2011
|
+
* - Use packed variant when: Computing C = A × Bᵀ where A ≠ B
|
|
2012
|
+
* - Full matrix computation (no symmetry to exploit)
|
|
2013
|
+
* - B can be pre-packed for cache efficiency
|
|
2014
|
+
*
|
|
2015
|
+
* @par Generated Functions
|
|
2016
|
+
*
|
|
2017
|
+
* This macro generates THREE functions:
|
|
2018
|
+
* 1. nk_##api_name##_symmetric_diagonal_##input_type_name##_##isa_suffix##_ (NK_INTERNAL)
|
|
2019
|
+
* 2. nk_##api_name##_symmetric_offdiagonal_##input_type_name##_##isa_suffix##_ (NK_INTERNAL)
|
|
2020
|
+
* 3. nk_##api_name##_symmetric_##input_type_name##_##isa_suffix (NK_PUBLIC wrapper)
|
|
2021
|
+
*
|
|
2022
|
+
* @param api_name Operation family (dots, hammings, jaccards) for codegen namespace
|
|
2023
|
+
* @param input_type_name Type identifier for codegen (f32, bf16, i8, u1, etc.)
|
|
2024
|
+
* @param isa_suffix ISA backend identifier (serial, haswell, neon, sve, icelake, etc.)
|
|
2025
|
+
* @param input_type C type of input matrix values (f32, bf16, i8, u1x8, etc.)
|
|
2026
|
+
* @param output_type C type of output matrix values (f32, u32, f64, etc.)
|
|
2027
|
+
* @param vec_type SIMD vector type for input vectors (e.g., __m256, nk_f32x8_t)
|
|
2028
|
+
* @param state_type Accumulator state type (often vec_type or wider, e.g., __m256 or __m512)
|
|
2029
|
+
* @param result_vec_type SIMD vector type for reduction results (e.g., __m128 for 4 f32 results)
|
|
2030
|
+
* @param init_accumulator_fn Initialize accumulator: void fn(state_type*)
|
|
2031
|
+
* @param load_vec_fn Full vector load: vec_type fn(input_type const*, nk_size_t offset)
|
|
2032
|
+
* @param partial_load_vec_fn Partial vector load for remainder
|
|
2033
|
+
* @param inner_product_fn Inner product accumulate
|
|
2034
|
+
* @param reduce_accumulators_fn Reduce 4 accumulators
|
|
2035
|
+
* @param partial_store_fn Partial store for results
|
|
2036
|
+
* @param depth_simd_dimensions SIMD vector width in logical dimensions (e.g., 8 for f32 on AVX2, 128 for u1 on serial)
|
|
2037
|
+
* @param dimensions_per_value Packing ratio: dimensions per storage value (1 for f32, 2 for i4x2, 8 for u1x8)
|
|
2038
|
+
*
|
|
2039
|
+
* @sa nk_define_cross_packed_ for asymmetric C = A × Bᵀ computation
|
|
2040
|
+
* @sa nk_define_cross_pack_size_ for calculating packed buffer size
|
|
2041
|
+
* @sa nk_define_cross_pack_ for packing B matrix
|
|
2042
|
+
* @sa include/numkong/set/serial.h for state type definitions
|
|
2043
|
+
* @sa include/numkong/cast/serial.h for load/store function implementations
|
|
2044
|
+
*/
|
|
2045
|
+
#define nk_define_cross_symmetric_(api_name, input_type_name, isa_suffix, input_value_type, result_value_type, \
|
|
2046
|
+
vec_type, state_type, result_vec_type, init_accumulator_fn, load_vec_fn, \
|
|
2047
|
+
partial_load_vec_fn, inner_product_fn, reduce_accumulators_fn, store_fn, \
|
|
2048
|
+
partial_store_fn, depth_simd_dimensions, dimensions_per_value) \
|
|
2049
|
+
NK_INTERNAL void nk_##api_name##_symmetric_diagonal_##input_type_name##_##isa_suffix##_( \
|
|
2050
|
+
nk_##input_value_type##_t const **vector_base_ptrs, nk_size_t i_macro, nk_size_t macro_size, \
|
|
2051
|
+
nk_size_t aligned_depth, nk_size_t remainder_depth, nk_size_t remainder_dimensions, \
|
|
2052
|
+
nk_size_t depth_step_values, nk_size_t dimensions_per_value_runtime, nk_##result_value_type##_t *result, \
|
|
2053
|
+
nk_size_t result_stride_values, nk_size_t finalizer_batch_size, nk_size_t depth) { \
|
|
2054
|
+
\
|
|
2055
|
+
nk_unused_(dimensions_per_value_runtime); \
|
|
2056
|
+
nk_unused_(finalizer_batch_size); \
|
|
2057
|
+
/* Tile-first architecture: Process 32×32 macro-tile as 4×4 register tiles (depth innermost) */ \
|
|
2058
|
+
for (nk_size_t tile_row_start = 0; tile_row_start < macro_size; tile_row_start += 4) { \
|
|
2059
|
+
for (nk_size_t tile_column_start = tile_row_start; tile_column_start < macro_size; \
|
|
2060
|
+
tile_column_start += 4) { \
|
|
2061
|
+
\
|
|
2062
|
+
nk_size_t tile_rows = (tile_row_start + 4 <= macro_size) ? 4 : (macro_size - tile_row_start); \
|
|
2063
|
+
nk_size_t tile_columns = (tile_column_start + 4 <= macro_size) ? 4 : (macro_size - tile_column_start); \
|
|
2064
|
+
int is_diagonal_tile = (tile_row_start == tile_column_start); \
|
|
2065
|
+
\
|
|
2066
|
+
/* Initialize register-resident accumulators — padded to [4][7] so that the reduce call */ \
|
|
2067
|
+
/* (which always reads 4 consecutive entries starting at column_start) stays in bounds */ \
|
|
2068
|
+
NK_ALIGN64 state_type accumulators[4][7]; \
|
|
2069
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
2070
|
+
nk_size_t init_start = is_diagonal_tile ? row : 0; \
|
|
2071
|
+
nk_size_t init_end = is_diagonal_tile ? (row + 4) : tile_columns; \
|
|
2072
|
+
for (nk_size_t column = init_start; column < init_end; column++) { \
|
|
2073
|
+
init_accumulator_fn(&accumulators[row][column]); \
|
|
2074
|
+
} \
|
|
2075
|
+
} \
|
|
2076
|
+
\
|
|
2077
|
+
/* Setup pointers (hoist outside depth loop) - always safe even for partial tiles */ \
|
|
2078
|
+
nk_##input_value_type##_t const *row_ptrs[4]; \
|
|
2079
|
+
nk_##input_value_type##_t const *column_ptrs[4]; \
|
|
2080
|
+
row_ptrs[0] = vector_base_ptrs[tile_row_start + 0]; \
|
|
2081
|
+
row_ptrs[1] = (tile_rows > 1) ? vector_base_ptrs[tile_row_start + 1] : row_ptrs[0]; \
|
|
2082
|
+
row_ptrs[2] = (tile_rows > 2) ? vector_base_ptrs[tile_row_start + 2] : row_ptrs[0]; \
|
|
2083
|
+
row_ptrs[3] = (tile_rows > 3) ? vector_base_ptrs[tile_row_start + 3] : row_ptrs[0]; \
|
|
2084
|
+
\
|
|
2085
|
+
if (is_diagonal_tile) { \
|
|
2086
|
+
column_ptrs[0] = row_ptrs[0]; \
|
|
2087
|
+
column_ptrs[1] = row_ptrs[1]; \
|
|
2088
|
+
column_ptrs[2] = row_ptrs[2]; \
|
|
2089
|
+
column_ptrs[3] = row_ptrs[3]; \
|
|
2090
|
+
} \
|
|
2091
|
+
else { \
|
|
2092
|
+
column_ptrs[0] = vector_base_ptrs[tile_column_start + 0]; \
|
|
2093
|
+
column_ptrs[1] = (tile_columns > 1) ? vector_base_ptrs[tile_column_start + 1] : column_ptrs[0]; \
|
|
2094
|
+
column_ptrs[2] = (tile_columns > 2) ? vector_base_ptrs[tile_column_start + 2] : column_ptrs[0]; \
|
|
2095
|
+
column_ptrs[3] = (tile_columns > 3) ? vector_base_ptrs[tile_column_start + 3] : column_ptrs[0]; \
|
|
2096
|
+
} \
|
|
2097
|
+
\
|
|
2098
|
+
/* Depth loop is now innermost - key optimization */ \
|
|
2099
|
+
vec_type row_vecs[4]; \
|
|
2100
|
+
vec_type column_vecs[4]; \
|
|
2101
|
+
\
|
|
2102
|
+
for (nk_size_t depth_offset = 0; depth_offset < aligned_depth; depth_offset += depth_step_values) { \
|
|
2103
|
+
/* Always load all 4 vectors - aliasing is cheaper than branches */ \
|
|
2104
|
+
load_vec_fn(row_ptrs[0] + depth_offset, &row_vecs[0]); \
|
|
2105
|
+
load_vec_fn(row_ptrs[1] + depth_offset, &row_vecs[1]); \
|
|
2106
|
+
load_vec_fn(row_ptrs[2] + depth_offset, &row_vecs[2]); \
|
|
2107
|
+
load_vec_fn(row_ptrs[3] + depth_offset, &row_vecs[3]); \
|
|
2108
|
+
\
|
|
2109
|
+
/* For diagonal tiles, column vectors alias row vectors (same memory) */ \
|
|
2110
|
+
load_vec_fn(column_ptrs[0] + depth_offset, &column_vecs[0]); \
|
|
2111
|
+
load_vec_fn(column_ptrs[1] + depth_offset, &column_vecs[1]); \
|
|
2112
|
+
load_vec_fn(column_ptrs[2] + depth_offset, &column_vecs[2]); \
|
|
2113
|
+
load_vec_fn(column_ptrs[3] + depth_offset, &column_vecs[3]); \
|
|
2114
|
+
\
|
|
2115
|
+
nk_size_t vector_offset = depth_offset * dimensions_per_value; \
|
|
2116
|
+
\
|
|
2117
|
+
/* Compute: always unroll for full 4×4, use loops only for partial tiles */ \
|
|
2118
|
+
if (tile_rows == 4 && tile_columns == 4) { \
|
|
2119
|
+
if (is_diagonal_tile) { \
|
|
2120
|
+
/* Full 4×4 diagonal tile - upper triangle only (10 FMAs) */ \
|
|
2121
|
+
inner_product_fn(&accumulators[0][0], row_vecs[0], column_vecs[0], vector_offset, \
|
|
2122
|
+
depth_simd_dimensions); \
|
|
2123
|
+
inner_product_fn(&accumulators[0][1], row_vecs[0], column_vecs[1], vector_offset, \
|
|
2124
|
+
depth_simd_dimensions); \
|
|
2125
|
+
inner_product_fn(&accumulators[0][2], row_vecs[0], column_vecs[2], vector_offset, \
|
|
2126
|
+
depth_simd_dimensions); \
|
|
2127
|
+
inner_product_fn(&accumulators[0][3], row_vecs[0], column_vecs[3], vector_offset, \
|
|
2128
|
+
depth_simd_dimensions); \
|
|
2129
|
+
inner_product_fn(&accumulators[1][1], row_vecs[1], column_vecs[1], vector_offset, \
|
|
2130
|
+
depth_simd_dimensions); \
|
|
2131
|
+
inner_product_fn(&accumulators[1][2], row_vecs[1], column_vecs[2], vector_offset, \
|
|
2132
|
+
depth_simd_dimensions); \
|
|
2133
|
+
inner_product_fn(&accumulators[1][3], row_vecs[1], column_vecs[3], vector_offset, \
|
|
2134
|
+
depth_simd_dimensions); \
|
|
2135
|
+
inner_product_fn(&accumulators[2][2], row_vecs[2], column_vecs[2], vector_offset, \
|
|
2136
|
+
depth_simd_dimensions); \
|
|
2137
|
+
inner_product_fn(&accumulators[2][3], row_vecs[2], column_vecs[3], vector_offset, \
|
|
2138
|
+
depth_simd_dimensions); \
|
|
2139
|
+
inner_product_fn(&accumulators[3][3], row_vecs[3], column_vecs[3], vector_offset, \
|
|
2140
|
+
depth_simd_dimensions); \
|
|
2141
|
+
} \
|
|
2142
|
+
else { \
|
|
2143
|
+
/* Full 4×4 off-diagonal tile (16 FMAs) */ \
|
|
2144
|
+
inner_product_fn(&accumulators[0][0], row_vecs[0], column_vecs[0], vector_offset, \
|
|
2145
|
+
depth_simd_dimensions); \
|
|
2146
|
+
inner_product_fn(&accumulators[0][1], row_vecs[0], column_vecs[1], vector_offset, \
|
|
2147
|
+
depth_simd_dimensions); \
|
|
2148
|
+
inner_product_fn(&accumulators[0][2], row_vecs[0], column_vecs[2], vector_offset, \
|
|
2149
|
+
depth_simd_dimensions); \
|
|
2150
|
+
inner_product_fn(&accumulators[0][3], row_vecs[0], column_vecs[3], vector_offset, \
|
|
2151
|
+
depth_simd_dimensions); \
|
|
2152
|
+
inner_product_fn(&accumulators[1][0], row_vecs[1], column_vecs[0], vector_offset, \
|
|
2153
|
+
depth_simd_dimensions); \
|
|
2154
|
+
inner_product_fn(&accumulators[1][1], row_vecs[1], column_vecs[1], vector_offset, \
|
|
2155
|
+
depth_simd_dimensions); \
|
|
2156
|
+
inner_product_fn(&accumulators[1][2], row_vecs[1], column_vecs[2], vector_offset, \
|
|
2157
|
+
depth_simd_dimensions); \
|
|
2158
|
+
inner_product_fn(&accumulators[1][3], row_vecs[1], column_vecs[3], vector_offset, \
|
|
2159
|
+
depth_simd_dimensions); \
|
|
2160
|
+
inner_product_fn(&accumulators[2][0], row_vecs[2], column_vecs[0], vector_offset, \
|
|
2161
|
+
depth_simd_dimensions); \
|
|
2162
|
+
inner_product_fn(&accumulators[2][1], row_vecs[2], column_vecs[1], vector_offset, \
|
|
2163
|
+
depth_simd_dimensions); \
|
|
2164
|
+
inner_product_fn(&accumulators[2][2], row_vecs[2], column_vecs[2], vector_offset, \
|
|
2165
|
+
depth_simd_dimensions); \
|
|
2166
|
+
inner_product_fn(&accumulators[2][3], row_vecs[2], column_vecs[3], vector_offset, \
|
|
2167
|
+
depth_simd_dimensions); \
|
|
2168
|
+
inner_product_fn(&accumulators[3][0], row_vecs[3], column_vecs[0], vector_offset, \
|
|
2169
|
+
depth_simd_dimensions); \
|
|
2170
|
+
inner_product_fn(&accumulators[3][1], row_vecs[3], column_vecs[1], vector_offset, \
|
|
2171
|
+
depth_simd_dimensions); \
|
|
2172
|
+
inner_product_fn(&accumulators[3][2], row_vecs[3], column_vecs[2], vector_offset, \
|
|
2173
|
+
depth_simd_dimensions); \
|
|
2174
|
+
inner_product_fn(&accumulators[3][3], row_vecs[3], column_vecs[3], vector_offset, \
|
|
2175
|
+
depth_simd_dimensions); \
|
|
2176
|
+
} \
|
|
2177
|
+
} \
|
|
2178
|
+
else { \
|
|
2179
|
+
/* Partial tile - use loops (rare edge case) */ \
|
|
2180
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
2181
|
+
nk_size_t column_start = is_diagonal_tile ? row : 0; \
|
|
2182
|
+
for (nk_size_t column = column_start; column < tile_columns; column++) { \
|
|
2183
|
+
inner_product_fn(&accumulators[row][column], row_vecs[row], column_vecs[column], \
|
|
2184
|
+
vector_offset, depth_simd_dimensions); \
|
|
2185
|
+
} \
|
|
2186
|
+
} \
|
|
2187
|
+
} \
|
|
2188
|
+
} \
|
|
2189
|
+
\
|
|
2190
|
+
/* Handle remainder depth (happens once per tile, not in hot loop) */ \
|
|
2191
|
+
if (remainder_depth > 0) { \
|
|
2192
|
+
partial_load_vec_fn(row_ptrs[0] + aligned_depth, &row_vecs[0], remainder_dimensions); \
|
|
2193
|
+
partial_load_vec_fn(row_ptrs[1] + aligned_depth, &row_vecs[1], remainder_dimensions); \
|
|
2194
|
+
partial_load_vec_fn(row_ptrs[2] + aligned_depth, &row_vecs[2], remainder_dimensions); \
|
|
2195
|
+
partial_load_vec_fn(row_ptrs[3] + aligned_depth, &row_vecs[3], remainder_dimensions); \
|
|
2196
|
+
partial_load_vec_fn(column_ptrs[0] + aligned_depth, &column_vecs[0], remainder_dimensions); \
|
|
2197
|
+
partial_load_vec_fn(column_ptrs[1] + aligned_depth, &column_vecs[1], remainder_dimensions); \
|
|
2198
|
+
partial_load_vec_fn(column_ptrs[2] + aligned_depth, &column_vecs[2], remainder_dimensions); \
|
|
2199
|
+
partial_load_vec_fn(column_ptrs[3] + aligned_depth, &column_vecs[3], remainder_dimensions); \
|
|
2200
|
+
\
|
|
2201
|
+
nk_size_t vector_offset = aligned_depth * dimensions_per_value; \
|
|
2202
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
2203
|
+
nk_size_t column_start = is_diagonal_tile ? row : 0; \
|
|
2204
|
+
for (nk_size_t column = column_start; column < tile_columns; column++) { \
|
|
2205
|
+
inner_product_fn(&accumulators[row][column], row_vecs[row], column_vecs[column], \
|
|
2206
|
+
vector_offset, remainder_dimensions); \
|
|
2207
|
+
} \
|
|
2208
|
+
} \
|
|
2209
|
+
} \
|
|
2210
|
+
\
|
|
2211
|
+
/* Direct finalization and store (no intermediate buffer) */ \
|
|
2212
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
2213
|
+
nk_size_t column_start = is_diagonal_tile ? row : 0; \
|
|
2214
|
+
nk_size_t columns_remaining = tile_columns - column_start; \
|
|
2215
|
+
result_vec_type result_vec; \
|
|
2216
|
+
\
|
|
2217
|
+
/* Always reduce 4 accumulators (partial_store handles actual count) */ \
|
|
2218
|
+
reduce_accumulators_fn(&accumulators[row][column_start], &accumulators[row][column_start + 1], \
|
|
2219
|
+
&accumulators[row][column_start + 2], &accumulators[row][column_start + 3], \
|
|
2220
|
+
depth, &result_vec); \
|
|
2221
|
+
\
|
|
2222
|
+
nk_##result_value_type##_t *output_ptr = \
|
|
2223
|
+
&result[(i_macro + tile_row_start + row) * result_stride_values + \
|
|
2224
|
+
(i_macro + tile_column_start + column_start)]; \
|
|
2225
|
+
partial_store_fn(&result_vec, output_ptr, columns_remaining); \
|
|
2226
|
+
} \
|
|
2227
|
+
} \
|
|
2228
|
+
} \
|
|
2229
|
+
} \
|
|
2230
|
+
NK_INTERNAL void nk_##api_name##_symmetric_##input_type_name##_##isa_suffix##_offdiagonal_( \
|
|
2231
|
+
nk_##input_value_type##_t const **vector_base_ptrs_i, nk_##input_value_type##_t const **vector_base_ptrs_j, \
|
|
2232
|
+
nk_size_t i_macro, nk_size_t j_macro, nk_size_t macro_i_size, nk_size_t macro_j_size, nk_size_t aligned_depth, \
|
|
2233
|
+
nk_size_t remainder_depth, nk_size_t remainder_dimensions, nk_size_t depth_step_values, \
|
|
2234
|
+
nk_size_t dimensions_per_value_runtime, nk_##result_value_type##_t *result, nk_size_t result_stride_values, \
|
|
2235
|
+
nk_size_t finalizer_batch_size, nk_size_t depth) { \
|
|
2236
|
+
\
|
|
2237
|
+
nk_unused_(dimensions_per_value_runtime); \
|
|
2238
|
+
nk_unused_(finalizer_batch_size); \
|
|
2239
|
+
/* Tile-first architecture: Process 32×32 macro-tile as 4×4 register tiles (depth innermost) */ \
|
|
2240
|
+
for (nk_size_t tile_row_start = 0; tile_row_start < macro_i_size; tile_row_start += 4) { \
|
|
2241
|
+
for (nk_size_t tile_column_start = 0; tile_column_start < macro_j_size; tile_column_start += 4) { \
|
|
2242
|
+
\
|
|
2243
|
+
nk_size_t tile_rows = (tile_row_start + 4 <= macro_i_size) ? 4 : (macro_i_size - tile_row_start); \
|
|
2244
|
+
nk_size_t tile_columns = (tile_column_start + 4 <= macro_j_size) ? 4 \
|
|
2245
|
+
: (macro_j_size - tile_column_start); \
|
|
2246
|
+
\
|
|
2247
|
+
/* Initialize 4×4 register-resident accumulators (full rectangle for off-diagonal) */ \
|
|
2248
|
+
NK_ALIGN64 state_type accumulators[4][4]; \
|
|
2249
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
2250
|
+
for (nk_size_t column = 0; column < tile_columns; column++) { \
|
|
2251
|
+
init_accumulator_fn(&accumulators[row][column]); \
|
|
2252
|
+
} \
|
|
2253
|
+
} \
|
|
2254
|
+
\
|
|
2255
|
+
/* Setup pointers (hoist outside depth loop) - always safe even for partial tiles */ \
|
|
2256
|
+
nk_##input_value_type##_t const *row_ptrs[4]; \
|
|
2257
|
+
nk_##input_value_type##_t const *column_ptrs[4]; \
|
|
2258
|
+
row_ptrs[0] = vector_base_ptrs_i[tile_row_start + 0]; \
|
|
2259
|
+
row_ptrs[1] = (tile_rows > 1) ? vector_base_ptrs_i[tile_row_start + 1] : row_ptrs[0]; \
|
|
2260
|
+
row_ptrs[2] = (tile_rows > 2) ? vector_base_ptrs_i[tile_row_start + 2] : row_ptrs[0]; \
|
|
2261
|
+
row_ptrs[3] = (tile_rows > 3) ? vector_base_ptrs_i[tile_row_start + 3] : row_ptrs[0]; \
|
|
2262
|
+
column_ptrs[0] = vector_base_ptrs_j[tile_column_start + 0]; \
|
|
2263
|
+
column_ptrs[1] = (tile_columns > 1) ? vector_base_ptrs_j[tile_column_start + 1] : column_ptrs[0]; \
|
|
2264
|
+
column_ptrs[2] = (tile_columns > 2) ? vector_base_ptrs_j[tile_column_start + 2] : column_ptrs[0]; \
|
|
2265
|
+
column_ptrs[3] = (tile_columns > 3) ? vector_base_ptrs_j[tile_column_start + 3] : column_ptrs[0]; \
|
|
2266
|
+
\
|
|
2267
|
+
/* Depth loop is now innermost - key optimization */ \
|
|
2268
|
+
vec_type row_vecs[4]; \
|
|
2269
|
+
vec_type column_vecs[4]; \
|
|
2270
|
+
\
|
|
2271
|
+
for (nk_size_t depth_offset = 0; depth_offset < aligned_depth; depth_offset += depth_step_values) { \
|
|
2272
|
+
/* Always load all 8 vectors - aliasing is cheaper than branches */ \
|
|
2273
|
+
load_vec_fn(row_ptrs[0] + depth_offset, &row_vecs[0]); \
|
|
2274
|
+
load_vec_fn(row_ptrs[1] + depth_offset, &row_vecs[1]); \
|
|
2275
|
+
load_vec_fn(row_ptrs[2] + depth_offset, &row_vecs[2]); \
|
|
2276
|
+
load_vec_fn(row_ptrs[3] + depth_offset, &row_vecs[3]); \
|
|
2277
|
+
load_vec_fn(column_ptrs[0] + depth_offset, &column_vecs[0]); \
|
|
2278
|
+
load_vec_fn(column_ptrs[1] + depth_offset, &column_vecs[1]); \
|
|
2279
|
+
load_vec_fn(column_ptrs[2] + depth_offset, &column_vecs[2]); \
|
|
2280
|
+
load_vec_fn(column_ptrs[3] + depth_offset, &column_vecs[3]); \
|
|
2281
|
+
\
|
|
2282
|
+
nk_size_t vector_offset = depth_offset * dimensions_per_value; \
|
|
2283
|
+
\
|
|
2284
|
+
/* Compute: always unroll for full 4×4, use loops only for partial tiles */ \
|
|
2285
|
+
if (tile_rows == 4 && tile_columns == 4) { \
|
|
2286
|
+
/* Full 4×4 off-diagonal tile (16 FMAs) */ \
|
|
2287
|
+
inner_product_fn(&accumulators[0][0], row_vecs[0], column_vecs[0], vector_offset, \
|
|
2288
|
+
depth_simd_dimensions); \
|
|
2289
|
+
inner_product_fn(&accumulators[0][1], row_vecs[0], column_vecs[1], vector_offset, \
|
|
2290
|
+
depth_simd_dimensions); \
|
|
2291
|
+
inner_product_fn(&accumulators[0][2], row_vecs[0], column_vecs[2], vector_offset, \
|
|
2292
|
+
depth_simd_dimensions); \
|
|
2293
|
+
inner_product_fn(&accumulators[0][3], row_vecs[0], column_vecs[3], vector_offset, \
|
|
2294
|
+
depth_simd_dimensions); \
|
|
2295
|
+
inner_product_fn(&accumulators[1][0], row_vecs[1], column_vecs[0], vector_offset, \
|
|
2296
|
+
depth_simd_dimensions); \
|
|
2297
|
+
inner_product_fn(&accumulators[1][1], row_vecs[1], column_vecs[1], vector_offset, \
|
|
2298
|
+
depth_simd_dimensions); \
|
|
2299
|
+
inner_product_fn(&accumulators[1][2], row_vecs[1], column_vecs[2], vector_offset, \
|
|
2300
|
+
depth_simd_dimensions); \
|
|
2301
|
+
inner_product_fn(&accumulators[1][3], row_vecs[1], column_vecs[3], vector_offset, \
|
|
2302
|
+
depth_simd_dimensions); \
|
|
2303
|
+
inner_product_fn(&accumulators[2][0], row_vecs[2], column_vecs[0], vector_offset, \
|
|
2304
|
+
depth_simd_dimensions); \
|
|
2305
|
+
inner_product_fn(&accumulators[2][1], row_vecs[2], column_vecs[1], vector_offset, \
|
|
2306
|
+
depth_simd_dimensions); \
|
|
2307
|
+
inner_product_fn(&accumulators[2][2], row_vecs[2], column_vecs[2], vector_offset, \
|
|
2308
|
+
depth_simd_dimensions); \
|
|
2309
|
+
inner_product_fn(&accumulators[2][3], row_vecs[2], column_vecs[3], vector_offset, \
|
|
2310
|
+
depth_simd_dimensions); \
|
|
2311
|
+
inner_product_fn(&accumulators[3][0], row_vecs[3], column_vecs[0], vector_offset, \
|
|
2312
|
+
depth_simd_dimensions); \
|
|
2313
|
+
inner_product_fn(&accumulators[3][1], row_vecs[3], column_vecs[1], vector_offset, \
|
|
2314
|
+
depth_simd_dimensions); \
|
|
2315
|
+
inner_product_fn(&accumulators[3][2], row_vecs[3], column_vecs[2], vector_offset, \
|
|
2316
|
+
depth_simd_dimensions); \
|
|
2317
|
+
inner_product_fn(&accumulators[3][3], row_vecs[3], column_vecs[3], vector_offset, \
|
|
2318
|
+
depth_simd_dimensions); \
|
|
2319
|
+
} \
|
|
2320
|
+
else { \
|
|
2321
|
+
/* Partial tile - use loops (rare edge case) */ \
|
|
2322
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
2323
|
+
for (nk_size_t column = 0; column < tile_columns; column++) { \
|
|
2324
|
+
inner_product_fn(&accumulators[row][column], row_vecs[row], column_vecs[column], \
|
|
2325
|
+
vector_offset, depth_simd_dimensions); \
|
|
2326
|
+
} \
|
|
2327
|
+
} \
|
|
2328
|
+
} \
|
|
2329
|
+
} \
|
|
2330
|
+
\
|
|
2331
|
+
/* Handle remainder depth (happens once per tile, not in hot loop) */ \
|
|
2332
|
+
if (remainder_depth > 0) { \
|
|
2333
|
+
partial_load_vec_fn(row_ptrs[0] + aligned_depth, &row_vecs[0], remainder_dimensions); \
|
|
2334
|
+
partial_load_vec_fn(row_ptrs[1] + aligned_depth, &row_vecs[1], remainder_dimensions); \
|
|
2335
|
+
partial_load_vec_fn(row_ptrs[2] + aligned_depth, &row_vecs[2], remainder_dimensions); \
|
|
2336
|
+
partial_load_vec_fn(row_ptrs[3] + aligned_depth, &row_vecs[3], remainder_dimensions); \
|
|
2337
|
+
partial_load_vec_fn(column_ptrs[0] + aligned_depth, &column_vecs[0], remainder_dimensions); \
|
|
2338
|
+
partial_load_vec_fn(column_ptrs[1] + aligned_depth, &column_vecs[1], remainder_dimensions); \
|
|
2339
|
+
partial_load_vec_fn(column_ptrs[2] + aligned_depth, &column_vecs[2], remainder_dimensions); \
|
|
2340
|
+
partial_load_vec_fn(column_ptrs[3] + aligned_depth, &column_vecs[3], remainder_dimensions); \
|
|
2341
|
+
\
|
|
2342
|
+
nk_size_t vector_offset = aligned_depth * dimensions_per_value; \
|
|
2343
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
2344
|
+
for (nk_size_t column = 0; column < tile_columns; column++) { \
|
|
2345
|
+
inner_product_fn(&accumulators[row][column], row_vecs[row], column_vecs[column], \
|
|
2346
|
+
vector_offset, remainder_dimensions); \
|
|
2347
|
+
} \
|
|
2348
|
+
} \
|
|
2349
|
+
} \
|
|
2350
|
+
\
|
|
2351
|
+
/* Direct finalization and store (no intermediate buffer) */ \
|
|
2352
|
+
for (nk_size_t row = 0; row < tile_rows; row++) { \
|
|
2353
|
+
result_vec_type result_vec; \
|
|
2354
|
+
\
|
|
2355
|
+
/* Always reduce 4 accumulators (partial_store handles actual count) */ \
|
|
2356
|
+
reduce_accumulators_fn(&accumulators[row][0], &accumulators[row][1], &accumulators[row][2], \
|
|
2357
|
+
&accumulators[row][3], depth, &result_vec); \
|
|
2358
|
+
\
|
|
2359
|
+
nk_##result_value_type##_t *output_ptr = \
|
|
2360
|
+
&result[(i_macro + tile_row_start + row) * result_stride_values + \
|
|
2361
|
+
(j_macro + tile_column_start)]; \
|
|
2362
|
+
partial_store_fn(&result_vec, output_ptr, tile_columns); \
|
|
2363
|
+
} \
|
|
2364
|
+
} \
|
|
2365
|
+
} \
|
|
2366
|
+
} \
|
|
2367
|
+
NK_PUBLIC void nk_##api_name##_symmetric_##input_type_name##_##isa_suffix( \
|
|
2368
|
+
nk_##input_value_type##_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, \
|
|
2369
|
+
nk_##result_value_type##_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) { \
|
|
2370
|
+
nk_size_t const macro_tile_size = 32; \
|
|
2371
|
+
nk_size_t const finalizer_batch_size = 4; \
|
|
2372
|
+
nk_size_t const row_block_size = 128; /* L2 cache blocking */ \
|
|
2373
|
+
nk_size_t const column_block_size = 2048; /* L3 cache blocking */ \
|
|
2374
|
+
\
|
|
2375
|
+
/* Stride and depth calculations */ \
|
|
2376
|
+
nk_size_t const vectors_stride_values = stride / sizeof(nk_##input_value_type##_t); \
|
|
2377
|
+
nk_size_t const result_stride_values = result_stride / sizeof(nk_##result_value_type##_t); \
|
|
2378
|
+
nk_size_t const depth_dimensions_aligned = (depth / depth_simd_dimensions) * depth_simd_dimensions; \
|
|
2379
|
+
nk_size_t const aligned_depth = nk_size_divide_round_up_(depth_dimensions_aligned, dimensions_per_value); \
|
|
2380
|
+
nk_size_t const depth_in_values = nk_size_divide_round_up_(depth, dimensions_per_value); \
|
|
2381
|
+
nk_size_t const remainder_depth = depth_in_values - aligned_depth; \
|
|
2382
|
+
nk_size_t const remainder_dimensions = depth - depth_dimensions_aligned; \
|
|
2383
|
+
nk_size_t const depth_step_values = nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
2384
|
+
nk_size_t const row_end = (row_start + row_count < n_vectors) ? (row_start + row_count) : n_vectors; \
|
|
2385
|
+
\
|
|
2386
|
+
/* Process upper triangle with L3/L2/L1 blocking (column blocks → row blocks → 32×32 macro-tiles) */ \
|
|
2387
|
+
for (nk_size_t j_block = 0; j_block < n_vectors; j_block += column_block_size) { \
|
|
2388
|
+
nk_size_t j_block_end = (j_block + column_block_size < n_vectors) ? j_block + column_block_size \
|
|
2389
|
+
: n_vectors; \
|
|
2390
|
+
\
|
|
2391
|
+
for (nk_size_t i_block = row_start; i_block < row_end; i_block += row_block_size) { \
|
|
2392
|
+
nk_size_t i_block_end = (i_block + row_block_size < row_end) ? i_block + row_block_size : row_end; \
|
|
2393
|
+
\
|
|
2394
|
+
/* Skip blocks entirely below diagonal. Blocks fully above the diagonal are still part of the upper \
|
|
2395
|
+
* triangle and must be computed. */ \
|
|
2396
|
+
if (i_block >= j_block_end) continue; \
|
|
2397
|
+
\
|
|
2398
|
+
for (nk_size_t i_macro = i_block; i_macro < i_block_end; i_macro += macro_tile_size) { \
|
|
2399
|
+
/* Upper triangle: j_macro starts at max(i_macro, j_block) */ \
|
|
2400
|
+
nk_size_t j_start = (i_macro > j_block) ? i_macro : j_block; \
|
|
2401
|
+
for (nk_size_t j_macro = j_start; j_macro < j_block_end; j_macro += macro_tile_size) { \
|
|
2402
|
+
nk_size_t macro_i_size = (i_macro + macro_tile_size <= i_block_end) ? macro_tile_size \
|
|
2403
|
+
: (i_block_end - i_macro); \
|
|
2404
|
+
nk_size_t macro_j_size = (j_macro + macro_tile_size <= j_block_end) ? macro_tile_size \
|
|
2405
|
+
: (j_block_end - j_macro); \
|
|
2406
|
+
\
|
|
2407
|
+
/* Hoist pointer computation outside depth loop */ \
|
|
2408
|
+
nk_##input_value_type##_t const *vector_base_ptrs_i[32]; \
|
|
2409
|
+
nk_##input_value_type##_t const *vector_base_ptrs_j[32]; \
|
|
2410
|
+
for (nk_size_t i = 0; i < macro_i_size; i++) { \
|
|
2411
|
+
vector_base_ptrs_i[i] = vectors + (i_macro + i) * vectors_stride_values; \
|
|
2412
|
+
} \
|
|
2413
|
+
if (i_macro != j_macro || macro_i_size != macro_j_size) { \
|
|
2414
|
+
for (nk_size_t j = 0; j < macro_j_size; j++) { \
|
|
2415
|
+
vector_base_ptrs_j[j] = vectors + (j_macro + j) * vectors_stride_values; \
|
|
2416
|
+
} \
|
|
2417
|
+
} \
|
|
2418
|
+
\
|
|
2419
|
+
if (i_macro == j_macro && macro_i_size == macro_j_size) { \
|
|
2420
|
+
/* Diagonal macro-tile: symmetric, upper triangle only */ \
|
|
2421
|
+
nk_##api_name##_symmetric_diagonal_##input_type_name##_##isa_suffix##_( \
|
|
2422
|
+
vector_base_ptrs_i, i_macro, macro_i_size, aligned_depth, remainder_depth, \
|
|
2423
|
+
remainder_dimensions, depth_step_values, dimensions_per_value, result, \
|
|
2424
|
+
result_stride_values, finalizer_batch_size, depth); \
|
|
2425
|
+
} \
|
|
2426
|
+
else { \
|
|
2427
|
+
/* Off-diagonal macro-tile: full rectangle */ \
|
|
2428
|
+
nk_##api_name##_symmetric_##input_type_name##_##isa_suffix##_offdiagonal##_( \
|
|
2429
|
+
vector_base_ptrs_i, vector_base_ptrs_j, i_macro, j_macro, macro_i_size, macro_j_size, \
|
|
2430
|
+
aligned_depth, remainder_depth, remainder_dimensions, depth_step_values, \
|
|
2431
|
+
dimensions_per_value, result, result_stride_values, finalizer_batch_size, depth); \
|
|
2432
|
+
} \
|
|
2433
|
+
} \
|
|
2434
|
+
} \
|
|
2435
|
+
} \
|
|
2436
|
+
} \
|
|
2437
|
+
}
|
|
2438
|
+
|
|
2439
|
+
/* Optimize serial GEMM instantiations for size rather than speed.
|
|
2440
|
+
* These fallback kernels are only used when no SIMD backend is available, so aggressive inlining/unrolling from -O3
|
|
2441
|
+
* wastes ~1.3 MB of binary space with negligible performance benefit on the serial path. Sadly, a scoped application
|
|
2442
|
+
* of `__attribute__((optimize("Os"))` isn't supported on Clang, so this flag only applies to GCC builds.
|
|
2443
|
+
*/
|
|
2444
|
+
#if defined(NDEBUG)
|
|
2445
|
+
#if defined(__GNUC__) && !defined(__clang__)
|
|
2446
|
+
#pragma GCC push_options
|
|
2447
|
+
#pragma GCC optimize("Os")
|
|
2448
|
+
#endif
|
|
2449
|
+
#endif
|
|
2450
|
+
|
|
2451
|
+
/* F64 GEMM: depth_simd_dimensions=2 (2 f64s = 16 bytes) */
|
|
2452
|
+
nk_define_cross_pack_size_(dots, f64, serial, f64, f64, /*norm_value_type=*/f64, /*depth_simd_dimensions=*/2,
|
|
2453
|
+
/*dimensions_per_value=*/1)
|
|
2454
|
+
nk_define_cross_pack_(dots, f64, serial, f64, f64, nk_assign_from_to_, /*norm_value_type=*/f64,
|
|
2455
|
+
nk_dots_reduce_sumsq_f64_,
|
|
2456
|
+
/*depth_simd_dimensions=*/2, /*dimensions_per_value=*/1)
|
|
2457
|
+
nk_define_cross_symmetric_(dots, f64, serial, f64, f64, nk_b128_vec_t, nk_dot_f64x2_state_serial_t, nk_b256_vec_t,
|
|
2458
|
+
nk_dot_f64x2_init_serial, nk_load_b128_serial_, nk_partial_load_b64x2_serial_,
|
|
2459
|
+
nk_dot_f64x2_update_serial, nk_dot_f64x2_finalize_serial, nk_store_b256_serial_,
|
|
2460
|
+
nk_partial_store_b64x4_serial_,
|
|
2461
|
+
/*depth_simd_dimensions=*/2, /*dimensions_per_value=*/1)
|
|
2462
|
+
nk_define_cross_packed_(dots, f64, serial, f64, f64, f64, nk_b128_vec_t, nk_dot_f64x2_state_serial_t, nk_b256_vec_t,
|
|
2463
|
+
nk_dot_f64x2_init_serial, nk_load_b128_serial_, nk_partial_load_b64x2_serial_,
|
|
2464
|
+
nk_load_b128_serial_, nk_partial_load_b64x2_serial_, nk_dot_f64x2_update_serial,
|
|
2465
|
+
nk_dot_f64x2_finalize_serial, nk_store_b256_serial_, nk_partial_store_b64x4_serial_,
|
|
2466
|
+
/*depth_simd_dimensions=*/2, /*dimensions_per_value=*/1)
|
|
2467
|
+
|
|
2468
|
+
/* F32 GEMM: depth_simd_dimensions=4 (4 f32s = 16 bytes) */
|
|
2469
|
+
nk_define_cross_pack_size_(dots, f32, serial, f32, f32, /*norm_value_type=*/f64, /*depth_simd_dimensions=*/4,
|
|
2470
|
+
/*dimensions_per_value=*/1)
|
|
2471
|
+
nk_define_cross_pack_(dots, f32, serial, f32, f32, nk_assign_from_to_, /*norm_value_type=*/f64,
|
|
2472
|
+
nk_dots_reduce_sumsq_f32_,
|
|
2473
|
+
/*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
|
|
2474
|
+
nk_define_cross_symmetric_(dots, f32, serial, f32, f64, nk_b128_vec_t, nk_dot_f32x4_state_serial_t, nk_b256_vec_t,
|
|
2475
|
+
nk_dot_f32x4_init_serial, nk_load_b128_serial_, nk_partial_load_b32x4_serial_,
|
|
2476
|
+
nk_dot_f32x4_update_serial, nk_dot_f32x4_finalize_serial, nk_store_b256_serial_,
|
|
2477
|
+
nk_partial_store_b64x4_serial_,
|
|
2478
|
+
/*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
|
|
2479
|
+
nk_define_cross_packed_(dots, f32, serial, f32, f32, f64, nk_b128_vec_t, nk_dot_f32x4_state_serial_t, nk_b256_vec_t,
|
|
2480
|
+
nk_dot_f32x4_init_serial, nk_load_b128_serial_, nk_partial_load_b32x4_serial_,
|
|
2481
|
+
nk_load_b128_serial_, nk_partial_load_b32x4_serial_, nk_dot_f32x4_update_serial,
|
|
2482
|
+
nk_dot_f32x4_finalize_serial, nk_store_b256_serial_, nk_partial_store_b64x4_serial_,
|
|
2483
|
+
/*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
|
|
2484
|
+
|
|
2485
|
+
/* F16 GEMM: depth_simd_dimensions=8 (8 f16s = 16 bytes), F32 accumulator */
|
|
2486
|
+
nk_define_cross_pack_size_(dots, f16, serial, f16, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
|
|
2487
|
+
/*dimensions_per_value=*/1)
|
|
2488
|
+
nk_define_cross_pack_(dots, f16, serial, f16, f16, nk_assign_from_to_, /*norm_value_type=*/f32,
|
|
2489
|
+
nk_dots_reduce_sumsq_f16_,
|
|
2490
|
+
/*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
|
|
2491
|
+
nk_define_cross_symmetric_(dots, f16, serial, f16, f32, nk_b128_vec_t, nk_dot_f16x8_state_serial_t, nk_b128_vec_t,
|
|
2492
|
+
nk_dot_f16x8_init_serial, nk_load_b128_serial_, nk_partial_load_b16x8_serial_,
|
|
2493
|
+
nk_dot_f16x8_update_serial, nk_dot_f16x8_finalize_serial, nk_store_b128_serial_,
|
|
2494
|
+
nk_partial_store_b32x4_serial_,
|
|
2495
|
+
/*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
|
|
2496
|
+
nk_define_cross_packed_(dots, f16, serial, f16, f16, f32, nk_b128_vec_t, nk_dot_f16x8_state_serial_t, nk_b128_vec_t,
|
|
2497
|
+
nk_dot_f16x8_init_serial, nk_load_b128_serial_, nk_partial_load_b16x8_serial_,
|
|
2498
|
+
nk_load_b128_serial_, nk_partial_load_b16x8_serial_, nk_dot_f16x8_update_serial,
|
|
2499
|
+
nk_dot_f16x8_finalize_serial, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
|
|
2500
|
+
/*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
|
|
2501
|
+
|
|
2502
|
+
/* BF16 GEMM: depth_simd_dimensions=8 (8 bf16s = 16 bytes), F32 accumulator */
|
|
2503
|
+
nk_define_cross_pack_size_(dots, bf16, serial, bf16, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
|
|
2504
|
+
/*dimensions_per_value=*/1)
|
|
2505
|
+
nk_define_cross_pack_(dots, bf16, serial, bf16, bf16, nk_assign_from_to_, /*norm_value_type=*/f32,
|
|
2506
|
+
nk_dots_reduce_sumsq_bf16_,
|
|
2507
|
+
/*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
|
|
2508
|
+
nk_define_cross_symmetric_(dots, bf16, serial, bf16, f32, nk_b128_vec_t, nk_dot_bf16x8_state_serial_t, nk_b128_vec_t,
|
|
2509
|
+
nk_dot_bf16x8_init_serial, nk_load_b128_serial_, nk_partial_load_b16x8_serial_,
|
|
2510
|
+
nk_dot_bf16x8_update_serial, nk_dot_bf16x8_finalize_serial, nk_store_b128_serial_,
|
|
2511
|
+
nk_partial_store_b32x4_serial_,
|
|
2512
|
+
/*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
|
|
2513
|
+
nk_define_cross_packed_(dots, bf16, serial, bf16, bf16, f32, nk_b128_vec_t, nk_dot_bf16x8_state_serial_t, nk_b128_vec_t,
|
|
2514
|
+
nk_dot_bf16x8_init_serial, nk_load_b128_serial_, nk_partial_load_b16x8_serial_,
|
|
2515
|
+
nk_load_b128_serial_, nk_partial_load_b16x8_serial_, nk_dot_bf16x8_update_serial,
|
|
2516
|
+
nk_dot_bf16x8_finalize_serial, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
|
|
2517
|
+
/*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
|
|
2518
|
+
|
|
2519
|
+
/* I8 GEMM: depth_simd_dimensions=16 (16 i8s = 16 bytes), I32 accumulator */
|
|
2520
|
+
nk_define_cross_pack_size_(dots, i8, serial, i8, i8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
|
|
2521
|
+
/*dimensions_per_value=*/1)
|
|
2522
|
+
nk_define_cross_pack_(dots, i8, serial, i8, i8, nk_assign_from_to_, /*norm_value_type=*/u32, nk_dots_reduce_sumsq_i8_,
|
|
2523
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2524
|
+
nk_define_cross_symmetric_(dots, i8, serial, i8, i32, nk_b128_vec_t, nk_dot_i8x16_state_serial_t, nk_b128_vec_t,
|
|
2525
|
+
nk_dot_i8x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2526
|
+
nk_dot_i8x16_update_serial, nk_dot_i8x16_finalize_serial, nk_store_b128_serial_,
|
|
2527
|
+
nk_partial_store_b32x4_serial_,
|
|
2528
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2529
|
+
nk_define_cross_packed_(dots, i8, serial, i8, i8, i32, nk_b128_vec_t, nk_dot_i8x16_state_serial_t, nk_b128_vec_t,
|
|
2530
|
+
nk_dot_i8x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2531
|
+
nk_load_b128_serial_, nk_partial_load_b8x16_serial_, nk_dot_i8x16_update_serial,
|
|
2532
|
+
nk_dot_i8x16_finalize_serial, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
|
|
2533
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2534
|
+
|
|
2535
|
+
/* U8 GEMM: depth_simd_dimensions=16 (16 u8s = 16 bytes), U32 accumulator */
|
|
2536
|
+
nk_define_cross_pack_size_(dots, u8, serial, u8, u8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
|
|
2537
|
+
/*dimensions_per_value=*/1)
|
|
2538
|
+
nk_define_cross_pack_(dots, u8, serial, u8, u8, nk_assign_from_to_, /*norm_value_type=*/u32, nk_dots_reduce_sumsq_u8_,
|
|
2539
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2540
|
+
nk_define_cross_symmetric_(dots, u8, serial, u8, u32, nk_b128_vec_t, nk_dot_u8x16_state_serial_t, nk_b128_vec_t,
|
|
2541
|
+
nk_dot_u8x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2542
|
+
nk_dot_u8x16_update_serial, nk_dot_u8x16_finalize_serial, nk_store_b128_serial_,
|
|
2543
|
+
nk_partial_store_b32x4_serial_,
|
|
2544
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2545
|
+
nk_define_cross_packed_(dots, u8, serial, u8, u8, u32, nk_b128_vec_t, nk_dot_u8x16_state_serial_t, nk_b128_vec_t,
|
|
2546
|
+
nk_dot_u8x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2547
|
+
nk_load_b128_serial_, nk_partial_load_b8x16_serial_, nk_dot_u8x16_update_serial,
|
|
2548
|
+
nk_dot_u8x16_finalize_serial, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
|
|
2549
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2550
|
+
|
|
2551
|
+
/* E4M3 GEMM: depth_simd_dimensions=16 (16 e4m3s = 16 bytes), F32 accumulator */
|
|
2552
|
+
nk_define_cross_pack_size_(dots, e4m3, serial, e4m3, e4m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
|
|
2553
|
+
/*dimensions_per_value=*/1)
|
|
2554
|
+
nk_define_cross_pack_(dots, e4m3, serial, e4m3, e4m3, nk_assign_from_to_, /*norm_value_type=*/f32,
|
|
2555
|
+
nk_dots_reduce_sumsq_e4m3_,
|
|
2556
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2557
|
+
nk_define_cross_symmetric_(dots, e4m3, serial, e4m3, f32, nk_b128_vec_t, nk_dot_e4m3x16_state_serial_t, nk_b128_vec_t,
|
|
2558
|
+
nk_dot_e4m3x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2559
|
+
nk_dot_e4m3x16_update_serial, nk_dot_e4m3x16_finalize_serial, nk_store_b128_serial_,
|
|
2560
|
+
nk_partial_store_b32x4_serial_,
|
|
2561
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2562
|
+
nk_define_cross_packed_(dots, e4m3, serial, e4m3, e4m3, f32, nk_b128_vec_t, nk_dot_e4m3x16_state_serial_t,
|
|
2563
|
+
nk_b128_vec_t, nk_dot_e4m3x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2564
|
+
nk_load_b128_serial_, nk_partial_load_b8x16_serial_, nk_dot_e4m3x16_update_serial,
|
|
2565
|
+
nk_dot_e4m3x16_finalize_serial, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
|
|
2566
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2567
|
+
|
|
2568
|
+
/* E5M2 GEMM: depth_simd_dimensions=16 (16 e5m2s = 16 bytes), F32 accumulator */
|
|
2569
|
+
nk_define_cross_pack_size_(dots, e5m2, serial, e5m2, e5m2, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
|
|
2570
|
+
/*dimensions_per_value=*/1)
|
|
2571
|
+
nk_define_cross_pack_(dots, e5m2, serial, e5m2, e5m2, nk_assign_from_to_, /*norm_value_type=*/f32,
|
|
2572
|
+
nk_dots_reduce_sumsq_e5m2_,
|
|
2573
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2574
|
+
nk_define_cross_symmetric_(dots, e5m2, serial, e5m2, f32, nk_b128_vec_t, nk_dot_e5m2x16_state_serial_t, nk_b128_vec_t,
|
|
2575
|
+
nk_dot_e5m2x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2576
|
+
nk_dot_e5m2x16_update_serial, nk_dot_e5m2x16_finalize_serial, nk_store_b128_serial_,
|
|
2577
|
+
nk_partial_store_b32x4_serial_,
|
|
2578
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2579
|
+
nk_define_cross_packed_(dots, e5m2, serial, e5m2, e5m2, f32, nk_b128_vec_t, nk_dot_e5m2x16_state_serial_t,
|
|
2580
|
+
nk_b128_vec_t, nk_dot_e5m2x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2581
|
+
nk_load_b128_serial_, nk_partial_load_b8x16_serial_, nk_dot_e5m2x16_update_serial,
|
|
2582
|
+
nk_dot_e5m2x16_finalize_serial, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
|
|
2583
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2584
|
+
|
|
2585
|
+
/* E2M3 GEMM: depth_simd_dimensions=16 (16 e2m3s = 16 bytes), F32 accumulator */
|
|
2586
|
+
nk_define_cross_pack_size_(dots, e2m3, serial, e2m3, e2m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
|
|
2587
|
+
/*dimensions_per_value=*/1)
|
|
2588
|
+
nk_define_cross_pack_(dots, e2m3, serial, e2m3, e2m3, nk_assign_from_to_, /*norm_value_type=*/f32,
|
|
2589
|
+
nk_dots_reduce_sumsq_e2m3_,
|
|
2590
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2591
|
+
nk_define_cross_symmetric_(dots, e2m3, serial, e2m3, f32, nk_b128_vec_t, nk_dot_e2m3x16_state_serial_t, nk_b128_vec_t,
|
|
2592
|
+
nk_dot_e2m3x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2593
|
+
nk_dot_e2m3x16_update_serial, nk_dot_e2m3x16_finalize_serial, nk_store_b128_serial_,
|
|
2594
|
+
nk_partial_store_b32x4_serial_,
|
|
2595
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2596
|
+
nk_define_cross_packed_(dots, e2m3, serial, e2m3, e2m3, f32, nk_b128_vec_t, nk_dot_e2m3x16_state_serial_t,
|
|
2597
|
+
nk_b128_vec_t, nk_dot_e2m3x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2598
|
+
nk_load_b128_serial_, nk_partial_load_b8x16_serial_, nk_dot_e2m3x16_update_serial,
|
|
2599
|
+
nk_dot_e2m3x16_finalize_serial, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
|
|
2600
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2601
|
+
|
|
2602
|
+
/* E3M2 GEMM: depth_simd_dimensions=16 (16 e3m2s = 16 bytes), F32 accumulator */
|
|
2603
|
+
nk_define_cross_pack_size_(dots, e3m2, serial, e3m2, e3m2, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
|
|
2604
|
+
/*dimensions_per_value=*/1)
|
|
2605
|
+
nk_define_cross_pack_(dots, e3m2, serial, e3m2, e3m2, nk_assign_from_to_, /*norm_value_type=*/f32,
|
|
2606
|
+
nk_dots_reduce_sumsq_e3m2_,
|
|
2607
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2608
|
+
nk_define_cross_symmetric_(dots, e3m2, serial, e3m2, f32, nk_b128_vec_t, nk_dot_e3m2x16_state_serial_t, nk_b128_vec_t,
|
|
2609
|
+
nk_dot_e3m2x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2610
|
+
nk_dot_e3m2x16_update_serial, nk_dot_e3m2x16_finalize_serial, nk_store_b128_serial_,
|
|
2611
|
+
nk_partial_store_b32x4_serial_,
|
|
2612
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2613
|
+
nk_define_cross_packed_(dots, e3m2, serial, e3m2, e3m2, f32, nk_b128_vec_t, nk_dot_e3m2x16_state_serial_t,
|
|
2614
|
+
nk_b128_vec_t, nk_dot_e3m2x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2615
|
+
nk_load_b128_serial_, nk_partial_load_b8x16_serial_, nk_dot_e3m2x16_update_serial,
|
|
2616
|
+
nk_dot_e3m2x16_finalize_serial, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
|
|
2617
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2618
|
+
|
|
2619
|
+
/* U4 GEMM: u4x2 for both A and B */
|
|
2620
|
+
nk_define_cross_pack_size_(dots, u4, serial, u4x2, u4x2, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
|
|
2621
|
+
/*dimensions_per_value=*/2)
|
|
2622
|
+
nk_define_cross_pack_(dots, u4, serial, u4x2, u4x2, nk_assign_from_to_, /*norm_value_type=*/u32,
|
|
2623
|
+
nk_dots_reduce_sumsq_u4_,
|
|
2624
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/2)
|
|
2625
|
+
nk_define_cross_symmetric_(dots, u4, serial, u4x2, u32, nk_b64_vec_t, nk_dot_u4x16_state_serial_t, nk_b128_vec_t,
|
|
2626
|
+
nk_dot_u4x16_init_serial, nk_load_b64_serial_, nk_partial_load_b4x16_serial_,
|
|
2627
|
+
nk_dot_u4x16_update_serial, nk_dot_u4x16_finalize_serial, nk_store_b128_serial_,
|
|
2628
|
+
nk_partial_store_b32x4_serial_,
|
|
2629
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/2)
|
|
2630
|
+
nk_define_cross_packed_(dots, u4, serial, u4x2, u4x2, u32, nk_b64_vec_t, nk_dot_u4x16_state_serial_t, nk_b128_vec_t,
|
|
2631
|
+
nk_dot_u4x16_init_serial, nk_load_b64_serial_, nk_partial_load_b4x16_serial_,
|
|
2632
|
+
nk_load_b64_serial_, nk_partial_load_b4x16_serial_, nk_dot_u4x16_update_serial,
|
|
2633
|
+
nk_dot_u4x16_finalize_serial, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
|
|
2634
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/2)
|
|
2635
|
+
|
|
2636
|
+
/* I4 GEMM: i4x2 for both A and B */
|
|
2637
|
+
nk_define_cross_pack_size_(dots, i4, serial, i4x2, i4x2, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
|
|
2638
|
+
/*dimensions_per_value=*/2)
|
|
2639
|
+
nk_define_cross_pack_(dots, i4, serial, i4x2, i4x2, nk_assign_from_to_, /*norm_value_type=*/u32,
|
|
2640
|
+
nk_dots_reduce_sumsq_i4_,
|
|
2641
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/2)
|
|
2642
|
+
nk_define_cross_symmetric_(dots, i4, serial, i4x2, i32, nk_b64_vec_t, nk_dot_i4x16_state_serial_t, nk_b128_vec_t,
|
|
2643
|
+
nk_dot_i4x16_init_serial, nk_load_b64_serial_, nk_partial_load_b4x16_serial_,
|
|
2644
|
+
nk_dot_i4x16_update_serial, nk_dot_i4x16_finalize_serial, nk_store_b128_serial_,
|
|
2645
|
+
nk_partial_store_b32x4_serial_,
|
|
2646
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/2)
|
|
2647
|
+
nk_define_cross_packed_(dots, i4, serial, i4x2, i4x2, i32, nk_b64_vec_t, nk_dot_i4x16_state_serial_t, nk_b128_vec_t,
|
|
2648
|
+
nk_dot_i4x16_init_serial, nk_load_b64_serial_, nk_partial_load_b4x16_serial_,
|
|
2649
|
+
nk_load_b64_serial_, nk_partial_load_b4x16_serial_, nk_dot_i4x16_update_serial,
|
|
2650
|
+
nk_dot_i4x16_finalize_serial, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
|
|
2651
|
+
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/2)
|
|
2652
|
+
|
|
2653
|
+
/* U1 GEMM: u1x8 for both A and B */
|
|
2654
|
+
nk_define_cross_pack_size_(dots, u1, serial, u1x8, u1x8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/128,
|
|
2655
|
+
/*dimensions_per_value=*/8)
|
|
2656
|
+
nk_define_cross_pack_(dots, u1, serial, u1x8, u1x8, nk_assign_from_to_, /*norm_value_type=*/u32, nk_dots_reduce_sum_u1_,
|
|
2657
|
+
/*depth_simd_dimensions=*/128, /*dimensions_per_value=*/8)
|
|
2658
|
+
nk_define_cross_symmetric_(dots, u1, serial, u1x8, u32, nk_b128_vec_t, nk_dot_u1x128_state_serial_t, nk_b128_vec_t,
|
|
2659
|
+
nk_dot_u1x128_init_serial, nk_load_b128_serial_, nk_partial_load_b1x128_serial_,
|
|
2660
|
+
nk_dot_u1x128_update_serial, nk_dot_u1x128_finalize_serial, nk_store_b128_serial_,
|
|
2661
|
+
nk_partial_store_b32x4_serial_,
|
|
2662
|
+
/*depth_simd_dimensions=*/128, /*dimensions_per_value=*/8)
|
|
2663
|
+
nk_define_cross_packed_(dots, u1, serial, u1x8, u1x8, u32, nk_b128_vec_t, nk_dot_u1x128_state_serial_t, nk_b128_vec_t,
|
|
2664
|
+
nk_dot_u1x128_init_serial, nk_load_b128_serial_, nk_partial_load_b1x128_serial_,
|
|
2665
|
+
nk_load_b128_serial_, nk_partial_load_b1x128_serial_, nk_dot_u1x128_update_serial,
|
|
2666
|
+
nk_dot_u1x128_finalize_serial, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
|
|
2667
|
+
/*depth_simd_dimensions=*/128, /*dimensions_per_value=*/8)
|
|
2668
|
+
|
|
2669
|
+
#if defined(NDEBUG)
|
|
2670
|
+
#if defined(__GNUC__) && !defined(__clang__)
|
|
2671
|
+
#pragma GCC pop_options
|
|
2672
|
+
#endif
|
|
2673
|
+
#endif
|
|
2674
|
+
|
|
2675
|
+
/* BF16 compact: truncate F32 → BF16 in-place.
|
|
2676
|
+
* Reads F32 matrix with c_stride_in_bytes, writes BF16 tightly packed (stride = column_count × sizeof(bf16)).
|
|
2677
|
+
*/
|
|
2678
|
+
NK_PUBLIC void nk_dots_compact_bf16_serial(void *c, nk_size_t row_count, nk_size_t column_count,
|
|
2679
|
+
nk_size_t c_stride_in_bytes) {
|
|
2680
|
+
nk_size_t const c_stride_in_values = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
2681
|
+
nk_f32_t const *c_f32 = (nk_f32_t const *)c;
|
|
2682
|
+
nk_bf16_t *c_bf16 = (nk_bf16_t *)c;
|
|
2683
|
+
|
|
2684
|
+
for (nk_size_t row_index = 0; row_index < row_count; row_index++) {
|
|
2685
|
+
nk_f32_t const *source_row = c_f32 + row_index * c_stride_in_values;
|
|
2686
|
+
nk_bf16_t *destination_row = c_bf16 + row_index * column_count;
|
|
2687
|
+
for (nk_size_t column_index = 0; column_index < column_count; column_index++) {
|
|
2688
|
+
nk_f32_to_bf16_serial(source_row + column_index, destination_row + column_index);
|
|
2689
|
+
}
|
|
2690
|
+
}
|
|
2691
|
+
}
|
|
2692
|
+
|
|
2693
|
+
/* I8 compact: re-normalize I32 → I8 using precomputed squared norms.
|
|
2694
|
+
* Formula: c_i8[i][j] = c_i32[i][j] × 127 / sqrt(a_norm[i] × b_norm[j])
|
|
2695
|
+
* Output is tightly packed (stride_in_bytes = column_count × sizeof(i8)).
|
|
2696
|
+
*/
|
|
2697
|
+
NK_PUBLIC void nk_dots_compact_i8_serial(void *c, nk_size_t row_count, nk_size_t column_count,
|
|
2698
|
+
nk_size_t c_stride_in_bytes, nk_i32_t const *a_squared_norms,
|
|
2699
|
+
nk_i32_t const *b_squared_norms) {
|
|
2700
|
+
nk_size_t const c_stride_in_values = c_stride_in_bytes / sizeof(nk_i32_t);
|
|
2701
|
+
nk_i32_t const *c_i32 = (nk_i32_t const *)c;
|
|
2702
|
+
nk_i8_t *c_i8 = (nk_i8_t *)c;
|
|
2703
|
+
|
|
2704
|
+
for (nk_size_t row_index = 0; row_index < row_count; row_index++) {
|
|
2705
|
+
nk_i32_t const *source_row = c_i32 + row_index * c_stride_in_values;
|
|
2706
|
+
nk_i8_t *destination_row = c_i8 + row_index * column_count;
|
|
2707
|
+
|
|
2708
|
+
nk_f32_t a_norm_f32_value = (nk_f32_t)a_squared_norms[row_index];
|
|
2709
|
+
nk_f32_t a_rsqrt_value = (a_norm_f32_value > 0) ? (1.0f / nk_f32_sqrt_serial(a_norm_f32_value)) : 0.0f;
|
|
2710
|
+
|
|
2711
|
+
for (nk_size_t column_index = 0; column_index < column_count; column_index++) {
|
|
2712
|
+
nk_f32_t b_norm_f32_value = (nk_f32_t)b_squared_norms[column_index];
|
|
2713
|
+
nk_f32_t b_rsqrt_value = (b_norm_f32_value > 0) ? (1.0f / nk_f32_sqrt_serial(b_norm_f32_value)) : 0.0f;
|
|
2714
|
+
|
|
2715
|
+
nk_f32_t normalized_value = (nk_f32_t)source_row[column_index] * 127.0f * a_rsqrt_value * b_rsqrt_value;
|
|
2716
|
+
nk_i32_t clamped_value = (nk_i32_t)normalized_value;
|
|
2717
|
+
if (clamped_value < -128) clamped_value = -128;
|
|
2718
|
+
if (clamped_value > 127) clamped_value = 127;
|
|
2719
|
+
destination_row[column_index] = (nk_i8_t)clamped_value;
|
|
2720
|
+
}
|
|
2721
|
+
}
|
|
2722
|
+
}
|
|
2723
|
+
|
|
2724
|
+
#define nk_define_cross_normalized_packed_(metric_name, input_type_name, isa_suffix, input_value_type, \
|
|
2725
|
+
packed_value_type, dot_result_type, norm_value_type, final_result_type, \
|
|
2726
|
+
vec_type, dots_packed_fn, from_dot_fn, compute_norm_fn, load_fn, \
|
|
2727
|
+
partial_load_fn, store_fn, partial_store_fn, dimensions_per_value) \
|
|
2728
|
+
NK_PUBLIC void nk_##metric_name##s_packed_##input_type_name##_##isa_suffix( \
|
|
2729
|
+
nk_##input_value_type##_t const *a_matrix, void const *b_packed_buffer, nk_##final_result_type##_t *c_matrix, \
|
|
2730
|
+
nk_size_t row_count, nk_size_t column_count, nk_size_t depth, nk_size_t a_stride_in_bytes, \
|
|
2731
|
+
nk_size_t c_stride_in_bytes) { \
|
|
2732
|
+
\
|
|
2733
|
+
dots_packed_fn(a_matrix, b_packed_buffer, (nk_##dot_result_type##_t *)c_matrix, row_count, column_count, \
|
|
2734
|
+
depth, a_stride_in_bytes, c_stride_in_bytes); \
|
|
2735
|
+
\
|
|
2736
|
+
nk_cross_packed_buffer_header_t const *header = (nk_cross_packed_buffer_header_t const *)b_packed_buffer; \
|
|
2737
|
+
nk_size_t depth_padded = header->depth_padded_values; \
|
|
2738
|
+
nk_##norm_value_type##_t const *b_norms = \
|
|
2739
|
+
(nk_##norm_value_type##_t const *)((char const *)b_packed_buffer + \
|
|
2740
|
+
sizeof(nk_cross_packed_buffer_header_t) + \
|
|
2741
|
+
column_count * depth_padded * sizeof(nk_##packed_value_type##_t)); \
|
|
2742
|
+
\
|
|
2743
|
+
for (nk_size_t row_index = 0; row_index < row_count; ++row_index) { \
|
|
2744
|
+
nk_##input_value_type##_t const *a_row = \
|
|
2745
|
+
(nk_##input_value_type##_t const *)((char const *)a_matrix + row_index * a_stride_in_bytes); \
|
|
2746
|
+
nk_##dot_result_type##_t query_norm = compute_norm_fn(a_row, depth); \
|
|
2747
|
+
nk_##dot_result_type##_t *r_row_dots = (nk_##dot_result_type##_t *)((char *)c_matrix + \
|
|
2748
|
+
row_index * c_stride_in_bytes); \
|
|
2749
|
+
nk_##final_result_type##_t *r_row_out = (nk_##final_result_type##_t *)((char *)c_matrix + \
|
|
2750
|
+
row_index * c_stride_in_bytes); \
|
|
2751
|
+
\
|
|
2752
|
+
nk_size_t column_index = 0; \
|
|
2753
|
+
for (; column_index + 4 <= column_count; column_index += 4) { \
|
|
2754
|
+
vec_type dots_vec, norms_vec, results_vec; \
|
|
2755
|
+
load_fn(r_row_dots + column_index, &dots_vec); \
|
|
2756
|
+
load_fn(b_norms + column_index, &norms_vec); \
|
|
2757
|
+
from_dot_fn(dots_vec, query_norm, norms_vec, &results_vec); \
|
|
2758
|
+
store_fn(&results_vec, r_row_out + column_index); \
|
|
2759
|
+
} \
|
|
2760
|
+
if (column_index < column_count) { \
|
|
2761
|
+
vec_type dots_vec = {0}, norms_vec = {0}, results_vec; \
|
|
2762
|
+
partial_load_fn(r_row_dots + column_index, &dots_vec, column_count - column_index); \
|
|
2763
|
+
partial_load_fn(b_norms + column_index, &norms_vec, column_count - column_index); \
|
|
2764
|
+
from_dot_fn(dots_vec, query_norm, norms_vec, &results_vec); \
|
|
2765
|
+
partial_store_fn(&results_vec, r_row_out + column_index, column_count - column_index); \
|
|
2766
|
+
} \
|
|
2767
|
+
} \
|
|
2768
|
+
}
|
|
2769
|
+
|
|
2770
|
+
#define nk_define_cross_normalized_symmetric_(metric_name, input_type_name, isa_suffix, input_value_type, \
|
|
2771
|
+
dot_result_type, norm_value_type, final_result_type, vec_type, \
|
|
2772
|
+
dots_symmetric_fn, from_dot_fn, compute_norm_fn, load_fn, \
|
|
2773
|
+
partial_load_fn, store_fn, partial_store_fn, dimensions_per_value) \
|
|
2774
|
+
NK_PUBLIC void nk_##metric_name##s_symmetric_##input_type_name##_##isa_suffix( \
|
|
2775
|
+
nk_##input_value_type##_t const *vectors, nk_size_t n_vectors, nk_size_t depth, nk_size_t stride, \
|
|
2776
|
+
nk_##final_result_type##_t *result, nk_size_t result_stride, nk_size_t row_start, nk_size_t row_count) { \
|
|
2777
|
+
\
|
|
2778
|
+
dots_symmetric_fn(vectors, n_vectors, depth, stride, (nk_##dot_result_type##_t *)result, result_stride, \
|
|
2779
|
+
row_start, row_count); \
|
|
2780
|
+
\
|
|
2781
|
+
/* Phase 1 — cache row norms in the result diagonal (O(row_count) calls) */ \
|
|
2782
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) { \
|
|
2783
|
+
nk_##input_value_type##_t const *row_vector = (nk_##input_value_type##_t const *)((char const *)vectors + \
|
|
2784
|
+
row_index * stride); \
|
|
2785
|
+
nk_##norm_value_type##_t *row_diag = (nk_##norm_value_type##_t *)((char *)result + \
|
|
2786
|
+
row_index * result_stride); \
|
|
2787
|
+
row_diag[row_index] = compute_norm_fn(row_vector, depth); \
|
|
2788
|
+
} \
|
|
2789
|
+
\
|
|
2790
|
+
/* Phase 2 — column-first post-processing with 256-element norm cache */ \
|
|
2791
|
+
nk_##norm_value_type##_t column_norms[256]; \
|
|
2792
|
+
for (nk_size_t column_chunk_start = 0; column_chunk_start < n_vectors; column_chunk_start += 256) { \
|
|
2793
|
+
nk_size_t column_chunk_end = column_chunk_start + 256 < n_vectors ? column_chunk_start + 256 : n_vectors; \
|
|
2794
|
+
\
|
|
2795
|
+
/* Pre-compute norms for this column chunk — each column visited exactly once */ \
|
|
2796
|
+
for (nk_size_t col = column_chunk_start; col < column_chunk_end; ++col) { \
|
|
2797
|
+
nk_##input_value_type##_t const *column_vector = \
|
|
2798
|
+
(nk_##input_value_type##_t const *)((char const *)vectors + col * stride); \
|
|
2799
|
+
column_norms[col - column_chunk_start] = compute_norm_fn(column_vector, depth); \
|
|
2800
|
+
} \
|
|
2801
|
+
\
|
|
2802
|
+
/* Sweep assigned rows against this column chunk */ \
|
|
2803
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) { \
|
|
2804
|
+
nk_size_t j_start = row_index + 1 > column_chunk_start ? row_index + 1 : column_chunk_start; \
|
|
2805
|
+
if (j_start >= column_chunk_end) continue; \
|
|
2806
|
+
char *row_ptr = (char *)result + row_index * result_stride; \
|
|
2807
|
+
nk_##norm_value_type##_t sumsq_i = ((nk_##norm_value_type##_t *)row_ptr)[row_index]; \
|
|
2808
|
+
nk_##dot_result_type##_t *r_dots = (nk_##dot_result_type##_t *)row_ptr; \
|
|
2809
|
+
nk_##final_result_type##_t *r_out = (nk_##final_result_type##_t *)row_ptr; \
|
|
2810
|
+
\
|
|
2811
|
+
/* 4-wide vectorized loop */ \
|
|
2812
|
+
nk_size_t j = j_start; \
|
|
2813
|
+
for (; j + 4 <= column_chunk_end; j += 4) { \
|
|
2814
|
+
vec_type target_norms_vec; \
|
|
2815
|
+
load_fn(&column_norms[j - column_chunk_start], &target_norms_vec); \
|
|
2816
|
+
vec_type dots_vec, results_vec; \
|
|
2817
|
+
load_fn(r_dots + j, &dots_vec); \
|
|
2818
|
+
from_dot_fn(dots_vec, sumsq_i, target_norms_vec, &results_vec); \
|
|
2819
|
+
store_fn(&results_vec, r_out + j); \
|
|
2820
|
+
} \
|
|
2821
|
+
/* Remainder */ \
|
|
2822
|
+
if (j < column_chunk_end) { \
|
|
2823
|
+
vec_type dots_vec = {0}, norms_vec = {0}, results_vec; \
|
|
2824
|
+
partial_load_fn(r_dots + j, &dots_vec, column_chunk_end - j); \
|
|
2825
|
+
partial_load_fn(&column_norms[j - column_chunk_start], &norms_vec, column_chunk_end - j); \
|
|
2826
|
+
from_dot_fn(dots_vec, sumsq_i, norms_vec, &results_vec); \
|
|
2827
|
+
partial_store_fn(&results_vec, r_out + j, column_chunk_end - j); \
|
|
2828
|
+
} \
|
|
2829
|
+
} \
|
|
2830
|
+
} \
|
|
2831
|
+
\
|
|
2832
|
+
/* Phase 3 — zero diagonals */ \
|
|
2833
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) { \
|
|
2834
|
+
nk_##final_result_type##_t *r_out = (nk_##final_result_type##_t *)((char *)result + \
|
|
2835
|
+
row_index * result_stride); \
|
|
2836
|
+
r_out[row_index] = 0; \
|
|
2837
|
+
} \
|
|
2838
|
+
}
|
|
2839
|
+
|
|
2840
|
+
#if defined(__cplusplus)
|
|
2841
|
+
} // extern "C"
|
|
2842
|
+
#endif
|
|
2843
|
+
|
|
2844
|
+
#endif // NK_DOTS_SERIAL_H
|