numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief ARMv8.6-BF16 implementations for the redesigned reduction API (moments + minmax).
|
|
3
|
+
* @file include/numkong/reduce/neonbfdot.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 13, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/reduce.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_REDUCE_NEONBFDOT_H
|
|
10
|
+
#define NK_REDUCE_NEONBFDOT_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_ARM_
|
|
13
|
+
#if NK_TARGET_NEONBFDOT
|
|
14
|
+
|
|
15
|
+
#include "numkong/types.h" // `nk_bf16_t`
|
|
16
|
+
#include "numkong/cast/neon.h" // `nk_e4m3x8_to_f16x8_neon_`
|
|
17
|
+
#include "numkong/cast/serial.h" // `nk_f32_to_bf16_serial`
|
|
18
|
+
#include "numkong/reduce/serial.h" // `nk_reduce_moments_bf16_serial`
|
|
19
|
+
|
|
20
|
+
#if defined(__cplusplus)
|
|
21
|
+
extern "C" {
|
|
22
|
+
#endif
|
|
23
|
+
|
|
24
|
+
#if defined(__clang__)
|
|
25
|
+
#pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function)
|
|
26
|
+
#elif defined(__GNUC__)
|
|
27
|
+
#pragma GCC push_options
|
|
28
|
+
#pragma GCC target("arch=armv8.6-a+simd+bf16")
|
|
29
|
+
#endif
|
|
30
|
+
|
|
31
|
+
NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_contiguous_( //
|
|
32
|
+
nk_bf16_t const *data_ptr, nk_size_t count, //
|
|
33
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
34
|
+
|
|
35
|
+
// bf16 representation of 1.0 is 0x3F80 (same as upper 16 bits of f32 1.0)
|
|
36
|
+
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(vdupq_n_u16(0x3F80));
|
|
37
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
38
|
+
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
39
|
+
nk_size_t idx = 0;
|
|
40
|
+
|
|
41
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
42
|
+
bfloat16x8_t data_bf16x8 = vld1q_bf16((nk_bf16_for_arm_simd_t const *)(data_ptr + idx));
|
|
43
|
+
sum_f32x4 = vbfdotq_f32(sum_f32x4, data_bf16x8, ones_bf16x8);
|
|
44
|
+
sumsq_f32x4 = vbfdotq_f32(sumsq_f32x4, data_bf16x8, data_bf16x8);
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
// Handle tail with type-agnostic partial load
|
|
48
|
+
if (idx < count) {
|
|
49
|
+
nk_b128_vec_t tail_vec;
|
|
50
|
+
nk_partial_load_b16x8_serial_(data_ptr + idx, &tail_vec, count - idx);
|
|
51
|
+
bfloat16x8_t data_bf16x8 = vreinterpretq_bf16_u16(tail_vec.u16x8);
|
|
52
|
+
sum_f32x4 = vbfdotq_f32(sum_f32x4, data_bf16x8, ones_bf16x8);
|
|
53
|
+
sumsq_f32x4 = vbfdotq_f32(sumsq_f32x4, data_bf16x8, data_bf16x8);
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
*sum_ptr = vaddvq_f32(sum_f32x4);
|
|
57
|
+
*sumsq_ptr = vaddvq_f32(sumsq_f32x4);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_strided_( //
|
|
61
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
62
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
63
|
+
|
|
64
|
+
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(vdupq_n_u16(0x3F80));
|
|
65
|
+
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
66
|
+
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
67
|
+
nk_size_t idx = 0;
|
|
68
|
+
|
|
69
|
+
if (stride_elements == 2) {
|
|
70
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
71
|
+
uint16x8x2_t loaded_u16x8x2 = vld2q_u16((uint16_t const *)(data_ptr + idx * 2));
|
|
72
|
+
bfloat16x8_t data_bf16x8 = vreinterpretq_bf16_u16(loaded_u16x8x2.val[0]);
|
|
73
|
+
sum_f32x4 = vbfdotq_f32(sum_f32x4, data_bf16x8, ones_bf16x8);
|
|
74
|
+
sumsq_f32x4 = vbfdotq_f32(sumsq_f32x4, data_bf16x8, data_bf16x8);
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
else if (stride_elements == 3) {
|
|
78
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
79
|
+
uint16x8x3_t loaded_u16x8x3 = vld3q_u16((uint16_t const *)(data_ptr + idx * 3));
|
|
80
|
+
bfloat16x8_t data_bf16x8 = vreinterpretq_bf16_u16(loaded_u16x8x3.val[0]);
|
|
81
|
+
sum_f32x4 = vbfdotq_f32(sum_f32x4, data_bf16x8, ones_bf16x8);
|
|
82
|
+
sumsq_f32x4 = vbfdotq_f32(sumsq_f32x4, data_bf16x8, data_bf16x8);
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
else if (stride_elements == 4) {
|
|
86
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
87
|
+
uint16x8x4_t loaded_u16x8x4 = vld4q_u16((uint16_t const *)(data_ptr + idx * 4));
|
|
88
|
+
bfloat16x8_t data_bf16x8 = vreinterpretq_bf16_u16(loaded_u16x8x4.val[0]);
|
|
89
|
+
sum_f32x4 = vbfdotq_f32(sum_f32x4, data_bf16x8, ones_bf16x8);
|
|
90
|
+
sumsq_f32x4 = vbfdotq_f32(sumsq_f32x4, data_bf16x8, data_bf16x8);
|
|
91
|
+
}
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
// Gather tail into contiguous buffer, then dot with ones
|
|
95
|
+
if (idx < count) {
|
|
96
|
+
nk_b128_vec_t tail_vec = {0};
|
|
97
|
+
nk_size_t remaining = count - idx;
|
|
98
|
+
for (nk_size_t k = 0; k < remaining; ++k)
|
|
99
|
+
tail_vec.u16s[k] = *(nk_u16_t const *)(data_ptr + (idx + k) * stride_elements);
|
|
100
|
+
bfloat16x8_t data_bf16x8 = vreinterpretq_bf16_u16(tail_vec.u16x8);
|
|
101
|
+
sum_f32x4 = vbfdotq_f32(sum_f32x4, data_bf16x8, ones_bf16x8);
|
|
102
|
+
sumsq_f32x4 = vbfdotq_f32(sumsq_f32x4, data_bf16x8, data_bf16x8);
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
*sum_ptr = vaddvq_f32(sum_f32x4);
|
|
106
|
+
*sumsq_ptr = vaddvq_f32(sumsq_f32x4);
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
NK_PUBLIC void nk_reduce_moments_bf16_neonbfdot( //
|
|
110
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
111
|
+
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
112
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_bf16_t);
|
|
113
|
+
int aligned = (stride_bytes % sizeof(nk_bf16_t) == 0);
|
|
114
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
115
|
+
else if (!aligned) nk_reduce_moments_bf16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
116
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
|
|
117
|
+
nk_size_t left_count = count / 2;
|
|
118
|
+
nk_f32_t left_sum_value, left_sumsq_value, right_sum_value, right_sumsq_value;
|
|
119
|
+
nk_reduce_moments_bf16_neonbfdot(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
|
|
120
|
+
nk_reduce_moments_bf16_neonbfdot(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
121
|
+
&right_sum_value, &right_sumsq_value);
|
|
122
|
+
*sum_ptr = left_sum_value + right_sum_value, *sumsq_ptr = left_sumsq_value + right_sumsq_value;
|
|
123
|
+
}
|
|
124
|
+
else if (stride_elements == 1) nk_reduce_moments_bf16_neonbfdot_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
125
|
+
else if (stride_elements <= 4)
|
|
126
|
+
nk_reduce_moments_bf16_neonbfdot_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
|
|
127
|
+
else nk_reduce_moments_bf16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
/** @brief Convert 8 raw bf16 sign-magnitude u16 to order-preserving comparable i16.
|
|
131
|
+
* Positive bf16 values (sign=0) are left as-is: they already sort correctly as i16.
|
|
132
|
+
* Negative bf16 values (sign=1) have their magnitude bits flipped (XOR 0x7FFF)
|
|
133
|
+
* so that more-negative values map to more-negative i16 values. */
|
|
134
|
+
NK_INTERNAL int16x8_t nk_bf16x8_to_comparable_i16x8_neon_(uint16x8_t raw_u16x8) {
|
|
135
|
+
int16x8_t raw_i16x8 = vreinterpretq_s16_u16(raw_u16x8);
|
|
136
|
+
uint16x8_t is_negative_u16x8 = vtstq_u16(raw_u16x8, vdupq_n_u16(0x8000));
|
|
137
|
+
int16x8_t flipped_i16x8 = veorq_s16(raw_i16x8, vdupq_n_s16(0x7FFF));
|
|
138
|
+
return vbslq_s16(is_negative_u16x8, flipped_i16x8, raw_i16x8);
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
/** @brief Convert a comparable i16 value back to raw bf16 u16 bits.
|
|
142
|
+
* Reverses the transformation from nk_bf16x8_to_comparable_i16x8_neon_. */
|
|
143
|
+
NK_INTERNAL nk_u16_t nk_comparable_i16_to_bf16_raw_(nk_i16_t comparable) {
|
|
144
|
+
if (comparable < 0) return (nk_u16_t)(comparable ^ 0x7FFF);
|
|
145
|
+
return (nk_u16_t)comparable;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
NK_INTERNAL void nk_reduce_minmax_bf16_neonbfdot_contiguous_( //
|
|
149
|
+
nk_bf16_t const *data_ptr, nk_size_t count, //
|
|
150
|
+
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
151
|
+
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
152
|
+
int16x8_t min_i16x8 = vdupq_n_s16(NK_I16_MAX), max_i16x8 = vdupq_n_s16(NK_I16_MIN);
|
|
153
|
+
uint16x8_t min_iter_u16x8 = vdupq_n_u16(0), max_iter_u16x8 = vdupq_n_u16(0);
|
|
154
|
+
uint16x8_t iter_u16x8 = vdupq_n_u16(0), one_u16x8 = vdupq_n_u16(1);
|
|
155
|
+
nk_size_t idx = 0;
|
|
156
|
+
for (; idx + 8 <= count; idx += 8) {
|
|
157
|
+
uint16x8_t raw_u16x8 = vld1q_u16((uint16_t const *)(data_ptr + idx));
|
|
158
|
+
int16x8_t comparable_i16x8 = nk_bf16x8_to_comparable_i16x8_neon_(raw_u16x8);
|
|
159
|
+
uint16x8_t less_u16x8 = vcltq_s16(comparable_i16x8, min_i16x8);
|
|
160
|
+
uint16x8_t greater_u16x8 = vcgtq_s16(comparable_i16x8, max_i16x8);
|
|
161
|
+
min_i16x8 = vbslq_s16(less_u16x8, comparable_i16x8, min_i16x8);
|
|
162
|
+
max_i16x8 = vbslq_s16(greater_u16x8, comparable_i16x8, max_i16x8);
|
|
163
|
+
min_iter_u16x8 = vbslq_u16(less_u16x8, iter_u16x8, min_iter_u16x8);
|
|
164
|
+
max_iter_u16x8 = vbslq_u16(greater_u16x8, iter_u16x8, max_iter_u16x8);
|
|
165
|
+
iter_u16x8 = vaddq_u16(iter_u16x8, one_u16x8);
|
|
166
|
+
}
|
|
167
|
+
// Handle tail with partial load and identity masking
|
|
168
|
+
nk_size_t remaining = count - idx;
|
|
169
|
+
if (remaining > 0) {
|
|
170
|
+
nk_b128_vec_t tail_vec;
|
|
171
|
+
nk_partial_load_b16x8_serial_(data_ptr + idx, &tail_vec, remaining);
|
|
172
|
+
int16x8_t comparable_i16x8 = nk_bf16x8_to_comparable_i16x8_neon_(tail_vec.u16x8);
|
|
173
|
+
uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
|
|
174
|
+
vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
|
|
175
|
+
uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((uint16_t)remaining));
|
|
176
|
+
int16x8_t data_for_min_i16x8 = vbslq_s16(valid_u16x8, comparable_i16x8, vdupq_n_s16(NK_I16_MAX));
|
|
177
|
+
int16x8_t data_for_max_i16x8 = vbslq_s16(valid_u16x8, comparable_i16x8, vdupq_n_s16(NK_I16_MIN));
|
|
178
|
+
uint16x8_t less_u16x8 = vcltq_s16(data_for_min_i16x8, min_i16x8);
|
|
179
|
+
uint16x8_t greater_u16x8 = vcgtq_s16(data_for_max_i16x8, max_i16x8);
|
|
180
|
+
min_i16x8 = vbslq_s16(less_u16x8, data_for_min_i16x8, min_i16x8);
|
|
181
|
+
max_i16x8 = vbslq_s16(greater_u16x8, data_for_max_i16x8, max_i16x8);
|
|
182
|
+
min_iter_u16x8 = vbslq_u16(less_u16x8, iter_u16x8, min_iter_u16x8);
|
|
183
|
+
max_iter_u16x8 = vbslq_u16(greater_u16x8, iter_u16x8, max_iter_u16x8);
|
|
184
|
+
}
|
|
185
|
+
// Horizontal reduction
|
|
186
|
+
nk_i16_t min_comparable = vminvq_s16(min_i16x8), max_comparable = vmaxvq_s16(max_i16x8);
|
|
187
|
+
// All-NaN early return: both sentinels unchanged means no valid data was found
|
|
188
|
+
if (min_comparable == NK_I16_MAX && max_comparable == NK_I16_MIN) {
|
|
189
|
+
*(nk_u16_t *)min_value_ptr = nk_comparable_i16_to_bf16_raw_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
|
|
190
|
+
*(nk_u16_t *)max_value_ptr = nk_comparable_i16_to_bf16_raw_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
|
|
191
|
+
return;
|
|
192
|
+
}
|
|
193
|
+
uint16x8_t min_value_match_u16x8 = vceqq_s16(min_i16x8, vdupq_n_s16(min_comparable));
|
|
194
|
+
uint16x8_t masked_min_iter_u16x8 = vbslq_u16(min_value_match_u16x8, min_iter_u16x8, vdupq_n_u16(0xFFFF));
|
|
195
|
+
nk_u16_t earliest_min_cycle = vminvq_u16(masked_min_iter_u16x8);
|
|
196
|
+
uint16x8_t max_value_match_u16x8 = vceqq_s16(max_i16x8, vdupq_n_s16(max_comparable));
|
|
197
|
+
uint16x8_t masked_max_iter_u16x8 = vbslq_u16(max_value_match_u16x8, max_iter_u16x8, vdupq_n_u16(0xFFFF));
|
|
198
|
+
nk_u16_t earliest_max_cycle = vminvq_u16(masked_max_iter_u16x8);
|
|
199
|
+
uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
|
|
200
|
+
vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
|
|
201
|
+
uint16x8_t min_cycle_match_u16x8 = vceqq_u16(min_iter_u16x8, vdupq_n_u16(earliest_min_cycle));
|
|
202
|
+
uint16x8_t min_both_match_u16x8 = vandq_u16(min_value_match_u16x8, min_cycle_match_u16x8);
|
|
203
|
+
uint16x8_t min_masked_lanes_u16x8 = vbslq_u16(min_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
|
|
204
|
+
nk_u16_t min_lane_offset = vminvq_u16(min_masked_lanes_u16x8);
|
|
205
|
+
nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 8 + (nk_size_t)min_lane_offset;
|
|
206
|
+
uint16x8_t max_cycle_match_u16x8 = vceqq_u16(max_iter_u16x8, vdupq_n_u16(earliest_max_cycle));
|
|
207
|
+
uint16x8_t max_both_match_u16x8 = vandq_u16(max_value_match_u16x8, max_cycle_match_u16x8);
|
|
208
|
+
uint16x8_t max_masked_lanes_u16x8 = vbslq_u16(max_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
|
|
209
|
+
nk_u16_t max_lane_offset = vminvq_u16(max_masked_lanes_u16x8);
|
|
210
|
+
nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 8 + (nk_size_t)max_lane_offset;
|
|
211
|
+
// Convert comparable back to bf16 raw bits
|
|
212
|
+
nk_u16_t min_raw = nk_comparable_i16_to_bf16_raw_(min_comparable);
|
|
213
|
+
nk_u16_t max_raw = nk_comparable_i16_to_bf16_raw_(max_comparable);
|
|
214
|
+
*(nk_u16_t *)min_value_ptr = min_raw, *min_index_ptr = min_idx;
|
|
215
|
+
*(nk_u16_t *)max_value_ptr = max_raw, *max_index_ptr = max_idx;
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
NK_INTERNAL void nk_reduce_minmax_bf16_neonbfdot_strided_( //
|
|
219
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
220
|
+
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
221
|
+
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
222
|
+
int16x8_t min_i16x8 = vdupq_n_s16(NK_I16_MAX), max_i16x8 = vdupq_n_s16(NK_I16_MIN);
|
|
223
|
+
uint16x8_t min_iter_u16x8 = vdupq_n_u16(0), max_iter_u16x8 = vdupq_n_u16(0);
|
|
224
|
+
uint16x8_t iter_u16x8 = vdupq_n_u16(0), one_u16x8 = vdupq_n_u16(1);
|
|
225
|
+
uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
|
|
226
|
+
vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
|
|
227
|
+
nk_size_t idx = 0;
|
|
228
|
+
int16x8_t data_for_min_i16x8, data_for_max_i16x8;
|
|
229
|
+
|
|
230
|
+
nk_reduce_minmax_bf16_neonbfdot_cycle:
|
|
231
|
+
if (stride_elements == 2 && idx + 8 <= count) {
|
|
232
|
+
uint16x8x2_t loaded = vld2q_u16((uint16_t const *)(data_ptr + idx * 2));
|
|
233
|
+
int16x8_t comparable_i16x8 = nk_bf16x8_to_comparable_i16x8_neon_(loaded.val[0]);
|
|
234
|
+
data_for_min_i16x8 = comparable_i16x8;
|
|
235
|
+
data_for_max_i16x8 = comparable_i16x8;
|
|
236
|
+
idx += 8;
|
|
237
|
+
}
|
|
238
|
+
else if (stride_elements == 3 && idx + 8 <= count) {
|
|
239
|
+
uint16x8x3_t loaded = vld3q_u16((uint16_t const *)(data_ptr + idx * 3));
|
|
240
|
+
int16x8_t comparable_i16x8 = nk_bf16x8_to_comparable_i16x8_neon_(loaded.val[0]);
|
|
241
|
+
data_for_min_i16x8 = comparable_i16x8;
|
|
242
|
+
data_for_max_i16x8 = comparable_i16x8;
|
|
243
|
+
idx += 8;
|
|
244
|
+
}
|
|
245
|
+
else if (stride_elements == 4 && idx + 8 <= count) {
|
|
246
|
+
uint16x8x4_t loaded = vld4q_u16((uint16_t const *)(data_ptr + idx * 4));
|
|
247
|
+
int16x8_t comparable_i16x8 = nk_bf16x8_to_comparable_i16x8_neon_(loaded.val[0]);
|
|
248
|
+
data_for_min_i16x8 = comparable_i16x8;
|
|
249
|
+
data_for_max_i16x8 = comparable_i16x8;
|
|
250
|
+
idx += 8;
|
|
251
|
+
}
|
|
252
|
+
else if (idx < count) {
|
|
253
|
+
nk_b128_vec_t tail_vec;
|
|
254
|
+
nk_strided_load_b16x8_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
|
|
255
|
+
int16x8_t comparable_i16x8 = nk_bf16x8_to_comparable_i16x8_neon_(tail_vec.u16x8);
|
|
256
|
+
uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((uint16_t)(count - idx)));
|
|
257
|
+
data_for_min_i16x8 = vbslq_s16(valid_u16x8, comparable_i16x8, vdupq_n_s16(NK_I16_MAX));
|
|
258
|
+
data_for_max_i16x8 = vbslq_s16(valid_u16x8, comparable_i16x8, vdupq_n_s16(NK_I16_MIN));
|
|
259
|
+
idx = count;
|
|
260
|
+
}
|
|
261
|
+
else {
|
|
262
|
+
nk_i16_t min_comparable = vminvq_s16(min_i16x8), max_comparable = vmaxvq_s16(max_i16x8);
|
|
263
|
+
if (min_comparable == NK_I16_MAX && max_comparable == NK_I16_MIN) {
|
|
264
|
+
*(nk_u16_t *)min_value_ptr = nk_comparable_i16_to_bf16_raw_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
|
|
265
|
+
*(nk_u16_t *)max_value_ptr = nk_comparable_i16_to_bf16_raw_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
|
|
266
|
+
return;
|
|
267
|
+
}
|
|
268
|
+
uint16x8_t min_value_match_u16x8 = vceqq_s16(min_i16x8, vdupq_n_s16(min_comparable));
|
|
269
|
+
uint16x8_t masked_min_iter_u16x8 = vbslq_u16(min_value_match_u16x8, min_iter_u16x8, vdupq_n_u16(0xFFFF));
|
|
270
|
+
nk_u16_t earliest_min_cycle = vminvq_u16(masked_min_iter_u16x8);
|
|
271
|
+
uint16x8_t max_value_match_u16x8 = vceqq_s16(max_i16x8, vdupq_n_s16(max_comparable));
|
|
272
|
+
uint16x8_t masked_max_iter_u16x8 = vbslq_u16(max_value_match_u16x8, max_iter_u16x8, vdupq_n_u16(0xFFFF));
|
|
273
|
+
nk_u16_t earliest_max_cycle = vminvq_u16(masked_max_iter_u16x8);
|
|
274
|
+
uint16x8_t min_cycle_match_u16x8 = vceqq_u16(min_iter_u16x8, vdupq_n_u16(earliest_min_cycle));
|
|
275
|
+
uint16x8_t min_both_match_u16x8 = vandq_u16(min_value_match_u16x8, min_cycle_match_u16x8);
|
|
276
|
+
uint16x8_t min_masked_lanes_u16x8 = vbslq_u16(min_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
|
|
277
|
+
nk_u16_t min_lane_offset = vminvq_u16(min_masked_lanes_u16x8);
|
|
278
|
+
nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 8 + (nk_size_t)min_lane_offset;
|
|
279
|
+
uint16x8_t max_cycle_match_u16x8 = vceqq_u16(max_iter_u16x8, vdupq_n_u16(earliest_max_cycle));
|
|
280
|
+
uint16x8_t max_both_match_u16x8 = vandq_u16(max_value_match_u16x8, max_cycle_match_u16x8);
|
|
281
|
+
uint16x8_t max_masked_lanes_u16x8 = vbslq_u16(max_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
|
|
282
|
+
nk_u16_t max_lane_offset = vminvq_u16(max_masked_lanes_u16x8);
|
|
283
|
+
nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 8 + (nk_size_t)max_lane_offset;
|
|
284
|
+
nk_u16_t min_raw = nk_comparable_i16_to_bf16_raw_(min_comparable);
|
|
285
|
+
nk_u16_t max_raw = nk_comparable_i16_to_bf16_raw_(max_comparable);
|
|
286
|
+
*(nk_u16_t *)min_value_ptr = min_raw, *min_index_ptr = min_idx;
|
|
287
|
+
*(nk_u16_t *)max_value_ptr = max_raw, *max_index_ptr = max_idx;
|
|
288
|
+
return;
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
// Shared update body
|
|
292
|
+
uint16x8_t less_u16x8 = vcltq_s16(data_for_min_i16x8, min_i16x8);
|
|
293
|
+
uint16x8_t greater_u16x8 = vcgtq_s16(data_for_max_i16x8, max_i16x8);
|
|
294
|
+
min_i16x8 = vbslq_s16(less_u16x8, data_for_min_i16x8, min_i16x8);
|
|
295
|
+
max_i16x8 = vbslq_s16(greater_u16x8, data_for_max_i16x8, max_i16x8);
|
|
296
|
+
min_iter_u16x8 = vbslq_u16(less_u16x8, iter_u16x8, min_iter_u16x8);
|
|
297
|
+
max_iter_u16x8 = vbslq_u16(greater_u16x8, iter_u16x8, max_iter_u16x8);
|
|
298
|
+
iter_u16x8 = vaddq_u16(iter_u16x8, one_u16x8);
|
|
299
|
+
goto nk_reduce_minmax_bf16_neonbfdot_cycle;
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
NK_PUBLIC void nk_reduce_minmax_bf16_neonbfdot( //
|
|
303
|
+
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
304
|
+
nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
|
|
305
|
+
nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
|
|
306
|
+
nk_size_t stride_elements = stride_bytes / sizeof(nk_bf16_t);
|
|
307
|
+
int aligned = (stride_bytes % sizeof(nk_bf16_t) == 0);
|
|
308
|
+
if (count == 0) {
|
|
309
|
+
*(nk_u16_t *)min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX;
|
|
310
|
+
*(nk_u16_t *)max_value_ptr = NK_BF16_MIN, *max_index_ptr = NK_SIZE_MAX;
|
|
311
|
+
}
|
|
312
|
+
else if (!aligned)
|
|
313
|
+
nk_reduce_minmax_bf16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
314
|
+
max_index_ptr);
|
|
315
|
+
else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
|
|
316
|
+
nk_size_t left_count = count / 2;
|
|
317
|
+
nk_bf16_t left_min_value, right_min_value, left_max_value, right_max_value;
|
|
318
|
+
nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
|
|
319
|
+
nk_reduce_minmax_bf16_neonbfdot(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index,
|
|
320
|
+
&left_max_value, &left_max_index);
|
|
321
|
+
nk_reduce_minmax_bf16_neonbfdot(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
|
|
322
|
+
&right_min_value, &right_min_index, &right_max_value, &right_max_index);
|
|
323
|
+
if (nk_bf16_order_serial(right_min_value, left_min_value) < 0)
|
|
324
|
+
*min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
|
|
325
|
+
else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
|
|
326
|
+
if (nk_bf16_order_serial(right_max_value, left_max_value) > 0)
|
|
327
|
+
*max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
|
|
328
|
+
else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
|
|
329
|
+
}
|
|
330
|
+
else if (stride_elements == 1)
|
|
331
|
+
nk_reduce_minmax_bf16_neonbfdot_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
332
|
+
max_index_ptr);
|
|
333
|
+
else if (stride_elements <= 4)
|
|
334
|
+
nk_reduce_minmax_bf16_neonbfdot_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
|
|
335
|
+
max_value_ptr, max_index_ptr);
|
|
336
|
+
else
|
|
337
|
+
nk_reduce_minmax_bf16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
|
|
338
|
+
max_index_ptr);
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
#if defined(__clang__)
|
|
342
|
+
#pragma clang attribute pop
|
|
343
|
+
#elif defined(__GNUC__)
|
|
344
|
+
#pragma GCC pop_options
|
|
345
|
+
#endif
|
|
346
|
+
|
|
347
|
+
#if defined(__cplusplus)
|
|
348
|
+
} // extern "C"
|
|
349
|
+
#endif
|
|
350
|
+
|
|
351
|
+
#endif // NK_TARGET_NEONBFDOT
|
|
352
|
+
#endif // NK_TARGET_ARM_
|
|
353
|
+
#endif // NK_REDUCE_NEONBFDOT_H
|