numkong 7.0.0 → 7.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +197 -124
- package/binding.gyp +34 -484
- package/c/dispatch_bf16.c +59 -1
- package/c/dispatch_e2m3.c +41 -8
- package/c/dispatch_e3m2.c +49 -8
- package/c/dispatch_e4m3.c +51 -9
- package/c/dispatch_e5m2.c +45 -1
- package/c/dispatch_f16.c +79 -26
- package/c/dispatch_f16c.c +5 -5
- package/c/dispatch_f32.c +56 -0
- package/c/dispatch_f64.c +52 -0
- package/c/dispatch_i4.c +3 -0
- package/c/dispatch_i8.c +62 -3
- package/c/dispatch_other.c +18 -0
- package/c/dispatch_u1.c +54 -9
- package/c/dispatch_u4.c +3 -0
- package/c/dispatch_u8.c +64 -3
- package/c/numkong.c +3 -0
- package/include/README.md +79 -9
- package/include/numkong/attention/sapphireamx.h +278 -276
- package/include/numkong/attention/sme.h +983 -977
- package/include/numkong/attention.h +1 -1
- package/include/numkong/capabilities.h +289 -94
- package/include/numkong/cast/README.md +40 -40
- package/include/numkong/cast/diamond.h +64 -0
- package/include/numkong/cast/haswell.h +42 -194
- package/include/numkong/cast/icelake.h +42 -37
- package/include/numkong/cast/loongsonasx.h +252 -0
- package/include/numkong/cast/neon.h +216 -249
- package/include/numkong/cast/powervsx.h +449 -0
- package/include/numkong/cast/rvv.h +223 -274
- package/include/numkong/cast/sapphire.h +18 -18
- package/include/numkong/cast/serial.h +1018 -944
- package/include/numkong/cast/skylake.h +82 -23
- package/include/numkong/cast/v128relaxed.h +462 -105
- package/include/numkong/cast.h +24 -0
- package/include/numkong/cast.hpp +44 -0
- package/include/numkong/curved/README.md +17 -17
- package/include/numkong/curved/neon.h +131 -7
- package/include/numkong/curved/neonbfdot.h +6 -7
- package/include/numkong/curved/rvv.h +26 -26
- package/include/numkong/curved/smef64.h +186 -182
- package/include/numkong/curved.h +14 -18
- package/include/numkong/dot/README.md +154 -137
- package/include/numkong/dot/alder.h +43 -43
- package/include/numkong/dot/diamond.h +158 -0
- package/include/numkong/dot/genoa.h +4 -30
- package/include/numkong/dot/haswell.h +215 -180
- package/include/numkong/dot/icelake.h +190 -76
- package/include/numkong/dot/loongsonasx.h +671 -0
- package/include/numkong/dot/neon.h +124 -73
- package/include/numkong/dot/neonbfdot.h +11 -12
- package/include/numkong/dot/neonfhm.h +44 -46
- package/include/numkong/dot/neonfp8.h +323 -0
- package/include/numkong/dot/neonsdot.h +190 -76
- package/include/numkong/dot/powervsx.h +752 -0
- package/include/numkong/dot/rvv.h +92 -84
- package/include/numkong/dot/rvvbf16.h +12 -12
- package/include/numkong/dot/rvvhalf.h +12 -12
- package/include/numkong/dot/sapphire.h +4 -4
- package/include/numkong/dot/serial.h +66 -30
- package/include/numkong/dot/sierra.h +31 -31
- package/include/numkong/dot/skylake.h +142 -110
- package/include/numkong/dot/sve.h +217 -177
- package/include/numkong/dot/svebfdot.h +10 -10
- package/include/numkong/dot/svehalf.h +85 -41
- package/include/numkong/dot/svesdot.h +89 -0
- package/include/numkong/dot/v128relaxed.h +124 -89
- package/include/numkong/dot.h +114 -48
- package/include/numkong/dots/README.md +203 -203
- package/include/numkong/dots/alder.h +12 -9
- package/include/numkong/dots/diamond.h +86 -0
- package/include/numkong/dots/genoa.h +10 -4
- package/include/numkong/dots/haswell.h +63 -48
- package/include/numkong/dots/icelake.h +27 -18
- package/include/numkong/dots/loongsonasx.h +176 -0
- package/include/numkong/dots/neon.h +14 -11
- package/include/numkong/dots/neonbfdot.h +4 -3
- package/include/numkong/dots/neonfhm.h +11 -9
- package/include/numkong/dots/neonfp8.h +99 -0
- package/include/numkong/dots/neonsdot.h +48 -12
- package/include/numkong/dots/powervsx.h +194 -0
- package/include/numkong/dots/rvv.h +451 -344
- package/include/numkong/dots/sapphireamx.h +1028 -984
- package/include/numkong/dots/serial.h +213 -197
- package/include/numkong/dots/sierra.h +10 -7
- package/include/numkong/dots/skylake.h +47 -36
- package/include/numkong/dots/sme.h +2001 -2364
- package/include/numkong/dots/smebi32.h +175 -162
- package/include/numkong/dots/smef64.h +328 -323
- package/include/numkong/dots/v128relaxed.h +64 -41
- package/include/numkong/dots.h +573 -293
- package/include/numkong/dots.hpp +45 -43
- package/include/numkong/each/README.md +133 -137
- package/include/numkong/each/haswell.h +6 -6
- package/include/numkong/each/icelake.h +7 -7
- package/include/numkong/each/neon.h +76 -42
- package/include/numkong/each/neonbfdot.h +11 -12
- package/include/numkong/each/neonhalf.h +24 -116
- package/include/numkong/each/rvv.h +28 -28
- package/include/numkong/each/sapphire.h +27 -161
- package/include/numkong/each/serial.h +6 -6
- package/include/numkong/each/skylake.h +7 -7
- package/include/numkong/each/v128relaxed.h +562 -0
- package/include/numkong/each.h +148 -62
- package/include/numkong/each.hpp +2 -2
- package/include/numkong/geospatial/README.md +18 -18
- package/include/numkong/geospatial/haswell.h +365 -325
- package/include/numkong/geospatial/neon.h +350 -306
- package/include/numkong/geospatial/rvv.h +4 -4
- package/include/numkong/geospatial/skylake.h +376 -340
- package/include/numkong/geospatial/v128relaxed.h +366 -327
- package/include/numkong/geospatial.h +17 -17
- package/include/numkong/matrix.hpp +4 -4
- package/include/numkong/maxsim/README.md +14 -14
- package/include/numkong/maxsim/alder.h +6 -6
- package/include/numkong/maxsim/genoa.h +4 -4
- package/include/numkong/maxsim/haswell.h +6 -6
- package/include/numkong/maxsim/icelake.h +18 -18
- package/include/numkong/maxsim/neonsdot.h +21 -21
- package/include/numkong/maxsim/sapphireamx.h +14 -14
- package/include/numkong/maxsim/serial.h +6 -6
- package/include/numkong/maxsim/sme.h +221 -196
- package/include/numkong/maxsim/v128relaxed.h +6 -6
- package/include/numkong/mesh/README.md +62 -56
- package/include/numkong/mesh/haswell.h +339 -464
- package/include/numkong/mesh/neon.h +1100 -519
- package/include/numkong/mesh/neonbfdot.h +36 -68
- package/include/numkong/mesh/rvv.h +530 -435
- package/include/numkong/mesh/serial.h +75 -91
- package/include/numkong/mesh/skylake.h +1627 -302
- package/include/numkong/mesh/v128relaxed.h +443 -330
- package/include/numkong/mesh.h +63 -49
- package/include/numkong/mesh.hpp +4 -4
- package/include/numkong/numkong.h +3 -3
- package/include/numkong/numkong.hpp +1 -0
- package/include/numkong/probability/README.md +23 -19
- package/include/numkong/probability/neon.h +82 -52
- package/include/numkong/probability/rvv.h +28 -23
- package/include/numkong/probability/serial.h +51 -39
- package/include/numkong/probability.h +20 -23
- package/include/numkong/random.h +1 -1
- package/include/numkong/reduce/README.md +143 -138
- package/include/numkong/reduce/alder.h +81 -77
- package/include/numkong/reduce/haswell.h +222 -220
- package/include/numkong/reduce/neon.h +629 -519
- package/include/numkong/reduce/neonbfdot.h +7 -218
- package/include/numkong/reduce/neonfhm.h +9 -381
- package/include/numkong/reduce/neonsdot.h +9 -9
- package/include/numkong/reduce/rvv.h +928 -802
- package/include/numkong/reduce/serial.h +23 -27
- package/include/numkong/reduce/sierra.h +20 -20
- package/include/numkong/reduce/skylake.h +326 -324
- package/include/numkong/reduce/v128relaxed.h +52 -52
- package/include/numkong/reduce.h +4 -23
- package/include/numkong/reduce.hpp +156 -11
- package/include/numkong/scalar/README.md +6 -6
- package/include/numkong/scalar/haswell.h +26 -17
- package/include/numkong/scalar/loongsonasx.h +74 -0
- package/include/numkong/scalar/neon.h +9 -9
- package/include/numkong/scalar/powervsx.h +96 -0
- package/include/numkong/scalar/rvv.h +2 -2
- package/include/numkong/scalar/sapphire.h +21 -10
- package/include/numkong/scalar/serial.h +21 -21
- package/include/numkong/scalar.h +13 -0
- package/include/numkong/set/README.md +28 -28
- package/include/numkong/set/haswell.h +12 -12
- package/include/numkong/set/icelake.h +14 -14
- package/include/numkong/set/loongsonasx.h +181 -0
- package/include/numkong/set/neon.h +17 -18
- package/include/numkong/set/powervsx.h +326 -0
- package/include/numkong/set/rvv.h +4 -4
- package/include/numkong/set/serial.h +6 -6
- package/include/numkong/set/sve.h +60 -59
- package/include/numkong/set/v128relaxed.h +6 -6
- package/include/numkong/set.h +21 -7
- package/include/numkong/sets/README.md +26 -26
- package/include/numkong/sets/loongsonasx.h +52 -0
- package/include/numkong/sets/powervsx.h +65 -0
- package/include/numkong/sets/smebi32.h +395 -364
- package/include/numkong/sets.h +83 -40
- package/include/numkong/sparse/README.md +4 -4
- package/include/numkong/sparse/icelake.h +101 -101
- package/include/numkong/sparse/serial.h +1 -1
- package/include/numkong/sparse/sve2.h +137 -141
- package/include/numkong/sparse/turin.h +12 -12
- package/include/numkong/sparse.h +10 -10
- package/include/numkong/spatial/README.md +230 -226
- package/include/numkong/spatial/alder.h +113 -116
- package/include/numkong/spatial/diamond.h +240 -0
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +74 -55
- package/include/numkong/spatial/icelake.h +539 -58
- package/include/numkong/spatial/loongsonasx.h +483 -0
- package/include/numkong/spatial/neon.h +125 -52
- package/include/numkong/spatial/neonbfdot.h +8 -9
- package/include/numkong/spatial/neonfp8.h +258 -0
- package/include/numkong/spatial/neonsdot.h +180 -12
- package/include/numkong/spatial/powervsx.h +738 -0
- package/include/numkong/spatial/rvv.h +146 -139
- package/include/numkong/spatial/rvvbf16.h +17 -12
- package/include/numkong/spatial/rvvhalf.h +13 -10
- package/include/numkong/spatial/serial.h +13 -12
- package/include/numkong/spatial/sierra.h +232 -39
- package/include/numkong/spatial/skylake.h +73 -74
- package/include/numkong/spatial/sve.h +93 -72
- package/include/numkong/spatial/svebfdot.h +29 -29
- package/include/numkong/spatial/svehalf.h +52 -26
- package/include/numkong/spatial/svesdot.h +142 -0
- package/include/numkong/spatial/v128relaxed.h +293 -41
- package/include/numkong/spatial.h +338 -82
- package/include/numkong/spatials/README.md +194 -194
- package/include/numkong/spatials/diamond.h +82 -0
- package/include/numkong/spatials/haswell.h +2 -2
- package/include/numkong/spatials/loongsonasx.h +153 -0
- package/include/numkong/spatials/neonfp8.h +111 -0
- package/include/numkong/spatials/neonsdot.h +34 -0
- package/include/numkong/spatials/powervsx.h +153 -0
- package/include/numkong/spatials/rvv.h +259 -243
- package/include/numkong/spatials/sapphireamx.h +173 -173
- package/include/numkong/spatials/serial.h +2 -2
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +590 -605
- package/include/numkong/spatials/smef64.h +139 -130
- package/include/numkong/spatials/v128relaxed.h +2 -2
- package/include/numkong/spatials.h +820 -500
- package/include/numkong/spatials.hpp +49 -48
- package/include/numkong/tensor.hpp +406 -17
- package/include/numkong/trigonometry/README.md +19 -19
- package/include/numkong/trigonometry/haswell.h +402 -401
- package/include/numkong/trigonometry/neon.h +386 -387
- package/include/numkong/trigonometry/rvv.h +52 -51
- package/include/numkong/trigonometry/serial.h +13 -13
- package/include/numkong/trigonometry/skylake.h +373 -369
- package/include/numkong/trigonometry/v128relaxed.h +375 -374
- package/include/numkong/trigonometry.h +13 -13
- package/include/numkong/trigonometry.hpp +2 -2
- package/include/numkong/types.h +287 -49
- package/include/numkong/types.hpp +436 -12
- package/include/numkong/vector.hpp +82 -14
- package/javascript/dist/cjs/numkong-wasm.js +6 -12
- package/javascript/dist/cjs/numkong.d.ts +7 -1
- package/javascript/dist/cjs/numkong.js +37 -11
- package/javascript/dist/cjs/types.d.ts +9 -0
- package/javascript/dist/cjs/types.js +96 -0
- package/javascript/dist/esm/numkong-browser.d.ts +14 -0
- package/javascript/dist/esm/numkong-browser.js +23 -0
- package/javascript/dist/esm/numkong-wasm.js +6 -12
- package/javascript/dist/esm/numkong.d.ts +7 -1
- package/javascript/dist/esm/numkong.js +37 -11
- package/javascript/dist/esm/types.d.ts +9 -0
- package/javascript/dist/esm/types.js +96 -0
- package/javascript/node-gyp-build.d.ts +4 -1
- package/javascript/numkong-browser.ts +40 -0
- package/javascript/numkong-wasm.ts +7 -13
- package/javascript/numkong.c +5 -26
- package/javascript/numkong.ts +36 -11
- package/javascript/tsconfig-base.json +1 -0
- package/javascript/tsconfig-cjs.json +6 -1
- package/javascript/types.ts +110 -0
- package/numkong.gypi +101 -0
- package/package.json +34 -13
- package/probes/arm_neon.c +8 -0
- package/probes/arm_neon_bfdot.c +9 -0
- package/probes/arm_neon_fhm.c +9 -0
- package/probes/arm_neon_half.c +8 -0
- package/probes/arm_neon_sdot.c +9 -0
- package/probes/arm_neonfp8.c +9 -0
- package/probes/arm_sme.c +16 -0
- package/probes/arm_sme2.c +16 -0
- package/probes/arm_sme2p1.c +16 -0
- package/probes/arm_sme_bf16.c +16 -0
- package/probes/arm_sme_bi32.c +16 -0
- package/probes/arm_sme_f64.c +16 -0
- package/probes/arm_sme_fa64.c +14 -0
- package/probes/arm_sme_half.c +16 -0
- package/probes/arm_sme_lut2.c +15 -0
- package/probes/arm_sve.c +18 -0
- package/probes/arm_sve2.c +20 -0
- package/probes/arm_sve2p1.c +18 -0
- package/probes/arm_sve_bfdot.c +20 -0
- package/probes/arm_sve_half.c +18 -0
- package/probes/arm_sve_sdot.c +21 -0
- package/probes/loongarch_lasx.c +12 -0
- package/probes/power_vsx.c +12 -0
- package/probes/probe.js +127 -0
- package/probes/riscv_rvv.c +14 -0
- package/probes/riscv_rvv_bb.c +15 -0
- package/probes/riscv_rvv_bf16.c +17 -0
- package/probes/riscv_rvv_half.c +14 -0
- package/probes/wasm_v128relaxed.c +11 -0
- package/probes/x86_alder.c +17 -0
- package/probes/x86_diamond.c +17 -0
- package/probes/x86_genoa.c +17 -0
- package/probes/x86_graniteamx.c +19 -0
- package/probes/x86_haswell.c +11 -0
- package/probes/x86_icelake.c +17 -0
- package/probes/x86_sapphire.c +16 -0
- package/probes/x86_sapphireamx.c +18 -0
- package/probes/x86_sierra.c +17 -0
- package/probes/x86_skylake.c +15 -0
- package/probes/x86_turin.c +17 -0
- package/wasm/numkong-emscripten.js +2 -0
- package/wasm/numkong.d.ts +14 -0
- package/wasm/numkong.js +1124 -0
- package/wasm/numkong.wasm +0 -0
- package/include/numkong/curved/neonhalf.h +0 -212
- package/include/numkong/dot/neonhalf.h +0 -198
- package/include/numkong/dots/neonhalf.h +0 -57
- package/include/numkong/mesh/neonhalf.h +0 -616
- package/include/numkong/reduce/neonhalf.h +0 -157
- package/include/numkong/spatial/neonhalf.h +0 -118
- package/include/numkong/spatial/sapphire.h +0 -343
- package/include/numkong/spatials/neonhalf.h +0 -58
- package/javascript/README.md +0 -246
|
@@ -79,6 +79,13 @@
|
|
|
79
79
|
#include "numkong/spatial/serial.h" // `nk_f32_sqrt_serial`
|
|
80
80
|
#include "numkong/reduce.h" // `nk_reduce_moments_*`
|
|
81
81
|
|
|
82
|
+
/* GCC's -Wstringop-overflow produces false positives on the padded accumulator arrays
|
|
83
|
+
* in nk_define_cross_symmetric_ macro expansions (accumulators[4][7] with runtime indexing). */
|
|
84
|
+
#if defined(__GNUC__) && !defined(__clang__)
|
|
85
|
+
#pragma GCC diagnostic push
|
|
86
|
+
#pragma GCC diagnostic ignored "-Wstringop-overflow"
|
|
87
|
+
#endif
|
|
88
|
+
|
|
82
89
|
#if defined(__cplusplus)
|
|
83
90
|
extern "C" {
|
|
84
91
|
#endif
|
|
@@ -264,82 +271,59 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
264
271
|
}
|
|
265
272
|
|
|
266
273
|
/**
|
|
267
|
-
* @brief Generates
|
|
268
|
-
*
|
|
269
|
-
* Packing serves two performance-critical purposes:
|
|
270
|
-
*
|
|
271
|
-
* 1. Type conversion (input_type → intermediate_type): For mixed-precision GEMM, convert B values
|
|
272
|
-
* once during packing rather than repeatedly in tight inner loops. Example: F16 → F32 conversion
|
|
273
|
-
* happens once per value instead of once per (row of A × value of B) access. This amortizes
|
|
274
|
-
* conversion cost across all rows of A.
|
|
274
|
+
* @brief Generates pack function using SIMD load/store helpers.
|
|
275
275
|
*
|
|
276
|
-
*
|
|
277
|
-
*
|
|
278
|
-
* causing conflict misses. Padding to 8200 → stride = 32,800 bytes (non-power-of-2) distributes
|
|
279
|
-
* accesses across more cache sets.
|
|
276
|
+
* Packs the B matrix into padded row-major layout with optional type conversion,
|
|
277
|
+
* using vectorized load/store for the bulk copy and a small scalar tail for padding.
|
|
280
278
|
*
|
|
281
|
-
*
|
|
282
|
-
*
|
|
283
|
-
*
|
|
284
|
-
*
|
|
285
|
-
*
|
|
286
|
-
*
|
|
287
|
-
*
|
|
288
|
-
* @param api_name Operation name (hammings, dots)
|
|
289
|
-
* @param input_type_name Original type's name of B matrix values (i4, f16, bf16, e4m3, e5m2, f32, etc.)
|
|
290
|
-
* @param isa_suffix Platform Instruct Set Architecture suffix (serial, haswell, icelake, etc.)
|
|
291
|
-
* @param input_type Original type of B matrix values (i4x2, f16, bf16, e4m3, e5m2, f32, etc.)
|
|
292
|
-
* @param intermediate_type Internal storage type in packed buffer (often bf16 or f32 for mixed precision)
|
|
293
|
-
* @param convert_value_fn Element conversion function: void fn(input_type const*, intermediate_type*)
|
|
294
|
-
* @param norm_value_type Type of per-column norm values (f32, f64, u32) appended after packed data
|
|
295
|
-
* @param compute_norm_fn Function: norm_value_type fn(input_value_type const*, nk_size_t count)
|
|
296
|
-
* @param depth_simd_dimensions SIMD vector width in values for depth padding alignment
|
|
297
|
-
* @param dimensions_per_value Number of logical dimensions in a single value of input_type.
|
|
279
|
+
* @param vec_type SIMD vector type (nk_b512_vec_t, nk_b256_vec_t, nk_b128_vec_t)
|
|
280
|
+
* @param load_fn Full load: void fn(void const*, vec_type*)
|
|
281
|
+
* @param partial_load_fn Masked/partial load: void fn(void const*, vec_type*, nk_size_t)
|
|
282
|
+
* @param store_fn Full store: void fn(vec_type const*, void*)
|
|
283
|
+
* @param partial_store_fn Masked/partial store: void fn(vec_type const*, void*, nk_size_t)
|
|
284
|
+
* @param simd_width Elements per SIMD load/store operation
|
|
298
285
|
*/
|
|
299
|
-
#define nk_define_cross_pack_(api_name, input_type_name, isa_suffix, input_value_type, packed_value_type,
|
|
300
|
-
|
|
301
|
-
dimensions_per_value)
|
|
286
|
+
#define nk_define_cross_pack_(api_name, input_type_name, isa_suffix, input_value_type, packed_value_type, vec_type, \
|
|
287
|
+
load_fn, partial_load_fn, store_fn, partial_store_fn, simd_width, norm_value_type, \
|
|
288
|
+
compute_norm_fn, depth_simd_dimensions, dimensions_per_value) \
|
|
302
289
|
NK_PUBLIC void nk_##api_name##_pack_##input_type_name##_##isa_suffix( \
|
|
303
290
|
nk_##input_value_type##_t const *b, nk_size_t column_count, nk_size_t depth, nk_size_t b_stride_in_bytes, \
|
|
304
291
|
void *b_packed) { \
|
|
305
|
-
/* Use identical padding calculation as pack_size */ \
|
|
306
292
|
nk_size_t depth_dimensions_padded = nk_size_round_up_to_multiple_(depth, depth_simd_dimensions); \
|
|
307
293
|
nk_size_t depth_values_padded = nk_size_divide_round_up_(depth_dimensions_padded, dimensions_per_value); \
|
|
308
|
-
\
|
|
309
|
-
/* Power-of-2 breaking (same as pack_size) */ \
|
|
310
294
|
nk_size_t const stride_bytes = depth_values_padded * sizeof(nk_##packed_value_type##_t); \
|
|
311
|
-
if ((stride_bytes & (stride_bytes - 1)) == 0 && stride_bytes > 0)
|
|
295
|
+
if ((stride_bytes & (stride_bytes - 1)) == 0 && stride_bytes > 0) \
|
|
312
296
|
depth_values_padded += nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
313
|
-
} \
|
|
314
|
-
\
|
|
315
|
-
/* Calculate input depth in values */ \
|
|
316
297
|
nk_size_t const depth_in_values = nk_size_divide_round_up_(depth, dimensions_per_value); \
|
|
317
298
|
\
|
|
318
|
-
/* Store dimensions in header */ \
|
|
319
299
|
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed; \
|
|
320
300
|
header->column_count = (nk_u32_t)column_count; \
|
|
321
|
-
header->depth_dimensions = (nk_u32_t)depth;
|
|
322
|
-
header->depth_padded_values = (nk_u32_t)depth_values_padded;
|
|
301
|
+
header->depth_dimensions = (nk_u32_t)depth; \
|
|
302
|
+
header->depth_padded_values = (nk_u32_t)depth_values_padded; \
|
|
323
303
|
\
|
|
324
304
|
nk_##packed_value_type##_t *packed = (nk_##packed_value_type##_t *)((char *)b_packed + \
|
|
325
305
|
sizeof(nk_cross_packed_buffer_header_t)); \
|
|
306
|
+
nk_size_t const full_chunks = depth_in_values / (simd_width); \
|
|
307
|
+
nk_size_t const remainder = depth_in_values % (simd_width); \
|
|
326
308
|
\
|
|
327
|
-
/* Zero entire buffer for depth padding */ \
|
|
328
|
-
nk_size_t const total_values = column_count * depth_values_padded; \
|
|
329
|
-
for (nk_size_t i = 0; i < total_values; ++i) packed[i] = 0; \
|
|
330
|
-
\
|
|
331
|
-
/* Copy/convert B[column_count, depth] to packed[column_count, depth_padded] - simple column-major */ \
|
|
332
309
|
for (nk_size_t column_index = 0; column_index < column_count; ++column_index) { \
|
|
333
|
-
nk_##packed_value_type##_t *destination_row = packed + column_index * depth_values_padded; \
|
|
334
310
|
nk_##input_value_type##_t const *source_row = \
|
|
335
311
|
(nk_##input_value_type##_t const *)((char const *)b + column_index * b_stride_in_bytes); \
|
|
336
|
-
|
|
337
|
-
|
|
312
|
+
nk_##packed_value_type##_t *destination_row = packed + column_index * depth_values_padded; \
|
|
313
|
+
for (nk_size_t chunk = 0; chunk < full_chunks; ++chunk) { \
|
|
314
|
+
vec_type vec; \
|
|
315
|
+
load_fn(source_row + chunk * (simd_width), &vec); \
|
|
316
|
+
store_fn(&vec, destination_row + chunk * (simd_width)); \
|
|
338
317
|
} \
|
|
339
|
-
|
|
318
|
+
if (remainder > 0) { \
|
|
319
|
+
vec_type vec; \
|
|
320
|
+
partial_load_fn(source_row + full_chunks * (simd_width), &vec, remainder); \
|
|
321
|
+
partial_store_fn(&vec, destination_row + full_chunks * (simd_width), remainder); \
|
|
322
|
+
} \
|
|
323
|
+
for (nk_size_t pad = depth_in_values; pad < depth_values_padded; ++pad) destination_row[pad] = 0; \
|
|
340
324
|
} \
|
|
341
325
|
\
|
|
342
|
-
|
|
326
|
+
nk_size_t const total_values = column_count * depth_values_padded; \
|
|
343
327
|
nk_##norm_value_type##_t *norms = (nk_##norm_value_type##_t *)(packed + total_values); \
|
|
344
328
|
for (nk_size_t column_index = 0; column_index < column_count; ++column_index) { \
|
|
345
329
|
nk_##input_value_type##_t const *source_row = \
|
|
@@ -372,42 +356,51 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
372
356
|
}
|
|
373
357
|
|
|
374
358
|
/**
|
|
375
|
-
* @brief
|
|
376
|
-
*
|
|
377
|
-
* Like nk_define_cross_pack_ but uses compute_moments_fn(data, count, &sum, &norm) to compute
|
|
378
|
-
* both sum and norm in a single pass, storing both after the packed data.
|
|
379
|
-
* Layout: [ Header ] [ Packed data ] [ Norms ] [ Column sums ]
|
|
359
|
+
* @brief Like nk_define_cross_pack_ but stores both per-column norms AND column sums.
|
|
360
|
+
* Layout: [ Header 64B ] [ Packed data ] [ Norms (norm_type) ] [ Column sums (sum_type) ]
|
|
380
361
|
*/
|
|
381
362
|
#define nk_define_cross_compensated_pack_(api_name, input_type_name, isa_suffix, input_value_type, packed_value_type, \
|
|
382
|
-
|
|
383
|
-
depth_simd_dimensions,
|
|
363
|
+
vec_type, load_fn, partial_load_fn, store_fn, partial_store_fn, simd_width, \
|
|
364
|
+
sum_value_type, norm_value_type, compute_moments_fn, depth_simd_dimensions, \
|
|
365
|
+
dimensions_per_value) \
|
|
384
366
|
NK_PUBLIC void nk_##api_name##_pack_##input_type_name##_##isa_suffix( \
|
|
385
367
|
nk_##input_value_type##_t const *b, nk_size_t column_count, nk_size_t depth, nk_size_t b_stride_in_bytes, \
|
|
386
368
|
void *b_packed) { \
|
|
387
369
|
nk_size_t depth_dimensions_padded = nk_size_round_up_to_multiple_(depth, depth_simd_dimensions); \
|
|
388
370
|
nk_size_t depth_values_padded = nk_size_divide_round_up_(depth_dimensions_padded, dimensions_per_value); \
|
|
389
371
|
nk_size_t const stride_bytes = depth_values_padded * sizeof(nk_##packed_value_type##_t); \
|
|
390
|
-
if ((stride_bytes & (stride_bytes - 1)) == 0 && stride_bytes > 0)
|
|
372
|
+
if ((stride_bytes & (stride_bytes - 1)) == 0 && stride_bytes > 0) \
|
|
391
373
|
depth_values_padded += nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
392
|
-
} \
|
|
393
374
|
nk_size_t const depth_in_values = nk_size_divide_round_up_(depth, dimensions_per_value); \
|
|
375
|
+
\
|
|
394
376
|
nk_cross_packed_buffer_header_t *header = (nk_cross_packed_buffer_header_t *)b_packed; \
|
|
395
377
|
header->column_count = (nk_u32_t)column_count; \
|
|
396
378
|
header->depth_dimensions = (nk_u32_t)depth; \
|
|
397
379
|
header->depth_padded_values = (nk_u32_t)depth_values_padded; \
|
|
380
|
+
\
|
|
398
381
|
nk_##packed_value_type##_t *packed = (nk_##packed_value_type##_t *)((char *)b_packed + \
|
|
399
382
|
sizeof(nk_cross_packed_buffer_header_t)); \
|
|
400
|
-
nk_size_t const
|
|
401
|
-
|
|
383
|
+
nk_size_t const full_chunks = depth_in_values / (simd_width); \
|
|
384
|
+
nk_size_t const remainder = depth_in_values % (simd_width); \
|
|
385
|
+
\
|
|
402
386
|
for (nk_size_t column_index = 0; column_index < column_count; ++column_index) { \
|
|
403
|
-
nk_##packed_value_type##_t *destination_row = packed + column_index * depth_values_padded; \
|
|
404
387
|
nk_##input_value_type##_t const *source_row = \
|
|
405
388
|
(nk_##input_value_type##_t const *)((char const *)b + column_index * b_stride_in_bytes); \
|
|
406
|
-
|
|
407
|
-
|
|
389
|
+
nk_##packed_value_type##_t *destination_row = packed + column_index * depth_values_padded; \
|
|
390
|
+
for (nk_size_t chunk = 0; chunk < full_chunks; ++chunk) { \
|
|
391
|
+
vec_type vec; \
|
|
392
|
+
load_fn(source_row + chunk * (simd_width), &vec); \
|
|
393
|
+
store_fn(&vec, destination_row + chunk * (simd_width)); \
|
|
394
|
+
} \
|
|
395
|
+
if (remainder > 0) { \
|
|
396
|
+
vec_type vec; \
|
|
397
|
+
partial_load_fn(source_row + full_chunks * (simd_width), &vec, remainder); \
|
|
398
|
+
partial_store_fn(&vec, destination_row + full_chunks * (simd_width), remainder); \
|
|
408
399
|
} \
|
|
400
|
+
for (nk_size_t pad = depth_in_values; pad < depth_values_padded; ++pad) destination_row[pad] = 0; \
|
|
409
401
|
} \
|
|
410
|
-
|
|
402
|
+
\
|
|
403
|
+
nk_size_t const total_values = column_count * depth_values_padded; \
|
|
411
404
|
nk_##norm_value_type##_t *norms = (nk_##norm_value_type##_t *)(packed + total_values); \
|
|
412
405
|
nk_##sum_value_type##_t *col_sums = (nk_##sum_value_type##_t *)(norms + column_count); \
|
|
413
406
|
for (nk_size_t column_index = 0; column_index < column_count; ++column_index) { \
|
|
@@ -1246,9 +1239,9 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
1246
1239
|
nk_##packed_value_type##_t const *bp5 = packed_data + (tc + 5) * depth_padded; \
|
|
1247
1240
|
nk_##packed_value_type##_t const *bp6 = packed_data + (tc + 6) * depth_padded; \
|
|
1248
1241
|
nk_##packed_value_type##_t const *bp7 = packed_data + (tc + 7) * depth_padded; \
|
|
1249
|
-
result_vec_type
|
|
1250
|
-
load_sum_fn(b_sums + tc, &
|
|
1251
|
-
load_sum_fn(b_sums + tc + 4, &
|
|
1242
|
+
result_vec_type b_sum_low, b_sum_high; \
|
|
1243
|
+
load_sum_fn(b_sums + tc, &b_sum_low); \
|
|
1244
|
+
load_sum_fn(b_sums + tc + 4, &b_sum_high); \
|
|
1252
1245
|
for (nk_size_t ri = rb2; ri < re2; ++ri) { \
|
|
1253
1246
|
state_type s0, s1, s2, s3, s4, s5, s6, s7; \
|
|
1254
1247
|
init_accumulator_fn(&s0), init_accumulator_fn(&s1), init_accumulator_fn(&s2), \
|
|
@@ -1277,9 +1270,9 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
1277
1270
|
result_vec_type rv; \
|
|
1278
1271
|
nk_##result_value_type##_t *c_row = (nk_##result_value_type##_t *)((char *)c_matrix + \
|
|
1279
1272
|
ri * c_stride_in_bytes); \
|
|
1280
|
-
compensated_finalize_fn(&s0, &s1, &s2, &s3, depth, a_sum_val,
|
|
1273
|
+
compensated_finalize_fn(&s0, &s1, &s2, &s3, depth, a_sum_val, b_sum_low, &rv); \
|
|
1281
1274
|
store_fn(&rv, c_row + tc); \
|
|
1282
|
-
compensated_finalize_fn(&s4, &s5, &s6, &s7, depth, a_sum_val,
|
|
1275
|
+
compensated_finalize_fn(&s4, &s5, &s6, &s7, depth, a_sum_val, b_sum_high, &rv); \
|
|
1283
1276
|
store_fn(&rv, c_row + tc + 4); \
|
|
1284
1277
|
} \
|
|
1285
1278
|
} \
|
|
@@ -1893,8 +1886,9 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
1893
1886
|
} \
|
|
1894
1887
|
} \
|
|
1895
1888
|
NK_PUBLIC void nk_##api_name##_symmetric_##input_type_name##_##isa_suffix( \
|
|
1896
|
-
nk_##input_value_type##_t const *vectors, nk_size_t
|
|
1897
|
-
nk_##result_value_type##_t *result, nk_size_t
|
|
1889
|
+
nk_##input_value_type##_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, \
|
|
1890
|
+
nk_##result_value_type##_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, \
|
|
1891
|
+
nk_size_t row_count) { \
|
|
1898
1892
|
nk_size_t const macro_tile_size = 32; \
|
|
1899
1893
|
nk_size_t const row_block_size = 128; /* L2 cache blocking */ \
|
|
1900
1894
|
nk_size_t const column_block_size = 2048; /* L3 cache blocking */ \
|
|
@@ -1904,13 +1898,13 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
1904
1898
|
nk_size_t const remainder_depth = depth_in_values - aligned_depth; \
|
|
1905
1899
|
nk_size_t const remainder_dimensions = depth - depth_dimensions_aligned; \
|
|
1906
1900
|
nk_size_t const depth_step = nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
1907
|
-
nk_size_t const result_stride_values =
|
|
1908
|
-
nk_size_t const row_end = (row_start + row_count <
|
|
1901
|
+
nk_size_t const result_stride_values = result_stride_in_bytes / sizeof(nk_##result_value_type##_t); \
|
|
1902
|
+
nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count; \
|
|
1909
1903
|
\
|
|
1910
1904
|
/* Process upper triangle with L3/L2/L1 blocking (column blocks → row blocks → 32×32 macro-tiles) */ \
|
|
1911
|
-
for (nk_size_t j_block = 0; j_block <
|
|
1912
|
-
nk_size_t j_block_end = (j_block + column_block_size <
|
|
1913
|
-
|
|
1905
|
+
for (nk_size_t j_block = 0; j_block < vectors_count; j_block += column_block_size) { \
|
|
1906
|
+
nk_size_t j_block_end = (j_block + column_block_size < vectors_count) ? j_block + column_block_size \
|
|
1907
|
+
: vectors_count; \
|
|
1914
1908
|
\
|
|
1915
1909
|
for (nk_size_t i_block = row_start; i_block < row_end; i_block += row_block_size) { \
|
|
1916
1910
|
nk_size_t i_block_end = (i_block + row_block_size < row_end) ? i_block + row_block_size : row_end; \
|
|
@@ -1933,7 +1927,7 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
1933
1927
|
nk_##input_value_type##_t const *vec_ptrs_j[32]; \
|
|
1934
1928
|
for (nk_size_t k = 0; k < macro_i_size; k++) \
|
|
1935
1929
|
vec_ptrs_i[k] = (nk_##input_value_type##_t const *)((char const *)vectors + \
|
|
1936
|
-
(i_macro + k) *
|
|
1930
|
+
(i_macro + k) * stride_in_bytes); \
|
|
1937
1931
|
for (nk_size_t k = macro_i_size; k < 32; k++) vec_ptrs_i[k] = vec_ptrs_i[0]; \
|
|
1938
1932
|
\
|
|
1939
1933
|
if (i_macro == j_macro && macro_i_size == macro_j_size) { \
|
|
@@ -1947,7 +1941,7 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
1947
1941
|
/* Off-diagonal macro-tile */ \
|
|
1948
1942
|
for (nk_size_t k = 0; k < macro_j_size; k++) \
|
|
1949
1943
|
vec_ptrs_j[k] = (nk_##input_value_type##_t const *)((char const *)vectors + \
|
|
1950
|
-
(j_macro + k) *
|
|
1944
|
+
(j_macro + k) * stride_in_bytes); \
|
|
1951
1945
|
for (nk_size_t k = macro_j_size; k < 32; k++) vec_ptrs_j[k] = vec_ptrs_j[0]; \
|
|
1952
1946
|
nk_##api_name##_symmetric_offdiagonal_##input_type_name##_##isa_suffix##_( \
|
|
1953
1947
|
vec_ptrs_i, vec_ptrs_j, i_macro, j_macro, macro_i_size, macro_j_size, aligned_depth, \
|
|
@@ -2365,28 +2359,29 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
2365
2359
|
} \
|
|
2366
2360
|
} \
|
|
2367
2361
|
NK_PUBLIC void nk_##api_name##_symmetric_##input_type_name##_##isa_suffix( \
|
|
2368
|
-
nk_##input_value_type##_t const *vectors, nk_size_t
|
|
2369
|
-
nk_##result_value_type##_t *result, nk_size_t
|
|
2362
|
+
nk_##input_value_type##_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, \
|
|
2363
|
+
nk_##result_value_type##_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, \
|
|
2364
|
+
nk_size_t row_count) { \
|
|
2370
2365
|
nk_size_t const macro_tile_size = 32; \
|
|
2371
2366
|
nk_size_t const finalizer_batch_size = 4; \
|
|
2372
2367
|
nk_size_t const row_block_size = 128; /* L2 cache blocking */ \
|
|
2373
2368
|
nk_size_t const column_block_size = 2048; /* L3 cache blocking */ \
|
|
2374
2369
|
\
|
|
2375
2370
|
/* Stride and depth calculations */ \
|
|
2376
|
-
nk_size_t const vectors_stride_values =
|
|
2377
|
-
nk_size_t const result_stride_values =
|
|
2371
|
+
nk_size_t const vectors_stride_values = stride_in_bytes / sizeof(nk_##input_value_type##_t); \
|
|
2372
|
+
nk_size_t const result_stride_values = result_stride_in_bytes / sizeof(nk_##result_value_type##_t); \
|
|
2378
2373
|
nk_size_t const depth_dimensions_aligned = (depth / depth_simd_dimensions) * depth_simd_dimensions; \
|
|
2379
2374
|
nk_size_t const aligned_depth = nk_size_divide_round_up_(depth_dimensions_aligned, dimensions_per_value); \
|
|
2380
2375
|
nk_size_t const depth_in_values = nk_size_divide_round_up_(depth, dimensions_per_value); \
|
|
2381
2376
|
nk_size_t const remainder_depth = depth_in_values - aligned_depth; \
|
|
2382
2377
|
nk_size_t const remainder_dimensions = depth - depth_dimensions_aligned; \
|
|
2383
2378
|
nk_size_t const depth_step_values = nk_size_divide_round_up_(depth_simd_dimensions, dimensions_per_value); \
|
|
2384
|
-
nk_size_t const row_end = (row_start + row_count <
|
|
2379
|
+
nk_size_t const row_end = (row_start + row_count < vectors_count) ? (row_start + row_count) : vectors_count; \
|
|
2385
2380
|
\
|
|
2386
2381
|
/* Process upper triangle with L3/L2/L1 blocking (column blocks → row blocks → 32×32 macro-tiles) */ \
|
|
2387
|
-
for (nk_size_t j_block = 0; j_block <
|
|
2388
|
-
nk_size_t j_block_end = (j_block + column_block_size <
|
|
2389
|
-
|
|
2382
|
+
for (nk_size_t j_block = 0; j_block < vectors_count; j_block += column_block_size) { \
|
|
2383
|
+
nk_size_t j_block_end = (j_block + column_block_size < vectors_count) ? j_block + column_block_size \
|
|
2384
|
+
: vectors_count; \
|
|
2390
2385
|
\
|
|
2391
2386
|
for (nk_size_t i_block = row_start; i_block < row_end; i_block += row_block_size) { \
|
|
2392
2387
|
nk_size_t i_block_end = (i_block + row_block_size < row_end) ? i_block + row_block_size : row_end; \
|
|
@@ -2451,9 +2446,9 @@ NK_INTERNAL nk_i32_t nk_dots_reduce_sum_i4_(nk_i4x2_t const *data, nk_size_t cou
|
|
|
2451
2446
|
/* F64 GEMM: depth_simd_dimensions=2 (2 f64s = 16 bytes) */
|
|
2452
2447
|
nk_define_cross_pack_size_(dots, f64, serial, f64, f64, /*norm_value_type=*/f64, /*depth_simd_dimensions=*/2,
|
|
2453
2448
|
/*dimensions_per_value=*/1)
|
|
2454
|
-
nk_define_cross_pack_(dots, f64, serial, f64, f64,
|
|
2455
|
-
|
|
2456
|
-
/*depth_simd_dimensions=*/2, /*dimensions_per_value=*/1)
|
|
2449
|
+
nk_define_cross_pack_(dots, f64, serial, f64, f64, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b64x2_serial_,
|
|
2450
|
+
nk_store_b128_serial_, nk_partial_store_b64x2_serial_, /*simd_width=*/2, /*norm_value_type=*/f64,
|
|
2451
|
+
nk_dots_reduce_sumsq_f64_, /*depth_simd_dimensions=*/2, /*dimensions_per_value=*/1)
|
|
2457
2452
|
nk_define_cross_symmetric_(dots, f64, serial, f64, f64, nk_b128_vec_t, nk_dot_f64x2_state_serial_t, nk_b256_vec_t,
|
|
2458
2453
|
nk_dot_f64x2_init_serial, nk_load_b128_serial_, nk_partial_load_b64x2_serial_,
|
|
2459
2454
|
nk_dot_f64x2_update_serial, nk_dot_f64x2_finalize_serial, nk_store_b256_serial_,
|
|
@@ -2468,9 +2463,9 @@ nk_define_cross_packed_(dots, f64, serial, f64, f64, f64, nk_b128_vec_t, nk_dot_
|
|
|
2468
2463
|
/* F32 GEMM: depth_simd_dimensions=4 (4 f32s = 16 bytes) */
|
|
2469
2464
|
nk_define_cross_pack_size_(dots, f32, serial, f32, f32, /*norm_value_type=*/f64, /*depth_simd_dimensions=*/4,
|
|
2470
2465
|
/*dimensions_per_value=*/1)
|
|
2471
|
-
nk_define_cross_pack_(dots, f32, serial, f32, f32,
|
|
2472
|
-
|
|
2473
|
-
/*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
|
|
2466
|
+
nk_define_cross_pack_(dots, f32, serial, f32, f32, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b32x4_serial_,
|
|
2467
|
+
nk_store_b128_serial_, nk_partial_store_b32x4_serial_, /*simd_width=*/4, /*norm_value_type=*/f64,
|
|
2468
|
+
nk_dots_reduce_sumsq_f32_, /*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
|
|
2474
2469
|
nk_define_cross_symmetric_(dots, f32, serial, f32, f64, nk_b128_vec_t, nk_dot_f32x4_state_serial_t, nk_b256_vec_t,
|
|
2475
2470
|
nk_dot_f32x4_init_serial, nk_load_b128_serial_, nk_partial_load_b32x4_serial_,
|
|
2476
2471
|
nk_dot_f32x4_update_serial, nk_dot_f32x4_finalize_serial, nk_store_b256_serial_,
|
|
@@ -2482,28 +2477,31 @@ nk_define_cross_packed_(dots, f32, serial, f32, f32, f64, nk_b128_vec_t, nk_dot_
|
|
|
2482
2477
|
nk_dot_f32x4_finalize_serial, nk_store_b256_serial_, nk_partial_store_b64x4_serial_,
|
|
2483
2478
|
/*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
|
|
2484
2479
|
|
|
2485
|
-
/* F16 GEMM:
|
|
2486
|
-
nk_define_cross_pack_size_(dots, f16, serial, f16, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/
|
|
2480
|
+
/* F16 packed GEMM: pre-upcast B to f32 and process 4 logical dimensions per 128-bit step. */
|
|
2481
|
+
nk_define_cross_pack_size_(dots, f16, serial, f16, f32, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/4,
|
|
2487
2482
|
/*dimensions_per_value=*/1)
|
|
2488
|
-
nk_define_cross_pack_(dots, f16, serial, f16,
|
|
2489
|
-
|
|
2490
|
-
/*
|
|
2483
|
+
nk_define_cross_pack_(dots, f16, serial, f16, f32, nk_b128_vec_t, nk_load_f16x4_to_f32x4_serial_,
|
|
2484
|
+
nk_partial_load_f16x4_to_f32x4_serial_, nk_store_b128_serial_, nk_partial_store_b32x4_serial_,
|
|
2485
|
+
/*simd_width=*/4, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_f16_,
|
|
2486
|
+
/*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
|
|
2491
2487
|
nk_define_cross_symmetric_(dots, f16, serial, f16, f32, nk_b128_vec_t, nk_dot_f16x8_state_serial_t, nk_b128_vec_t,
|
|
2492
2488
|
nk_dot_f16x8_init_serial, nk_load_b128_serial_, nk_partial_load_b16x8_serial_,
|
|
2493
2489
|
nk_dot_f16x8_update_serial, nk_dot_f16x8_finalize_serial, nk_store_b128_serial_,
|
|
2494
2490
|
nk_partial_store_b32x4_serial_,
|
|
2495
2491
|
/*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
|
|
2496
|
-
nk_define_cross_packed_(dots, f16, serial, f16,
|
|
2497
|
-
|
|
2498
|
-
|
|
2499
|
-
|
|
2500
|
-
|
|
2492
|
+
nk_define_cross_packed_(dots, f16, serial, f16, f32, f32, nk_b128_vec_t, nk_dot_through_f32x4_state_serial_t,
|
|
2493
|
+
nk_b128_vec_t, nk_dot_through_f32x4_init_serial, nk_load_f16x4_to_f32x4_serial_,
|
|
2494
|
+
nk_partial_load_f16x4_to_f32x4_serial_, nk_load_b128_serial_, nk_partial_load_b32x4_serial_,
|
|
2495
|
+
nk_dot_through_f32x4_update_serial, nk_dot_through_f32x4_finalize_serial, nk_store_b128_serial_,
|
|
2496
|
+
nk_partial_store_b32x4_serial_,
|
|
2497
|
+
/*depth_simd_dimensions=*/4, /*dimensions_per_value=*/1)
|
|
2501
2498
|
|
|
2502
2499
|
/* BF16 GEMM: depth_simd_dimensions=8 (8 bf16s = 16 bytes), F32 accumulator */
|
|
2503
|
-
nk_define_cross_pack_size_(dots, bf16, serial, bf16,
|
|
2500
|
+
nk_define_cross_pack_size_(dots, bf16, serial, bf16, bf16, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/8,
|
|
2504
2501
|
/*dimensions_per_value=*/1)
|
|
2505
|
-
nk_define_cross_pack_(dots, bf16, serial, bf16, bf16,
|
|
2506
|
-
|
|
2502
|
+
nk_define_cross_pack_(dots, bf16, serial, bf16, bf16, nk_b128_vec_t, nk_load_b128_serial_,
|
|
2503
|
+
nk_partial_load_b16x8_serial_, nk_store_b128_serial_, nk_partial_store_b16x8_serial_,
|
|
2504
|
+
/*simd_width=*/8, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_bf16_,
|
|
2507
2505
|
/*depth_simd_dimensions=*/8, /*dimensions_per_value=*/1)
|
|
2508
2506
|
nk_define_cross_symmetric_(dots, bf16, serial, bf16, f32, nk_b128_vec_t, nk_dot_bf16x8_state_serial_t, nk_b128_vec_t,
|
|
2509
2507
|
nk_dot_bf16x8_init_serial, nk_load_b128_serial_, nk_partial_load_b16x8_serial_,
|
|
@@ -2519,8 +2517,10 @@ nk_define_cross_packed_(dots, bf16, serial, bf16, bf16, f32, nk_b128_vec_t, nk_d
|
|
|
2519
2517
|
/* I8 GEMM: depth_simd_dimensions=16 (16 i8s = 16 bytes), I32 accumulator */
|
|
2520
2518
|
nk_define_cross_pack_size_(dots, i8, serial, i8, i8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
|
|
2521
2519
|
/*dimensions_per_value=*/1)
|
|
2522
|
-
nk_define_cross_pack_(dots, i8, serial, i8, i8,
|
|
2523
|
-
/*
|
|
2520
|
+
nk_define_cross_pack_(dots, i8, serial, i8, i8, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2521
|
+
nk_store_b128_serial_, nk_partial_store_b8x16_serial_, /*simd_width=*/16,
|
|
2522
|
+
/*norm_value_type=*/u32, nk_dots_reduce_sumsq_i8_, /*depth_simd_dimensions=*/16,
|
|
2523
|
+
/*dimensions_per_value=*/1)
|
|
2524
2524
|
nk_define_cross_symmetric_(dots, i8, serial, i8, i32, nk_b128_vec_t, nk_dot_i8x16_state_serial_t, nk_b128_vec_t,
|
|
2525
2525
|
nk_dot_i8x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2526
2526
|
nk_dot_i8x16_update_serial, nk_dot_i8x16_finalize_serial, nk_store_b128_serial_,
|
|
@@ -2535,8 +2535,10 @@ nk_define_cross_packed_(dots, i8, serial, i8, i8, i32, nk_b128_vec_t, nk_dot_i8x
|
|
|
2535
2535
|
/* U8 GEMM: depth_simd_dimensions=16 (16 u8s = 16 bytes), U32 accumulator */
|
|
2536
2536
|
nk_define_cross_pack_size_(dots, u8, serial, u8, u8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
|
|
2537
2537
|
/*dimensions_per_value=*/1)
|
|
2538
|
-
nk_define_cross_pack_(dots, u8, serial, u8, u8,
|
|
2539
|
-
/*
|
|
2538
|
+
nk_define_cross_pack_(dots, u8, serial, u8, u8, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2539
|
+
nk_store_b128_serial_, nk_partial_store_b8x16_serial_, /*simd_width=*/16,
|
|
2540
|
+
/*norm_value_type=*/u32, nk_dots_reduce_sumsq_u8_, /*depth_simd_dimensions=*/16,
|
|
2541
|
+
/*dimensions_per_value=*/1)
|
|
2540
2542
|
nk_define_cross_symmetric_(dots, u8, serial, u8, u32, nk_b128_vec_t, nk_dot_u8x16_state_serial_t, nk_b128_vec_t,
|
|
2541
2543
|
nk_dot_u8x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2542
2544
|
nk_dot_u8x16_update_serial, nk_dot_u8x16_finalize_serial, nk_store_b128_serial_,
|
|
@@ -2551,8 +2553,9 @@ nk_define_cross_packed_(dots, u8, serial, u8, u8, u32, nk_b128_vec_t, nk_dot_u8x
|
|
|
2551
2553
|
/* E4M3 GEMM: depth_simd_dimensions=16 (16 e4m3s = 16 bytes), F32 accumulator */
|
|
2552
2554
|
nk_define_cross_pack_size_(dots, e4m3, serial, e4m3, e4m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
|
|
2553
2555
|
/*dimensions_per_value=*/1)
|
|
2554
|
-
nk_define_cross_pack_(dots, e4m3, serial, e4m3, e4m3,
|
|
2555
|
-
|
|
2556
|
+
nk_define_cross_pack_(dots, e4m3, serial, e4m3, e4m3, nk_b128_vec_t, nk_load_b128_serial_,
|
|
2557
|
+
nk_partial_load_b8x16_serial_, nk_store_b128_serial_, nk_partial_store_b8x16_serial_,
|
|
2558
|
+
/*simd_width=*/16, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e4m3_,
|
|
2556
2559
|
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2557
2560
|
nk_define_cross_symmetric_(dots, e4m3, serial, e4m3, f32, nk_b128_vec_t, nk_dot_e4m3x16_state_serial_t, nk_b128_vec_t,
|
|
2558
2561
|
nk_dot_e4m3x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
@@ -2568,8 +2571,9 @@ nk_define_cross_packed_(dots, e4m3, serial, e4m3, e4m3, f32, nk_b128_vec_t, nk_d
|
|
|
2568
2571
|
/* E5M2 GEMM: depth_simd_dimensions=16 (16 e5m2s = 16 bytes), F32 accumulator */
|
|
2569
2572
|
nk_define_cross_pack_size_(dots, e5m2, serial, e5m2, e5m2, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
|
|
2570
2573
|
/*dimensions_per_value=*/1)
|
|
2571
|
-
nk_define_cross_pack_(dots, e5m2, serial, e5m2, e5m2,
|
|
2572
|
-
|
|
2574
|
+
nk_define_cross_pack_(dots, e5m2, serial, e5m2, e5m2, nk_b128_vec_t, nk_load_b128_serial_,
|
|
2575
|
+
nk_partial_load_b8x16_serial_, nk_store_b128_serial_, nk_partial_store_b8x16_serial_,
|
|
2576
|
+
/*simd_width=*/16, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e5m2_,
|
|
2573
2577
|
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2574
2578
|
nk_define_cross_symmetric_(dots, e5m2, serial, e5m2, f32, nk_b128_vec_t, nk_dot_e5m2x16_state_serial_t, nk_b128_vec_t,
|
|
2575
2579
|
nk_dot_e5m2x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
@@ -2585,8 +2589,9 @@ nk_define_cross_packed_(dots, e5m2, serial, e5m2, e5m2, f32, nk_b128_vec_t, nk_d
|
|
|
2585
2589
|
/* E2M3 GEMM: depth_simd_dimensions=16 (16 e2m3s = 16 bytes), F32 accumulator */
|
|
2586
2590
|
nk_define_cross_pack_size_(dots, e2m3, serial, e2m3, e2m3, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
|
|
2587
2591
|
/*dimensions_per_value=*/1)
|
|
2588
|
-
nk_define_cross_pack_(dots, e2m3, serial, e2m3, e2m3,
|
|
2589
|
-
|
|
2592
|
+
nk_define_cross_pack_(dots, e2m3, serial, e2m3, e2m3, nk_b128_vec_t, nk_load_b128_serial_,
|
|
2593
|
+
nk_partial_load_b8x16_serial_, nk_store_b128_serial_, nk_partial_store_b8x16_serial_,
|
|
2594
|
+
/*simd_width=*/16, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e2m3_,
|
|
2590
2595
|
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2591
2596
|
nk_define_cross_symmetric_(dots, e2m3, serial, e2m3, f32, nk_b128_vec_t, nk_dot_e2m3x16_state_serial_t, nk_b128_vec_t,
|
|
2592
2597
|
nk_dot_e2m3x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
@@ -2602,8 +2607,9 @@ nk_define_cross_packed_(dots, e2m3, serial, e2m3, e2m3, f32, nk_b128_vec_t, nk_d
|
|
|
2602
2607
|
/* E3M2 GEMM: depth_simd_dimensions=16 (16 e3m2s = 16 bytes), F32 accumulator */
|
|
2603
2608
|
nk_define_cross_pack_size_(dots, e3m2, serial, e3m2, e3m2, /*norm_value_type=*/f32, /*depth_simd_dimensions=*/16,
|
|
2604
2609
|
/*dimensions_per_value=*/1)
|
|
2605
|
-
nk_define_cross_pack_(dots, e3m2, serial, e3m2, e3m2,
|
|
2606
|
-
|
|
2610
|
+
nk_define_cross_pack_(dots, e3m2, serial, e3m2, e3m2, nk_b128_vec_t, nk_load_b128_serial_,
|
|
2611
|
+
nk_partial_load_b8x16_serial_, nk_store_b128_serial_, nk_partial_store_b8x16_serial_,
|
|
2612
|
+
/*simd_width=*/16, /*norm_value_type=*/f32, nk_dots_reduce_sumsq_e3m2_,
|
|
2607
2613
|
/*depth_simd_dimensions=*/16, /*dimensions_per_value=*/1)
|
|
2608
2614
|
nk_define_cross_symmetric_(dots, e3m2, serial, e3m2, f32, nk_b128_vec_t, nk_dot_e3m2x16_state_serial_t, nk_b128_vec_t,
|
|
2609
2615
|
nk_dot_e3m2x16_init_serial, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
@@ -2619,9 +2625,10 @@ nk_define_cross_packed_(dots, e3m2, serial, e3m2, e3m2, f32, nk_b128_vec_t, nk_d
|
|
|
2619
2625
|
/* U4 GEMM: u4x2 for both A and B */
|
|
2620
2626
|
nk_define_cross_pack_size_(dots, u4, serial, u4x2, u4x2, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
|
|
2621
2627
|
/*dimensions_per_value=*/2)
|
|
2622
|
-
nk_define_cross_pack_(dots, u4, serial, u4x2, u4x2,
|
|
2623
|
-
|
|
2624
|
-
/*
|
|
2628
|
+
nk_define_cross_pack_(dots, u4, serial, u4x2, u4x2, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2629
|
+
nk_store_b128_serial_, nk_partial_store_b8x16_serial_, /*simd_width=*/16,
|
|
2630
|
+
/*norm_value_type=*/u32, nk_dots_reduce_sumsq_u4_, /*depth_simd_dimensions=*/16,
|
|
2631
|
+
/*dimensions_per_value=*/2)
|
|
2625
2632
|
nk_define_cross_symmetric_(dots, u4, serial, u4x2, u32, nk_b64_vec_t, nk_dot_u4x16_state_serial_t, nk_b128_vec_t,
|
|
2626
2633
|
nk_dot_u4x16_init_serial, nk_load_b64_serial_, nk_partial_load_b4x16_serial_,
|
|
2627
2634
|
nk_dot_u4x16_update_serial, nk_dot_u4x16_finalize_serial, nk_store_b128_serial_,
|
|
@@ -2636,9 +2643,10 @@ nk_define_cross_packed_(dots, u4, serial, u4x2, u4x2, u32, nk_b64_vec_t, nk_dot_
|
|
|
2636
2643
|
/* I4 GEMM: i4x2 for both A and B */
|
|
2637
2644
|
nk_define_cross_pack_size_(dots, i4, serial, i4x2, i4x2, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/16,
|
|
2638
2645
|
/*dimensions_per_value=*/2)
|
|
2639
|
-
nk_define_cross_pack_(dots, i4, serial, i4x2, i4x2,
|
|
2640
|
-
|
|
2641
|
-
/*
|
|
2646
|
+
nk_define_cross_pack_(dots, i4, serial, i4x2, i4x2, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2647
|
+
nk_store_b128_serial_, nk_partial_store_b8x16_serial_, /*simd_width=*/16,
|
|
2648
|
+
/*norm_value_type=*/u32, nk_dots_reduce_sumsq_i4_, /*depth_simd_dimensions=*/16,
|
|
2649
|
+
/*dimensions_per_value=*/2)
|
|
2642
2650
|
nk_define_cross_symmetric_(dots, i4, serial, i4x2, i32, nk_b64_vec_t, nk_dot_i4x16_state_serial_t, nk_b128_vec_t,
|
|
2643
2651
|
nk_dot_i4x16_init_serial, nk_load_b64_serial_, nk_partial_load_b4x16_serial_,
|
|
2644
2652
|
nk_dot_i4x16_update_serial, nk_dot_i4x16_finalize_serial, nk_store_b128_serial_,
|
|
@@ -2653,8 +2661,10 @@ nk_define_cross_packed_(dots, i4, serial, i4x2, i4x2, i32, nk_b64_vec_t, nk_dot_
|
|
|
2653
2661
|
/* U1 GEMM: u1x8 for both A and B */
|
|
2654
2662
|
nk_define_cross_pack_size_(dots, u1, serial, u1x8, u1x8, /*norm_value_type=*/u32, /*depth_simd_dimensions=*/128,
|
|
2655
2663
|
/*dimensions_per_value=*/8)
|
|
2656
|
-
nk_define_cross_pack_(dots, u1, serial, u1x8, u1x8,
|
|
2657
|
-
|
|
2664
|
+
nk_define_cross_pack_(dots, u1, serial, u1x8, u1x8, nk_b128_vec_t, nk_load_b128_serial_, nk_partial_load_b8x16_serial_,
|
|
2665
|
+
nk_store_b128_serial_, nk_partial_store_b8x16_serial_, /*simd_width=*/16,
|
|
2666
|
+
/*norm_value_type=*/u32, nk_dots_reduce_sum_u1_, /*depth_simd_dimensions=*/128,
|
|
2667
|
+
/*dimensions_per_value=*/8)
|
|
2658
2668
|
nk_define_cross_symmetric_(dots, u1, serial, u1x8, u32, nk_b128_vec_t, nk_dot_u1x128_state_serial_t, nk_b128_vec_t,
|
|
2659
2669
|
nk_dot_u1x128_init_serial, nk_load_b128_serial_, nk_partial_load_b1x128_serial_,
|
|
2660
2670
|
nk_dot_u1x128_update_serial, nk_dot_u1x128_finalize_serial, nk_store_b128_serial_,
|
|
@@ -2673,7 +2683,7 @@ nk_define_cross_packed_(dots, u1, serial, u1x8, u1x8, u32, nk_b128_vec_t, nk_dot
|
|
|
2673
2683
|
#endif
|
|
2674
2684
|
|
|
2675
2685
|
/* BF16 compact: truncate F32 → BF16 in-place.
|
|
2676
|
-
* Reads F32 matrix with c_stride_in_bytes, writes BF16 tightly packed (
|
|
2686
|
+
* Reads F32 matrix with c_stride_in_bytes, writes BF16 tightly packed (stride_in_bytes = column_count × sizeof(bf16)).
|
|
2677
2687
|
*/
|
|
2678
2688
|
NK_PUBLIC void nk_dots_compact_bf16_serial(void *c, nk_size_t row_count, nk_size_t column_count,
|
|
2679
2689
|
nk_size_t c_stride_in_bytes) {
|
|
@@ -2767,78 +2777,84 @@ NK_PUBLIC void nk_dots_compact_i8_serial(void *c, nk_size_t row_count, nk_size_t
|
|
|
2767
2777
|
} \
|
|
2768
2778
|
}
|
|
2769
2779
|
|
|
2770
|
-
#define nk_define_cross_normalized_symmetric_(metric_name, input_type_name, isa_suffix, input_value_type,
|
|
2771
|
-
dot_result_type, norm_value_type, final_result_type, vec_type,
|
|
2772
|
-
dots_symmetric_fn, from_dot_fn, compute_norm_fn, load_fn,
|
|
2773
|
-
partial_load_fn, store_fn, partial_store_fn, dimensions_per_value)
|
|
2774
|
-
NK_PUBLIC void nk_##metric_name##s_symmetric_##input_type_name##_##isa_suffix(
|
|
2775
|
-
nk_##input_value_type##_t const *vectors, nk_size_t
|
|
2776
|
-
nk_##final_result_type##_t *result, nk_size_t
|
|
2777
|
-
|
|
2778
|
-
|
|
2779
|
-
|
|
2780
|
-
|
|
2781
|
-
|
|
2782
|
-
|
|
2783
|
-
|
|
2784
|
-
|
|
2785
|
-
|
|
2786
|
-
|
|
2787
|
-
|
|
2788
|
-
|
|
2789
|
-
|
|
2790
|
-
|
|
2791
|
-
|
|
2792
|
-
|
|
2793
|
-
|
|
2794
|
-
|
|
2795
|
-
|
|
2796
|
-
|
|
2797
|
-
|
|
2798
|
-
|
|
2799
|
-
|
|
2800
|
-
|
|
2801
|
-
|
|
2802
|
-
|
|
2803
|
-
|
|
2804
|
-
|
|
2805
|
-
|
|
2806
|
-
|
|
2807
|
-
|
|
2808
|
-
|
|
2809
|
-
nk_##
|
|
2810
|
-
|
|
2811
|
-
|
|
2812
|
-
|
|
2813
|
-
|
|
2814
|
-
|
|
2815
|
-
|
|
2816
|
-
vec_type
|
|
2817
|
-
load_fn(
|
|
2818
|
-
|
|
2819
|
-
|
|
2820
|
-
|
|
2821
|
-
|
|
2822
|
-
|
|
2823
|
-
|
|
2824
|
-
|
|
2825
|
-
|
|
2826
|
-
|
|
2827
|
-
|
|
2828
|
-
|
|
2829
|
-
|
|
2830
|
-
|
|
2831
|
-
|
|
2832
|
-
|
|
2833
|
-
|
|
2834
|
-
|
|
2835
|
-
|
|
2836
|
-
r_out
|
|
2837
|
-
|
|
2780
|
+
#define nk_define_cross_normalized_symmetric_(metric_name, input_type_name, isa_suffix, input_value_type, \
|
|
2781
|
+
dot_result_type, norm_value_type, final_result_type, vec_type, \
|
|
2782
|
+
dots_symmetric_fn, from_dot_fn, compute_norm_fn, load_fn, \
|
|
2783
|
+
partial_load_fn, store_fn, partial_store_fn, dimensions_per_value) \
|
|
2784
|
+
NK_PUBLIC void nk_##metric_name##s_symmetric_##input_type_name##_##isa_suffix( \
|
|
2785
|
+
nk_##input_value_type##_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, \
|
|
2786
|
+
nk_##final_result_type##_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, \
|
|
2787
|
+
nk_size_t row_count) { \
|
|
2788
|
+
\
|
|
2789
|
+
dots_symmetric_fn(vectors, vectors_count, depth, stride_in_bytes, (nk_##dot_result_type##_t *)result, \
|
|
2790
|
+
result_stride_in_bytes, row_start, row_count); \
|
|
2791
|
+
\
|
|
2792
|
+
/* Phase 1 — cache row norms in the result diagonal (O(row_count) calls) */ \
|
|
2793
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) { \
|
|
2794
|
+
nk_##input_value_type##_t const *row_vector = \
|
|
2795
|
+
(nk_##input_value_type##_t const *)((char const *)vectors + row_index * stride_in_bytes); \
|
|
2796
|
+
nk_##norm_value_type##_t *row_diag = (nk_##norm_value_type##_t *)((char *)result + \
|
|
2797
|
+
row_index * result_stride_in_bytes); \
|
|
2798
|
+
row_diag[row_index] = compute_norm_fn(row_vector, depth); \
|
|
2799
|
+
} \
|
|
2800
|
+
\
|
|
2801
|
+
/* Phase 2 — column-first post-processing with 256-element norm cache */ \
|
|
2802
|
+
nk_##norm_value_type##_t column_norms[256]; \
|
|
2803
|
+
for (nk_size_t column_chunk_start = 0; column_chunk_start < vectors_count; column_chunk_start += 256) { \
|
|
2804
|
+
nk_size_t column_chunk_end = column_chunk_start + 256 < vectors_count ? column_chunk_start + 256 \
|
|
2805
|
+
: vectors_count; \
|
|
2806
|
+
\
|
|
2807
|
+
/* Pre-compute norms for this column chunk — each column visited exactly once */ \
|
|
2808
|
+
for (nk_size_t col = column_chunk_start; col < column_chunk_end; ++col) { \
|
|
2809
|
+
nk_##input_value_type##_t const *column_vector = \
|
|
2810
|
+
(nk_##input_value_type##_t const *)((char const *)vectors + col * stride_in_bytes); \
|
|
2811
|
+
column_norms[col - column_chunk_start] = compute_norm_fn(column_vector, depth); \
|
|
2812
|
+
} \
|
|
2813
|
+
\
|
|
2814
|
+
/* Sweep assigned rows against this column chunk */ \
|
|
2815
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) { \
|
|
2816
|
+
nk_size_t j_start = row_index + 1 > column_chunk_start ? row_index + 1 : column_chunk_start; \
|
|
2817
|
+
if (j_start >= column_chunk_end) continue; \
|
|
2818
|
+
char *row_ptr = (char *)result + row_index * result_stride_in_bytes; \
|
|
2819
|
+
nk_##norm_value_type##_t sumsq_i = ((nk_##norm_value_type##_t *)row_ptr)[row_index]; \
|
|
2820
|
+
nk_##dot_result_type##_t *r_dots = (nk_##dot_result_type##_t *)row_ptr; \
|
|
2821
|
+
nk_##final_result_type##_t *r_out = (nk_##final_result_type##_t *)row_ptr; \
|
|
2822
|
+
\
|
|
2823
|
+
/* 4-wide vectorized loop */ \
|
|
2824
|
+
nk_size_t j = j_start; \
|
|
2825
|
+
for (; j + 4 <= column_chunk_end; j += 4) { \
|
|
2826
|
+
vec_type target_norms_vec; \
|
|
2827
|
+
load_fn(&column_norms[j - column_chunk_start], &target_norms_vec); \
|
|
2828
|
+
vec_type dots_vec, results_vec; \
|
|
2829
|
+
load_fn(r_dots + j, &dots_vec); \
|
|
2830
|
+
from_dot_fn(dots_vec, sumsq_i, target_norms_vec, &results_vec); \
|
|
2831
|
+
store_fn(&results_vec, r_out + j); \
|
|
2832
|
+
} \
|
|
2833
|
+
/* Remainder */ \
|
|
2834
|
+
if (j < column_chunk_end) { \
|
|
2835
|
+
vec_type dots_vec = {0}, norms_vec = {0}, results_vec; \
|
|
2836
|
+
partial_load_fn(r_dots + j, &dots_vec, column_chunk_end - j); \
|
|
2837
|
+
partial_load_fn(&column_norms[j - column_chunk_start], &norms_vec, column_chunk_end - j); \
|
|
2838
|
+
from_dot_fn(dots_vec, sumsq_i, norms_vec, &results_vec); \
|
|
2839
|
+
partial_store_fn(&results_vec, r_out + j, column_chunk_end - j); \
|
|
2840
|
+
} \
|
|
2841
|
+
} \
|
|
2842
|
+
} \
|
|
2843
|
+
\
|
|
2844
|
+
/* Phase 3 — zero diagonals */ \
|
|
2845
|
+
for (nk_size_t row_index = row_start; row_index < row_start + row_count; ++row_index) { \
|
|
2846
|
+
nk_##final_result_type##_t *r_out = (nk_##final_result_type##_t *)((char *)result + \
|
|
2847
|
+
row_index * result_stride_in_bytes); \
|
|
2848
|
+
r_out[row_index] = 0; \
|
|
2849
|
+
} \
|
|
2838
2850
|
}
|
|
2839
2851
|
|
|
2840
2852
|
#if defined(__cplusplus)
|
|
2841
2853
|
} // extern "C"
|
|
2842
2854
|
#endif
|
|
2843
2855
|
|
|
2856
|
+
#if defined(__GNUC__) && !defined(__clang__)
|
|
2857
|
+
#pragma GCC diagnostic pop
|
|
2858
|
+
#endif
|
|
2859
|
+
|
|
2844
2860
|
#endif // NK_DOTS_SERIAL_H
|