numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief AVX-512 BF16 implementations for the redesigned reduction API (moments).
|
|
3
|
+
* @file include/numkong/reduce/genoa.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 12, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/reduce.h
|
|
8
|
+
*
|
|
9
|
+
* @section bf16_moments BF16 Moments
|
|
10
|
+
*
|
|
11
|
+
* `_mm512_dpbf16_ps(acc, a, b)` (VDPBF16PS) computes paired bf16→f32 dot products:
|
|
12
|
+
* `acc[i] += bf16_to_f32(a[2i]) * bf16_to_f32(b[2i]) + bf16_to_f32(a[2i+1]) * bf16_to_f32(b[2i+1])`
|
|
13
|
+
* Processing 32 bf16 values into 16 f32 accumulators per instruction.
|
|
14
|
+
*
|
|
15
|
+
* For sum: use ones vector (bf16 1.0 = 0x3F80).
|
|
16
|
+
* For sumsq: dot product of data with itself.
|
|
17
|
+
*/
|
|
18
|
+
#ifndef NK_REDUCE_GENOA_H
|
|
19
|
+
#define NK_REDUCE_GENOA_H
|
|
20
|
+
|
|
21
|
+
#if NK_TARGET_X86_
|
|
22
|
+
#if NK_TARGET_GENOA
|
|
23
|
+
|
|
24
|
+
#include "numkong/reduce/serial.h"
|
|
25
|
+
#include "numkong/cast/icelake.h" // `nk_e4m3x32_to_bf16x32_icelake_` etc.
|
|
26
|
+
|
|
27
|
+
#if defined(__cplusplus)
|
|
28
|
+
extern "C" {
|
|
29
|
+
#endif
|
|
30
|
+
|
|
31
|
+
#if defined(__clang__)
|
|
32
|
+
#pragma clang attribute push( \
|
|
33
|
+
__attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512bf16,f16c,fma,bmi,bmi2"))), \
|
|
34
|
+
apply_to = function)
|
|
35
|
+
#elif defined(__GNUC__)
|
|
36
|
+
#pragma GCC push_options
|
|
37
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512bf16", "f16c", "fma", "bmi", "bmi2")
|
|
38
|
+
#endif
|
|
39
|
+
|
|
40
|
+
NK_INTERNAL void nk_reduce_moments_bf16_genoa_contiguous_( //
|
|
41
|
+
nk_bf16_t const *data_ptr, nk_size_t count, //
|
|
42
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
43
|
+
|
|
44
|
+
// bf16(1.0) = 0x3F80. Pack 32 of them as __m512bh.
|
|
45
|
+
__m512bh ones_bf16x32 = nk_m512bh_from_m512i_(_mm512_set1_epi16(0x3F80));
|
|
46
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
47
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
48
|
+
nk_size_t idx = 0;
|
|
49
|
+
|
|
50
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
51
|
+
__m512bh data_bf16x32 = nk_m512bh_from_m512i_(_mm512_loadu_si512(data_ptr + idx));
|
|
52
|
+
sum_f32x16 = _mm512_dpbf16_ps(sum_f32x16, data_bf16x32, ones_bf16x32);
|
|
53
|
+
sumsq_f32x16 = _mm512_dpbf16_ps(sumsq_f32x16, data_bf16x32, data_bf16x32);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
// Tail: masked load for remaining elements (< 32)
|
|
57
|
+
nk_size_t remaining = count - idx;
|
|
58
|
+
if (remaining > 0) {
|
|
59
|
+
__mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining);
|
|
60
|
+
__m512bh data_bf16x32 = nk_m512bh_from_m512i_(_mm512_maskz_loadu_epi16(tail_mask, data_ptr + idx));
|
|
61
|
+
sum_f32x16 = _mm512_dpbf16_ps(sum_f32x16, data_bf16x32, ones_bf16x32);
|
|
62
|
+
sumsq_f32x16 = _mm512_dpbf16_ps(sumsq_f32x16, data_bf16x32, data_bf16x32);
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
66
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
NK_PUBLIC void nk_reduce_moments_bf16_genoa( //
|
|
70
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
71
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
72
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_bf16_t);
|
|
73
|
+
int aligned = (stride_bytes % sizeof(nk_bf16_t) == 0);
|
|
74
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
75
|
+
else if (!aligned) nk_reduce_moments_bf16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
76
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
77
|
+
nk_size_t left_count = count / 2;
|
|
78
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
79
|
+
nk_reduce_moments_bf16_genoa(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
80
|
+
nk_reduce_moments_bf16_genoa(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
81
|
+
&right_sum, &right_sumsq);
|
|
82
|
+
*sum_ptr = left_sum + right_sum;
|
|
83
|
+
*sumsq_ptr = left_sumsq + right_sumsq;
|
|
84
|
+
}
|
|
85
|
+
else if (stride_elements == 1) nk_reduce_moments_bf16_genoa_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
86
|
+
else nk_reduce_moments_bf16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
NK_INTERNAL void nk_reduce_moments_e4m3_genoa_contiguous_( //
|
|
90
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, //
|
|
91
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
92
|
+
|
|
93
|
+
__m512bh ones_bf16x32 = nk_m512bh_from_m512i_(_mm512_set1_epi16(0x3F80)); // bf16(1.0)
|
|
94
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
95
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
96
|
+
nk_size_t idx = 0;
|
|
97
|
+
|
|
98
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
99
|
+
__m256i raw_u8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
100
|
+
__m512bh data_bf16x32 = nk_m512bh_from_m512i_(nk_e4m3x32_to_bf16x32_icelake_(raw_u8x32));
|
|
101
|
+
sum_f32x16 = _mm512_dpbf16_ps(sum_f32x16, data_bf16x32, ones_bf16x32);
|
|
102
|
+
sumsq_f32x16 = _mm512_dpbf16_ps(sumsq_f32x16, data_bf16x32, data_bf16x32);
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
// Tail: masked load for remaining elements (< 32)
|
|
106
|
+
nk_size_t remaining = count - idx;
|
|
107
|
+
if (remaining > 0) {
|
|
108
|
+
__mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining);
|
|
109
|
+
__m256i raw_u8x32 = _mm256_maskz_loadu_epi8(tail_mask, data_ptr + idx);
|
|
110
|
+
__m512bh data_bf16x32 = nk_m512bh_from_m512i_(nk_e4m3x32_to_bf16x32_icelake_(raw_u8x32));
|
|
111
|
+
sum_f32x16 = _mm512_dpbf16_ps(sum_f32x16, data_bf16x32, ones_bf16x32);
|
|
112
|
+
sumsq_f32x16 = _mm512_dpbf16_ps(sumsq_f32x16, data_bf16x32, data_bf16x32);
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
116
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
NK_PUBLIC void nk_reduce_moments_e4m3_genoa( //
|
|
120
|
+
nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
121
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
122
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
|
|
123
|
+
int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
|
|
124
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
125
|
+
else if (!aligned) nk_reduce_moments_e4m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
126
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
127
|
+
nk_size_t left_count = count / 2;
|
|
128
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
129
|
+
nk_reduce_moments_e4m3_genoa(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
130
|
+
nk_reduce_moments_e4m3_genoa(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
131
|
+
&right_sum, &right_sumsq);
|
|
132
|
+
*sum_ptr = left_sum + right_sum;
|
|
133
|
+
*sumsq_ptr = left_sumsq + right_sumsq;
|
|
134
|
+
}
|
|
135
|
+
else if (stride_elements == 1) nk_reduce_moments_e4m3_genoa_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
136
|
+
else nk_reduce_moments_e4m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
NK_INTERNAL void nk_reduce_moments_e5m2_genoa_contiguous_( //
|
|
140
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, //
|
|
141
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
142
|
+
|
|
143
|
+
__m512bh ones_bf16x32 = nk_m512bh_from_m512i_(_mm512_set1_epi16(0x3F80)); // bf16(1.0)
|
|
144
|
+
__m512 sum_f32x16 = _mm512_setzero_ps();
|
|
145
|
+
__m512 sumsq_f32x16 = _mm512_setzero_ps();
|
|
146
|
+
nk_size_t idx = 0;
|
|
147
|
+
|
|
148
|
+
for (; idx + 32 <= count; idx += 32) {
|
|
149
|
+
__m256i raw_u8x32 = _mm256_loadu_si256((__m256i const *)(data_ptr + idx));
|
|
150
|
+
__m512bh data_bf16x32 = nk_m512bh_from_m512i_(nk_e5m2x32_to_bf16x32_icelake_(raw_u8x32));
|
|
151
|
+
sum_f32x16 = _mm512_dpbf16_ps(sum_f32x16, data_bf16x32, ones_bf16x32);
|
|
152
|
+
sumsq_f32x16 = _mm512_dpbf16_ps(sumsq_f32x16, data_bf16x32, data_bf16x32);
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
// Tail: masked load for remaining elements (< 32)
|
|
156
|
+
nk_size_t remaining = count - idx;
|
|
157
|
+
if (remaining > 0) {
|
|
158
|
+
__mmask32 tail_mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)remaining);
|
|
159
|
+
__m256i raw_u8x32 = _mm256_maskz_loadu_epi8(tail_mask, data_ptr + idx);
|
|
160
|
+
__m512bh data_bf16x32 = nk_m512bh_from_m512i_(nk_e5m2x32_to_bf16x32_icelake_(raw_u8x32));
|
|
161
|
+
sum_f32x16 = _mm512_dpbf16_ps(sum_f32x16, data_bf16x32, ones_bf16x32);
|
|
162
|
+
sumsq_f32x16 = _mm512_dpbf16_ps(sumsq_f32x16, data_bf16x32, data_bf16x32);
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
*sum_ptr = nk_reduce_add_f32x16_skylake_(sum_f32x16);
|
|
166
|
+
*sumsq_ptr = nk_reduce_add_f32x16_skylake_(sumsq_f32x16);
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
NK_PUBLIC void nk_reduce_moments_e5m2_genoa( //
|
|
170
|
+
nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
171
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
172
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
|
|
173
|
+
int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
|
|
174
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
175
|
+
else if (!aligned) nk_reduce_moments_e5m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
176
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 32) {
|
|
177
|
+
nk_size_t left_count = count / 2;
|
|
178
|
+
nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
|
|
179
|
+
nk_reduce_moments_e5m2_genoa(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
|
|
180
|
+
nk_reduce_moments_e5m2_genoa(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
181
|
+
&right_sum, &right_sumsq);
|
|
182
|
+
*sum_ptr = left_sum + right_sum;
|
|
183
|
+
*sumsq_ptr = left_sumsq + right_sumsq;
|
|
184
|
+
}
|
|
185
|
+
else if (stride_elements == 1) nk_reduce_moments_e5m2_genoa_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
186
|
+
else nk_reduce_moments_e5m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
#if defined(__clang__)
|
|
190
|
+
#pragma clang attribute pop
|
|
191
|
+
#elif defined(__GNUC__)
|
|
192
|
+
#pragma GCC pop_options
|
|
193
|
+
#endif
|
|
194
|
+
|
|
195
|
+
#if defined(__cplusplus)
|
|
196
|
+
} // extern "C"
|
|
197
|
+
#endif
|
|
198
|
+
|
|
199
|
+
#endif // NK_TARGET_GENOA
|
|
200
|
+
#endif // NK_TARGET_X86_
|
|
201
|
+
#endif // NK_REDUCE_GENOA_H
|