numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,3841 @@
1
+ /**
2
+ * @brief Base NEON (ARMv8-A) implementations for the redesigned reduction API (moments + minmax).
3
+ * @file include/numkong/reduce/neon.h
4
+ * @author Ash Vardanian
5
+ * @date February 13, 2026
6
+ *
7
+ * @sa include/numkong/reduce.h
8
+ */
9
+ #ifndef NK_REDUCE_NEON_H
10
+ #define NK_REDUCE_NEON_H
11
+
12
+ #if NK_TARGET_ARM_
13
+ #if NK_TARGET_NEON
14
+
15
+ #include "numkong/types.h" // `nk_size_t`
16
+ #include "numkong/cast/neon.h" // `nk_e4m3x16_to_f16x8x2_neon_`
17
+ #include "numkong/cast/serial.h" // `nk_e4m3_to_f16_serial`
18
+
19
+ #if defined(__cplusplus)
20
+ extern "C" {
21
+ #endif
22
+
23
+ #if defined(__clang__)
24
+ #pragma clang attribute push(__attribute__((target("arch=armv8-a+simd"))), apply_to = function)
25
+ #elif defined(__GNUC__)
26
+ #pragma GCC push_options
27
+ #pragma GCC target("arch=armv8-a+simd")
28
+ #endif
29
+
30
+ NK_INTERNAL nk_u64_t nk_reduce_sadd_u64x2_neon_(uint64x2_t v) {
31
+ uint64x2_t swapped_u64x2 = vextq_u64(v, v, 1);
32
+ return vgetq_lane_u64(vqaddq_u64(v, swapped_u64x2), 0);
33
+ }
34
+
35
+ /** @brief Saturating square of each i64 lane → u64. If |a| >= 2^32, a² overflows u64 → saturate. */
36
+ NK_INTERNAL uint64x2_t nk_i64_smul_sq_i64x2_neon_(int64x2_t val) {
37
+ uint64x2_t absolute_u64x2 = vreinterpretq_u64_s64(vabsq_s64(val));
38
+ uint32x2_t low_halves_u32x2 = vmovn_u64(absolute_u64x2);
39
+ uint64x2_t high_bits_u64x2 = vshrq_n_u64(absolute_u64x2, 32);
40
+ uint64x2_t low_squared_u64x2 = vmull_u32(low_halves_u32x2, low_halves_u32x2);
41
+ uint64x2_t is_small_u64x2 = vceqq_u64(high_bits_u64x2, vdupq_n_u64(0));
42
+ return vbslq_u64(is_small_u64x2, low_squared_u64x2, vdupq_n_u64(NK_U64_MAX));
43
+ }
44
+
45
+ /** @brief Saturating square of each u64 lane → u64. If a >= 2^32, a² overflows u64 → saturate. */
46
+ NK_INTERNAL uint64x2_t nk_u64_smul_sq_u64x2_neon_(uint64x2_t val) {
47
+ uint32x2_t low_halves_u32x2 = vmovn_u64(val);
48
+ uint64x2_t high_bits_u64x2 = vshrq_n_u64(val, 32);
49
+ uint64x2_t low_squared_u64x2 = vmull_u32(low_halves_u32x2, low_halves_u32x2);
50
+ uint64x2_t is_small_u64x2 = vceqq_u64(high_bits_u64x2, vdupq_n_u64(0));
51
+ return vbslq_u64(is_small_u64x2, low_squared_u64x2, vdupq_n_u64(NK_U64_MAX));
52
+ }
53
+
54
+ NK_INTERNAL void nk_reduce_moments_f32_neon_contiguous_( //
55
+ nk_f32_t const *data_ptr, nk_size_t count, //
56
+ nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
57
+ float64x2_t sum_f64x2 = vdupq_n_f64(0), sumsq_f64x2 = vdupq_n_f64(0);
58
+ nk_size_t idx = 0;
59
+ for (; idx + 4 <= count; idx += 4) {
60
+ float32x4_t data_f32x4 = vld1q_f32(data_ptr + idx);
61
+ float64x2_t data_low_f64x2 = vcvt_f64_f32(vget_low_f32(data_f32x4));
62
+ float64x2_t data_high_f64x2 = vcvt_f64_f32(vget_high_f32(data_f32x4));
63
+ sum_f64x2 = vaddq_f64(sum_f64x2, data_low_f64x2);
64
+ sum_f64x2 = vaddq_f64(sum_f64x2, data_high_f64x2);
65
+ sumsq_f64x2 = vfmaq_f64(sumsq_f64x2, data_low_f64x2, data_low_f64x2);
66
+ sumsq_f64x2 = vfmaq_f64(sumsq_f64x2, data_high_f64x2, data_high_f64x2);
67
+ }
68
+ nk_f64_t sum = vaddvq_f64(sum_f64x2), sumsq = vaddvq_f64(sumsq_f64x2);
69
+ for (; idx < count; ++idx) {
70
+ nk_f64_t value = (nk_f64_t)data_ptr[idx];
71
+ sum += value, sumsq += value * value;
72
+ }
73
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
74
+ }
75
+
76
+ NK_INTERNAL void nk_reduce_moments_f32_neon_strided_( //
77
+ nk_f32_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
78
+ nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
79
+ float64x2_t sum_f64x2 = vdupq_n_f64(0), sumsq_f64x2 = vdupq_n_f64(0);
80
+ nk_size_t idx = 0;
81
+ if (stride_elements == 2) {
82
+ for (; idx + 4 <= count; idx += 4) {
83
+ float32x4x2_t loaded_f32x4x2 = vld2q_f32(data_ptr + idx * 2);
84
+ float64x2_t data_low_f64x2 = vcvt_f64_f32(vget_low_f32(loaded_f32x4x2.val[0]));
85
+ float64x2_t data_high_f64x2 = vcvt_f64_f32(vget_high_f32(loaded_f32x4x2.val[0]));
86
+ sum_f64x2 = vaddq_f64(sum_f64x2, data_low_f64x2);
87
+ sum_f64x2 = vaddq_f64(sum_f64x2, data_high_f64x2);
88
+ sumsq_f64x2 = vfmaq_f64(sumsq_f64x2, data_low_f64x2, data_low_f64x2);
89
+ sumsq_f64x2 = vfmaq_f64(sumsq_f64x2, data_high_f64x2, data_high_f64x2);
90
+ }
91
+ }
92
+ else if (stride_elements == 3) {
93
+ for (; idx + 4 <= count; idx += 4) {
94
+ float32x4x3_t loaded_f32x4x3 = vld3q_f32(data_ptr + idx * 3);
95
+ float64x2_t data_low_f64x2 = vcvt_f64_f32(vget_low_f32(loaded_f32x4x3.val[0]));
96
+ float64x2_t data_high_f64x2 = vcvt_f64_f32(vget_high_f32(loaded_f32x4x3.val[0]));
97
+ sum_f64x2 = vaddq_f64(sum_f64x2, data_low_f64x2);
98
+ sum_f64x2 = vaddq_f64(sum_f64x2, data_high_f64x2);
99
+ sumsq_f64x2 = vfmaq_f64(sumsq_f64x2, data_low_f64x2, data_low_f64x2);
100
+ sumsq_f64x2 = vfmaq_f64(sumsq_f64x2, data_high_f64x2, data_high_f64x2);
101
+ }
102
+ }
103
+ else {
104
+ for (; idx + 4 <= count; idx += 4) {
105
+ float32x4x4_t loaded_f32x4x4 = vld4q_f32(data_ptr + idx * 4);
106
+ float64x2_t data_low_f64x2 = vcvt_f64_f32(vget_low_f32(loaded_f32x4x4.val[0]));
107
+ float64x2_t data_high_f64x2 = vcvt_f64_f32(vget_high_f32(loaded_f32x4x4.val[0]));
108
+ sum_f64x2 = vaddq_f64(sum_f64x2, data_low_f64x2);
109
+ sum_f64x2 = vaddq_f64(sum_f64x2, data_high_f64x2);
110
+ sumsq_f64x2 = vfmaq_f64(sumsq_f64x2, data_low_f64x2, data_low_f64x2);
111
+ sumsq_f64x2 = vfmaq_f64(sumsq_f64x2, data_high_f64x2, data_high_f64x2);
112
+ }
113
+ }
114
+ nk_f64_t sum = vaddvq_f64(sum_f64x2), sumsq = vaddvq_f64(sumsq_f64x2);
115
+ nk_f32_t const *current_ptr = data_ptr + idx * stride_elements;
116
+ for (; idx < count; ++idx, current_ptr += stride_elements) {
117
+ nk_f64_t value = (nk_f64_t)(*current_ptr);
118
+ sum += value, sumsq += value * value;
119
+ }
120
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
121
+ }
122
+
123
+ NK_PUBLIC void nk_reduce_moments_f32_neon( //
124
+ nk_f32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
125
+ nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
126
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_f32_t);
127
+ int aligned = (stride_bytes % sizeof(nk_f32_t) == 0);
128
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
129
+ else if (!aligned) nk_reduce_moments_f32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
130
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 4) {
131
+ nk_size_t left_count = count / 2;
132
+ nk_f64_t left_sum, left_sumsq, right_sum, right_sumsq;
133
+ nk_reduce_moments_f32_neon(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
134
+ nk_reduce_moments_f32_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
135
+ &right_sum, &right_sumsq);
136
+ *sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
137
+ }
138
+ else if (stride_elements == 1) nk_reduce_moments_f32_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
139
+ else if (stride_elements <= 4)
140
+ nk_reduce_moments_f32_neon_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
141
+ else nk_reduce_moments_f32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
142
+ }
143
+
144
+ NK_INTERNAL void nk_reduce_minmax_f32_neon_contiguous_( //
145
+ nk_f32_t const *data_ptr, nk_size_t count, //
146
+ nk_f32_t *min_value_ptr, nk_size_t *min_index_ptr, //
147
+ nk_f32_t *max_value_ptr, nk_size_t *max_index_ptr) {
148
+ float32x4_t min_f32x4 = vdupq_n_f32(NK_F32_MAX), max_f32x4 = vdupq_n_f32(NK_F32_MIN);
149
+ uint32x4_t min_iter_u32x4 = vdupq_n_u32(0), max_iter_u32x4 = vdupq_n_u32(0);
150
+ uint32x4_t iter_u32x4 = vdupq_n_u32(0), one_u32x4 = vdupq_n_u32(1);
151
+ nk_size_t idx = 0;
152
+ for (; idx + 4 <= count; idx += 4) {
153
+ float32x4_t data_f32x4 = vld1q_f32(data_ptr + idx);
154
+ uint32x4_t less_u32x4 = vcltq_f32(data_f32x4, min_f32x4);
155
+ uint32x4_t greater_u32x4 = vcgtq_f32(data_f32x4, max_f32x4);
156
+ min_f32x4 = vbslq_f32(less_u32x4, data_f32x4, min_f32x4);
157
+ max_f32x4 = vbslq_f32(greater_u32x4, data_f32x4, max_f32x4);
158
+ min_iter_u32x4 = vbslq_u32(less_u32x4, iter_u32x4, min_iter_u32x4);
159
+ max_iter_u32x4 = vbslq_u32(greater_u32x4, iter_u32x4, max_iter_u32x4);
160
+ iter_u32x4 = vaddq_u32(iter_u32x4, one_u32x4);
161
+ }
162
+ nk_size_t remaining = count - idx;
163
+ if (remaining > 0) {
164
+ nk_b128_vec_t tail_vec;
165
+ nk_partial_load_b32x4_serial_(data_ptr + idx, &tail_vec, remaining);
166
+ uint32x4_t lane_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
167
+ vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
168
+ uint32x4_t valid_u32x4 = vcltq_u32(lane_u32x4, vdupq_n_u32((uint32_t)remaining));
169
+ float32x4_t data_min_f32x4 = vbslq_f32(valid_u32x4, tail_vec.f32x4, min_f32x4);
170
+ float32x4_t data_max_f32x4 = vbslq_f32(valid_u32x4, tail_vec.f32x4, max_f32x4);
171
+ uint32x4_t less_u32x4 = vcltq_f32(data_min_f32x4, min_f32x4);
172
+ uint32x4_t greater_u32x4 = vcgtq_f32(data_max_f32x4, max_f32x4);
173
+ min_f32x4 = vbslq_f32(less_u32x4, data_min_f32x4, min_f32x4);
174
+ max_f32x4 = vbslq_f32(greater_u32x4, data_max_f32x4, max_f32x4);
175
+ min_iter_u32x4 = vbslq_u32(less_u32x4, iter_u32x4, min_iter_u32x4);
176
+ max_iter_u32x4 = vbslq_u32(greater_u32x4, iter_u32x4, max_iter_u32x4);
177
+ }
178
+ nk_f32_t min_value = vminvq_f32(min_f32x4), max_value = vmaxvq_f32(max_f32x4);
179
+
180
+ // All-NaN / sentinel check: sentinels remain unchanged when all data is NaN.
181
+ if (min_value == NK_F32_MAX && max_value == NK_F32_MIN) {
182
+ *min_value_ptr = NK_F32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F32_MIN,
183
+ *max_index_ptr = NK_SIZE_MAX;
184
+ return;
185
+ }
186
+
187
+ uint32x4_t min_value_match_u32x4 = vceqq_f32(min_f32x4, vdupq_n_f32(min_value));
188
+ uint32x4_t masked_min_iter_u32x4 = vbslq_u32(min_value_match_u32x4, min_iter_u32x4, vdupq_n_u32(NK_U32_MAX));
189
+ nk_u32_t earliest_min_cycle = vminvq_u32(masked_min_iter_u32x4);
190
+ uint32x4_t max_value_match_u32x4 = vceqq_f32(max_f32x4, vdupq_n_f32(max_value));
191
+ uint32x4_t masked_max_iter_u32x4 = vbslq_u32(max_value_match_u32x4, max_iter_u32x4, vdupq_n_u32(NK_U32_MAX));
192
+ nk_u32_t earliest_max_cycle = vminvq_u32(masked_max_iter_u32x4);
193
+ uint32x4_t lane_indices_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
194
+ vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
195
+ uint32x4_t min_cycle_match_u32x4 = vceqq_u32(min_iter_u32x4, vdupq_n_u32(earliest_min_cycle));
196
+ uint32x4_t min_both_match_u32x4 = vandq_u32(min_value_match_u32x4, min_cycle_match_u32x4);
197
+ uint32x4_t min_masked_lanes_u32x4 = vbslq_u32(min_both_match_u32x4, lane_indices_u32x4, vdupq_n_u32(NK_U32_MAX));
198
+ nk_u32_t min_lane_offset = vminvq_u32(min_masked_lanes_u32x4);
199
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 4 + (nk_size_t)min_lane_offset;
200
+ uint32x4_t max_cycle_match_u32x4 = vceqq_u32(max_iter_u32x4, vdupq_n_u32(earliest_max_cycle));
201
+ uint32x4_t max_both_match_u32x4 = vandq_u32(max_value_match_u32x4, max_cycle_match_u32x4);
202
+ uint32x4_t max_masked_lanes_u32x4 = vbslq_u32(max_both_match_u32x4, lane_indices_u32x4, vdupq_n_u32(NK_U32_MAX));
203
+ nk_u32_t max_lane_offset = vminvq_u32(max_masked_lanes_u32x4);
204
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 4 + (nk_size_t)max_lane_offset;
205
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
206
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
207
+ }
208
+
209
+ NK_INTERNAL void nk_reduce_minmax_f32_neon_strided_( //
210
+ nk_f32_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
211
+ nk_f32_t *min_value_ptr, nk_size_t *min_index_ptr, //
212
+ nk_f32_t *max_value_ptr, nk_size_t *max_index_ptr) {
213
+ float32x4_t min_f32x4 = vdupq_n_f32(NK_F32_MAX), max_f32x4 = vdupq_n_f32(NK_F32_MIN);
214
+ uint32x4_t min_iter_u32x4 = vdupq_n_u32(0), max_iter_u32x4 = vdupq_n_u32(0);
215
+ uint32x4_t iter_u32x4 = vdupq_n_u32(0), one_u32x4 = vdupq_n_u32(1);
216
+ uint32x4_t lane_indices_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
217
+ vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
218
+ nk_size_t idx = 0;
219
+ float32x4_t data_for_min_f32x4, data_for_max_f32x4;
220
+
221
+ nk_reduce_minmax_f32_neon_cycle:
222
+ if (stride_elements == 2 && idx + 4 <= count) {
223
+ float32x4x2_t loaded = vld2q_f32(data_ptr + idx * 2);
224
+ data_for_min_f32x4 = loaded.val[0];
225
+ data_for_max_f32x4 = loaded.val[0];
226
+ idx += 4;
227
+ }
228
+ else if (stride_elements == 3 && idx + 4 <= count) {
229
+ float32x4x3_t loaded = vld3q_f32(data_ptr + idx * 3);
230
+ data_for_min_f32x4 = loaded.val[0];
231
+ data_for_max_f32x4 = loaded.val[0];
232
+ idx += 4;
233
+ }
234
+ else if (stride_elements == 4 && idx + 4 <= count) {
235
+ float32x4x4_t loaded = vld4q_f32(data_ptr + idx * 4);
236
+ data_for_min_f32x4 = loaded.val[0];
237
+ data_for_max_f32x4 = loaded.val[0];
238
+ idx += 4;
239
+ }
240
+ else if (idx < count) {
241
+ nk_b128_vec_t tail_vec;
242
+ nk_strided_load_b32x4_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
243
+ uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((uint32_t)(count - idx)));
244
+ data_for_min_f32x4 = vbslq_f32(valid_u32x4, tail_vec.f32x4, min_f32x4);
245
+ data_for_max_f32x4 = vbslq_f32(valid_u32x4, tail_vec.f32x4, max_f32x4);
246
+ idx = count;
247
+ }
248
+ else {
249
+ nk_f32_t min_value = vminvq_f32(min_f32x4), max_value = vmaxvq_f32(max_f32x4);
250
+ if (min_value == NK_F32_MAX && max_value == NK_F32_MIN) {
251
+ *min_value_ptr = NK_F32_MAX, *min_index_ptr = NK_SIZE_MAX;
252
+ *max_value_ptr = NK_F32_MIN, *max_index_ptr = NK_SIZE_MAX;
253
+ return;
254
+ }
255
+ uint32x4_t min_value_match_u32x4 = vceqq_f32(min_f32x4, vdupq_n_f32(min_value));
256
+ uint32x4_t masked_min_iter_u32x4 = vbslq_u32(min_value_match_u32x4, min_iter_u32x4, vdupq_n_u32(NK_U32_MAX));
257
+ nk_u32_t earliest_min_cycle = vminvq_u32(masked_min_iter_u32x4);
258
+ uint32x4_t max_value_match_u32x4 = vceqq_f32(max_f32x4, vdupq_n_f32(max_value));
259
+ uint32x4_t masked_max_iter_u32x4 = vbslq_u32(max_value_match_u32x4, max_iter_u32x4, vdupq_n_u32(NK_U32_MAX));
260
+ nk_u32_t earliest_max_cycle = vminvq_u32(masked_max_iter_u32x4);
261
+ uint32x4_t min_cycle_match_u32x4 = vceqq_u32(min_iter_u32x4, vdupq_n_u32(earliest_min_cycle));
262
+ uint32x4_t min_both_match_u32x4 = vandq_u32(min_value_match_u32x4, min_cycle_match_u32x4);
263
+ uint32x4_t min_masked_lanes_u32x4 = vbslq_u32(min_both_match_u32x4, lane_indices_u32x4,
264
+ vdupq_n_u32(NK_U32_MAX));
265
+ nk_u32_t min_lane_offset = vminvq_u32(min_masked_lanes_u32x4);
266
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 4 + (nk_size_t)min_lane_offset;
267
+ uint32x4_t max_cycle_match_u32x4 = vceqq_u32(max_iter_u32x4, vdupq_n_u32(earliest_max_cycle));
268
+ uint32x4_t max_both_match_u32x4 = vandq_u32(max_value_match_u32x4, max_cycle_match_u32x4);
269
+ uint32x4_t max_masked_lanes_u32x4 = vbslq_u32(max_both_match_u32x4, lane_indices_u32x4,
270
+ vdupq_n_u32(NK_U32_MAX));
271
+ nk_u32_t max_lane_offset = vminvq_u32(max_masked_lanes_u32x4);
272
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 4 + (nk_size_t)max_lane_offset;
273
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
274
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
275
+ return;
276
+ }
277
+
278
+ // Shared update body
279
+ uint32x4_t less_u32x4 = vcltq_f32(data_for_min_f32x4, min_f32x4);
280
+ uint32x4_t greater_u32x4 = vcgtq_f32(data_for_max_f32x4, max_f32x4);
281
+ min_f32x4 = vbslq_f32(less_u32x4, data_for_min_f32x4, min_f32x4);
282
+ max_f32x4 = vbslq_f32(greater_u32x4, data_for_max_f32x4, max_f32x4);
283
+ min_iter_u32x4 = vbslq_u32(less_u32x4, iter_u32x4, min_iter_u32x4);
284
+ max_iter_u32x4 = vbslq_u32(greater_u32x4, iter_u32x4, max_iter_u32x4);
285
+ iter_u32x4 = vaddq_u32(iter_u32x4, one_u32x4);
286
+ goto nk_reduce_minmax_f32_neon_cycle;
287
+ }
288
+
289
+ NK_PUBLIC void nk_reduce_minmax_f32_neon( //
290
+ nk_f32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
291
+ nk_f32_t *min_value_ptr, nk_size_t *min_index_ptr, //
292
+ nk_f32_t *max_value_ptr, nk_size_t *max_index_ptr) {
293
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_f32_t);
294
+ int aligned = (stride_bytes % sizeof(nk_f32_t) == 0);
295
+ if (count == 0)
296
+ *min_value_ptr = NK_F32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F32_MIN,
297
+ *max_index_ptr = NK_SIZE_MAX;
298
+ else if (!aligned)
299
+ nk_reduce_minmax_f32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
300
+ max_index_ptr);
301
+ else if (count > (nk_size_t)NK_U32_MAX * 4) {
302
+ nk_size_t left_count = count / 2;
303
+ nk_f32_t left_min_value, right_min_value, left_max_value, right_max_value;
304
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
305
+ nk_reduce_minmax_f32_neon(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index, &left_max_value,
306
+ &left_max_index);
307
+ nk_reduce_minmax_f32_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
308
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
309
+ if (right_min_value < left_min_value)
310
+ *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
311
+ else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
312
+ if (right_max_value > left_max_value)
313
+ *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
314
+ else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
315
+ }
316
+ else if (stride_elements == 1)
317
+ nk_reduce_minmax_f32_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
318
+ max_index_ptr);
319
+ else if (stride_elements <= 4)
320
+ nk_reduce_minmax_f32_neon_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
321
+ max_value_ptr, max_index_ptr);
322
+ else
323
+ nk_reduce_minmax_f32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
324
+ max_index_ptr);
325
+ }
326
+
327
+ NK_INTERNAL void nk_reduce_moments_f64_neon_contiguous_( //
328
+ nk_f64_t const *data_ptr, nk_size_t count, //
329
+ nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
330
+ float64x2_t sum_f64x2 = vdupq_n_f64(0), sum_compensation_f64x2 = vdupq_n_f64(0);
331
+ float64x2_t sumsq_f64x2 = vdupq_n_f64(0), sumsq_compensation_f64x2 = vdupq_n_f64(0);
332
+ nk_size_t idx = 0;
333
+ for (; idx + 2 <= count; idx += 2) {
334
+ float64x2_t data_f64x2 = vld1q_f64(data_ptr + idx);
335
+ float64x2_t temp_sum_f64x2 = vaddq_f64(sum_f64x2, data_f64x2);
336
+ float64x2_t residual_f64x2 = vsubq_f64(temp_sum_f64x2, sum_f64x2);
337
+ sum_compensation_f64x2 = vaddq_f64(sum_compensation_f64x2,
338
+ vaddq_f64(vsubq_f64(sum_f64x2, vsubq_f64(temp_sum_f64x2, residual_f64x2)),
339
+ vsubq_f64(data_f64x2, residual_f64x2)));
340
+ sum_f64x2 = temp_sum_f64x2;
341
+ float64x2_t data_squared_f64x2 = vmulq_f64(data_f64x2, data_f64x2);
342
+ float64x2_t temp_sumsq_f64x2 = vaddq_f64(sumsq_f64x2, data_squared_f64x2);
343
+ float64x2_t residual_sumsq_f64x2 = vsubq_f64(temp_sumsq_f64x2, sumsq_f64x2);
344
+ sumsq_compensation_f64x2 = vaddq_f64(
345
+ sumsq_compensation_f64x2,
346
+ vaddq_f64(vsubq_f64(sumsq_f64x2, vsubq_f64(temp_sumsq_f64x2, residual_sumsq_f64x2)),
347
+ vsubq_f64(data_squared_f64x2, residual_sumsq_f64x2)));
348
+ sumsq_f64x2 = temp_sumsq_f64x2;
349
+ }
350
+ nk_size_t remaining = count - idx;
351
+ if (remaining > 0) {
352
+ nk_b128_vec_t tail_vec;
353
+ nk_partial_load_b64x2_serial_(data_ptr + idx, &tail_vec, remaining);
354
+ float64x2_t data_f64x2 = tail_vec.f64x2;
355
+ float64x2_t temp_sum_f64x2 = vaddq_f64(sum_f64x2, data_f64x2);
356
+ float64x2_t residual_f64x2 = vsubq_f64(temp_sum_f64x2, sum_f64x2);
357
+ sum_compensation_f64x2 = vaddq_f64(sum_compensation_f64x2,
358
+ vaddq_f64(vsubq_f64(sum_f64x2, vsubq_f64(temp_sum_f64x2, residual_f64x2)),
359
+ vsubq_f64(data_f64x2, residual_f64x2)));
360
+ sum_f64x2 = temp_sum_f64x2;
361
+ float64x2_t data_squared_f64x2 = vmulq_f64(data_f64x2, data_f64x2);
362
+ float64x2_t temp_sumsq_f64x2 = vaddq_f64(sumsq_f64x2, data_squared_f64x2);
363
+ float64x2_t residual_sumsq_f64x2 = vsubq_f64(temp_sumsq_f64x2, sumsq_f64x2);
364
+ sumsq_compensation_f64x2 = vaddq_f64(
365
+ sumsq_compensation_f64x2,
366
+ vaddq_f64(vsubq_f64(sumsq_f64x2, vsubq_f64(temp_sumsq_f64x2, residual_sumsq_f64x2)),
367
+ vsubq_f64(data_squared_f64x2, residual_sumsq_f64x2)));
368
+ sumsq_f64x2 = temp_sumsq_f64x2;
369
+ }
370
+ *sum_ptr = vaddvq_f64(vaddq_f64(sum_f64x2, sum_compensation_f64x2));
371
+ *sumsq_ptr = vaddvq_f64(vaddq_f64(sumsq_f64x2, sumsq_compensation_f64x2));
372
+ }
373
+
374
+ NK_PUBLIC void nk_reduce_moments_f64_neon( //
375
+ nk_f64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
376
+ nk_f64_t *sum_ptr, nk_f64_t *sumsq_ptr) {
377
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_f64_t);
378
+ int aligned = (stride_bytes % sizeof(nk_f64_t) == 0);
379
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
380
+ else if (!aligned) nk_reduce_moments_f64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
381
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 2) {
382
+ nk_size_t left_count = count / 2;
383
+ nk_f64_t left_sum, left_sumsq, right_sum, right_sumsq;
384
+ nk_reduce_moments_f64_neon(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
385
+ nk_reduce_moments_f64_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
386
+ &right_sum, &right_sumsq);
387
+ *sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
388
+ }
389
+ else if (stride_elements == 1) nk_reduce_moments_f64_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
390
+ else nk_reduce_moments_f64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
391
+ }
392
+
393
+ NK_INTERNAL void nk_reduce_minmax_f64_neon_contiguous_( //
394
+ nk_f64_t const *data_ptr, nk_size_t count, //
395
+ nk_f64_t *min_value_ptr, nk_size_t *min_index_ptr, //
396
+ nk_f64_t *max_value_ptr, nk_size_t *max_index_ptr) {
397
+ float64x2_t min_f64x2 = vdupq_n_f64(NK_F64_MAX), max_f64x2 = vdupq_n_f64(NK_F64_MIN);
398
+ uint64x2_t min_iter = vdupq_n_u64(0), max_iter = vdupq_n_u64(0);
399
+ uint64x2_t iter = vdupq_n_u64(0), one = vdupq_n_u64(1);
400
+ nk_size_t idx = 0;
401
+ for (; idx + 2 <= count; idx += 2) {
402
+ float64x2_t data_f64x2 = vld1q_f64(data_ptr + idx);
403
+ uint64x2_t less_u64x2 = vcltq_f64(data_f64x2, min_f64x2);
404
+ uint64x2_t greater_u64x2 = vcgtq_f64(data_f64x2, max_f64x2);
405
+ min_f64x2 = vbslq_f64(less_u64x2, data_f64x2, min_f64x2);
406
+ max_f64x2 = vbslq_f64(greater_u64x2, data_f64x2, max_f64x2);
407
+ min_iter = vbslq_u64(less_u64x2, iter, min_iter);
408
+ max_iter = vbslq_u64(greater_u64x2, iter, max_iter);
409
+ iter = vaddq_u64(iter, one);
410
+ }
411
+ nk_b128_vec_t min_values_vec, max_values_vec, min_indices_vec, max_indices_vec;
412
+ min_values_vec.f64x2 = min_f64x2;
413
+ min_indices_vec.u64x2 = min_iter;
414
+ max_values_vec.f64x2 = max_f64x2;
415
+ max_indices_vec.u64x2 = max_iter;
416
+ nk_f64_t min_value, max_value;
417
+ nk_size_t min_index, max_index;
418
+ if (min_values_vec.f64s[0] <= min_values_vec.f64s[1])
419
+ min_value = min_values_vec.f64s[0], min_index = (nk_size_t)min_indices_vec.u64s[0] * 2;
420
+ else min_value = min_values_vec.f64s[1], min_index = (nk_size_t)min_indices_vec.u64s[1] * 2 + 1;
421
+ if (max_values_vec.f64s[0] >= max_values_vec.f64s[1])
422
+ max_value = max_values_vec.f64s[0], max_index = (nk_size_t)max_indices_vec.u64s[0] * 2;
423
+ else max_value = max_values_vec.f64s[1], max_index = (nk_size_t)max_indices_vec.u64s[1] * 2 + 1;
424
+ for (; idx < count; ++idx) {
425
+ nk_f64_t val = data_ptr[idx];
426
+ if (val < min_value) min_value = val, min_index = idx;
427
+ if (val > max_value) max_value = val, max_index = idx;
428
+ }
429
+ // All-NaN / sentinel check: sentinels remain unchanged when all data is NaN.
430
+ if (min_value == NK_F64_MAX && max_value == NK_F64_MIN) {
431
+ *min_value_ptr = NK_F64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F64_MIN,
432
+ *max_index_ptr = NK_SIZE_MAX;
433
+ return;
434
+ }
435
+ *min_value_ptr = min_value, *min_index_ptr = min_index;
436
+ *max_value_ptr = max_value, *max_index_ptr = max_index;
437
+ }
438
+
439
+ NK_PUBLIC void nk_reduce_minmax_f64_neon( //
440
+ nk_f64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
441
+ nk_f64_t *min_value_ptr, nk_size_t *min_index_ptr, //
442
+ nk_f64_t *max_value_ptr, nk_size_t *max_index_ptr) {
443
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_f64_t);
444
+ int aligned = (stride_bytes % sizeof(nk_f64_t) == 0);
445
+ if (count == 0)
446
+ *min_value_ptr = NK_F64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F64_MIN,
447
+ *max_index_ptr = NK_SIZE_MAX;
448
+ else if (!aligned)
449
+ nk_reduce_minmax_f64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
450
+ max_index_ptr);
451
+ else if (stride_elements == 1)
452
+ nk_reduce_minmax_f64_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
453
+ max_index_ptr);
454
+ else
455
+ nk_reduce_minmax_f64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
456
+ max_index_ptr);
457
+ }
458
+
459
+ NK_INTERNAL void nk_reduce_moments_i8_neon_contiguous_( //
460
+ nk_i8_t const *data_ptr, nk_size_t count, //
461
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
462
+ int32x4_t sum_i32x4 = vdupq_n_s32(0);
463
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
464
+ nk_size_t idx = 0;
465
+ for (; idx + 16 <= count; idx += 16) {
466
+ int8x16_t data_i8x16 = vld1q_s8(data_ptr + idx);
467
+ int16x8_t pairwise_i16x8 = vpaddlq_s8(data_i8x16);
468
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
469
+ int16x8_t squares_lo_i16x8 = vmull_s8(vget_low_s8(data_i8x16), vget_low_s8(data_i8x16));
470
+ int16x8_t squares_hi_i16x8 = vmull_high_s8(data_i8x16, data_i8x16);
471
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_lo_i16x8))));
472
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_hi_i16x8))));
473
+ }
474
+ nk_i64_t sum = vaddlvq_s32(sum_i32x4);
475
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
476
+ for (; idx < count; ++idx) {
477
+ nk_i64_t value_i64 = (nk_i64_t)data_ptr[idx];
478
+ sum += value_i64, sumsq += (nk_u64_t)(value_i64 * value_i64);
479
+ }
480
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
481
+ }
482
+
483
+ NK_INTERNAL void nk_reduce_moments_i8_neon_strided_( //
484
+ nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
485
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
486
+ int32x4_t sum_i32x4 = vdupq_n_s32(0);
487
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
488
+ nk_size_t idx = 0;
489
+ if (stride_elements == 2) {
490
+ for (; idx + 16 <= count; idx += 16) {
491
+ int8x16x2_t loaded_i8x16x2 = vld2q_s8(data_ptr + idx * 2);
492
+ int8x16_t data_i8x16 = loaded_i8x16x2.val[0];
493
+ int16x8_t pairwise_i16x8 = vpaddlq_s8(data_i8x16);
494
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
495
+ int16x8_t squares_lo_i16x8 = vmull_s8(vget_low_s8(data_i8x16), vget_low_s8(data_i8x16));
496
+ int16x8_t squares_hi_i16x8 = vmull_high_s8(data_i8x16, data_i8x16);
497
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_lo_i16x8))));
498
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_hi_i16x8))));
499
+ }
500
+ }
501
+ else if (stride_elements == 3) {
502
+ for (; idx + 16 <= count; idx += 16) {
503
+ int8x16x3_t loaded_i8x16x3 = vld3q_s8(data_ptr + idx * 3);
504
+ int8x16_t data_i8x16 = loaded_i8x16x3.val[0];
505
+ int16x8_t pairwise_i16x8 = vpaddlq_s8(data_i8x16);
506
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
507
+ int16x8_t squares_lo_i16x8 = vmull_s8(vget_low_s8(data_i8x16), vget_low_s8(data_i8x16));
508
+ int16x8_t squares_hi_i16x8 = vmull_high_s8(data_i8x16, data_i8x16);
509
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_lo_i16x8))));
510
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_hi_i16x8))));
511
+ }
512
+ }
513
+ else {
514
+ for (; idx + 16 <= count; idx += 16) {
515
+ int8x16x4_t loaded_i8x16x4 = vld4q_s8(data_ptr + idx * 4);
516
+ int8x16_t data_i8x16 = loaded_i8x16x4.val[0];
517
+ int16x8_t pairwise_i16x8 = vpaddlq_s8(data_i8x16);
518
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
519
+ int16x8_t squares_lo_i16x8 = vmull_s8(vget_low_s8(data_i8x16), vget_low_s8(data_i8x16));
520
+ int16x8_t squares_hi_i16x8 = vmull_high_s8(data_i8x16, data_i8x16);
521
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_lo_i16x8))));
522
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_hi_i16x8))));
523
+ }
524
+ }
525
+ nk_i64_t sum = vaddlvq_s32(sum_i32x4);
526
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
527
+ for (; idx < count; ++idx) {
528
+ nk_i64_t value_i64 = (nk_i64_t)data_ptr[idx * stride_elements];
529
+ sum += value_i64, sumsq += (nk_u64_t)(value_i64 * value_i64);
530
+ }
531
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
532
+ }
533
+
534
+ NK_PUBLIC void nk_reduce_moments_i8_neon( //
535
+ nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
536
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
537
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
538
+ int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
539
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
540
+ else if (!aligned) nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
541
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
542
+ nk_size_t left_count = count / 2;
543
+ nk_i64_t left_sum, right_sum;
544
+ nk_u64_t left_sumsq, right_sumsq;
545
+ nk_reduce_moments_i8_neon(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
546
+ nk_reduce_moments_i8_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes, &right_sum,
547
+ &right_sumsq);
548
+ *sum_ptr = nk_i64_saturating_add_serial(left_sum, right_sum);
549
+ *sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
550
+ }
551
+ else if (stride_elements == 1) nk_reduce_moments_i8_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
552
+ else if (stride_elements <= 4)
553
+ nk_reduce_moments_i8_neon_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
554
+ else nk_reduce_moments_i8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
555
+ }
556
+
557
+ NK_INTERNAL void nk_reduce_minmax_i8_neon_contiguous_( //
558
+ nk_i8_t const *data_ptr, nk_size_t count, //
559
+ nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
560
+ nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
561
+ int8x16_t min_i8x16 = vdupq_n_s8(NK_I8_MAX), max_i8x16 = vdupq_n_s8(NK_I8_MIN);
562
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
563
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
564
+ nk_size_t idx = 0;
565
+ for (; idx + 16 <= count; idx += 16) {
566
+ int8x16_t data_i8x16 = vld1q_s8(data_ptr + idx);
567
+ uint8x16_t less_u8x16 = vcltq_s8(data_i8x16, min_i8x16);
568
+ uint8x16_t greater_u8x16 = vcgtq_s8(data_i8x16, max_i8x16);
569
+ min_i8x16 = vbslq_s8(less_u8x16, data_i8x16, min_i8x16);
570
+ max_i8x16 = vbslq_s8(greater_u8x16, data_i8x16, max_i8x16);
571
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
572
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
573
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
574
+ }
575
+ nk_size_t remaining = count - idx;
576
+ if (remaining > 0) {
577
+ nk_b128_vec_t tail_vec;
578
+ nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
579
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
580
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
581
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)remaining));
582
+ int8x16_t data_for_min_i8x16 = vbslq_s8(valid_u8x16, tail_vec.i8x16, vdupq_n_s8(NK_I8_MAX));
583
+ int8x16_t data_for_max_i8x16 = vbslq_s8(valid_u8x16, tail_vec.i8x16, vdupq_n_s8(NK_I8_MIN));
584
+ uint8x16_t less_u8x16 = vcltq_s8(data_for_min_i8x16, min_i8x16);
585
+ uint8x16_t greater_u8x16 = vcgtq_s8(data_for_max_i8x16, max_i8x16);
586
+ min_i8x16 = vbslq_s8(less_u8x16, data_for_min_i8x16, min_i8x16);
587
+ max_i8x16 = vbslq_s8(greater_u8x16, data_for_max_i8x16, max_i8x16);
588
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
589
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
590
+ }
591
+ nk_i8_t min_value = vminvq_s8(min_i8x16), max_value = vmaxvq_s8(max_i8x16);
592
+ uint8x16_t min_value_match_u8x16 = vceqq_s8(min_i8x16, vdupq_n_s8(min_value));
593
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
594
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
595
+ uint8x16_t max_value_match_u8x16 = vceqq_s8(max_i8x16, vdupq_n_s8(max_value));
596
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
597
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
598
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
599
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
600
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
601
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
602
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
603
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
604
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
605
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
606
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
607
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
608
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
609
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
610
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
611
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
612
+ }
613
+
614
+ NK_INTERNAL void nk_reduce_minmax_i8_neon_strided_( //
615
+ nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
616
+ nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
617
+ nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
618
+ int8x16_t min_i8x16 = vdupq_n_s8(NK_I8_MAX), max_i8x16 = vdupq_n_s8(NK_I8_MIN);
619
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
620
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
621
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
622
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
623
+ nk_size_t idx = 0;
624
+ int8x16_t data_for_min_i8x16, data_for_max_i8x16;
625
+
626
+ nk_reduce_minmax_i8_neon_cycle:
627
+ if (stride_elements == 2 && idx + 16 <= count) {
628
+ int8x16x2_t loaded = vld2q_s8(data_ptr + idx * 2);
629
+ data_for_min_i8x16 = loaded.val[0];
630
+ data_for_max_i8x16 = loaded.val[0];
631
+ idx += 16;
632
+ }
633
+ else if (stride_elements == 3 && idx + 16 <= count) {
634
+ int8x16x3_t loaded = vld3q_s8(data_ptr + idx * 3);
635
+ data_for_min_i8x16 = loaded.val[0];
636
+ data_for_max_i8x16 = loaded.val[0];
637
+ idx += 16;
638
+ }
639
+ else if (stride_elements == 4 && idx + 16 <= count) {
640
+ int8x16x4_t loaded = vld4q_s8(data_ptr + idx * 4);
641
+ data_for_min_i8x16 = loaded.val[0];
642
+ data_for_max_i8x16 = loaded.val[0];
643
+ idx += 16;
644
+ }
645
+ else if (idx < count) {
646
+ nk_b128_vec_t tail_vec;
647
+ nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
648
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)(count - idx)));
649
+ data_for_min_i8x16 = vbslq_s8(valid_u8x16, tail_vec.i8x16, min_i8x16);
650
+ data_for_max_i8x16 = vbslq_s8(valid_u8x16, tail_vec.i8x16, max_i8x16);
651
+ idx = count;
652
+ }
653
+ else {
654
+ nk_i8_t min_value = vminvq_s8(min_i8x16), max_value = vmaxvq_s8(max_i8x16);
655
+ uint8x16_t min_value_match_u8x16 = vceqq_s8(min_i8x16, vdupq_n_s8(min_value));
656
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
657
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
658
+ uint8x16_t max_value_match_u8x16 = vceqq_s8(max_i8x16, vdupq_n_s8(max_value));
659
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
660
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
661
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
662
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
663
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
664
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
665
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
666
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
667
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
668
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
669
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
670
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
671
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
672
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
673
+ return;
674
+ }
675
+
676
+ // Shared update body
677
+ uint8x16_t less_u8x16 = vcltq_s8(data_for_min_i8x16, min_i8x16);
678
+ uint8x16_t greater_u8x16 = vcgtq_s8(data_for_max_i8x16, max_i8x16);
679
+ min_i8x16 = vbslq_s8(less_u8x16, data_for_min_i8x16, min_i8x16);
680
+ max_i8x16 = vbslq_s8(greater_u8x16, data_for_max_i8x16, max_i8x16);
681
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
682
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
683
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
684
+ goto nk_reduce_minmax_i8_neon_cycle;
685
+ }
686
+
687
+ NK_PUBLIC void nk_reduce_minmax_i8_neon( //
688
+ nk_i8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
689
+ nk_i8_t *min_value_ptr, nk_size_t *min_index_ptr, //
690
+ nk_i8_t *max_value_ptr, nk_size_t *max_index_ptr) {
691
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i8_t);
692
+ int aligned = (stride_bytes % sizeof(nk_i8_t) == 0);
693
+ if (count == 0)
694
+ *min_value_ptr = NK_I8_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I8_MIN,
695
+ *max_index_ptr = NK_SIZE_MAX;
696
+ else if (!aligned)
697
+ nk_reduce_minmax_i8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
698
+ max_index_ptr);
699
+ else if (count > (nk_size_t)(NK_U8_MAX + 1) * 16) {
700
+ nk_size_t left_count = count / 2;
701
+ nk_i8_t left_min_value, right_min_value, left_max_value, right_max_value;
702
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
703
+ nk_reduce_minmax_i8_neon(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index, &left_max_value,
704
+ &left_max_index);
705
+ nk_reduce_minmax_i8_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
706
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
707
+ if (right_min_value < left_min_value)
708
+ *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
709
+ else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
710
+ if (right_max_value > left_max_value)
711
+ *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
712
+ else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
713
+ }
714
+ else if (stride_elements == 1)
715
+ nk_reduce_minmax_i8_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
716
+ max_index_ptr);
717
+ else if (stride_elements <= 4)
718
+ nk_reduce_minmax_i8_neon_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr, max_value_ptr,
719
+ max_index_ptr);
720
+ else
721
+ nk_reduce_minmax_i8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
722
+ max_index_ptr);
723
+ }
724
+
725
+ NK_INTERNAL void nk_reduce_moments_u8_neon_contiguous_( //
726
+ nk_u8_t const *data_ptr, nk_size_t count, //
727
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
728
+ uint32x4_t sum_u32x4 = vdupq_n_u32(0);
729
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
730
+ nk_size_t idx = 0;
731
+ for (; idx + 16 <= count; idx += 16) {
732
+ uint8x16_t data_u8x16 = vld1q_u8(data_ptr + idx);
733
+ uint16x8_t sum16 = vpaddlq_u8(data_u8x16);
734
+ sum_u32x4 = vaddq_u32(sum_u32x4, vpaddlq_u16(sum16));
735
+ uint16x8_t sq_lo = vmull_u8(vget_low_u8(data_u8x16), vget_low_u8(data_u8x16));
736
+ uint16x8_t sq_hi = vmull_high_u8(data_u8x16, data_u8x16);
737
+ uint32x4_t sq32_lo = vpaddlq_u16(sq_lo);
738
+ uint32x4_t sq32_hi = vpaddlq_u16(sq_hi);
739
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq32_lo));
740
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq32_hi));
741
+ }
742
+ nk_u64_t sum = vaddlvq_u32(sum_u32x4);
743
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
744
+ for (; idx < count; ++idx) {
745
+ nk_u64_t value = (nk_u64_t)data_ptr[idx];
746
+ sum += value, sumsq += value * value;
747
+ }
748
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
749
+ }
750
+
751
+ NK_INTERNAL void nk_reduce_moments_u8_neon_strided_( //
752
+ nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
753
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
754
+ uint32x4_t sum_u32x4 = vdupq_n_u32(0);
755
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
756
+ nk_size_t idx = 0;
757
+ if (stride_elements == 2) {
758
+ for (; idx + 16 <= count; idx += 16) {
759
+ uint8x16x2_t loaded_u8x16x2 = vld2q_u8(data_ptr + idx * 2);
760
+ uint8x16_t data_u8x16 = loaded_u8x16x2.val[0];
761
+ uint16x8_t pairwise_u16x8 = vpaddlq_u8(data_u8x16);
762
+ sum_u32x4 = vaddq_u32(sum_u32x4, vpaddlq_u16(pairwise_u16x8));
763
+ uint16x8_t squares_lo_u16x8 = vmull_u8(vget_low_u8(data_u8x16), vget_low_u8(data_u8x16));
764
+ uint16x8_t squares_hi_u16x8 = vmull_high_u8(data_u8x16, data_u8x16);
765
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(squares_lo_u16x8)));
766
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(squares_hi_u16x8)));
767
+ }
768
+ }
769
+ else if (stride_elements == 3) {
770
+ for (; idx + 16 <= count; idx += 16) {
771
+ uint8x16x3_t loaded_u8x16x3 = vld3q_u8(data_ptr + idx * 3);
772
+ uint8x16_t data_u8x16 = loaded_u8x16x3.val[0];
773
+ uint16x8_t pairwise_u16x8 = vpaddlq_u8(data_u8x16);
774
+ sum_u32x4 = vaddq_u32(sum_u32x4, vpaddlq_u16(pairwise_u16x8));
775
+ uint16x8_t squares_lo_u16x8 = vmull_u8(vget_low_u8(data_u8x16), vget_low_u8(data_u8x16));
776
+ uint16x8_t squares_hi_u16x8 = vmull_high_u8(data_u8x16, data_u8x16);
777
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(squares_lo_u16x8)));
778
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(squares_hi_u16x8)));
779
+ }
780
+ }
781
+ else {
782
+ for (; idx + 16 <= count; idx += 16) {
783
+ uint8x16x4_t loaded_u8x16x4 = vld4q_u8(data_ptr + idx * 4);
784
+ uint8x16_t data_u8x16 = loaded_u8x16x4.val[0];
785
+ uint16x8_t pairwise_u16x8 = vpaddlq_u8(data_u8x16);
786
+ sum_u32x4 = vaddq_u32(sum_u32x4, vpaddlq_u16(pairwise_u16x8));
787
+ uint16x8_t squares_lo_u16x8 = vmull_u8(vget_low_u8(data_u8x16), vget_low_u8(data_u8x16));
788
+ uint16x8_t squares_hi_u16x8 = vmull_high_u8(data_u8x16, data_u8x16);
789
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(squares_lo_u16x8)));
790
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(squares_hi_u16x8)));
791
+ }
792
+ }
793
+ nk_u64_t sum = vaddlvq_u32(sum_u32x4);
794
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
795
+ for (; idx < count; ++idx) {
796
+ nk_u64_t value = (nk_u64_t)data_ptr[idx * stride_elements];
797
+ sum += value, sumsq += value * value;
798
+ }
799
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
800
+ }
801
+
802
+ NK_PUBLIC void nk_reduce_moments_u8_neon( //
803
+ nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
804
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
805
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
806
+ int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
807
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
808
+ else if (!aligned) nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
809
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
810
+ nk_size_t left_count = count / 2;
811
+ nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
812
+ nk_reduce_moments_u8_neon(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
813
+ nk_reduce_moments_u8_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes, &right_sum,
814
+ &right_sumsq);
815
+ *sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
816
+ *sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
817
+ }
818
+ else if (stride_elements == 1) nk_reduce_moments_u8_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
819
+ else if (stride_elements <= 4)
820
+ nk_reduce_moments_u8_neon_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
821
+ else nk_reduce_moments_u8_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
822
+ }
823
+
824
+ NK_INTERNAL void nk_reduce_minmax_u8_neon_contiguous_( //
825
+ nk_u8_t const *data_ptr, nk_size_t count, //
826
+ nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
827
+ nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
828
+ uint8x16_t min_u8x16 = vdupq_n_u8(NK_U8_MAX), max_u8x16 = vdupq_n_u8(0);
829
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
830
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
831
+ nk_size_t idx = 0;
832
+ for (; idx + 16 <= count; idx += 16) {
833
+ uint8x16_t data_u8x16 = vld1q_u8(data_ptr + idx);
834
+ uint8x16_t less_u8x16 = vcltq_u8(data_u8x16, min_u8x16);
835
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_u8x16, max_u8x16);
836
+ min_u8x16 = vbslq_u8(less_u8x16, data_u8x16, min_u8x16);
837
+ max_u8x16 = vbslq_u8(greater_u8x16, data_u8x16, max_u8x16);
838
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
839
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
840
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
841
+ }
842
+ nk_size_t remaining = count - idx;
843
+ if (remaining > 0) {
844
+ nk_b128_vec_t tail_vec;
845
+ nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
846
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
847
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
848
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)remaining));
849
+ uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, tail_vec.u8x16, vdupq_n_u8(NK_U8_MAX));
850
+ uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, tail_vec.u8x16, vdupq_n_u8(0));
851
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
852
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
853
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
854
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
855
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
856
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
857
+ }
858
+ nk_u8_t min_value = vminvq_u8(min_u8x16), max_value = vmaxvq_u8(max_u8x16);
859
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_value));
860
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
861
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
862
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_value));
863
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
864
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
865
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
866
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
867
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
868
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
869
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
870
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
871
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
872
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
873
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
874
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
875
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
876
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
877
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
878
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
879
+ }
880
+
881
+ NK_INTERNAL void nk_reduce_minmax_u8_neon_strided_( //
882
+ nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
883
+ nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
884
+ nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
885
+ uint8x16_t min_u8x16 = vdupq_n_u8(NK_U8_MAX), max_u8x16 = vdupq_n_u8(0);
886
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
887
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
888
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
889
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
890
+ nk_size_t idx = 0;
891
+ uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
892
+
893
+ nk_reduce_minmax_u8_neon_cycle:
894
+ if (stride_elements == 2 && idx + 16 <= count) {
895
+ uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)data_ptr + idx * 2);
896
+ data_for_min_u8x16 = loaded.val[0];
897
+ data_for_max_u8x16 = loaded.val[0];
898
+ idx += 16;
899
+ }
900
+ else if (stride_elements == 3 && idx + 16 <= count) {
901
+ uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)data_ptr + idx * 3);
902
+ data_for_min_u8x16 = loaded.val[0];
903
+ data_for_max_u8x16 = loaded.val[0];
904
+ idx += 16;
905
+ }
906
+ else if (stride_elements == 4 && idx + 16 <= count) {
907
+ uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)data_ptr + idx * 4);
908
+ data_for_min_u8x16 = loaded.val[0];
909
+ data_for_max_u8x16 = loaded.val[0];
910
+ idx += 16;
911
+ }
912
+ else if (idx < count) {
913
+ nk_b128_vec_t tail_vec;
914
+ nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
915
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)(count - idx)));
916
+ data_for_min_u8x16 = vbslq_u8(valid_u8x16, tail_vec.u8x16, min_u8x16);
917
+ data_for_max_u8x16 = vbslq_u8(valid_u8x16, tail_vec.u8x16, max_u8x16);
918
+ idx = count;
919
+ }
920
+ else {
921
+ nk_u8_t min_value = vminvq_u8(min_u8x16), max_value = vmaxvq_u8(max_u8x16);
922
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_value));
923
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
924
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
925
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_value));
926
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
927
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
928
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
929
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
930
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
931
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
932
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
933
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
934
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
935
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
936
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
937
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
938
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
939
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
940
+ return;
941
+ }
942
+
943
+ // Shared update body
944
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
945
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
946
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
947
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
948
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
949
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
950
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
951
+ goto nk_reduce_minmax_u8_neon_cycle;
952
+ }
953
+
954
+ NK_PUBLIC void nk_reduce_minmax_u8_neon( //
955
+ nk_u8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
956
+ nk_u8_t *min_value_ptr, nk_size_t *min_index_ptr, //
957
+ nk_u8_t *max_value_ptr, nk_size_t *max_index_ptr) {
958
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u8_t);
959
+ int aligned = (stride_bytes % sizeof(nk_u8_t) == 0);
960
+ if (count == 0)
961
+ *min_value_ptr = NK_U8_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = 0, *max_index_ptr = NK_SIZE_MAX;
962
+ else if (!aligned)
963
+ nk_reduce_minmax_u8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
964
+ max_index_ptr);
965
+ else if (count > (nk_size_t)(NK_U8_MAX + 1) * 16) {
966
+ nk_size_t left_count = count / 2;
967
+ nk_u8_t left_min_value, right_min_value, left_max_value, right_max_value;
968
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
969
+ nk_reduce_minmax_u8_neon(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index, &left_max_value,
970
+ &left_max_index);
971
+ nk_reduce_minmax_u8_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
972
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
973
+ if (right_min_value < left_min_value)
974
+ *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
975
+ else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
976
+ if (right_max_value > left_max_value)
977
+ *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
978
+ else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
979
+ }
980
+ else if (stride_elements == 1)
981
+ nk_reduce_minmax_u8_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
982
+ max_index_ptr);
983
+ else if (stride_elements <= 4)
984
+ nk_reduce_minmax_u8_neon_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr, max_value_ptr,
985
+ max_index_ptr);
986
+ else
987
+ nk_reduce_minmax_u8_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
988
+ max_index_ptr);
989
+ }
990
+
991
+ NK_INTERNAL void nk_reduce_moments_i16_neon_contiguous_( //
992
+ nk_i16_t const *data_ptr, nk_size_t count, //
993
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
994
+ int64x2_t sum_i64x2 = vdupq_n_s64(0);
995
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
996
+ nk_size_t idx = 0;
997
+ for (; idx + 8 <= count; idx += 8) {
998
+ int16x8_t data_i16x8 = vld1q_s16(data_ptr + idx);
999
+ int32x4_t sum32 = vpaddlq_s16(data_i16x8);
1000
+ sum_i64x2 = vaddq_s64(sum_i64x2, vpaddlq_s32(sum32));
1001
+ // sumsq: widening multiply i16*i16 -> i32, then widen to u64
1002
+ int32x4_t sq_lo = vmull_s16(vget_low_s16(data_i16x8), vget_low_s16(data_i16x8));
1003
+ int32x4_t sq_hi = vmull_high_s16(data_i16x8, data_i16x8);
1004
+ // i16*i16 squares are always non-negative, safe to reinterpret as u32
1005
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(sq_lo)));
1006
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(sq_hi)));
1007
+ }
1008
+ nk_i64_t sum = vgetq_lane_s64(sum_i64x2, 0) + vgetq_lane_s64(sum_i64x2, 1);
1009
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
1010
+ for (; idx < count; ++idx) {
1011
+ nk_i64_t value_i64 = (nk_i64_t)data_ptr[idx];
1012
+ sum += value_i64;
1013
+ sumsq += (nk_u64_t)(value_i64 * value_i64);
1014
+ }
1015
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
1016
+ }
1017
+
1018
+ NK_INTERNAL void nk_reduce_moments_i16_neon_strided_( //
1019
+ nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
1020
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1021
+ int64x2_t sum_i64x2 = vdupq_n_s64(0);
1022
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
1023
+ nk_size_t idx = 0;
1024
+ if (stride_elements == 2) {
1025
+ for (; idx + 8 <= count; idx += 8) {
1026
+ int16x8x2_t loaded_i16x8x2 = vld2q_s16(data_ptr + idx * 2);
1027
+ int16x8_t data_i16x8 = loaded_i16x8x2.val[0];
1028
+ int32x4_t pairwise_i32x4 = vpaddlq_s16(data_i16x8);
1029
+ sum_i64x2 = vaddq_s64(sum_i64x2, vpaddlq_s32(pairwise_i32x4));
1030
+ int32x4_t squares_lo_i32x4 = vmull_s16(vget_low_s16(data_i16x8), vget_low_s16(data_i16x8));
1031
+ int32x4_t squares_hi_i32x4 = vmull_high_s16(data_i16x8, data_i16x8);
1032
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_lo_i32x4)));
1033
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_hi_i32x4)));
1034
+ }
1035
+ }
1036
+ else if (stride_elements == 3) {
1037
+ for (; idx + 8 <= count; idx += 8) {
1038
+ int16x8x3_t loaded_i16x8x3 = vld3q_s16(data_ptr + idx * 3);
1039
+ int16x8_t data_i16x8 = loaded_i16x8x3.val[0];
1040
+ int32x4_t pairwise_i32x4 = vpaddlq_s16(data_i16x8);
1041
+ sum_i64x2 = vaddq_s64(sum_i64x2, vpaddlq_s32(pairwise_i32x4));
1042
+ int32x4_t squares_lo_i32x4 = vmull_s16(vget_low_s16(data_i16x8), vget_low_s16(data_i16x8));
1043
+ int32x4_t squares_hi_i32x4 = vmull_high_s16(data_i16x8, data_i16x8);
1044
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_lo_i32x4)));
1045
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_hi_i32x4)));
1046
+ }
1047
+ }
1048
+ else {
1049
+ for (; idx + 8 <= count; idx += 8) {
1050
+ int16x8x4_t loaded_i16x8x4 = vld4q_s16(data_ptr + idx * 4);
1051
+ int16x8_t data_i16x8 = loaded_i16x8x4.val[0];
1052
+ int32x4_t pairwise_i32x4 = vpaddlq_s16(data_i16x8);
1053
+ sum_i64x2 = vaddq_s64(sum_i64x2, vpaddlq_s32(pairwise_i32x4));
1054
+ int32x4_t squares_lo_i32x4 = vmull_s16(vget_low_s16(data_i16x8), vget_low_s16(data_i16x8));
1055
+ int32x4_t squares_hi_i32x4 = vmull_high_s16(data_i16x8, data_i16x8);
1056
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_lo_i32x4)));
1057
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_hi_i32x4)));
1058
+ }
1059
+ }
1060
+ nk_i64_t sum = vgetq_lane_s64(sum_i64x2, 0) + vgetq_lane_s64(sum_i64x2, 1);
1061
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
1062
+ for (; idx < count; ++idx) {
1063
+ nk_i64_t value_i64 = (nk_i64_t)data_ptr[idx * stride_elements];
1064
+ sum += value_i64, sumsq += (nk_u64_t)(value_i64 * value_i64);
1065
+ }
1066
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
1067
+ }
1068
+
1069
+ NK_PUBLIC void nk_reduce_moments_i16_neon( //
1070
+ nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1071
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1072
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
1073
+ int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
1074
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
1075
+ else if (!aligned) nk_reduce_moments_i16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
1076
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
1077
+ nk_size_t left_count = count / 2;
1078
+ nk_i64_t left_sum, right_sum;
1079
+ nk_u64_t left_sumsq, right_sumsq;
1080
+ nk_reduce_moments_i16_neon(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
1081
+ nk_reduce_moments_i16_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
1082
+ &right_sum, &right_sumsq);
1083
+ *sum_ptr = nk_i64_saturating_add_serial(left_sum, right_sum);
1084
+ *sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
1085
+ }
1086
+ else if (stride_elements == 1) nk_reduce_moments_i16_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
1087
+ else if (stride_elements <= 4)
1088
+ nk_reduce_moments_i16_neon_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
1089
+ else nk_reduce_moments_i16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
1090
+ }
1091
+
1092
+ NK_INTERNAL void nk_reduce_minmax_i16_neon_contiguous_( //
1093
+ nk_i16_t const *data_ptr, nk_size_t count, //
1094
+ nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
1095
+ nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
1096
+ int16x8_t min_i16x8 = vdupq_n_s16(NK_I16_MAX), max_i16x8 = vdupq_n_s16(NK_I16_MIN);
1097
+ uint16x8_t min_iter_u16x8 = vdupq_n_u16(0), max_iter_u16x8 = vdupq_n_u16(0);
1098
+ uint16x8_t iter_u16x8 = vdupq_n_u16(0), one_u16x8 = vdupq_n_u16(1);
1099
+ nk_size_t idx = 0;
1100
+ for (; idx + 8 <= count; idx += 8) {
1101
+ int16x8_t data_i16x8 = vld1q_s16(data_ptr + idx);
1102
+ uint16x8_t less_u16x8 = vcltq_s16(data_i16x8, min_i16x8);
1103
+ uint16x8_t greater_u16x8 = vcgtq_s16(data_i16x8, max_i16x8);
1104
+ min_i16x8 = vbslq_s16(less_u16x8, data_i16x8, min_i16x8);
1105
+ max_i16x8 = vbslq_s16(greater_u16x8, data_i16x8, max_i16x8);
1106
+ min_iter_u16x8 = vbslq_u16(less_u16x8, iter_u16x8, min_iter_u16x8);
1107
+ max_iter_u16x8 = vbslq_u16(greater_u16x8, iter_u16x8, max_iter_u16x8);
1108
+ iter_u16x8 = vaddq_u16(iter_u16x8, one_u16x8);
1109
+ }
1110
+ nk_size_t remaining = count - idx;
1111
+ if (remaining > 0) {
1112
+ nk_b128_vec_t tail_vec;
1113
+ nk_partial_load_b16x8_serial_(data_ptr + idx, &tail_vec, remaining);
1114
+ uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
1115
+ vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
1116
+ uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((uint16_t)remaining));
1117
+ int16x8_t data_for_min_i16x8 = vbslq_s16(valid_u16x8, tail_vec.i16x8, vdupq_n_s16(NK_I16_MAX));
1118
+ int16x8_t data_for_max_i16x8 = vbslq_s16(valid_u16x8, tail_vec.i16x8, vdupq_n_s16(NK_I16_MIN));
1119
+ uint16x8_t less_u16x8 = vcltq_s16(data_for_min_i16x8, min_i16x8);
1120
+ uint16x8_t greater_u16x8 = vcgtq_s16(data_for_max_i16x8, max_i16x8);
1121
+ min_i16x8 = vbslq_s16(less_u16x8, data_for_min_i16x8, min_i16x8);
1122
+ max_i16x8 = vbslq_s16(greater_u16x8, data_for_max_i16x8, max_i16x8);
1123
+ min_iter_u16x8 = vbslq_u16(less_u16x8, iter_u16x8, min_iter_u16x8);
1124
+ max_iter_u16x8 = vbslq_u16(greater_u16x8, iter_u16x8, max_iter_u16x8);
1125
+ }
1126
+ nk_i16_t min_value = vminvq_s16(min_i16x8), max_value = vmaxvq_s16(max_i16x8);
1127
+ uint16x8_t min_value_match_u16x8 = vceqq_s16(min_i16x8, vdupq_n_s16(min_value));
1128
+ uint16x8_t masked_min_iter_u16x8 = vbslq_u16(min_value_match_u16x8, min_iter_u16x8, vdupq_n_u16(0xFFFF));
1129
+ nk_u16_t earliest_min_cycle = vminvq_u16(masked_min_iter_u16x8);
1130
+ uint16x8_t max_value_match_u16x8 = vceqq_s16(max_i16x8, vdupq_n_s16(max_value));
1131
+ uint16x8_t masked_max_iter_u16x8 = vbslq_u16(max_value_match_u16x8, max_iter_u16x8, vdupq_n_u16(0xFFFF));
1132
+ nk_u16_t earliest_max_cycle = vminvq_u16(masked_max_iter_u16x8);
1133
+ uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
1134
+ vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
1135
+ uint16x8_t min_cycle_match_u16x8 = vceqq_u16(min_iter_u16x8, vdupq_n_u16(earliest_min_cycle));
1136
+ uint16x8_t min_both_match_u16x8 = vandq_u16(min_value_match_u16x8, min_cycle_match_u16x8);
1137
+ uint16x8_t min_masked_lanes_u16x8 = vbslq_u16(min_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
1138
+ nk_u16_t min_lane_offset = vminvq_u16(min_masked_lanes_u16x8);
1139
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 8 + (nk_size_t)min_lane_offset;
1140
+ uint16x8_t max_cycle_match_u16x8 = vceqq_u16(max_iter_u16x8, vdupq_n_u16(earliest_max_cycle));
1141
+ uint16x8_t max_both_match_u16x8 = vandq_u16(max_value_match_u16x8, max_cycle_match_u16x8);
1142
+ uint16x8_t max_masked_lanes_u16x8 = vbslq_u16(max_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
1143
+ nk_u16_t max_lane_offset = vminvq_u16(max_masked_lanes_u16x8);
1144
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 8 + (nk_size_t)max_lane_offset;
1145
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
1146
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
1147
+ }
1148
+
1149
+ NK_INTERNAL void nk_reduce_minmax_i16_neon_strided_( //
1150
+ nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
1151
+ nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
1152
+ nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
1153
+ int16x8_t min_i16x8 = vdupq_n_s16(NK_I16_MAX), max_i16x8 = vdupq_n_s16(NK_I16_MIN);
1154
+ uint16x8_t min_iter_u16x8 = vdupq_n_u16(0), max_iter_u16x8 = vdupq_n_u16(0);
1155
+ uint16x8_t iter_u16x8 = vdupq_n_u16(0), one_u16x8 = vdupq_n_u16(1);
1156
+ uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
1157
+ vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
1158
+ nk_size_t idx = 0;
1159
+ int16x8_t data_for_min_i16x8, data_for_max_i16x8;
1160
+
1161
+ nk_reduce_minmax_i16_neon_cycle:
1162
+ if (stride_elements == 2 && idx + 8 <= count) {
1163
+ int16x8x2_t loaded = vld2q_s16(data_ptr + idx * 2);
1164
+ data_for_min_i16x8 = loaded.val[0];
1165
+ data_for_max_i16x8 = loaded.val[0];
1166
+ idx += 8;
1167
+ }
1168
+ else if (stride_elements == 3 && idx + 8 <= count) {
1169
+ int16x8x3_t loaded = vld3q_s16(data_ptr + idx * 3);
1170
+ data_for_min_i16x8 = loaded.val[0];
1171
+ data_for_max_i16x8 = loaded.val[0];
1172
+ idx += 8;
1173
+ }
1174
+ else if (stride_elements == 4 && idx + 8 <= count) {
1175
+ int16x8x4_t loaded = vld4q_s16(data_ptr + idx * 4);
1176
+ data_for_min_i16x8 = loaded.val[0];
1177
+ data_for_max_i16x8 = loaded.val[0];
1178
+ idx += 8;
1179
+ }
1180
+ else if (idx < count) {
1181
+ nk_b128_vec_t tail_vec;
1182
+ nk_strided_load_b16x8_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
1183
+ uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((uint16_t)(count - idx)));
1184
+ data_for_min_i16x8 = vbslq_s16(valid_u16x8, tail_vec.i16x8, min_i16x8);
1185
+ data_for_max_i16x8 = vbslq_s16(valid_u16x8, tail_vec.i16x8, max_i16x8);
1186
+ idx = count;
1187
+ }
1188
+ else {
1189
+ nk_i16_t min_value = vminvq_s16(min_i16x8), max_value = vmaxvq_s16(max_i16x8);
1190
+ uint16x8_t min_value_match_u16x8 = vceqq_s16(min_i16x8, vdupq_n_s16(min_value));
1191
+ uint16x8_t masked_min_iter_u16x8 = vbslq_u16(min_value_match_u16x8, min_iter_u16x8, vdupq_n_u16(0xFFFF));
1192
+ nk_u16_t earliest_min_cycle = vminvq_u16(masked_min_iter_u16x8);
1193
+ uint16x8_t max_value_match_u16x8 = vceqq_s16(max_i16x8, vdupq_n_s16(max_value));
1194
+ uint16x8_t masked_max_iter_u16x8 = vbslq_u16(max_value_match_u16x8, max_iter_u16x8, vdupq_n_u16(0xFFFF));
1195
+ nk_u16_t earliest_max_cycle = vminvq_u16(masked_max_iter_u16x8);
1196
+ uint16x8_t min_cycle_match_u16x8 = vceqq_u16(min_iter_u16x8, vdupq_n_u16(earliest_min_cycle));
1197
+ uint16x8_t min_both_match_u16x8 = vandq_u16(min_value_match_u16x8, min_cycle_match_u16x8);
1198
+ uint16x8_t min_masked_lanes_u16x8 = vbslq_u16(min_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
1199
+ nk_u16_t min_lane_offset = vminvq_u16(min_masked_lanes_u16x8);
1200
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 8 + (nk_size_t)min_lane_offset;
1201
+ uint16x8_t max_cycle_match_u16x8 = vceqq_u16(max_iter_u16x8, vdupq_n_u16(earliest_max_cycle));
1202
+ uint16x8_t max_both_match_u16x8 = vandq_u16(max_value_match_u16x8, max_cycle_match_u16x8);
1203
+ uint16x8_t max_masked_lanes_u16x8 = vbslq_u16(max_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
1204
+ nk_u16_t max_lane_offset = vminvq_u16(max_masked_lanes_u16x8);
1205
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 8 + (nk_size_t)max_lane_offset;
1206
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
1207
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
1208
+ return;
1209
+ }
1210
+
1211
+ // Shared update body
1212
+ uint16x8_t less_u16x8 = vcltq_s16(data_for_min_i16x8, min_i16x8);
1213
+ uint16x8_t greater_u16x8 = vcgtq_s16(data_for_max_i16x8, max_i16x8);
1214
+ min_i16x8 = vbslq_s16(less_u16x8, data_for_min_i16x8, min_i16x8);
1215
+ max_i16x8 = vbslq_s16(greater_u16x8, data_for_max_i16x8, max_i16x8);
1216
+ min_iter_u16x8 = vbslq_u16(less_u16x8, iter_u16x8, min_iter_u16x8);
1217
+ max_iter_u16x8 = vbslq_u16(greater_u16x8, iter_u16x8, max_iter_u16x8);
1218
+ iter_u16x8 = vaddq_u16(iter_u16x8, one_u16x8);
1219
+ goto nk_reduce_minmax_i16_neon_cycle;
1220
+ }
1221
+
1222
+ NK_PUBLIC void nk_reduce_minmax_i16_neon( //
1223
+ nk_i16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1224
+ nk_i16_t *min_value_ptr, nk_size_t *min_index_ptr, //
1225
+ nk_i16_t *max_value_ptr, nk_size_t *max_index_ptr) {
1226
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i16_t);
1227
+ int aligned = (stride_bytes % sizeof(nk_i16_t) == 0);
1228
+ if (count == 0)
1229
+ *min_value_ptr = NK_I16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I16_MIN,
1230
+ *max_index_ptr = NK_SIZE_MAX;
1231
+ else if (!aligned)
1232
+ nk_reduce_minmax_i16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1233
+ max_index_ptr);
1234
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
1235
+ nk_size_t left_count = count / 2;
1236
+ nk_i16_t left_min_value, right_min_value, left_max_value, right_max_value;
1237
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
1238
+ nk_reduce_minmax_i16_neon(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index, &left_max_value,
1239
+ &left_max_index);
1240
+ nk_reduce_minmax_i16_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
1241
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
1242
+ if (right_min_value < left_min_value)
1243
+ *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
1244
+ else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
1245
+ if (right_max_value > left_max_value)
1246
+ *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
1247
+ else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
1248
+ }
1249
+ else if (stride_elements == 1)
1250
+ nk_reduce_minmax_i16_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
1251
+ max_index_ptr);
1252
+ else if (stride_elements <= 4)
1253
+ nk_reduce_minmax_i16_neon_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
1254
+ max_value_ptr, max_index_ptr);
1255
+ else
1256
+ nk_reduce_minmax_i16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1257
+ max_index_ptr);
1258
+ }
1259
+
1260
+ NK_INTERNAL void nk_reduce_moments_u16_neon_contiguous_( //
1261
+ nk_u16_t const *data_ptr, nk_size_t count, //
1262
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1263
+ uint64x2_t sum_u64x2 = vdupq_n_u64(0);
1264
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
1265
+ nk_size_t idx = 0;
1266
+ for (; idx + 8 <= count; idx += 8) {
1267
+ uint16x8_t data_u16x8 = vld1q_u16(data_ptr + idx);
1268
+ uint32x4_t sum32 = vpaddlq_u16(data_u16x8);
1269
+ sum_u64x2 = vaddq_u64(sum_u64x2, vpaddlq_u32(sum32));
1270
+ uint32x4_t sq_lo = vmull_u16(vget_low_u16(data_u16x8), vget_low_u16(data_u16x8));
1271
+ uint32x4_t sq_hi = vmull_high_u16(data_u16x8, data_u16x8);
1272
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_lo));
1273
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_hi));
1274
+ }
1275
+ nk_u64_t sum = vgetq_lane_u64(sum_u64x2, 0) + vgetq_lane_u64(sum_u64x2, 1);
1276
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
1277
+ for (; idx < count; ++idx) {
1278
+ nk_u64_t value = (nk_u64_t)data_ptr[idx];
1279
+ sum += value, sumsq += value * value;
1280
+ }
1281
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
1282
+ }
1283
+
1284
+ NK_INTERNAL void nk_reduce_moments_u16_neon_strided_( //
1285
+ nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
1286
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1287
+ uint64x2_t sum_u64x2 = vdupq_n_u64(0);
1288
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
1289
+ nk_size_t idx = 0;
1290
+
1291
+ if (stride_elements == 2) {
1292
+ for (; idx + 8 <= count; idx += 8) {
1293
+ uint16x8x2_t loaded_u16x8x2 = vld2q_u16(data_ptr + idx * 2);
1294
+ uint16x8_t data_u16x8 = loaded_u16x8x2.val[0];
1295
+ uint32x4_t widened_sum_u32x4 = vpaddlq_u16(data_u16x8);
1296
+ sum_u64x2 = vaddq_u64(sum_u64x2, vpaddlq_u32(widened_sum_u32x4));
1297
+ uint32x4_t sq_lo_u32x4 = vmull_u16(vget_low_u16(data_u16x8), vget_low_u16(data_u16x8));
1298
+ uint32x4_t sq_hi_u32x4 = vmull_high_u16(data_u16x8, data_u16x8);
1299
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_lo_u32x4));
1300
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_hi_u32x4));
1301
+ }
1302
+ }
1303
+ else if (stride_elements == 3) {
1304
+ for (; idx + 8 <= count; idx += 8) {
1305
+ uint16x8x3_t loaded_u16x8x3 = vld3q_u16(data_ptr + idx * 3);
1306
+ uint16x8_t data_u16x8 = loaded_u16x8x3.val[0];
1307
+ uint32x4_t widened_sum_u32x4 = vpaddlq_u16(data_u16x8);
1308
+ sum_u64x2 = vaddq_u64(sum_u64x2, vpaddlq_u32(widened_sum_u32x4));
1309
+ uint32x4_t sq_lo_u32x4 = vmull_u16(vget_low_u16(data_u16x8), vget_low_u16(data_u16x8));
1310
+ uint32x4_t sq_hi_u32x4 = vmull_high_u16(data_u16x8, data_u16x8);
1311
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_lo_u32x4));
1312
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_hi_u32x4));
1313
+ }
1314
+ }
1315
+ else {
1316
+ for (; idx + 8 <= count; idx += 8) {
1317
+ uint16x8x4_t loaded_u16x8x4 = vld4q_u16(data_ptr + idx * 4);
1318
+ uint16x8_t data_u16x8 = loaded_u16x8x4.val[0];
1319
+ uint32x4_t widened_sum_u32x4 = vpaddlq_u16(data_u16x8);
1320
+ sum_u64x2 = vaddq_u64(sum_u64x2, vpaddlq_u32(widened_sum_u32x4));
1321
+ uint32x4_t sq_lo_u32x4 = vmull_u16(vget_low_u16(data_u16x8), vget_low_u16(data_u16x8));
1322
+ uint32x4_t sq_hi_u32x4 = vmull_high_u16(data_u16x8, data_u16x8);
1323
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_lo_u32x4));
1324
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(sq_hi_u32x4));
1325
+ }
1326
+ }
1327
+
1328
+ nk_u64_t sum = vgetq_lane_u64(sum_u64x2, 0) + vgetq_lane_u64(sum_u64x2, 1);
1329
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
1330
+ for (; idx < count; ++idx) {
1331
+ nk_u64_t value = (nk_u64_t)data_ptr[idx * stride_elements];
1332
+ sum += value, sumsq += value * value;
1333
+ }
1334
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
1335
+ }
1336
+
1337
+ NK_PUBLIC void nk_reduce_moments_u16_neon( //
1338
+ nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1339
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1340
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u16_t);
1341
+ int aligned = (stride_bytes % sizeof(nk_u16_t) == 0);
1342
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
1343
+ else if (!aligned) nk_reduce_moments_u16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
1344
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
1345
+ nk_size_t left_count = count / 2;
1346
+ nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
1347
+ nk_reduce_moments_u16_neon(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
1348
+ nk_reduce_moments_u16_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
1349
+ &right_sum, &right_sumsq);
1350
+ *sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
1351
+ *sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
1352
+ }
1353
+ else if (stride_elements == 1) nk_reduce_moments_u16_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
1354
+ else if (stride_elements <= 4)
1355
+ nk_reduce_moments_u16_neon_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
1356
+ else nk_reduce_moments_u16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
1357
+ }
1358
+
1359
+ NK_INTERNAL void nk_reduce_minmax_u16_neon_contiguous_( //
1360
+ nk_u16_t const *data_ptr, nk_size_t count, //
1361
+ nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
1362
+ nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
1363
+ uint16x8_t min_u16x8 = vdupq_n_u16(NK_U16_MAX), max_u16x8 = vdupq_n_u16(0);
1364
+ uint16x8_t min_iter_u16x8 = vdupq_n_u16(0), max_iter_u16x8 = vdupq_n_u16(0);
1365
+ uint16x8_t iter_u16x8 = vdupq_n_u16(0), one_u16x8 = vdupq_n_u16(1);
1366
+ nk_size_t idx = 0;
1367
+ for (; idx + 8 <= count; idx += 8) {
1368
+ uint16x8_t data_u16x8 = vld1q_u16(data_ptr + idx);
1369
+ uint16x8_t less_u16x8 = vcltq_u16(data_u16x8, min_u16x8);
1370
+ uint16x8_t greater_u16x8 = vcgtq_u16(data_u16x8, max_u16x8);
1371
+ min_u16x8 = vbslq_u16(less_u16x8, data_u16x8, min_u16x8);
1372
+ max_u16x8 = vbslq_u16(greater_u16x8, data_u16x8, max_u16x8);
1373
+ min_iter_u16x8 = vbslq_u16(less_u16x8, iter_u16x8, min_iter_u16x8);
1374
+ max_iter_u16x8 = vbslq_u16(greater_u16x8, iter_u16x8, max_iter_u16x8);
1375
+ iter_u16x8 = vaddq_u16(iter_u16x8, one_u16x8);
1376
+ }
1377
+ nk_size_t remaining = count - idx;
1378
+ if (remaining > 0) {
1379
+ nk_b128_vec_t tail_vec;
1380
+ nk_partial_load_b16x8_serial_(data_ptr + idx, &tail_vec, remaining);
1381
+ uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
1382
+ vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
1383
+ uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((uint16_t)remaining));
1384
+ uint16x8_t data_for_min_u16x8 = vbslq_u16(valid_u16x8, tail_vec.u16x8, vdupq_n_u16(NK_U16_MAX));
1385
+ uint16x8_t data_for_max_u16x8 = vbslq_u16(valid_u16x8, tail_vec.u16x8, vdupq_n_u16(0));
1386
+ uint16x8_t less_u16x8 = vcltq_u16(data_for_min_u16x8, min_u16x8);
1387
+ uint16x8_t greater_u16x8 = vcgtq_u16(data_for_max_u16x8, max_u16x8);
1388
+ min_u16x8 = vbslq_u16(less_u16x8, data_for_min_u16x8, min_u16x8);
1389
+ max_u16x8 = vbslq_u16(greater_u16x8, data_for_max_u16x8, max_u16x8);
1390
+ min_iter_u16x8 = vbslq_u16(less_u16x8, iter_u16x8, min_iter_u16x8);
1391
+ max_iter_u16x8 = vbslq_u16(greater_u16x8, iter_u16x8, max_iter_u16x8);
1392
+ }
1393
+ nk_u16_t min_value = vminvq_u16(min_u16x8), max_value = vmaxvq_u16(max_u16x8);
1394
+ uint16x8_t min_value_match_u16x8 = vceqq_u16(min_u16x8, vdupq_n_u16(min_value));
1395
+ uint16x8_t masked_min_iter_u16x8 = vbslq_u16(min_value_match_u16x8, min_iter_u16x8, vdupq_n_u16(0xFFFF));
1396
+ nk_u16_t earliest_min_cycle = vminvq_u16(masked_min_iter_u16x8);
1397
+ uint16x8_t max_value_match_u16x8 = vceqq_u16(max_u16x8, vdupq_n_u16(max_value));
1398
+ uint16x8_t masked_max_iter_u16x8 = vbslq_u16(max_value_match_u16x8, max_iter_u16x8, vdupq_n_u16(0xFFFF));
1399
+ nk_u16_t earliest_max_cycle = vminvq_u16(masked_max_iter_u16x8);
1400
+ uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
1401
+ vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
1402
+ uint16x8_t min_cycle_match_u16x8 = vceqq_u16(min_iter_u16x8, vdupq_n_u16(earliest_min_cycle));
1403
+ uint16x8_t min_both_match_u16x8 = vandq_u16(min_value_match_u16x8, min_cycle_match_u16x8);
1404
+ uint16x8_t min_masked_lanes_u16x8 = vbslq_u16(min_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
1405
+ nk_u16_t min_lane_offset = vminvq_u16(min_masked_lanes_u16x8);
1406
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 8 + (nk_size_t)min_lane_offset;
1407
+ uint16x8_t max_cycle_match_u16x8 = vceqq_u16(max_iter_u16x8, vdupq_n_u16(earliest_max_cycle));
1408
+ uint16x8_t max_both_match_u16x8 = vandq_u16(max_value_match_u16x8, max_cycle_match_u16x8);
1409
+ uint16x8_t max_masked_lanes_u16x8 = vbslq_u16(max_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
1410
+ nk_u16_t max_lane_offset = vminvq_u16(max_masked_lanes_u16x8);
1411
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 8 + (nk_size_t)max_lane_offset;
1412
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
1413
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
1414
+ }
1415
+
1416
+ NK_INTERNAL void nk_reduce_minmax_u16_neon_strided_( //
1417
+ nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
1418
+ nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
1419
+ nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
1420
+ uint16x8_t min_u16x8 = vdupq_n_u16(NK_U16_MAX), max_u16x8 = vdupq_n_u16(0);
1421
+ uint16x8_t min_iter_u16x8 = vdupq_n_u16(0), max_iter_u16x8 = vdupq_n_u16(0);
1422
+ uint16x8_t iter_u16x8 = vdupq_n_u16(0), one_u16x8 = vdupq_n_u16(1);
1423
+ uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
1424
+ vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
1425
+ nk_size_t idx = 0;
1426
+ uint16x8_t data_for_min_u16x8, data_for_max_u16x8;
1427
+
1428
+ nk_reduce_minmax_u16_neon_cycle:
1429
+ if (stride_elements == 2 && idx + 8 <= count) {
1430
+ uint16x8x2_t loaded = vld2q_u16((nk_u16_t const *)data_ptr + idx * 2);
1431
+ data_for_min_u16x8 = loaded.val[0];
1432
+ data_for_max_u16x8 = loaded.val[0];
1433
+ idx += 8;
1434
+ }
1435
+ else if (stride_elements == 3 && idx + 8 <= count) {
1436
+ uint16x8x3_t loaded = vld3q_u16((nk_u16_t const *)data_ptr + idx * 3);
1437
+ data_for_min_u16x8 = loaded.val[0];
1438
+ data_for_max_u16x8 = loaded.val[0];
1439
+ idx += 8;
1440
+ }
1441
+ else if (stride_elements == 4 && idx + 8 <= count) {
1442
+ uint16x8x4_t loaded = vld4q_u16((nk_u16_t const *)data_ptr + idx * 4);
1443
+ data_for_min_u16x8 = loaded.val[0];
1444
+ data_for_max_u16x8 = loaded.val[0];
1445
+ idx += 8;
1446
+ }
1447
+ else if (idx < count) {
1448
+ nk_b128_vec_t tail_vec;
1449
+ nk_strided_load_b16x8_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
1450
+ uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((uint16_t)(count - idx)));
1451
+ data_for_min_u16x8 = vbslq_u16(valid_u16x8, tail_vec.u16x8, min_u16x8);
1452
+ data_for_max_u16x8 = vbslq_u16(valid_u16x8, tail_vec.u16x8, max_u16x8);
1453
+ idx = count;
1454
+ }
1455
+ else {
1456
+ nk_u16_t min_value = vminvq_u16(min_u16x8), max_value = vmaxvq_u16(max_u16x8);
1457
+ uint16x8_t min_value_match_u16x8 = vceqq_u16(min_u16x8, vdupq_n_u16(min_value));
1458
+ uint16x8_t masked_min_iter_u16x8 = vbslq_u16(min_value_match_u16x8, min_iter_u16x8, vdupq_n_u16(0xFFFF));
1459
+ nk_u16_t earliest_min_cycle = vminvq_u16(masked_min_iter_u16x8);
1460
+ uint16x8_t max_value_match_u16x8 = vceqq_u16(max_u16x8, vdupq_n_u16(max_value));
1461
+ uint16x8_t masked_max_iter_u16x8 = vbslq_u16(max_value_match_u16x8, max_iter_u16x8, vdupq_n_u16(0xFFFF));
1462
+ nk_u16_t earliest_max_cycle = vminvq_u16(masked_max_iter_u16x8);
1463
+ uint16x8_t min_cycle_match_u16x8 = vceqq_u16(min_iter_u16x8, vdupq_n_u16(earliest_min_cycle));
1464
+ uint16x8_t min_both_match_u16x8 = vandq_u16(min_value_match_u16x8, min_cycle_match_u16x8);
1465
+ uint16x8_t min_masked_lanes_u16x8 = vbslq_u16(min_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
1466
+ nk_u16_t min_lane_offset = vminvq_u16(min_masked_lanes_u16x8);
1467
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 8 + (nk_size_t)min_lane_offset;
1468
+ uint16x8_t max_cycle_match_u16x8 = vceqq_u16(max_iter_u16x8, vdupq_n_u16(earliest_max_cycle));
1469
+ uint16x8_t max_both_match_u16x8 = vandq_u16(max_value_match_u16x8, max_cycle_match_u16x8);
1470
+ uint16x8_t max_masked_lanes_u16x8 = vbslq_u16(max_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
1471
+ nk_u16_t max_lane_offset = vminvq_u16(max_masked_lanes_u16x8);
1472
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 8 + (nk_size_t)max_lane_offset;
1473
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
1474
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
1475
+ return;
1476
+ }
1477
+
1478
+ // Shared update body
1479
+ uint16x8_t less_u16x8 = vcltq_u16(data_for_min_u16x8, min_u16x8);
1480
+ uint16x8_t greater_u16x8 = vcgtq_u16(data_for_max_u16x8, max_u16x8);
1481
+ min_u16x8 = vbslq_u16(less_u16x8, data_for_min_u16x8, min_u16x8);
1482
+ max_u16x8 = vbslq_u16(greater_u16x8, data_for_max_u16x8, max_u16x8);
1483
+ min_iter_u16x8 = vbslq_u16(less_u16x8, iter_u16x8, min_iter_u16x8);
1484
+ max_iter_u16x8 = vbslq_u16(greater_u16x8, iter_u16x8, max_iter_u16x8);
1485
+ iter_u16x8 = vaddq_u16(iter_u16x8, one_u16x8);
1486
+ goto nk_reduce_minmax_u16_neon_cycle;
1487
+ }
1488
+
1489
+ NK_PUBLIC void nk_reduce_minmax_u16_neon( //
1490
+ nk_u16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1491
+ nk_u16_t *min_value_ptr, nk_size_t *min_index_ptr, //
1492
+ nk_u16_t *max_value_ptr, nk_size_t *max_index_ptr) {
1493
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u16_t);
1494
+ int aligned = (stride_bytes % sizeof(nk_u16_t) == 0);
1495
+ if (count == 0)
1496
+ *min_value_ptr = NK_U16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = 0, *max_index_ptr = NK_SIZE_MAX;
1497
+ else if (!aligned)
1498
+ nk_reduce_minmax_u16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1499
+ max_index_ptr);
1500
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
1501
+ nk_size_t left_count = count / 2;
1502
+ nk_u16_t left_min_value, right_min_value, left_max_value, right_max_value;
1503
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
1504
+ nk_reduce_minmax_u16_neon(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index, &left_max_value,
1505
+ &left_max_index);
1506
+ nk_reduce_minmax_u16_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
1507
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
1508
+ if (right_min_value < left_min_value)
1509
+ *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
1510
+ else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
1511
+ if (right_max_value > left_max_value)
1512
+ *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
1513
+ else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
1514
+ }
1515
+ else if (stride_elements == 1)
1516
+ nk_reduce_minmax_u16_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
1517
+ max_index_ptr);
1518
+ else if (stride_elements <= 4)
1519
+ nk_reduce_minmax_u16_neon_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
1520
+ max_value_ptr, max_index_ptr);
1521
+ else
1522
+ nk_reduce_minmax_u16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1523
+ max_index_ptr);
1524
+ }
1525
+
1526
+ NK_INTERNAL void nk_reduce_moments_i32_neon_contiguous_( //
1527
+ nk_i32_t const *data_ptr, nk_size_t count, //
1528
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1529
+ // 128-bit accumulation: lower (u64) + upper (i64) per lane
1530
+ uint64x2_t sum_lower_u64x2 = vdupq_n_u64(0);
1531
+ int64x2_t sum_upper_i64x2 = vdupq_n_s64(0);
1532
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
1533
+ int sumsq_overflow = 0;
1534
+ // XOR sign-bit trick for unsigned u64 compare on NEON
1535
+ int64x2_t sign_bit_i64x2 = vdupq_n_s64((nk_i64_t)0x8000000000000000ULL);
1536
+ nk_size_t idx = 0;
1537
+ for (; idx + 4 <= count; idx += 4) {
1538
+ int32x4_t data_i32x4 = vld1q_s32(data_ptr + idx);
1539
+ // Sum: widen i32->i64 and accumulate with carry detection
1540
+ int64x2_t data_low_f64x2 = vmovl_s32(vget_low_s32(data_i32x4));
1541
+ uint64x2_t before = sum_lower_u64x2;
1542
+ sum_lower_u64x2 = vaddq_u64(sum_lower_u64x2, vreinterpretq_u64_s64(data_low_f64x2));
1543
+ int64x2_t result_biased = veorq_s64(vreinterpretq_s64_u64(sum_lower_u64x2), sign_bit_i64x2);
1544
+ int64x2_t before_biased = veorq_s64(vreinterpretq_s64_u64(before), sign_bit_i64x2);
1545
+ uint64x2_t carry = vcgtq_s64(before_biased, result_biased);
1546
+ sum_upper_i64x2 = vsubq_s64(sum_upper_i64x2, vreinterpretq_s64_u64(carry));
1547
+ sum_upper_i64x2 = vaddq_s64(sum_upper_i64x2, vshrq_n_s64(data_low_f64x2, 63));
1548
+
1549
+ int64x2_t data_high_f64x2 = vmovl_high_s32(data_i32x4);
1550
+ before = sum_lower_u64x2;
1551
+ sum_lower_u64x2 = vaddq_u64(sum_lower_u64x2, vreinterpretq_u64_s64(data_high_f64x2));
1552
+ result_biased = veorq_s64(vreinterpretq_s64_u64(sum_lower_u64x2), sign_bit_i64x2);
1553
+ before_biased = veorq_s64(vreinterpretq_s64_u64(before), sign_bit_i64x2);
1554
+ carry = vcgtq_s64(before_biased, result_biased);
1555
+ sum_upper_i64x2 = vsubq_s64(sum_upper_i64x2, vreinterpretq_s64_u64(carry));
1556
+ sum_upper_i64x2 = vaddq_s64(sum_upper_i64x2, vshrq_n_s64(data_high_f64x2, 63));
1557
+
1558
+ // Sumsq: widening multiply i32*i32 -> i64 (always non-negative for squares)
1559
+ int64x2_t sq_lo = vmull_s32(vget_low_s32(data_i32x4), vget_low_s32(data_i32x4));
1560
+ int64x2_t sq_hi = vmull_high_s32(data_i32x4, data_i32x4);
1561
+ uint64x2_t sq_before = sumsq_u64x2;
1562
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(sq_lo));
1563
+ result_biased = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
1564
+ before_biased = veorq_s64(vreinterpretq_s64_u64(sq_before), sign_bit_i64x2);
1565
+ sumsq_overflow |= (vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased, result_biased)), 0) |
1566
+ vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased, result_biased)), 1));
1567
+ sq_before = sumsq_u64x2;
1568
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(sq_hi));
1569
+ result_biased = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
1570
+ before_biased = veorq_s64(vreinterpretq_s64_u64(sq_before), sign_bit_i64x2);
1571
+ sumsq_overflow |= (vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased, result_biased)), 0) |
1572
+ vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased, result_biased)), 1));
1573
+ }
1574
+ // Sumsq horizontal saturating reduction
1575
+ nk_u64_t sumsq;
1576
+ if (sumsq_overflow) sumsq = NK_U64_MAX;
1577
+ else sumsq = nk_reduce_sadd_u64x2_neon_(sumsq_u64x2);
1578
+ // Sum: horizontal 128-bit reduction (2 lanes -> scalar)
1579
+ nk_b128_vec_t lower_vec, upper_vec;
1580
+ lower_vec.u64x2 = sum_lower_u64x2;
1581
+ upper_vec.i64x2 = sum_upper_i64x2;
1582
+ nk_u64_t sum_lower = 0;
1583
+ nk_i64_t sum_upper = 0;
1584
+ nk_u64_t sum_before = sum_lower;
1585
+ sum_lower += lower_vec.u64s[0], sum_upper += (sum_lower < sum_before) + upper_vec.i64s[0];
1586
+ sum_before = sum_lower;
1587
+ sum_lower += lower_vec.u64s[1], sum_upper += (sum_lower < sum_before) + upper_vec.i64s[1];
1588
+ // Scalar tail
1589
+ for (; idx < count; ++idx) {
1590
+ nk_i64_t value_i64 = (nk_i64_t)data_ptr[idx];
1591
+ sum_before = sum_lower;
1592
+ sum_lower += (nk_u64_t)value_i64;
1593
+ if (sum_lower < sum_before) sum_upper++;
1594
+ sum_upper += (value_i64 >> 63);
1595
+ nk_i64_t product = nk_i64_saturating_mul_serial(value_i64, value_i64);
1596
+ nk_u64_t unsigned_product = (nk_u64_t)product;
1597
+ sumsq = nk_u64_saturating_add_serial(sumsq, unsigned_product);
1598
+ }
1599
+ // Clamp 128-bit sum to i64 range
1600
+ nk_i64_t sum_lower_signed = (nk_i64_t)sum_lower;
1601
+ if (sum_upper == (sum_lower_signed >> 63)) *sum_ptr = sum_lower_signed;
1602
+ else if (sum_upper >= 0) *sum_ptr = NK_I64_MAX;
1603
+ else *sum_ptr = NK_I64_MIN;
1604
+ *sumsq_ptr = sumsq;
1605
+ }
1606
+
1607
+ NK_INTERNAL void nk_reduce_moments_i32_neon_strided_( //
1608
+ nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
1609
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1610
+ uint64x2_t sum_lower_u64x2 = vdupq_n_u64(0);
1611
+ int64x2_t sum_upper_i64x2 = vdupq_n_s64(0);
1612
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
1613
+ int sumsq_overflow = 0;
1614
+ int64x2_t sign_bit_i64x2 = vdupq_n_s64((nk_i64_t)0x8000000000000000ULL);
1615
+ nk_size_t idx = 0;
1616
+ if (stride_elements == 2) {
1617
+ for (; idx + 4 <= count; idx += 4) {
1618
+ int32x4x2_t loaded_i32x4x2 = vld2q_s32(data_ptr + idx * 2);
1619
+ int32x4_t data_i32x4 = loaded_i32x4x2.val[0];
1620
+ int64x2_t lo_i64x2 = vmovl_s32(vget_low_s32(data_i32x4));
1621
+ uint64x2_t before_u64x2 = sum_lower_u64x2;
1622
+ sum_lower_u64x2 = vaddq_u64(sum_lower_u64x2, vreinterpretq_u64_s64(lo_i64x2));
1623
+ int64x2_t result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_lower_u64x2), sign_bit_i64x2);
1624
+ int64x2_t before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
1625
+ uint64x2_t carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
1626
+ sum_upper_i64x2 = vsubq_s64(sum_upper_i64x2, vreinterpretq_s64_u64(carry_u64x2));
1627
+ sum_upper_i64x2 = vaddq_s64(sum_upper_i64x2, vshrq_n_s64(lo_i64x2, 63));
1628
+ int64x2_t hi_i64x2 = vmovl_high_s32(data_i32x4);
1629
+ before_u64x2 = sum_lower_u64x2;
1630
+ sum_lower_u64x2 = vaddq_u64(sum_lower_u64x2, vreinterpretq_u64_s64(hi_i64x2));
1631
+ result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_lower_u64x2), sign_bit_i64x2);
1632
+ before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
1633
+ carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
1634
+ sum_upper_i64x2 = vsubq_s64(sum_upper_i64x2, vreinterpretq_s64_u64(carry_u64x2));
1635
+ sum_upper_i64x2 = vaddq_s64(sum_upper_i64x2, vshrq_n_s64(hi_i64x2, 63));
1636
+ int64x2_t squares_lo_i64x2 = vmull_s32(vget_low_s32(data_i32x4), vget_low_s32(data_i32x4));
1637
+ int64x2_t squares_hi_i64x2 = vmull_high_s32(data_i32x4, data_i32x4);
1638
+ uint64x2_t sq_before_u64x2 = sumsq_u64x2;
1639
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_lo_i64x2));
1640
+ result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
1641
+ before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
1642
+ sumsq_overflow |=
1643
+ (vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 0) |
1644
+ vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 1));
1645
+ sq_before_u64x2 = sumsq_u64x2;
1646
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_hi_i64x2));
1647
+ result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
1648
+ before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
1649
+ sumsq_overflow |=
1650
+ (vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 0) |
1651
+ vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 1));
1652
+ }
1653
+ }
1654
+ else if (stride_elements == 3) {
1655
+ for (; idx + 4 <= count; idx += 4) {
1656
+ int32x4x3_t loaded_i32x4x3 = vld3q_s32(data_ptr + idx * 3);
1657
+ int32x4_t data_i32x4 = loaded_i32x4x3.val[0];
1658
+ int64x2_t lo_i64x2 = vmovl_s32(vget_low_s32(data_i32x4));
1659
+ uint64x2_t before_u64x2 = sum_lower_u64x2;
1660
+ sum_lower_u64x2 = vaddq_u64(sum_lower_u64x2, vreinterpretq_u64_s64(lo_i64x2));
1661
+ int64x2_t result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_lower_u64x2), sign_bit_i64x2);
1662
+ int64x2_t before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
1663
+ uint64x2_t carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
1664
+ sum_upper_i64x2 = vsubq_s64(sum_upper_i64x2, vreinterpretq_s64_u64(carry_u64x2));
1665
+ sum_upper_i64x2 = vaddq_s64(sum_upper_i64x2, vshrq_n_s64(lo_i64x2, 63));
1666
+ int64x2_t hi_i64x2 = vmovl_high_s32(data_i32x4);
1667
+ before_u64x2 = sum_lower_u64x2;
1668
+ sum_lower_u64x2 = vaddq_u64(sum_lower_u64x2, vreinterpretq_u64_s64(hi_i64x2));
1669
+ result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_lower_u64x2), sign_bit_i64x2);
1670
+ before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
1671
+ carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
1672
+ sum_upper_i64x2 = vsubq_s64(sum_upper_i64x2, vreinterpretq_s64_u64(carry_u64x2));
1673
+ sum_upper_i64x2 = vaddq_s64(sum_upper_i64x2, vshrq_n_s64(hi_i64x2, 63));
1674
+ int64x2_t squares_lo_i64x2 = vmull_s32(vget_low_s32(data_i32x4), vget_low_s32(data_i32x4));
1675
+ int64x2_t squares_hi_i64x2 = vmull_high_s32(data_i32x4, data_i32x4);
1676
+ uint64x2_t sq_before_u64x2 = sumsq_u64x2;
1677
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_lo_i64x2));
1678
+ result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
1679
+ before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
1680
+ sumsq_overflow |=
1681
+ (vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 0) |
1682
+ vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 1));
1683
+ sq_before_u64x2 = sumsq_u64x2;
1684
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_hi_i64x2));
1685
+ result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
1686
+ before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
1687
+ sumsq_overflow |=
1688
+ (vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 0) |
1689
+ vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 1));
1690
+ }
1691
+ }
1692
+ else {
1693
+ for (; idx + 4 <= count; idx += 4) {
1694
+ int32x4x4_t loaded_i32x4x4 = vld4q_s32(data_ptr + idx * 4);
1695
+ int32x4_t data_i32x4 = loaded_i32x4x4.val[0];
1696
+ int64x2_t lo_i64x2 = vmovl_s32(vget_low_s32(data_i32x4));
1697
+ uint64x2_t before_u64x2 = sum_lower_u64x2;
1698
+ sum_lower_u64x2 = vaddq_u64(sum_lower_u64x2, vreinterpretq_u64_s64(lo_i64x2));
1699
+ int64x2_t result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_lower_u64x2), sign_bit_i64x2);
1700
+ int64x2_t before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
1701
+ uint64x2_t carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
1702
+ sum_upper_i64x2 = vsubq_s64(sum_upper_i64x2, vreinterpretq_s64_u64(carry_u64x2));
1703
+ sum_upper_i64x2 = vaddq_s64(sum_upper_i64x2, vshrq_n_s64(lo_i64x2, 63));
1704
+ int64x2_t hi_i64x2 = vmovl_high_s32(data_i32x4);
1705
+ before_u64x2 = sum_lower_u64x2;
1706
+ sum_lower_u64x2 = vaddq_u64(sum_lower_u64x2, vreinterpretq_u64_s64(hi_i64x2));
1707
+ result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sum_lower_u64x2), sign_bit_i64x2);
1708
+ before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(before_u64x2), sign_bit_i64x2);
1709
+ carry_u64x2 = vcgtq_s64(before_biased_i64x2, result_biased_i64x2);
1710
+ sum_upper_i64x2 = vsubq_s64(sum_upper_i64x2, vreinterpretq_s64_u64(carry_u64x2));
1711
+ sum_upper_i64x2 = vaddq_s64(sum_upper_i64x2, vshrq_n_s64(hi_i64x2, 63));
1712
+ int64x2_t squares_lo_i64x2 = vmull_s32(vget_low_s32(data_i32x4), vget_low_s32(data_i32x4));
1713
+ int64x2_t squares_hi_i64x2 = vmull_high_s32(data_i32x4, data_i32x4);
1714
+ uint64x2_t sq_before_u64x2 = sumsq_u64x2;
1715
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_lo_i64x2));
1716
+ result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
1717
+ before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
1718
+ sumsq_overflow |=
1719
+ (vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 0) |
1720
+ vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 1));
1721
+ sq_before_u64x2 = sumsq_u64x2;
1722
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vreinterpretq_u64_s64(squares_hi_i64x2));
1723
+ result_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
1724
+ before_biased_i64x2 = veorq_s64(vreinterpretq_s64_u64(sq_before_u64x2), sign_bit_i64x2);
1725
+ sumsq_overflow |=
1726
+ (vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 0) |
1727
+ vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased_i64x2, result_biased_i64x2)), 1));
1728
+ }
1729
+ }
1730
+ nk_u64_t sumsq;
1731
+ if (sumsq_overflow) sumsq = NK_U64_MAX;
1732
+ else sumsq = nk_reduce_sadd_u64x2_neon_(sumsq_u64x2);
1733
+ nk_b128_vec_t lower_vec, upper_vec;
1734
+ lower_vec.u64x2 = sum_lower_u64x2;
1735
+ upper_vec.i64x2 = sum_upper_i64x2;
1736
+ nk_u64_t sum_lower = 0;
1737
+ nk_i64_t sum_upper = 0;
1738
+ nk_u64_t sum_before = sum_lower;
1739
+ sum_lower += lower_vec.u64s[0], sum_upper += (sum_lower < sum_before) + upper_vec.i64s[0];
1740
+ sum_before = sum_lower;
1741
+ sum_lower += lower_vec.u64s[1], sum_upper += (sum_lower < sum_before) + upper_vec.i64s[1];
1742
+ for (; idx < count; ++idx) {
1743
+ nk_i64_t val = (nk_i64_t) * (data_ptr + idx * stride_elements);
1744
+ sum_before = sum_lower;
1745
+ sum_lower += (nk_u64_t)val;
1746
+ if (sum_lower < sum_before) sum_upper++;
1747
+ sum_upper += (val >> 63);
1748
+ nk_i64_t product = nk_i64_saturating_mul_serial(val, val);
1749
+ nk_u64_t unsigned_product = (nk_u64_t)product;
1750
+ sumsq = nk_u64_saturating_add_serial(sumsq, unsigned_product);
1751
+ }
1752
+ nk_i64_t sum_lower_signed = (nk_i64_t)sum_lower;
1753
+ if (sum_upper == (sum_lower_signed >> 63)) *sum_ptr = sum_lower_signed;
1754
+ else if (sum_upper >= 0) *sum_ptr = NK_I64_MAX;
1755
+ else *sum_ptr = NK_I64_MIN;
1756
+ *sumsq_ptr = sumsq;
1757
+ }
1758
+
1759
+ NK_PUBLIC void nk_reduce_moments_i32_neon( //
1760
+ nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1761
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1762
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i32_t);
1763
+ int aligned = (stride_bytes % sizeof(nk_i32_t) == 0);
1764
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
1765
+ else if (!aligned) nk_reduce_moments_i32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
1766
+ else if (stride_elements == 1) nk_reduce_moments_i32_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
1767
+ else if (stride_elements <= 4)
1768
+ nk_reduce_moments_i32_neon_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
1769
+ else nk_reduce_moments_i32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
1770
+ }
1771
+
1772
+ NK_INTERNAL void nk_reduce_minmax_i32_neon_contiguous_( //
1773
+ nk_i32_t const *data_ptr, nk_size_t count, //
1774
+ nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
1775
+ nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
1776
+ int32x4_t min_i32x4 = vdupq_n_s32(NK_I32_MAX), max_i32x4 = vdupq_n_s32(NK_I32_MIN);
1777
+ uint32x4_t min_iter_u32x4 = vdupq_n_u32(0), max_iter_u32x4 = vdupq_n_u32(0);
1778
+ uint32x4_t iter_u32x4 = vdupq_n_u32(0), one_u32x4 = vdupq_n_u32(1);
1779
+ nk_size_t idx = 0;
1780
+ for (; idx + 4 <= count; idx += 4) {
1781
+ int32x4_t data_i32x4 = vld1q_s32(data_ptr + idx);
1782
+ uint32x4_t less_u32x4 = vcltq_s32(data_i32x4, min_i32x4);
1783
+ uint32x4_t greater_u32x4 = vcgtq_s32(data_i32x4, max_i32x4);
1784
+ min_i32x4 = vbslq_s32(less_u32x4, data_i32x4, min_i32x4);
1785
+ max_i32x4 = vbslq_s32(greater_u32x4, data_i32x4, max_i32x4);
1786
+ min_iter_u32x4 = vbslq_u32(less_u32x4, iter_u32x4, min_iter_u32x4);
1787
+ max_iter_u32x4 = vbslq_u32(greater_u32x4, iter_u32x4, max_iter_u32x4);
1788
+ iter_u32x4 = vaddq_u32(iter_u32x4, one_u32x4);
1789
+ }
1790
+ nk_size_t remaining = count - idx;
1791
+ if (remaining > 0) {
1792
+ nk_b128_vec_t tail_vec;
1793
+ nk_partial_load_b32x4_serial_(data_ptr + idx, &tail_vec, remaining);
1794
+ uint32x4_t lane_indices_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
1795
+ vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
1796
+ uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((uint32_t)remaining));
1797
+ int32x4_t data_min_i32x4 = vbslq_s32(valid_u32x4, tail_vec.i32x4, vdupq_n_s32(NK_I32_MAX));
1798
+ int32x4_t data_max_i32x4 = vbslq_s32(valid_u32x4, tail_vec.i32x4, vdupq_n_s32(NK_I32_MIN));
1799
+ uint32x4_t less_u32x4 = vcltq_s32(data_min_i32x4, min_i32x4);
1800
+ uint32x4_t greater_u32x4 = vcgtq_s32(data_max_i32x4, max_i32x4);
1801
+ min_i32x4 = vbslq_s32(less_u32x4, data_min_i32x4, min_i32x4);
1802
+ max_i32x4 = vbslq_s32(greater_u32x4, data_max_i32x4, max_i32x4);
1803
+ min_iter_u32x4 = vbslq_u32(less_u32x4, iter_u32x4, min_iter_u32x4);
1804
+ max_iter_u32x4 = vbslq_u32(greater_u32x4, iter_u32x4, max_iter_u32x4);
1805
+ }
1806
+ nk_i32_t min_value = vminvq_s32(min_i32x4), max_value = vmaxvq_s32(max_i32x4);
1807
+ uint32x4_t min_value_match_u32x4 = vceqq_s32(min_i32x4, vdupq_n_s32(min_value));
1808
+ uint32x4_t masked_min_iter_u32x4 = vbslq_u32(min_value_match_u32x4, min_iter_u32x4, vdupq_n_u32(NK_U32_MAX));
1809
+ nk_u32_t earliest_min_cycle = vminvq_u32(masked_min_iter_u32x4);
1810
+ uint32x4_t max_value_match_u32x4 = vceqq_s32(max_i32x4, vdupq_n_s32(max_value));
1811
+ uint32x4_t masked_max_iter_u32x4 = vbslq_u32(max_value_match_u32x4, max_iter_u32x4, vdupq_n_u32(NK_U32_MAX));
1812
+ nk_u32_t earliest_max_cycle = vminvq_u32(masked_max_iter_u32x4);
1813
+ uint32x4_t lane_indices_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
1814
+ vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
1815
+ uint32x4_t min_cycle_match_u32x4 = vceqq_u32(min_iter_u32x4, vdupq_n_u32(earliest_min_cycle));
1816
+ uint32x4_t min_both_match_u32x4 = vandq_u32(min_value_match_u32x4, min_cycle_match_u32x4);
1817
+ uint32x4_t min_masked_lanes_u32x4 = vbslq_u32(min_both_match_u32x4, lane_indices_u32x4, vdupq_n_u32(NK_U32_MAX));
1818
+ nk_u32_t min_lane_offset = vminvq_u32(min_masked_lanes_u32x4);
1819
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 4 + (nk_size_t)min_lane_offset;
1820
+ uint32x4_t max_cycle_match_u32x4 = vceqq_u32(max_iter_u32x4, vdupq_n_u32(earliest_max_cycle));
1821
+ uint32x4_t max_both_match_u32x4 = vandq_u32(max_value_match_u32x4, max_cycle_match_u32x4);
1822
+ uint32x4_t max_masked_lanes_u32x4 = vbslq_u32(max_both_match_u32x4, lane_indices_u32x4, vdupq_n_u32(NK_U32_MAX));
1823
+ nk_u32_t max_lane_offset = vminvq_u32(max_masked_lanes_u32x4);
1824
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 4 + (nk_size_t)max_lane_offset;
1825
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
1826
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
1827
+ }
1828
+
1829
+ NK_INTERNAL void nk_reduce_minmax_i32_neon_strided_( //
1830
+ nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
1831
+ nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
1832
+ nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
1833
+ int32x4_t min_i32x4 = vdupq_n_s32(NK_I32_MAX), max_i32x4 = vdupq_n_s32(NK_I32_MIN);
1834
+ uint32x4_t min_iter_u32x4 = vdupq_n_u32(0), max_iter_u32x4 = vdupq_n_u32(0);
1835
+ uint32x4_t iter_u32x4 = vdupq_n_u32(0), one_u32x4 = vdupq_n_u32(1);
1836
+ uint32x4_t lane_indices_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
1837
+ vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
1838
+ nk_size_t idx = 0;
1839
+ int32x4_t data_for_min_i32x4, data_for_max_i32x4;
1840
+
1841
+ nk_reduce_minmax_i32_neon_cycle:
1842
+ if (stride_elements == 2 && idx + 4 <= count) {
1843
+ int32x4x2_t loaded = vld2q_s32(data_ptr + idx * 2);
1844
+ data_for_min_i32x4 = loaded.val[0];
1845
+ data_for_max_i32x4 = loaded.val[0];
1846
+ idx += 4;
1847
+ }
1848
+ else if (stride_elements == 3 && idx + 4 <= count) {
1849
+ int32x4x3_t loaded = vld3q_s32(data_ptr + idx * 3);
1850
+ data_for_min_i32x4 = loaded.val[0];
1851
+ data_for_max_i32x4 = loaded.val[0];
1852
+ idx += 4;
1853
+ }
1854
+ else if (stride_elements == 4 && idx + 4 <= count) {
1855
+ int32x4x4_t loaded = vld4q_s32(data_ptr + idx * 4);
1856
+ data_for_min_i32x4 = loaded.val[0];
1857
+ data_for_max_i32x4 = loaded.val[0];
1858
+ idx += 4;
1859
+ }
1860
+ else if (idx < count) {
1861
+ nk_b128_vec_t tail_vec;
1862
+ nk_strided_load_b32x4_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
1863
+ uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((uint32_t)(count - idx)));
1864
+ data_for_min_i32x4 = vbslq_s32(valid_u32x4, tail_vec.i32x4, min_i32x4);
1865
+ data_for_max_i32x4 = vbslq_s32(valid_u32x4, tail_vec.i32x4, max_i32x4);
1866
+ idx = count;
1867
+ }
1868
+ else {
1869
+ nk_i32_t min_value = vminvq_s32(min_i32x4), max_value = vmaxvq_s32(max_i32x4);
1870
+ uint32x4_t min_value_match_u32x4 = vceqq_s32(min_i32x4, vdupq_n_s32(min_value));
1871
+ uint32x4_t masked_min_iter_u32x4 = vbslq_u32(min_value_match_u32x4, min_iter_u32x4, vdupq_n_u32(NK_U32_MAX));
1872
+ nk_u32_t earliest_min_cycle = vminvq_u32(masked_min_iter_u32x4);
1873
+ uint32x4_t max_value_match_u32x4 = vceqq_s32(max_i32x4, vdupq_n_s32(max_value));
1874
+ uint32x4_t masked_max_iter_u32x4 = vbslq_u32(max_value_match_u32x4, max_iter_u32x4, vdupq_n_u32(NK_U32_MAX));
1875
+ nk_u32_t earliest_max_cycle = vminvq_u32(masked_max_iter_u32x4);
1876
+ uint32x4_t min_cycle_match_u32x4 = vceqq_u32(min_iter_u32x4, vdupq_n_u32(earliest_min_cycle));
1877
+ uint32x4_t min_both_match_u32x4 = vandq_u32(min_value_match_u32x4, min_cycle_match_u32x4);
1878
+ uint32x4_t min_masked_lanes_u32x4 = vbslq_u32(min_both_match_u32x4, lane_indices_u32x4,
1879
+ vdupq_n_u32(NK_U32_MAX));
1880
+ nk_u32_t min_lane_offset = vminvq_u32(min_masked_lanes_u32x4);
1881
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 4 + (nk_size_t)min_lane_offset;
1882
+ uint32x4_t max_cycle_match_u32x4 = vceqq_u32(max_iter_u32x4, vdupq_n_u32(earliest_max_cycle));
1883
+ uint32x4_t max_both_match_u32x4 = vandq_u32(max_value_match_u32x4, max_cycle_match_u32x4);
1884
+ uint32x4_t max_masked_lanes_u32x4 = vbslq_u32(max_both_match_u32x4, lane_indices_u32x4,
1885
+ vdupq_n_u32(NK_U32_MAX));
1886
+ nk_u32_t max_lane_offset = vminvq_u32(max_masked_lanes_u32x4);
1887
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 4 + (nk_size_t)max_lane_offset;
1888
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
1889
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
1890
+ return;
1891
+ }
1892
+
1893
+ // Shared update body
1894
+ uint32x4_t less_u32x4 = vcltq_s32(data_for_min_i32x4, min_i32x4);
1895
+ uint32x4_t greater_u32x4 = vcgtq_s32(data_for_max_i32x4, max_i32x4);
1896
+ min_i32x4 = vbslq_s32(less_u32x4, data_for_min_i32x4, min_i32x4);
1897
+ max_i32x4 = vbslq_s32(greater_u32x4, data_for_max_i32x4, max_i32x4);
1898
+ min_iter_u32x4 = vbslq_u32(less_u32x4, iter_u32x4, min_iter_u32x4);
1899
+ max_iter_u32x4 = vbslq_u32(greater_u32x4, iter_u32x4, max_iter_u32x4);
1900
+ iter_u32x4 = vaddq_u32(iter_u32x4, one_u32x4);
1901
+ goto nk_reduce_minmax_i32_neon_cycle;
1902
+ }
1903
+
1904
+ NK_PUBLIC void nk_reduce_minmax_i32_neon( //
1905
+ nk_i32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
1906
+ nk_i32_t *min_value_ptr, nk_size_t *min_index_ptr, //
1907
+ nk_i32_t *max_value_ptr, nk_size_t *max_index_ptr) {
1908
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i32_t);
1909
+ int aligned = (stride_bytes % sizeof(nk_i32_t) == 0);
1910
+ if (count == 0)
1911
+ *min_value_ptr = NK_I32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I32_MIN,
1912
+ *max_index_ptr = NK_SIZE_MAX;
1913
+ else if (!aligned)
1914
+ nk_reduce_minmax_i32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1915
+ max_index_ptr);
1916
+ else if (count > (nk_size_t)NK_U32_MAX * 4) {
1917
+ nk_size_t left_count = count / 2;
1918
+ nk_i32_t left_min_value, right_min_value, left_max_value, right_max_value;
1919
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
1920
+ nk_reduce_minmax_i32_neon(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index, &left_max_value,
1921
+ &left_max_index);
1922
+ nk_reduce_minmax_i32_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
1923
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
1924
+ if (right_min_value < left_min_value)
1925
+ *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
1926
+ else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
1927
+ if (right_max_value > left_max_value)
1928
+ *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
1929
+ else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
1930
+ }
1931
+ else if (stride_elements == 1)
1932
+ nk_reduce_minmax_i32_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
1933
+ max_index_ptr);
1934
+ else if (stride_elements <= 4)
1935
+ nk_reduce_minmax_i32_neon_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
1936
+ max_value_ptr, max_index_ptr);
1937
+ else
1938
+ nk_reduce_minmax_i32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
1939
+ max_index_ptr);
1940
+ }
1941
+
1942
+ NK_INTERNAL void nk_reduce_moments_u32_neon_contiguous_( //
1943
+ nk_u32_t const *data_ptr, nk_size_t count, //
1944
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1945
+ uint64x2_t sum_u64x2 = vdupq_n_u64(0);
1946
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
1947
+ nk_size_t idx = 0;
1948
+ for (; idx + 4 <= count; idx += 4) {
1949
+ uint32x4_t data_u32x4 = vld1q_u32(data_ptr + idx);
1950
+ // Widen u32 -> u64 and accumulate sum
1951
+ sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_u32(vget_low_u32(data_u32x4)));
1952
+ sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_high_u32(data_u32x4));
1953
+ // Sumsq: widening multiply u32*u32 -> u64, saturating add
1954
+ uint64x2_t sq_lo = vmull_u32(vget_low_u32(data_u32x4), vget_low_u32(data_u32x4));
1955
+ uint64x2_t sq_hi = vmull_high_u32(data_u32x4, data_u32x4);
1956
+ sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, sq_lo);
1957
+ sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, sq_hi);
1958
+ }
1959
+ nk_u64_t sum = vgetq_lane_u64(sum_u64x2, 0) + vgetq_lane_u64(sum_u64x2, 1);
1960
+ nk_u64_t sumsq = nk_reduce_sadd_u64x2_neon_(sumsq_u64x2);
1961
+ for (; idx < count; ++idx) {
1962
+ nk_u64_t value = (nk_u64_t)data_ptr[idx];
1963
+ sum += value;
1964
+ nk_u64_t product = value * value;
1965
+ sumsq = nk_u64_saturating_add_serial(sumsq, product);
1966
+ }
1967
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
1968
+ }
1969
+
1970
+ NK_INTERNAL void nk_reduce_moments_u32_neon_strided_( //
1971
+ nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
1972
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
1973
+ uint64x2_t sum_u64x2 = vdupq_n_u64(0);
1974
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
1975
+ nk_size_t idx = 0;
1976
+ if (stride_elements == 2) {
1977
+ for (; idx + 4 <= count; idx += 4) {
1978
+ uint32x4x2_t loaded_u32x4x2 = vld2q_u32(data_ptr + idx * 2);
1979
+ uint32x4_t data_u32x4 = loaded_u32x4x2.val[0];
1980
+ sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_u32(vget_low_u32(data_u32x4)));
1981
+ sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_high_u32(data_u32x4));
1982
+ uint64x2_t squares_lo_u64x2 = vmull_u32(vget_low_u32(data_u32x4), vget_low_u32(data_u32x4));
1983
+ uint64x2_t squares_hi_u64x2 = vmull_high_u32(data_u32x4, data_u32x4);
1984
+ sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, squares_lo_u64x2);
1985
+ sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, squares_hi_u64x2);
1986
+ }
1987
+ }
1988
+ else if (stride_elements == 3) {
1989
+ for (; idx + 4 <= count; idx += 4) {
1990
+ uint32x4x3_t loaded_u32x4x3 = vld3q_u32(data_ptr + idx * 3);
1991
+ uint32x4_t data_u32x4 = loaded_u32x4x3.val[0];
1992
+ sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_u32(vget_low_u32(data_u32x4)));
1993
+ sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_high_u32(data_u32x4));
1994
+ uint64x2_t squares_lo_u64x2 = vmull_u32(vget_low_u32(data_u32x4), vget_low_u32(data_u32x4));
1995
+ uint64x2_t squares_hi_u64x2 = vmull_high_u32(data_u32x4, data_u32x4);
1996
+ sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, squares_lo_u64x2);
1997
+ sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, squares_hi_u64x2);
1998
+ }
1999
+ }
2000
+ else {
2001
+ for (; idx + 4 <= count; idx += 4) {
2002
+ uint32x4x4_t loaded_u32x4x4 = vld4q_u32(data_ptr + idx * 4);
2003
+ uint32x4_t data_u32x4 = loaded_u32x4x4.val[0];
2004
+ sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_u32(vget_low_u32(data_u32x4)));
2005
+ sum_u64x2 = vaddq_u64(sum_u64x2, vmovl_high_u32(data_u32x4));
2006
+ uint64x2_t squares_lo_u64x2 = vmull_u32(vget_low_u32(data_u32x4), vget_low_u32(data_u32x4));
2007
+ uint64x2_t squares_hi_u64x2 = vmull_high_u32(data_u32x4, data_u32x4);
2008
+ sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, squares_lo_u64x2);
2009
+ sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, squares_hi_u64x2);
2010
+ }
2011
+ }
2012
+ nk_u64_t sum = vgetq_lane_u64(sum_u64x2, 0) + vgetq_lane_u64(sum_u64x2, 1);
2013
+ nk_u64_t sumsq = nk_reduce_sadd_u64x2_neon_(sumsq_u64x2);
2014
+ for (; idx < count; ++idx) {
2015
+ nk_u64_t val = (nk_u64_t) * (data_ptr + idx * stride_elements);
2016
+ sum += val;
2017
+ nk_u64_t product = val * val;
2018
+ sumsq = nk_u64_saturating_add_serial(sumsq, product);
2019
+ }
2020
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
2021
+ }
2022
+
2023
+ NK_PUBLIC void nk_reduce_moments_u32_neon( //
2024
+ nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2025
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
2026
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u32_t);
2027
+ int aligned = (stride_bytes % sizeof(nk_u32_t) == 0);
2028
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
2029
+ else if (!aligned) nk_reduce_moments_u32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2030
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 4) {
2031
+ nk_size_t left_count = count / 2;
2032
+ nk_u64_t left_sum, left_sumsq, right_sum, right_sumsq;
2033
+ nk_reduce_moments_u32_neon(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
2034
+ nk_reduce_moments_u32_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
2035
+ &right_sum, &right_sumsq);
2036
+ *sum_ptr = nk_u64_saturating_add_serial(left_sum, right_sum);
2037
+ *sumsq_ptr = nk_u64_saturating_add_serial(left_sumsq, right_sumsq);
2038
+ }
2039
+ else if (stride_elements == 1) nk_reduce_moments_u32_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
2040
+ else if (stride_elements <= 4)
2041
+ nk_reduce_moments_u32_neon_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
2042
+ else nk_reduce_moments_u32_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2043
+ }
2044
+
2045
+ NK_INTERNAL void nk_reduce_minmax_u32_neon_contiguous_( //
2046
+ nk_u32_t const *data_ptr, nk_size_t count, //
2047
+ nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
2048
+ nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
2049
+ uint32x4_t min_u32x4 = vdupq_n_u32(NK_U32_MAX), max_u32x4 = vdupq_n_u32(0);
2050
+ uint32x4_t min_iter_u32x4 = vdupq_n_u32(0), max_iter_u32x4 = vdupq_n_u32(0);
2051
+ uint32x4_t iter_u32x4 = vdupq_n_u32(0), one_u32x4 = vdupq_n_u32(1);
2052
+ nk_size_t idx = 0;
2053
+ for (; idx + 4 <= count; idx += 4) {
2054
+ uint32x4_t data_u32x4 = vld1q_u32(data_ptr + idx);
2055
+ uint32x4_t less_u32x4 = vcltq_u32(data_u32x4, min_u32x4);
2056
+ uint32x4_t greater_u32x4 = vcgtq_u32(data_u32x4, max_u32x4);
2057
+ min_u32x4 = vbslq_u32(less_u32x4, data_u32x4, min_u32x4);
2058
+ max_u32x4 = vbslq_u32(greater_u32x4, data_u32x4, max_u32x4);
2059
+ min_iter_u32x4 = vbslq_u32(less_u32x4, iter_u32x4, min_iter_u32x4);
2060
+ max_iter_u32x4 = vbslq_u32(greater_u32x4, iter_u32x4, max_iter_u32x4);
2061
+ iter_u32x4 = vaddq_u32(iter_u32x4, one_u32x4);
2062
+ }
2063
+ nk_size_t remaining = count - idx;
2064
+ if (remaining > 0) {
2065
+ nk_b128_vec_t tail_vec;
2066
+ nk_partial_load_b32x4_serial_(data_ptr + idx, &tail_vec, remaining);
2067
+ uint32x4_t lane_indices_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
2068
+ vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
2069
+ uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((uint32_t)remaining));
2070
+ uint32x4_t data_min_u32x4 = vbslq_u32(valid_u32x4, tail_vec.u32x4, vdupq_n_u32(NK_U32_MAX));
2071
+ uint32x4_t data_max_u32x4 = vbslq_u32(valid_u32x4, tail_vec.u32x4, vdupq_n_u32(0));
2072
+ uint32x4_t less_u32x4 = vcltq_u32(data_min_u32x4, min_u32x4);
2073
+ uint32x4_t greater_u32x4 = vcgtq_u32(data_max_u32x4, max_u32x4);
2074
+ min_u32x4 = vbslq_u32(less_u32x4, data_min_u32x4, min_u32x4);
2075
+ max_u32x4 = vbslq_u32(greater_u32x4, data_max_u32x4, max_u32x4);
2076
+ min_iter_u32x4 = vbslq_u32(less_u32x4, iter_u32x4, min_iter_u32x4);
2077
+ max_iter_u32x4 = vbslq_u32(greater_u32x4, iter_u32x4, max_iter_u32x4);
2078
+ }
2079
+ nk_u32_t min_value = vminvq_u32(min_u32x4), max_value = vmaxvq_u32(max_u32x4);
2080
+ uint32x4_t min_value_match_u32x4 = vceqq_u32(min_u32x4, vdupq_n_u32(min_value));
2081
+ uint32x4_t masked_min_iter_u32x4 = vbslq_u32(min_value_match_u32x4, min_iter_u32x4, vdupq_n_u32(NK_U32_MAX));
2082
+ nk_u32_t earliest_min_cycle = vminvq_u32(masked_min_iter_u32x4);
2083
+ uint32x4_t max_value_match_u32x4 = vceqq_u32(max_u32x4, vdupq_n_u32(max_value));
2084
+ uint32x4_t masked_max_iter_u32x4 = vbslq_u32(max_value_match_u32x4, max_iter_u32x4, vdupq_n_u32(NK_U32_MAX));
2085
+ nk_u32_t earliest_max_cycle = vminvq_u32(masked_max_iter_u32x4);
2086
+ uint32x4_t lane_indices_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
2087
+ vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
2088
+ uint32x4_t min_cycle_match_u32x4 = vceqq_u32(min_iter_u32x4, vdupq_n_u32(earliest_min_cycle));
2089
+ uint32x4_t min_both_match_u32x4 = vandq_u32(min_value_match_u32x4, min_cycle_match_u32x4);
2090
+ uint32x4_t min_masked_lanes_u32x4 = vbslq_u32(min_both_match_u32x4, lane_indices_u32x4, vdupq_n_u32(NK_U32_MAX));
2091
+ nk_u32_t min_lane_offset = vminvq_u32(min_masked_lanes_u32x4);
2092
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 4 + (nk_size_t)min_lane_offset;
2093
+ uint32x4_t max_cycle_match_u32x4 = vceqq_u32(max_iter_u32x4, vdupq_n_u32(earliest_max_cycle));
2094
+ uint32x4_t max_both_match_u32x4 = vandq_u32(max_value_match_u32x4, max_cycle_match_u32x4);
2095
+ uint32x4_t max_masked_lanes_u32x4 = vbslq_u32(max_both_match_u32x4, lane_indices_u32x4, vdupq_n_u32(NK_U32_MAX));
2096
+ nk_u32_t max_lane_offset = vminvq_u32(max_masked_lanes_u32x4);
2097
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 4 + (nk_size_t)max_lane_offset;
2098
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
2099
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
2100
+ }
2101
+
2102
+ NK_INTERNAL void nk_reduce_minmax_u32_neon_strided_( //
2103
+ nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
2104
+ nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
2105
+ nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
2106
+ uint32x4_t min_u32x4 = vdupq_n_u32(NK_U32_MAX), max_u32x4 = vdupq_n_u32(0);
2107
+ uint32x4_t min_iter_u32x4 = vdupq_n_u32(0), max_iter_u32x4 = vdupq_n_u32(0);
2108
+ uint32x4_t iter_u32x4 = vdupq_n_u32(0), one_u32x4 = vdupq_n_u32(1);
2109
+ uint32x4_t lane_indices_u32x4 = vcombine_u32(vreinterpret_u32_u64(vcreate_u64(0x0000000100000000ULL)),
2110
+ vreinterpret_u32_u64(vcreate_u64(0x0000000300000002ULL)));
2111
+ nk_size_t idx = 0;
2112
+ uint32x4_t data_for_min_u32x4, data_for_max_u32x4;
2113
+
2114
+ nk_reduce_minmax_u32_neon_cycle:
2115
+ if (stride_elements == 2 && idx + 4 <= count) {
2116
+ uint32x4x2_t loaded = vld2q_u32(data_ptr + idx * 2);
2117
+ data_for_min_u32x4 = loaded.val[0];
2118
+ data_for_max_u32x4 = loaded.val[0];
2119
+ idx += 4;
2120
+ }
2121
+ else if (stride_elements == 3 && idx + 4 <= count) {
2122
+ uint32x4x3_t loaded = vld3q_u32(data_ptr + idx * 3);
2123
+ data_for_min_u32x4 = loaded.val[0];
2124
+ data_for_max_u32x4 = loaded.val[0];
2125
+ idx += 4;
2126
+ }
2127
+ else if (stride_elements == 4 && idx + 4 <= count) {
2128
+ uint32x4x4_t loaded = vld4q_u32(data_ptr + idx * 4);
2129
+ data_for_min_u32x4 = loaded.val[0];
2130
+ data_for_max_u32x4 = loaded.val[0];
2131
+ idx += 4;
2132
+ }
2133
+ else if (idx < count) {
2134
+ nk_b128_vec_t tail_vec;
2135
+ nk_strided_load_b32x4_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
2136
+ uint32x4_t valid_u32x4 = vcltq_u32(lane_indices_u32x4, vdupq_n_u32((uint32_t)(count - idx)));
2137
+ data_for_min_u32x4 = vbslq_u32(valid_u32x4, tail_vec.u32x4, min_u32x4);
2138
+ data_for_max_u32x4 = vbslq_u32(valid_u32x4, tail_vec.u32x4, max_u32x4);
2139
+ idx = count;
2140
+ }
2141
+ else {
2142
+ nk_u32_t min_value = vminvq_u32(min_u32x4), max_value = vmaxvq_u32(max_u32x4);
2143
+ uint32x4_t min_value_match_u32x4 = vceqq_u32(min_u32x4, vdupq_n_u32(min_value));
2144
+ uint32x4_t masked_min_iter_u32x4 = vbslq_u32(min_value_match_u32x4, min_iter_u32x4, vdupq_n_u32(NK_U32_MAX));
2145
+ nk_u32_t earliest_min_cycle = vminvq_u32(masked_min_iter_u32x4);
2146
+ uint32x4_t max_value_match_u32x4 = vceqq_u32(max_u32x4, vdupq_n_u32(max_value));
2147
+ uint32x4_t masked_max_iter_u32x4 = vbslq_u32(max_value_match_u32x4, max_iter_u32x4, vdupq_n_u32(NK_U32_MAX));
2148
+ nk_u32_t earliest_max_cycle = vminvq_u32(masked_max_iter_u32x4);
2149
+ uint32x4_t min_cycle_match_u32x4 = vceqq_u32(min_iter_u32x4, vdupq_n_u32(earliest_min_cycle));
2150
+ uint32x4_t min_both_match_u32x4 = vandq_u32(min_value_match_u32x4, min_cycle_match_u32x4);
2151
+ uint32x4_t min_masked_lanes_u32x4 = vbslq_u32(min_both_match_u32x4, lane_indices_u32x4,
2152
+ vdupq_n_u32(NK_U32_MAX));
2153
+ nk_u32_t min_lane_offset = vminvq_u32(min_masked_lanes_u32x4);
2154
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 4 + (nk_size_t)min_lane_offset;
2155
+ uint32x4_t max_cycle_match_u32x4 = vceqq_u32(max_iter_u32x4, vdupq_n_u32(earliest_max_cycle));
2156
+ uint32x4_t max_both_match_u32x4 = vandq_u32(max_value_match_u32x4, max_cycle_match_u32x4);
2157
+ uint32x4_t max_masked_lanes_u32x4 = vbslq_u32(max_both_match_u32x4, lane_indices_u32x4,
2158
+ vdupq_n_u32(NK_U32_MAX));
2159
+ nk_u32_t max_lane_offset = vminvq_u32(max_masked_lanes_u32x4);
2160
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 4 + (nk_size_t)max_lane_offset;
2161
+ *min_value_ptr = min_value, *min_index_ptr = min_idx;
2162
+ *max_value_ptr = max_value, *max_index_ptr = max_idx;
2163
+ return;
2164
+ }
2165
+
2166
+ // Shared update body
2167
+ uint32x4_t less_u32x4 = vcltq_u32(data_for_min_u32x4, min_u32x4);
2168
+ uint32x4_t greater_u32x4 = vcgtq_u32(data_for_max_u32x4, max_u32x4);
2169
+ min_u32x4 = vbslq_u32(less_u32x4, data_for_min_u32x4, min_u32x4);
2170
+ max_u32x4 = vbslq_u32(greater_u32x4, data_for_max_u32x4, max_u32x4);
2171
+ min_iter_u32x4 = vbslq_u32(less_u32x4, iter_u32x4, min_iter_u32x4);
2172
+ max_iter_u32x4 = vbslq_u32(greater_u32x4, iter_u32x4, max_iter_u32x4);
2173
+ iter_u32x4 = vaddq_u32(iter_u32x4, one_u32x4);
2174
+ goto nk_reduce_minmax_u32_neon_cycle;
2175
+ }
2176
+
2177
+ NK_PUBLIC void nk_reduce_minmax_u32_neon( //
2178
+ nk_u32_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2179
+ nk_u32_t *min_value_ptr, nk_size_t *min_index_ptr, //
2180
+ nk_u32_t *max_value_ptr, nk_size_t *max_index_ptr) {
2181
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u32_t);
2182
+ int aligned = (stride_bytes % sizeof(nk_u32_t) == 0);
2183
+ if (count == 0)
2184
+ *min_value_ptr = NK_U32_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = 0, *max_index_ptr = NK_SIZE_MAX;
2185
+ else if (!aligned)
2186
+ nk_reduce_minmax_u32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2187
+ max_index_ptr);
2188
+ else if (count > (nk_size_t)NK_U32_MAX * 4) {
2189
+ nk_size_t left_count = count / 2;
2190
+ nk_u32_t left_min_value, right_min_value, left_max_value, right_max_value;
2191
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
2192
+ nk_reduce_minmax_u32_neon(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index, &left_max_value,
2193
+ &left_max_index);
2194
+ nk_reduce_minmax_u32_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
2195
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
2196
+ if (right_min_value < left_min_value)
2197
+ *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
2198
+ else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
2199
+ if (right_max_value > left_max_value)
2200
+ *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
2201
+ else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
2202
+ }
2203
+ else if (stride_elements == 1)
2204
+ nk_reduce_minmax_u32_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
2205
+ max_index_ptr);
2206
+ else if (stride_elements <= 4)
2207
+ nk_reduce_minmax_u32_neon_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
2208
+ max_value_ptr, max_index_ptr);
2209
+ else
2210
+ nk_reduce_minmax_u32_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2211
+ max_index_ptr);
2212
+ }
2213
+
2214
+ NK_INTERNAL void nk_reduce_moments_i64_neon_contiguous_( //
2215
+ nk_i64_t const *data_ptr, nk_size_t count, //
2216
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
2217
+ uint64x2_t sum_lower_u64x2 = vdupq_n_u64(0);
2218
+ int64x2_t sum_upper_i64x2 = vdupq_n_s64(0);
2219
+ // NEON can still load/extract i64 vectors for sumsq via scalar nk_i64_smul_
2220
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
2221
+ int sumsq_overflow = 0;
2222
+ int64x2_t sign_bit_i64x2 = vdupq_n_s64((nk_i64_t)0x8000000000000000ULL);
2223
+ nk_size_t idx = 0;
2224
+ for (; idx + 2 <= count; idx += 2) {
2225
+ int64x2_t data_i64x2 = vld1q_s64(data_ptr + idx);
2226
+ // Sumsq via helper (scalar per-lane multiply)
2227
+ uint64x2_t sq = nk_i64_smul_sq_i64x2_neon_(data_i64x2);
2228
+ uint64x2_t sq_before = sumsq_u64x2;
2229
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, sq);
2230
+ int64x2_t result_biased = veorq_s64(vreinterpretq_s64_u64(sumsq_u64x2), sign_bit_i64x2);
2231
+ int64x2_t before_biased = veorq_s64(vreinterpretq_s64_u64(sq_before), sign_bit_i64x2);
2232
+ sumsq_overflow |= (vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased, result_biased)), 0) |
2233
+ vgetq_lane_s64(vreinterpretq_s64_u64(vcgtq_s64(before_biased, result_biased)), 1));
2234
+ // Vectorized 128-bit carry-propagating sum
2235
+ uint64x2_t sum_before_u64x2 = sum_lower_u64x2;
2236
+ sum_lower_u64x2 = vaddq_u64(sum_lower_u64x2, vreinterpretq_u64_s64(data_i64x2));
2237
+ int64x2_t sb_biased = veorq_s64(vreinterpretq_s64_u64(sum_before_u64x2), sign_bit_i64x2);
2238
+ int64x2_t sr_biased = veorq_s64(vreinterpretq_s64_u64(sum_lower_u64x2), sign_bit_i64x2);
2239
+ uint64x2_t carry_u64x2 = vcgtq_s64(sb_biased, sr_biased);
2240
+ sum_upper_i64x2 = vsubq_s64(sum_upper_i64x2, vreinterpretq_s64_u64(carry_u64x2));
2241
+ int64x2_t sign_ext_i64x2 = vshrq_n_s64(data_i64x2, 63);
2242
+ sum_upper_i64x2 = vaddq_s64(sum_upper_i64x2, sign_ext_i64x2);
2243
+ }
2244
+ // Horizontal reduction of 2 lanes to scalar (sum_lower, sum_upper)
2245
+ nk_u64_t sum_lower = vgetq_lane_u64(sum_lower_u64x2, 0);
2246
+ nk_i64_t sum_upper = vgetq_lane_s64(sum_upper_i64x2, 0);
2247
+ {
2248
+ nk_u64_t before = sum_lower;
2249
+ sum_lower += vgetq_lane_u64(sum_lower_u64x2, 1);
2250
+ if (sum_lower < before) sum_upper++;
2251
+ sum_upper += vgetq_lane_s64(sum_upper_i64x2, 1);
2252
+ }
2253
+ nk_u64_t sumsq;
2254
+ if (sumsq_overflow) sumsq = NK_U64_MAX;
2255
+ else sumsq = nk_reduce_sadd_u64x2_neon_(sumsq_u64x2);
2256
+ for (; idx < count; ++idx) {
2257
+ nk_i64_t val = data_ptr[idx];
2258
+ nk_i64_t product = nk_i64_saturating_mul_serial(val, val);
2259
+ nk_u64_t unsigned_product = (nk_u64_t)product;
2260
+ sumsq = nk_u64_saturating_add_serial(sumsq, unsigned_product);
2261
+ nk_u64_t before = sum_lower;
2262
+ sum_lower += (nk_u64_t)val;
2263
+ if (sum_lower < before) sum_upper++;
2264
+ sum_upper += (val >> 63);
2265
+ }
2266
+ nk_i64_t sum_lower_signed = (nk_i64_t)sum_lower;
2267
+ if (sum_upper == (sum_lower_signed >> 63)) *sum_ptr = sum_lower_signed;
2268
+ else if (sum_upper >= 0) *sum_ptr = NK_I64_MAX;
2269
+ else *sum_ptr = NK_I64_MIN;
2270
+ *sumsq_ptr = sumsq;
2271
+ }
2272
+
2273
+ NK_PUBLIC void nk_reduce_moments_i64_neon( //
2274
+ nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2275
+ nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
2276
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i64_t);
2277
+ int aligned = (stride_bytes % sizeof(nk_i64_t) == 0);
2278
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
2279
+ else if (!aligned) nk_reduce_moments_i64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2280
+ else if (stride_elements == 1) nk_reduce_moments_i64_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
2281
+ else nk_reduce_moments_i64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2282
+ }
2283
+
2284
+ NK_INTERNAL void nk_reduce_minmax_i64_neon_contiguous_( //
2285
+ nk_i64_t const *data_ptr, nk_size_t count, //
2286
+ nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
2287
+ nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
2288
+ int64x2_t min_i64x2 = vdupq_n_s64(NK_I64_MAX), max_i64x2 = vdupq_n_s64(NK_I64_MIN);
2289
+ uint64x2_t min_iter = vdupq_n_u64(0), max_iter = vdupq_n_u64(0);
2290
+ uint64x2_t iter = vdupq_n_u64(0), one = vdupq_n_u64(1);
2291
+ nk_size_t idx = 0;
2292
+ for (; idx + 2 <= count; idx += 2) {
2293
+ int64x2_t data_i64x2 = vld1q_s64(data_ptr + idx);
2294
+ uint64x2_t less_u64x2 = vcltq_s64(data_i64x2, min_i64x2);
2295
+ uint64x2_t greater_u64x2 = vcgtq_s64(data_i64x2, max_i64x2);
2296
+ min_i64x2 = vbslq_s64(less_u64x2, data_i64x2, min_i64x2);
2297
+ max_i64x2 = vbslq_s64(greater_u64x2, data_i64x2, max_i64x2);
2298
+ min_iter = vbslq_u64(less_u64x2, iter, min_iter);
2299
+ max_iter = vbslq_u64(greater_u64x2, iter, max_iter);
2300
+ iter = vaddq_u64(iter, one);
2301
+ }
2302
+ nk_b128_vec_t min_values_vec, max_values_vec, min_indices_vec, max_indices_vec;
2303
+ min_values_vec.i64x2 = min_i64x2;
2304
+ min_indices_vec.u64x2 = min_iter;
2305
+ max_values_vec.i64x2 = max_i64x2;
2306
+ max_indices_vec.u64x2 = max_iter;
2307
+ nk_i64_t min_value, max_value;
2308
+ nk_size_t min_index, max_index;
2309
+ if (min_values_vec.i64s[0] <= min_values_vec.i64s[1])
2310
+ min_value = min_values_vec.i64s[0], min_index = (nk_size_t)min_indices_vec.u64s[0] * 2;
2311
+ else min_value = min_values_vec.i64s[1], min_index = (nk_size_t)min_indices_vec.u64s[1] * 2 + 1;
2312
+ if (max_values_vec.i64s[0] >= max_values_vec.i64s[1])
2313
+ max_value = max_values_vec.i64s[0], max_index = (nk_size_t)max_indices_vec.u64s[0] * 2;
2314
+ else max_value = max_values_vec.i64s[1], max_index = (nk_size_t)max_indices_vec.u64s[1] * 2 + 1;
2315
+ for (; idx < count; ++idx) {
2316
+ nk_i64_t val = data_ptr[idx];
2317
+ if (val < min_value) min_value = val, min_index = idx;
2318
+ if (val > max_value) max_value = val, max_index = idx;
2319
+ }
2320
+ *min_value_ptr = min_value, *min_index_ptr = min_index;
2321
+ *max_value_ptr = max_value, *max_index_ptr = max_index;
2322
+ }
2323
+
2324
+ NK_PUBLIC void nk_reduce_minmax_i64_neon( //
2325
+ nk_i64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2326
+ nk_i64_t *min_value_ptr, nk_size_t *min_index_ptr, //
2327
+ nk_i64_t *max_value_ptr, nk_size_t *max_index_ptr) {
2328
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_i64_t);
2329
+ int aligned = (stride_bytes % sizeof(nk_i64_t) == 0);
2330
+ if (count == 0)
2331
+ *min_value_ptr = NK_I64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_I64_MIN,
2332
+ *max_index_ptr = NK_SIZE_MAX;
2333
+ else if (!aligned)
2334
+ nk_reduce_minmax_i64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2335
+ max_index_ptr);
2336
+ else if (stride_elements == 1)
2337
+ nk_reduce_minmax_i64_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
2338
+ max_index_ptr);
2339
+ else
2340
+ nk_reduce_minmax_i64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2341
+ max_index_ptr);
2342
+ }
2343
+
2344
+ NK_INTERNAL void nk_reduce_moments_u64_neon_contiguous_( //
2345
+ nk_u64_t const *data_ptr, nk_size_t count, //
2346
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
2347
+ uint64x2_t sum_u64x2 = vdupq_n_u64(0);
2348
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
2349
+ nk_size_t idx = 0;
2350
+ for (; idx + 2 <= count; idx += 2) {
2351
+ uint64x2_t data_u64x2 = vld1q_u64(data_ptr + idx);
2352
+ sum_u64x2 = vqaddq_u64(sum_u64x2, data_u64x2);
2353
+ uint64x2_t sq = nk_u64_smul_sq_u64x2_neon_(data_u64x2);
2354
+ sumsq_u64x2 = vqaddq_u64(sumsq_u64x2, sq);
2355
+ }
2356
+ nk_u64_t sum = nk_reduce_sadd_u64x2_neon_(sum_u64x2);
2357
+ nk_u64_t sumsq = nk_reduce_sadd_u64x2_neon_(sumsq_u64x2);
2358
+ for (; idx < count; ++idx) {
2359
+ nk_u64_t val = data_ptr[idx];
2360
+ sum = nk_u64_saturating_add_serial(sum, val);
2361
+ nk_u64_t product = nk_u64_saturating_mul_serial(val, val);
2362
+ sumsq = nk_u64_saturating_add_serial(sumsq, product);
2363
+ }
2364
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
2365
+ }
2366
+
2367
+ NK_PUBLIC void nk_reduce_moments_u64_neon( //
2368
+ nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2369
+ nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
2370
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u64_t);
2371
+ int aligned = (stride_bytes % sizeof(nk_u64_t) == 0);
2372
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
2373
+ else if (!aligned) nk_reduce_moments_u64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2374
+ else if (stride_elements == 1) nk_reduce_moments_u64_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
2375
+ else nk_reduce_moments_u64_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2376
+ }
2377
+
2378
+ NK_INTERNAL void nk_reduce_minmax_u64_neon_contiguous_( //
2379
+ nk_u64_t const *data_ptr, nk_size_t count, //
2380
+ nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
2381
+ nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
2382
+ uint64x2_t min_u64x2 = vdupq_n_u64(NK_U64_MAX), max_u64x2 = vdupq_n_u64(0);
2383
+ uint64x2_t min_iter = vdupq_n_u64(0), max_iter = vdupq_n_u64(0);
2384
+ uint64x2_t iter = vdupq_n_u64(0), one = vdupq_n_u64(1);
2385
+ nk_size_t idx = 0;
2386
+ for (; idx + 2 <= count; idx += 2) {
2387
+ uint64x2_t data_u64x2 = vld1q_u64(data_ptr + idx);
2388
+ uint64x2_t less_u64x2 = vcltq_u64(data_u64x2, min_u64x2);
2389
+ uint64x2_t greater_u64x2 = vcgtq_u64(data_u64x2, max_u64x2);
2390
+ min_u64x2 = vbslq_u64(less_u64x2, data_u64x2, min_u64x2);
2391
+ max_u64x2 = vbslq_u64(greater_u64x2, data_u64x2, max_u64x2);
2392
+ min_iter = vbslq_u64(less_u64x2, iter, min_iter);
2393
+ max_iter = vbslq_u64(greater_u64x2, iter, max_iter);
2394
+ iter = vaddq_u64(iter, one);
2395
+ }
2396
+ nk_b128_vec_t min_values_vec, max_values_vec, min_indices_vec, max_indices_vec;
2397
+ min_values_vec.u64x2 = min_u64x2;
2398
+ min_indices_vec.u64x2 = min_iter;
2399
+ max_values_vec.u64x2 = max_u64x2;
2400
+ max_indices_vec.u64x2 = max_iter;
2401
+ nk_u64_t min_value, max_value;
2402
+ nk_size_t min_index, max_index;
2403
+ if (min_values_vec.u64s[0] <= min_values_vec.u64s[1])
2404
+ min_value = min_values_vec.u64s[0], min_index = (nk_size_t)min_indices_vec.u64s[0] * 2;
2405
+ else min_value = min_values_vec.u64s[1], min_index = (nk_size_t)min_indices_vec.u64s[1] * 2 + 1;
2406
+ if (max_values_vec.u64s[0] >= max_values_vec.u64s[1])
2407
+ max_value = max_values_vec.u64s[0], max_index = (nk_size_t)max_indices_vec.u64s[0] * 2;
2408
+ else max_value = max_values_vec.u64s[1], max_index = (nk_size_t)max_indices_vec.u64s[1] * 2 + 1;
2409
+ for (; idx < count; ++idx) {
2410
+ nk_u64_t val = data_ptr[idx];
2411
+ if (val < min_value) min_value = val, min_index = idx;
2412
+ if (val > max_value) max_value = val, max_index = idx;
2413
+ }
2414
+ *min_value_ptr = min_value, *min_index_ptr = min_index;
2415
+ *max_value_ptr = max_value, *max_index_ptr = max_index;
2416
+ }
2417
+
2418
+ NK_PUBLIC void nk_reduce_minmax_u64_neon( //
2419
+ nk_u64_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2420
+ nk_u64_t *min_value_ptr, nk_size_t *min_index_ptr, //
2421
+ nk_u64_t *max_value_ptr, nk_size_t *max_index_ptr) {
2422
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_u64_t);
2423
+ int aligned = (stride_bytes % sizeof(nk_u64_t) == 0);
2424
+ if (count == 0)
2425
+ *min_value_ptr = NK_U64_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = 0, *max_index_ptr = NK_SIZE_MAX;
2426
+ else if (!aligned)
2427
+ nk_reduce_minmax_u64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2428
+ max_index_ptr);
2429
+ else if (stride_elements == 1)
2430
+ nk_reduce_minmax_u64_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
2431
+ max_index_ptr);
2432
+ else
2433
+ nk_reduce_minmax_u64_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2434
+ max_index_ptr);
2435
+ }
2436
+
2437
+ /** @brief Convert 16 raw FP6 (e2m3/e3m2) sign-magnitude bytes to unsigned-comparable bytes.
2438
+ * FP6: sign bit 5, 5-bit magnitude. Positive maps to [0x20..0x3F], negative to [0x00..0x1F]. */
2439
+ NK_INTERNAL uint8x16_t nk_fp6x16_to_comparable_neon_(uint8x16_t raw_u8x16) {
2440
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
2441
+ uint8x16_t sign_mask_u8x16 = vdupq_n_u8(0x20);
2442
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, sign_mask_u8x16);
2443
+ uint8x16_t positive_u8x16 = vorrq_u8(magnitude_u8x16, sign_mask_u8x16);
2444
+ uint8x16_t negative_u8x16 = vsubq_u8(vdupq_n_u8(0x1F), magnitude_u8x16);
2445
+ return vbslq_u8(is_negative_u8x16, negative_u8x16, positive_u8x16);
2446
+ }
2447
+
2448
+ /** @brief Convert a single comparable byte back to raw FP6 sign-magnitude byte. */
2449
+ NK_INTERNAL nk_u8_t nk_comparable_to_fp6_(nk_u8_t comparable) {
2450
+ if (comparable >= 0x20) return comparable ^ 0x20; // was positive
2451
+ else return (0x1F - comparable) | 0x20; // was negative
2452
+ }
2453
+
2454
+ /** @brief Convert 16 raw FP8 (e4m3/e5m2) sign-magnitude bytes to unsigned-comparable bytes. */
2455
+ NK_INTERNAL uint8x16_t nk_fp8x16_to_comparable_neon_(uint8x16_t raw_u8x16) {
2456
+ uint8x16_t sign_mask_u8x16 = vdupq_n_u8(0x80);
2457
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, sign_mask_u8x16);
2458
+ uint8x16_t flip_positive_u8x16 = veorq_u8(raw_u8x16, sign_mask_u8x16);
2459
+ uint8x16_t flip_negative_u8x16 = vmvnq_u8(raw_u8x16);
2460
+ return vbslq_u8(is_negative_u8x16, flip_negative_u8x16, flip_positive_u8x16);
2461
+ }
2462
+
2463
+ /** @brief Convert a single comparable byte back to raw FP8 sign-magnitude byte. */
2464
+ NK_INTERNAL nk_u8_t nk_comparable_to_fp8_(nk_u8_t comparable) {
2465
+ if (comparable >= 0x80) return comparable ^ 0x80; // was positive
2466
+ else return ~comparable; // was negative
2467
+ }
2468
+
2469
+ NK_INTERNAL void nk_reduce_moments_e2m3_neon_contiguous_( //
2470
+ nk_e2m3_t const *data_ptr, nk_size_t count, //
2471
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2472
+ // VTBL LUT: maps 6-bit magnitude (0..31) to value×16 (unsigned), fits in u8
2473
+ uint8x16x2_t lut_e2m3_x16;
2474
+ // table[0]: values for magnitudes 0..15
2475
+ // 0x0E0C0A0806040200 → bytes [0..7] = 0,2,4,6,8,10,12,14
2476
+ // 0x1E1C1A1816141210 → bytes [8..15] = 16,18,20,22,24,26,28,30
2477
+ lut_e2m3_x16.val[0] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0E0C0A0806040200ULL)),
2478
+ vreinterpret_u8_u64(vcreate_u64(0x1E1C1A1816141210ULL)));
2479
+ // table[1]: values for magnitudes 16..31
2480
+ // 0x3C3834302C282420 → bytes [0..7] = 32,36,40,44,48,52,56,60
2481
+ // 0x7870686058504840 → bytes [8..15] = 64,72,80,88,96,104,112,120
2482
+ lut_e2m3_x16.val[1] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x3C3834302C282420ULL)),
2483
+ vreinterpret_u8_u64(vcreate_u64(0x7870686058504840ULL)));
2484
+ int32x4_t sum_i32x4 = vdupq_n_s32(0);
2485
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
2486
+ nk_size_t idx = 0;
2487
+ for (; idx + 16 <= count; idx += 16) {
2488
+ uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
2489
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
2490
+ uint8x16_t unsigned_u8x16 = vqtbl2q_u8(lut_e2m3_x16, magnitude_u8x16);
2491
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
2492
+ int8x16_t positive_i8x16 = vreinterpretq_s8_u8(unsigned_u8x16);
2493
+ int8x16_t negative_i8x16 = vnegq_s8(positive_i8x16);
2494
+ int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
2495
+ int16x8_t pairwise_i16x8 = vpaddlq_s8(scaled_i8x16);
2496
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
2497
+ int16x8_t squares_lo_i16x8 = vmull_s8(vget_low_s8(scaled_i8x16), vget_low_s8(scaled_i8x16));
2498
+ int16x8_t squares_hi_i16x8 = vmull_high_s8(scaled_i8x16, scaled_i8x16);
2499
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_lo_i16x8))));
2500
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_hi_i16x8))));
2501
+ }
2502
+ nk_i64_t sum = vaddlvq_s32(sum_i32x4);
2503
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
2504
+ for (; idx < count; ++idx) {
2505
+ nk_f32_t value_f32;
2506
+ nk_e2m3_to_f32_serial(&data_ptr[idx], &value_f32);
2507
+ sum += (nk_i64_t)(value_f32 * 16.0f), sumsq += (nk_u64_t)(nk_i64_t)(value_f32 * value_f32 * 256.0f);
2508
+ }
2509
+ *sum_ptr = (nk_f32_t)sum / 16.0f, *sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
2510
+ }
2511
+
2512
+ NK_INTERNAL void nk_reduce_moments_e2m3_neon_strided_( //
2513
+ nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
2514
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2515
+ uint8x16x2_t lut_e2m3_x16;
2516
+ // table[0]: values for magnitudes 0..15
2517
+ // 0x0E0C0A0806040200 → bytes [0..7] = 0,2,4,6,8,10,12,14
2518
+ // 0x1E1C1A1816141210 → bytes [8..15] = 16,18,20,22,24,26,28,30
2519
+ lut_e2m3_x16.val[0] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0E0C0A0806040200ULL)),
2520
+ vreinterpret_u8_u64(vcreate_u64(0x1E1C1A1816141210ULL)));
2521
+ // table[1]: values for magnitudes 16..31
2522
+ // 0x3C3834302C282420 → bytes [0..7] = 32,36,40,44,48,52,56,60
2523
+ // 0x7870686058504840 → bytes [8..15] = 64,72,80,88,96,104,112,120
2524
+ lut_e2m3_x16.val[1] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x3C3834302C282420ULL)),
2525
+ vreinterpret_u8_u64(vcreate_u64(0x7870686058504840ULL)));
2526
+ int32x4_t sum_i32x4 = vdupq_n_s32(0);
2527
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
2528
+ nk_size_t idx = 0;
2529
+ if (stride_elements == 2) {
2530
+ for (; idx + 16 <= count; idx += 16) {
2531
+ uint8x16x2_t loaded_u8x16x2 = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
2532
+ uint8x16_t raw_u8x16 = loaded_u8x16x2.val[0];
2533
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
2534
+ uint8x16_t unsigned_u8x16 = vqtbl2q_u8(lut_e2m3_x16, magnitude_u8x16);
2535
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
2536
+ int8x16_t positive_i8x16 = vreinterpretq_s8_u8(unsigned_u8x16);
2537
+ int8x16_t negative_i8x16 = vnegq_s8(positive_i8x16);
2538
+ int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
2539
+ int16x8_t pairwise_i16x8 = vpaddlq_s8(scaled_i8x16);
2540
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
2541
+ int16x8_t squares_lo_i16x8 = vmull_s8(vget_low_s8(scaled_i8x16), vget_low_s8(scaled_i8x16));
2542
+ int16x8_t squares_hi_i16x8 = vmull_high_s8(scaled_i8x16, scaled_i8x16);
2543
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_lo_i16x8))));
2544
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_hi_i16x8))));
2545
+ }
2546
+ }
2547
+ else if (stride_elements == 3) {
2548
+ for (; idx + 16 <= count; idx += 16) {
2549
+ uint8x16x3_t loaded_u8x16x3 = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
2550
+ uint8x16_t raw_u8x16 = loaded_u8x16x3.val[0];
2551
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
2552
+ uint8x16_t unsigned_u8x16 = vqtbl2q_u8(lut_e2m3_x16, magnitude_u8x16);
2553
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
2554
+ int8x16_t positive_i8x16 = vreinterpretq_s8_u8(unsigned_u8x16);
2555
+ int8x16_t negative_i8x16 = vnegq_s8(positive_i8x16);
2556
+ int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
2557
+ int16x8_t pairwise_i16x8 = vpaddlq_s8(scaled_i8x16);
2558
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
2559
+ int16x8_t squares_lo_i16x8 = vmull_s8(vget_low_s8(scaled_i8x16), vget_low_s8(scaled_i8x16));
2560
+ int16x8_t squares_hi_i16x8 = vmull_high_s8(scaled_i8x16, scaled_i8x16);
2561
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_lo_i16x8))));
2562
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_hi_i16x8))));
2563
+ }
2564
+ }
2565
+ else {
2566
+ for (; idx + 16 <= count; idx += 16) {
2567
+ uint8x16x4_t loaded_u8x16x4 = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
2568
+ uint8x16_t raw_u8x16 = loaded_u8x16x4.val[0];
2569
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
2570
+ uint8x16_t unsigned_u8x16 = vqtbl2q_u8(lut_e2m3_x16, magnitude_u8x16);
2571
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
2572
+ int8x16_t positive_i8x16 = vreinterpretq_s8_u8(unsigned_u8x16);
2573
+ int8x16_t negative_i8x16 = vnegq_s8(positive_i8x16);
2574
+ int8x16_t scaled_i8x16 = vbslq_s8(is_negative_u8x16, negative_i8x16, positive_i8x16);
2575
+ int16x8_t pairwise_i16x8 = vpaddlq_s8(scaled_i8x16);
2576
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(pairwise_i16x8));
2577
+ int16x8_t squares_lo_i16x8 = vmull_s8(vget_low_s8(scaled_i8x16), vget_low_s8(scaled_i8x16));
2578
+ int16x8_t squares_hi_i16x8 = vmull_high_s8(scaled_i8x16, scaled_i8x16);
2579
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_lo_i16x8))));
2580
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vpaddlq_u16(vreinterpretq_u16_s16(squares_hi_i16x8))));
2581
+ }
2582
+ }
2583
+ nk_i64_t sum = vaddlvq_s32(sum_i32x4);
2584
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
2585
+ for (; idx < count; ++idx) {
2586
+ nk_f32_t val;
2587
+ nk_e2m3_to_f32_serial((nk_e2m3_t const *)(data_ptr + idx * stride_elements), &val);
2588
+ sum += (nk_i64_t)(val * 16.0f), sumsq += (nk_u64_t)(nk_i64_t)(val * val * 256.0f);
2589
+ }
2590
+ *sum_ptr = (nk_f32_t)sum / 16.0f, *sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
2591
+ }
2592
+
2593
+ NK_PUBLIC void nk_reduce_moments_e2m3_neon( //
2594
+ nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2595
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2596
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
2597
+ int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
2598
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
2599
+ else if (!aligned) nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2600
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
2601
+ nk_size_t left_count = count / 2;
2602
+ nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
2603
+ nk_reduce_moments_e2m3_neon(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
2604
+ nk_reduce_moments_e2m3_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
2605
+ &right_sum, &right_sumsq);
2606
+ *sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
2607
+ }
2608
+ else if (stride_elements == 1) nk_reduce_moments_e2m3_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
2609
+ else if (stride_elements <= 4)
2610
+ nk_reduce_moments_e2m3_neon_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
2611
+ else nk_reduce_moments_e2m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
2612
+ }
2613
+
2614
+ NK_INTERNAL void nk_reduce_minmax_e2m3_neon_contiguous_( //
2615
+ nk_e2m3_t const *data_ptr, nk_size_t count, //
2616
+ nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
2617
+ nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
2618
+ // Handle initial chunk — partial or full
2619
+ uint8x16_t first_comparable_u8x16;
2620
+ nk_size_t first_count = count < 16 ? count : 16;
2621
+ if (count < 16) {
2622
+ nk_b128_vec_t first_vec;
2623
+ nk_partial_load_b8x16_serial_(data_ptr, &first_vec, count);
2624
+ first_comparable_u8x16 = nk_fp6x16_to_comparable_neon_(first_vec.u8x16);
2625
+ // Mask invalid lanes: min gets 0xFF (won't be selected), max gets 0x00
2626
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
2627
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
2628
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)count));
2629
+ first_comparable_u8x16 = vbslq_u8(valid_u8x16, first_comparable_u8x16, vdupq_n_u8(0));
2630
+ }
2631
+ else {
2632
+ uint8x16_t first_raw_u8x16 = vld1q_u8((nk_u8_t const *)data_ptr);
2633
+ first_comparable_u8x16 = nk_fp6x16_to_comparable_neon_(first_raw_u8x16);
2634
+ }
2635
+ // For min: invalid lanes (0x00) should not win, so initialize min from masked data where invalid = 0xFF
2636
+ // For max: invalid lanes (0x00) should not win, which is already correct since 0x00 won't beat real data
2637
+ uint8x16_t lane_indices_init_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
2638
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
2639
+ uint8x16_t valid_init_u8x16 = vcltq_u8(lane_indices_init_u8x16, vdupq_n_u8((uint8_t)first_count));
2640
+ uint8x16_t min_u8x16 = vbslq_u8(valid_init_u8x16, first_comparable_u8x16, vdupq_n_u8(0xFF));
2641
+ uint8x16_t max_u8x16 = first_comparable_u8x16; // invalid lanes are 0x00, safe for max
2642
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
2643
+ uint8x16_t iter_u8x16 = vdupq_n_u8(1), one_u8x16 = vdupq_n_u8(1);
2644
+ nk_size_t idx = first_count;
2645
+ for (; idx + 16 <= count; idx += 16) {
2646
+ uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
2647
+ uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(raw_u8x16);
2648
+ uint8x16_t less_u8x16 = vcltq_u8(comparable_u8x16, min_u8x16);
2649
+ uint8x16_t greater_u8x16 = vcgtq_u8(comparable_u8x16, max_u8x16);
2650
+ min_u8x16 = vbslq_u8(less_u8x16, comparable_u8x16, min_u8x16);
2651
+ max_u8x16 = vbslq_u8(greater_u8x16, comparable_u8x16, max_u8x16);
2652
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
2653
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
2654
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
2655
+ }
2656
+ nk_size_t remaining = count - idx;
2657
+ if (remaining > 0) {
2658
+ nk_b128_vec_t tail_vec;
2659
+ nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
2660
+ uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(tail_vec.u8x16);
2661
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
2662
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
2663
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)remaining));
2664
+ uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
2665
+ uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0));
2666
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
2667
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
2668
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
2669
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
2670
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
2671
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
2672
+ }
2673
+ nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
2674
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
2675
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
2676
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
2677
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
2678
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
2679
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
2680
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
2681
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
2682
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
2683
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
2684
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
2685
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
2686
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
2687
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
2688
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
2689
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
2690
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
2691
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
2692
+ *min_value_ptr = nk_comparable_to_fp6_(min_comparable), *min_index_ptr = min_idx;
2693
+ *max_value_ptr = nk_comparable_to_fp6_(max_comparable), *max_index_ptr = max_idx;
2694
+ }
2695
+
2696
+ NK_INTERNAL void nk_reduce_minmax_e2m3_neon_strided_( //
2697
+ nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
2698
+ nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
2699
+ nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
2700
+ uint8x16_t min_u8x16 = vdupq_n_u8(0xFF), max_u8x16 = vdupq_n_u8(0);
2701
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
2702
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
2703
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
2704
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
2705
+ nk_size_t idx = 0;
2706
+ uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
2707
+
2708
+ nk_reduce_minmax_e2m3_neon_cycle:
2709
+ if (stride_elements == 2 && idx + 16 <= count) {
2710
+ uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
2711
+ uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(loaded.val[0]);
2712
+ data_for_min_u8x16 = comparable_u8x16;
2713
+ data_for_max_u8x16 = comparable_u8x16;
2714
+ idx += 16;
2715
+ }
2716
+ else if (stride_elements == 3 && idx + 16 <= count) {
2717
+ uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
2718
+ uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(loaded.val[0]);
2719
+ data_for_min_u8x16 = comparable_u8x16;
2720
+ data_for_max_u8x16 = comparable_u8x16;
2721
+ idx += 16;
2722
+ }
2723
+ else if (stride_elements == 4 && idx + 16 <= count) {
2724
+ uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
2725
+ uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(loaded.val[0]);
2726
+ data_for_min_u8x16 = comparable_u8x16;
2727
+ data_for_max_u8x16 = comparable_u8x16;
2728
+ idx += 16;
2729
+ }
2730
+ else if (idx < count) {
2731
+ nk_b128_vec_t tail_vec;
2732
+ nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
2733
+ uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(tail_vec.u8x16);
2734
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)(count - idx)));
2735
+ data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
2736
+ data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0x00));
2737
+ idx = count;
2738
+ }
2739
+ else {
2740
+ nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
2741
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
2742
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
2743
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
2744
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
2745
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
2746
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
2747
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
2748
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
2749
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
2750
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
2751
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
2752
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
2753
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
2754
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
2755
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
2756
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
2757
+ *min_value_ptr = nk_comparable_to_fp6_(min_comparable), *min_index_ptr = min_idx;
2758
+ *max_value_ptr = nk_comparable_to_fp6_(max_comparable), *max_index_ptr = max_idx;
2759
+ return;
2760
+ }
2761
+
2762
+ // Shared update body
2763
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
2764
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
2765
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
2766
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
2767
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
2768
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
2769
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
2770
+ goto nk_reduce_minmax_e2m3_neon_cycle;
2771
+ }
2772
+
2773
+ NK_PUBLIC void nk_reduce_minmax_e2m3_neon( //
2774
+ nk_e2m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2775
+ nk_e2m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
2776
+ nk_e2m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
2777
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e2m3_t);
2778
+ int aligned = (stride_bytes % sizeof(nk_e2m3_t) == 0);
2779
+ if (count == 0)
2780
+ *min_value_ptr = NK_E2M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E2M3_MIN,
2781
+ *max_index_ptr = NK_SIZE_MAX;
2782
+ else if (!aligned)
2783
+ nk_reduce_minmax_e2m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2784
+ max_index_ptr);
2785
+ else if (count > (nk_size_t)256 * 16) {
2786
+ nk_size_t left_count = count / 2;
2787
+ nk_e2m3_t left_min_value, right_min_value, left_max_value, right_max_value;
2788
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
2789
+ nk_reduce_minmax_e2m3_neon(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index,
2790
+ &left_max_value, &left_max_index);
2791
+ nk_reduce_minmax_e2m3_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
2792
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
2793
+ if (nk_e2m3_order_serial(right_min_value, left_min_value) < 0)
2794
+ *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
2795
+ else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
2796
+ if (nk_e2m3_order_serial(right_max_value, left_max_value) > 0)
2797
+ *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
2798
+ else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
2799
+ }
2800
+ else if (stride_elements == 1)
2801
+ nk_reduce_minmax_e2m3_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
2802
+ max_index_ptr);
2803
+ else if (stride_elements <= 4)
2804
+ nk_reduce_minmax_e2m3_neon_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
2805
+ max_value_ptr, max_index_ptr);
2806
+ else
2807
+ nk_reduce_minmax_e2m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
2808
+ max_index_ptr);
2809
+ }
2810
+
2811
+ NK_INTERNAL void nk_reduce_moments_e3m2_neon_contiguous_( //
2812
+ nk_e3m2_t const *data_ptr, nk_size_t count, //
2813
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2814
+ // VTBL LUT: maps 6-bit magnitude (0..31) to (value×16) low byte; max value×16 = 448 needs i16
2815
+ uint8x16x2_t lut_e3m2_lo;
2816
+ // table[0]: low bytes for magnitudes 0..15
2817
+ // 0x0706050403020100 → bytes [0..7] = 0,1,2,3,4,5,6,7
2818
+ // 0x1C1814100E0C0A08 → bytes [8..15] = 8,10,12,14,16,20,24,28
2819
+ lut_e3m2_lo.val[0] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
2820
+ vreinterpret_u8_u64(vcreate_u64(0x1C1814100E0C0A08ULL)));
2821
+ // table[1]: low bytes for magnitudes 16..31
2822
+ // 0x7060504038302820 → bytes [0..7] = 32,40,48,56,64,80,96,112
2823
+ // 0xC0804000E0C0A080 → bytes [8..15] = 128,160,192,224,0,64,128,192
2824
+ lut_e3m2_lo.val[1] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x7060504038302820ULL)),
2825
+ vreinterpret_u8_u64(vcreate_u64(0xC0804000E0C0A080ULL)));
2826
+ int32x4_t sum_i32x4 = vdupq_n_s32(0);
2827
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
2828
+ nk_size_t idx = 0;
2829
+ for (; idx + 16 <= count; idx += 16) {
2830
+ uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
2831
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
2832
+ uint8x16_t low_byte_u8x16 = vqtbl2q_u8(lut_e3m2_lo, magnitude_u8x16);
2833
+ uint8x16_t high_byte_u8x16 = vandq_u8(vcgtq_u8(magnitude_u8x16, vdupq_n_u8(27)), vdupq_n_u8(1));
2834
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
2835
+ // Interleave low+high bytes into i16 values (two halves of 8 each)
2836
+ uint16x8_t unsigned_lo_u16x8 = vreinterpretq_u16_u8(vzip1q_u8(low_byte_u8x16, high_byte_u8x16));
2837
+ uint16x8_t unsigned_hi_u16x8 = vreinterpretq_u16_u8(vzip2q_u8(low_byte_u8x16, high_byte_u8x16));
2838
+ // Sign-extend the per-byte negative mask to per-i16 lanes
2839
+ int8x8_t is_negative_lo_i8x8 = vreinterpret_s8_u8(vget_low_u8(is_negative_u8x16));
2840
+ int8x8_t is_negative_hi_i8x8 = vreinterpret_s8_u8(vget_high_u8(is_negative_u8x16));
2841
+ uint16x8_t is_negative_lo_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_lo_i8x8));
2842
+ uint16x8_t is_negative_hi_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_hi_i8x8));
2843
+ // Apply sign via conditional negate
2844
+ int16x8_t positive_lo_i16x8 = vreinterpretq_s16_u16(unsigned_lo_u16x8);
2845
+ int16x8_t scaled_lo_i16x8 = vbslq_s16(is_negative_lo_u16x8, vnegq_s16(positive_lo_i16x8), positive_lo_i16x8);
2846
+ int16x8_t positive_hi_i16x8 = vreinterpretq_s16_u16(unsigned_hi_u16x8);
2847
+ int16x8_t scaled_hi_i16x8 = vbslq_s16(is_negative_hi_u16x8, vnegq_s16(positive_hi_i16x8), positive_hi_i16x8);
2848
+ // Sum: i16→i32 widening, accumulate in i32x4
2849
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_lo_i16x8));
2850
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_hi_i16x8));
2851
+ // Sumsq: vmull_s16→i32 (always positive as squares), widen to u64
2852
+ int32x4_t squares_lo_a_i32x4 = vmull_s16(vget_low_s16(scaled_lo_i16x8), vget_low_s16(scaled_lo_i16x8));
2853
+ int32x4_t squares_lo_b_i32x4 = vmull_high_s16(scaled_lo_i16x8, scaled_lo_i16x8);
2854
+ int32x4_t squares_hi_a_i32x4 = vmull_s16(vget_low_s16(scaled_hi_i16x8), vget_low_s16(scaled_hi_i16x8));
2855
+ int32x4_t squares_hi_b_i32x4 = vmull_high_s16(scaled_hi_i16x8, scaled_hi_i16x8);
2856
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_lo_a_i32x4)));
2857
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_lo_b_i32x4)));
2858
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_hi_a_i32x4)));
2859
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_hi_b_i32x4)));
2860
+ }
2861
+ nk_i64_t sum = vaddlvq_s32(sum_i32x4);
2862
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
2863
+ for (; idx < count; ++idx) {
2864
+ nk_f32_t value_f32;
2865
+ nk_e3m2_to_f32_serial(&data_ptr[idx], &value_f32);
2866
+ sum += (nk_i64_t)(value_f32 * 16.0f), sumsq += (nk_u64_t)(nk_i64_t)(value_f32 * value_f32 * 256.0f);
2867
+ }
2868
+ *sum_ptr = (nk_f32_t)sum / 16.0f, *sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
2869
+ }
2870
+
2871
+ NK_INTERNAL void nk_reduce_moments_e3m2_neon_strided_( //
2872
+ nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
2873
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2874
+ uint8x16x2_t lut_e3m2_lo;
2875
+ // table[0]: low bytes for magnitudes 0..15
2876
+ // 0x0706050403020100 → bytes [0..7] = 0,1,2,3,4,5,6,7
2877
+ // 0x1C1814100E0C0A08 → bytes [8..15] = 8,10,12,14,16,20,24,28
2878
+ lut_e3m2_lo.val[0] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
2879
+ vreinterpret_u8_u64(vcreate_u64(0x1C1814100E0C0A08ULL)));
2880
+ // table[1]: low bytes for magnitudes 16..31
2881
+ // 0x7060504038302820 → bytes [0..7] = 32,40,48,56,64,80,96,112
2882
+ // 0xC0804000E0C0A080 → bytes [8..15] = 128,160,192,224,0,64,128,192
2883
+ lut_e3m2_lo.val[1] = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x7060504038302820ULL)),
2884
+ vreinterpret_u8_u64(vcreate_u64(0xC0804000E0C0A080ULL)));
2885
+ int32x4_t sum_i32x4 = vdupq_n_s32(0);
2886
+ uint64x2_t sumsq_u64x2 = vdupq_n_u64(0);
2887
+ nk_size_t idx = 0;
2888
+ if (stride_elements == 2) {
2889
+ for (; idx + 16 <= count; idx += 16) {
2890
+ uint8x16x2_t loaded_u8x16x2 = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
2891
+ uint8x16_t raw_u8x16 = loaded_u8x16x2.val[0];
2892
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
2893
+ uint8x16_t low_byte_u8x16 = vqtbl2q_u8(lut_e3m2_lo, magnitude_u8x16);
2894
+ uint8x16_t high_byte_u8x16 = vandq_u8(vcgtq_u8(magnitude_u8x16, vdupq_n_u8(27)), vdupq_n_u8(1));
2895
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
2896
+ uint16x8_t unsigned_lo_u16x8 = vreinterpretq_u16_u8(vzip1q_u8(low_byte_u8x16, high_byte_u8x16));
2897
+ uint16x8_t unsigned_hi_u16x8 = vreinterpretq_u16_u8(vzip2q_u8(low_byte_u8x16, high_byte_u8x16));
2898
+ int8x8_t is_negative_lo_i8x8 = vreinterpret_s8_u8(vget_low_u8(is_negative_u8x16));
2899
+ int8x8_t is_negative_hi_i8x8 = vreinterpret_s8_u8(vget_high_u8(is_negative_u8x16));
2900
+ uint16x8_t is_negative_lo_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_lo_i8x8));
2901
+ uint16x8_t is_negative_hi_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_hi_i8x8));
2902
+ int16x8_t positive_lo_i16x8 = vreinterpretq_s16_u16(unsigned_lo_u16x8);
2903
+ int16x8_t scaled_lo_i16x8 = vbslq_s16(is_negative_lo_u16x8, vnegq_s16(positive_lo_i16x8),
2904
+ positive_lo_i16x8);
2905
+ int16x8_t positive_hi_i16x8 = vreinterpretq_s16_u16(unsigned_hi_u16x8);
2906
+ int16x8_t scaled_hi_i16x8 = vbslq_s16(is_negative_hi_u16x8, vnegq_s16(positive_hi_i16x8),
2907
+ positive_hi_i16x8);
2908
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_lo_i16x8));
2909
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_hi_i16x8));
2910
+ int32x4_t squares_lo_a_i32x4 = vmull_s16(vget_low_s16(scaled_lo_i16x8), vget_low_s16(scaled_lo_i16x8));
2911
+ int32x4_t squares_lo_b_i32x4 = vmull_high_s16(scaled_lo_i16x8, scaled_lo_i16x8);
2912
+ int32x4_t squares_hi_a_i32x4 = vmull_s16(vget_low_s16(scaled_hi_i16x8), vget_low_s16(scaled_hi_i16x8));
2913
+ int32x4_t squares_hi_b_i32x4 = vmull_high_s16(scaled_hi_i16x8, scaled_hi_i16x8);
2914
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_lo_a_i32x4)));
2915
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_lo_b_i32x4)));
2916
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_hi_a_i32x4)));
2917
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_hi_b_i32x4)));
2918
+ }
2919
+ }
2920
+ else if (stride_elements == 3) {
2921
+ for (; idx + 16 <= count; idx += 16) {
2922
+ uint8x16x3_t loaded_u8x16x3 = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
2923
+ uint8x16_t raw_u8x16 = loaded_u8x16x3.val[0];
2924
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
2925
+ uint8x16_t low_byte_u8x16 = vqtbl2q_u8(lut_e3m2_lo, magnitude_u8x16);
2926
+ uint8x16_t high_byte_u8x16 = vandq_u8(vcgtq_u8(magnitude_u8x16, vdupq_n_u8(27)), vdupq_n_u8(1));
2927
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
2928
+ uint16x8_t unsigned_lo_u16x8 = vreinterpretq_u16_u8(vzip1q_u8(low_byte_u8x16, high_byte_u8x16));
2929
+ uint16x8_t unsigned_hi_u16x8 = vreinterpretq_u16_u8(vzip2q_u8(low_byte_u8x16, high_byte_u8x16));
2930
+ int8x8_t is_negative_lo_i8x8 = vreinterpret_s8_u8(vget_low_u8(is_negative_u8x16));
2931
+ int8x8_t is_negative_hi_i8x8 = vreinterpret_s8_u8(vget_high_u8(is_negative_u8x16));
2932
+ uint16x8_t is_negative_lo_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_lo_i8x8));
2933
+ uint16x8_t is_negative_hi_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_hi_i8x8));
2934
+ int16x8_t positive_lo_i16x8 = vreinterpretq_s16_u16(unsigned_lo_u16x8);
2935
+ int16x8_t scaled_lo_i16x8 = vbslq_s16(is_negative_lo_u16x8, vnegq_s16(positive_lo_i16x8),
2936
+ positive_lo_i16x8);
2937
+ int16x8_t positive_hi_i16x8 = vreinterpretq_s16_u16(unsigned_hi_u16x8);
2938
+ int16x8_t scaled_hi_i16x8 = vbslq_s16(is_negative_hi_u16x8, vnegq_s16(positive_hi_i16x8),
2939
+ positive_hi_i16x8);
2940
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_lo_i16x8));
2941
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_hi_i16x8));
2942
+ int32x4_t squares_lo_a_i32x4 = vmull_s16(vget_low_s16(scaled_lo_i16x8), vget_low_s16(scaled_lo_i16x8));
2943
+ int32x4_t squares_lo_b_i32x4 = vmull_high_s16(scaled_lo_i16x8, scaled_lo_i16x8);
2944
+ int32x4_t squares_hi_a_i32x4 = vmull_s16(vget_low_s16(scaled_hi_i16x8), vget_low_s16(scaled_hi_i16x8));
2945
+ int32x4_t squares_hi_b_i32x4 = vmull_high_s16(scaled_hi_i16x8, scaled_hi_i16x8);
2946
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_lo_a_i32x4)));
2947
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_lo_b_i32x4)));
2948
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_hi_a_i32x4)));
2949
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_hi_b_i32x4)));
2950
+ }
2951
+ }
2952
+ else {
2953
+ for (; idx + 16 <= count; idx += 16) {
2954
+ uint8x16x4_t loaded_u8x16x4 = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
2955
+ uint8x16_t raw_u8x16 = loaded_u8x16x4.val[0];
2956
+ uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
2957
+ uint8x16_t low_byte_u8x16 = vqtbl2q_u8(lut_e3m2_lo, magnitude_u8x16);
2958
+ uint8x16_t high_byte_u8x16 = vandq_u8(vcgtq_u8(magnitude_u8x16, vdupq_n_u8(27)), vdupq_n_u8(1));
2959
+ uint8x16_t is_negative_u8x16 = vtstq_u8(raw_u8x16, vdupq_n_u8(0x20));
2960
+ uint16x8_t unsigned_lo_u16x8 = vreinterpretq_u16_u8(vzip1q_u8(low_byte_u8x16, high_byte_u8x16));
2961
+ uint16x8_t unsigned_hi_u16x8 = vreinterpretq_u16_u8(vzip2q_u8(low_byte_u8x16, high_byte_u8x16));
2962
+ int8x8_t is_negative_lo_i8x8 = vreinterpret_s8_u8(vget_low_u8(is_negative_u8x16));
2963
+ int8x8_t is_negative_hi_i8x8 = vreinterpret_s8_u8(vget_high_u8(is_negative_u8x16));
2964
+ uint16x8_t is_negative_lo_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_lo_i8x8));
2965
+ uint16x8_t is_negative_hi_u16x8 = vreinterpretq_u16_s16(vmovl_s8(is_negative_hi_i8x8));
2966
+ int16x8_t positive_lo_i16x8 = vreinterpretq_s16_u16(unsigned_lo_u16x8);
2967
+ int16x8_t scaled_lo_i16x8 = vbslq_s16(is_negative_lo_u16x8, vnegq_s16(positive_lo_i16x8),
2968
+ positive_lo_i16x8);
2969
+ int16x8_t positive_hi_i16x8 = vreinterpretq_s16_u16(unsigned_hi_u16x8);
2970
+ int16x8_t scaled_hi_i16x8 = vbslq_s16(is_negative_hi_u16x8, vnegq_s16(positive_hi_i16x8),
2971
+ positive_hi_i16x8);
2972
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_lo_i16x8));
2973
+ sum_i32x4 = vaddq_s32(sum_i32x4, vpaddlq_s16(scaled_hi_i16x8));
2974
+ int32x4_t squares_lo_a_i32x4 = vmull_s16(vget_low_s16(scaled_lo_i16x8), vget_low_s16(scaled_lo_i16x8));
2975
+ int32x4_t squares_lo_b_i32x4 = vmull_high_s16(scaled_lo_i16x8, scaled_lo_i16x8);
2976
+ int32x4_t squares_hi_a_i32x4 = vmull_s16(vget_low_s16(scaled_hi_i16x8), vget_low_s16(scaled_hi_i16x8));
2977
+ int32x4_t squares_hi_b_i32x4 = vmull_high_s16(scaled_hi_i16x8, scaled_hi_i16x8);
2978
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_lo_a_i32x4)));
2979
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_lo_b_i32x4)));
2980
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_hi_a_i32x4)));
2981
+ sumsq_u64x2 = vaddq_u64(sumsq_u64x2, vpaddlq_u32(vreinterpretq_u32_s32(squares_hi_b_i32x4)));
2982
+ }
2983
+ }
2984
+ nk_i64_t sum = vaddlvq_s32(sum_i32x4);
2985
+ nk_u64_t sumsq = vgetq_lane_u64(sumsq_u64x2, 0) + vgetq_lane_u64(sumsq_u64x2, 1);
2986
+ for (; idx < count; ++idx) {
2987
+ nk_f32_t val;
2988
+ nk_e3m2_to_f32_serial((nk_e3m2_t const *)(data_ptr + idx * stride_elements), &val);
2989
+ sum += (nk_i64_t)(val * 16.0f), sumsq += (nk_u64_t)(nk_i64_t)(val * val * 256.0f);
2990
+ }
2991
+ *sum_ptr = (nk_f32_t)sum / 16.0f, *sumsq_ptr = (nk_f32_t)sumsq / 256.0f;
2992
+ }
2993
+
2994
+ NK_PUBLIC void nk_reduce_moments_e3m2_neon( //
2995
+ nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
2996
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
2997
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
2998
+ int aligned = (stride_bytes % sizeof(nk_e3m2_t) == 0);
2999
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
3000
+ else if (!aligned) nk_reduce_moments_e3m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3001
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
3002
+ nk_size_t left_count = count / 2;
3003
+ nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
3004
+ nk_reduce_moments_e3m2_neon(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
3005
+ nk_reduce_moments_e3m2_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
3006
+ &right_sum, &right_sumsq);
3007
+ *sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
3008
+ }
3009
+ else if (stride_elements == 1) nk_reduce_moments_e3m2_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
3010
+ else if (stride_elements <= 4)
3011
+ nk_reduce_moments_e3m2_neon_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
3012
+ else nk_reduce_moments_e3m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3013
+ }
3014
+
3015
+ NK_INTERNAL void nk_reduce_minmax_e3m2_neon_contiguous_( //
3016
+ nk_e3m2_t const *data_ptr, nk_size_t count, //
3017
+ nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
3018
+ nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
3019
+ // Handle initial chunk — partial or full
3020
+ uint8x16_t first_comparable_u8x16;
3021
+ nk_size_t first_count = count < 16 ? count : 16;
3022
+ if (count < 16) {
3023
+ nk_b128_vec_t first_vec;
3024
+ nk_partial_load_b8x16_serial_(data_ptr, &first_vec, count);
3025
+ first_comparable_u8x16 = nk_fp6x16_to_comparable_neon_(first_vec.u8x16);
3026
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
3027
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
3028
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)count));
3029
+ first_comparable_u8x16 = vbslq_u8(valid_u8x16, first_comparable_u8x16, vdupq_n_u8(0));
3030
+ }
3031
+ else {
3032
+ uint8x16_t first_raw_u8x16 = vld1q_u8((nk_u8_t const *)data_ptr);
3033
+ first_comparable_u8x16 = nk_fp6x16_to_comparable_neon_(first_raw_u8x16);
3034
+ }
3035
+ uint8x16_t lane_indices_init_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
3036
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
3037
+ uint8x16_t valid_init_u8x16 = vcltq_u8(lane_indices_init_u8x16, vdupq_n_u8((uint8_t)first_count));
3038
+ uint8x16_t min_u8x16 = vbslq_u8(valid_init_u8x16, first_comparable_u8x16, vdupq_n_u8(0xFF));
3039
+ uint8x16_t max_u8x16 = first_comparable_u8x16;
3040
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
3041
+ uint8x16_t iter_u8x16 = vdupq_n_u8(1), one_u8x16 = vdupq_n_u8(1);
3042
+ nk_size_t idx = first_count;
3043
+ for (; idx + 16 <= count; idx += 16) {
3044
+ uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
3045
+ uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(raw_u8x16);
3046
+ uint8x16_t less_u8x16 = vcltq_u8(comparable_u8x16, min_u8x16);
3047
+ uint8x16_t greater_u8x16 = vcgtq_u8(comparable_u8x16, max_u8x16);
3048
+ min_u8x16 = vbslq_u8(less_u8x16, comparable_u8x16, min_u8x16);
3049
+ max_u8x16 = vbslq_u8(greater_u8x16, comparable_u8x16, max_u8x16);
3050
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
3051
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
3052
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
3053
+ }
3054
+ nk_size_t remaining = count - idx;
3055
+ if (remaining > 0) {
3056
+ nk_b128_vec_t tail_vec;
3057
+ nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
3058
+ uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(tail_vec.u8x16);
3059
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
3060
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
3061
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)remaining));
3062
+ uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
3063
+ uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0));
3064
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
3065
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
3066
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
3067
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
3068
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
3069
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
3070
+ }
3071
+ nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
3072
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
3073
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
3074
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
3075
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
3076
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
3077
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
3078
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
3079
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
3080
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
3081
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
3082
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
3083
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
3084
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
3085
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
3086
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
3087
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
3088
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
3089
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
3090
+ *min_value_ptr = nk_comparable_to_fp6_(min_comparable), *min_index_ptr = min_idx;
3091
+ *max_value_ptr = nk_comparable_to_fp6_(max_comparable), *max_index_ptr = max_idx;
3092
+ }
3093
+
3094
+ NK_INTERNAL void nk_reduce_minmax_e3m2_neon_strided_( //
3095
+ nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
3096
+ nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
3097
+ nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
3098
+ uint8x16_t min_u8x16 = vdupq_n_u8(0xFF), max_u8x16 = vdupq_n_u8(0);
3099
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
3100
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
3101
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
3102
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
3103
+ nk_size_t idx = 0;
3104
+ uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
3105
+
3106
+ nk_reduce_minmax_e3m2_neon_cycle:
3107
+ if (stride_elements == 2 && idx + 16 <= count) {
3108
+ uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
3109
+ uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(loaded.val[0]);
3110
+ data_for_min_u8x16 = comparable_u8x16;
3111
+ data_for_max_u8x16 = comparable_u8x16;
3112
+ idx += 16;
3113
+ }
3114
+ else if (stride_elements == 3 && idx + 16 <= count) {
3115
+ uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
3116
+ uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(loaded.val[0]);
3117
+ data_for_min_u8x16 = comparable_u8x16;
3118
+ data_for_max_u8x16 = comparable_u8x16;
3119
+ idx += 16;
3120
+ }
3121
+ else if (stride_elements == 4 && idx + 16 <= count) {
3122
+ uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
3123
+ uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(loaded.val[0]);
3124
+ data_for_min_u8x16 = comparable_u8x16;
3125
+ data_for_max_u8x16 = comparable_u8x16;
3126
+ idx += 16;
3127
+ }
3128
+ else if (idx < count) {
3129
+ nk_b128_vec_t tail_vec;
3130
+ nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
3131
+ uint8x16_t comparable_u8x16 = nk_fp6x16_to_comparable_neon_(tail_vec.u8x16);
3132
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)(count - idx)));
3133
+ data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
3134
+ data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0x00));
3135
+ idx = count;
3136
+ }
3137
+ else {
3138
+ nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
3139
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
3140
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
3141
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
3142
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
3143
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
3144
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
3145
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
3146
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
3147
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
3148
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
3149
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
3150
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
3151
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
3152
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
3153
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
3154
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
3155
+ *min_value_ptr = nk_comparable_to_fp6_(min_comparable), *min_index_ptr = min_idx;
3156
+ *max_value_ptr = nk_comparable_to_fp6_(max_comparable), *max_index_ptr = max_idx;
3157
+ return;
3158
+ }
3159
+
3160
+ // Shared update body
3161
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
3162
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
3163
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
3164
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
3165
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
3166
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
3167
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
3168
+ goto nk_reduce_minmax_e3m2_neon_cycle;
3169
+ }
3170
+
3171
+ NK_PUBLIC void nk_reduce_minmax_e3m2_neon( //
3172
+ nk_e3m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3173
+ nk_e3m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
3174
+ nk_e3m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
3175
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e3m2_t);
3176
+ int aligned = (stride_bytes % sizeof(nk_e3m2_t) == 0);
3177
+ if (count == 0)
3178
+ *min_value_ptr = NK_E3M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E3M2_MIN,
3179
+ *max_index_ptr = NK_SIZE_MAX;
3180
+ else if (!aligned)
3181
+ nk_reduce_minmax_e3m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
3182
+ max_index_ptr);
3183
+ else if (count > (nk_size_t)256 * 16) {
3184
+ nk_size_t left_count = count / 2;
3185
+ nk_e3m2_t left_min_value, right_min_value, left_max_value, right_max_value;
3186
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
3187
+ nk_reduce_minmax_e3m2_neon(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index,
3188
+ &left_max_value, &left_max_index);
3189
+ nk_reduce_minmax_e3m2_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
3190
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
3191
+ if (nk_e3m2_order_serial(right_min_value, left_min_value) < 0)
3192
+ *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
3193
+ else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
3194
+ if (nk_e3m2_order_serial(right_max_value, left_max_value) > 0)
3195
+ *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
3196
+ else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
3197
+ }
3198
+ else if (stride_elements == 1)
3199
+ nk_reduce_minmax_e3m2_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
3200
+ max_index_ptr);
3201
+ else if (stride_elements <= 4)
3202
+ nk_reduce_minmax_e3m2_neon_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
3203
+ max_value_ptr, max_index_ptr);
3204
+ else
3205
+ nk_reduce_minmax_e3m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
3206
+ max_index_ptr);
3207
+ }
3208
+
3209
+ NK_INTERNAL void nk_reduce_moments_e4m3_neon_contiguous_( //
3210
+ nk_e4m3_t const *data_ptr, nk_size_t count, //
3211
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
3212
+ float32x4_t sum_f32x4 = vdupq_n_f32(0), sumsq_f32x4 = vdupq_n_f32(0);
3213
+ nk_size_t idx = 0;
3214
+ for (; idx + 16 <= count; idx += 16) {
3215
+ uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
3216
+ float16x8_t half_lo_f16x8, half_hi_f16x8;
3217
+ nk_e4m3x16_to_f16x8x2_neon_(raw_u8x16, &half_lo_f16x8, &half_hi_f16x8);
3218
+ float32x4_t a_f32x4 = vcvt_f32_f16(vget_low_f16(half_lo_f16x8));
3219
+ float32x4_t b_f32x4 = vcvt_high_f32_f16(half_lo_f16x8);
3220
+ float32x4_t c_f32x4 = vcvt_f32_f16(vget_low_f16(half_hi_f16x8));
3221
+ float32x4_t d_f32x4 = vcvt_high_f32_f16(half_hi_f16x8);
3222
+ sum_f32x4 = vaddq_f32(vaddq_f32(sum_f32x4, vaddq_f32(a_f32x4, b_f32x4)), vaddq_f32(c_f32x4, d_f32x4));
3223
+ sumsq_f32x4 = vfmaq_f32(vfmaq_f32(vfmaq_f32(vfmaq_f32( //
3224
+ sumsq_f32x4, a_f32x4, a_f32x4),
3225
+ b_f32x4, b_f32x4),
3226
+ c_f32x4, c_f32x4),
3227
+ d_f32x4, d_f32x4);
3228
+ }
3229
+ nk_f32_t sum = vaddvq_f32(sum_f32x4), sumsq = vaddvq_f32(sumsq_f32x4);
3230
+ for (; idx < count; ++idx) {
3231
+ nk_f32_t value_f32;
3232
+ nk_e4m3_to_f32_serial(&data_ptr[idx], &value_f32);
3233
+ sum += value_f32, sumsq += value_f32 * value_f32;
3234
+ }
3235
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
3236
+ }
3237
+
3238
+ NK_INTERNAL void nk_reduce_moments_e4m3_neon_strided_( //
3239
+ nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
3240
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
3241
+ float32x4_t sum_f32x4 = vdupq_n_f32(0), sumsq_f32x4 = vdupq_n_f32(0);
3242
+ nk_size_t idx = 0;
3243
+ if (stride_elements == 2) {
3244
+ for (; idx + 16 <= count; idx += 16) {
3245
+ uint8x16x2_t loaded_u8x16x2 = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
3246
+ float16x8_t half_lo_f16x8, half_hi_f16x8;
3247
+ nk_e4m3x16_to_f16x8x2_neon_(loaded_u8x16x2.val[0], &half_lo_f16x8, &half_hi_f16x8);
3248
+ float32x4_t a_f32x4 = vcvt_f32_f16(vget_low_f16(half_lo_f16x8));
3249
+ float32x4_t b_f32x4 = vcvt_high_f32_f16(half_lo_f16x8);
3250
+ float32x4_t c_f32x4 = vcvt_f32_f16(vget_low_f16(half_hi_f16x8));
3251
+ float32x4_t d_f32x4 = vcvt_high_f32_f16(half_hi_f16x8);
3252
+ sum_f32x4 = vaddq_f32(vaddq_f32(sum_f32x4, vaddq_f32(a_f32x4, b_f32x4)), vaddq_f32(c_f32x4, d_f32x4));
3253
+ sumsq_f32x4 = vfmaq_f32(vfmaq_f32(vfmaq_f32(vfmaq_f32( //
3254
+ sumsq_f32x4, a_f32x4, a_f32x4),
3255
+ b_f32x4, b_f32x4),
3256
+ c_f32x4, c_f32x4),
3257
+ d_f32x4, d_f32x4);
3258
+ }
3259
+ }
3260
+ else if (stride_elements == 3) {
3261
+ for (; idx + 16 <= count; idx += 16) {
3262
+ uint8x16x3_t loaded_u8x16x3 = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
3263
+ float16x8_t half_lo_f16x8, half_hi_f16x8;
3264
+ nk_e4m3x16_to_f16x8x2_neon_(loaded_u8x16x3.val[0], &half_lo_f16x8, &half_hi_f16x8);
3265
+ float32x4_t a_f32x4 = vcvt_f32_f16(vget_low_f16(half_lo_f16x8));
3266
+ float32x4_t b_f32x4 = vcvt_high_f32_f16(half_lo_f16x8);
3267
+ float32x4_t c_f32x4 = vcvt_f32_f16(vget_low_f16(half_hi_f16x8));
3268
+ float32x4_t d_f32x4 = vcvt_high_f32_f16(half_hi_f16x8);
3269
+ sum_f32x4 = vaddq_f32(vaddq_f32(sum_f32x4, vaddq_f32(a_f32x4, b_f32x4)), vaddq_f32(c_f32x4, d_f32x4));
3270
+ sumsq_f32x4 = vfmaq_f32(vfmaq_f32(vfmaq_f32(vfmaq_f32( //
3271
+ sumsq_f32x4, a_f32x4, a_f32x4),
3272
+ b_f32x4, b_f32x4),
3273
+ c_f32x4, c_f32x4),
3274
+ d_f32x4, d_f32x4);
3275
+ }
3276
+ }
3277
+ else {
3278
+ for (; idx + 16 <= count; idx += 16) {
3279
+ uint8x16x4_t loaded_u8x16x4 = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
3280
+ float16x8_t half_lo_f16x8, half_hi_f16x8;
3281
+ nk_e4m3x16_to_f16x8x2_neon_(loaded_u8x16x4.val[0], &half_lo_f16x8, &half_hi_f16x8);
3282
+ float32x4_t a_f32x4 = vcvt_f32_f16(vget_low_f16(half_lo_f16x8));
3283
+ float32x4_t b_f32x4 = vcvt_high_f32_f16(half_lo_f16x8);
3284
+ float32x4_t c_f32x4 = vcvt_f32_f16(vget_low_f16(half_hi_f16x8));
3285
+ float32x4_t d_f32x4 = vcvt_high_f32_f16(half_hi_f16x8);
3286
+ sum_f32x4 = vaddq_f32(vaddq_f32(sum_f32x4, vaddq_f32(a_f32x4, b_f32x4)), vaddq_f32(c_f32x4, d_f32x4));
3287
+ sumsq_f32x4 = vfmaq_f32(vfmaq_f32(vfmaq_f32(vfmaq_f32( //
3288
+ sumsq_f32x4, a_f32x4, a_f32x4),
3289
+ b_f32x4, b_f32x4),
3290
+ c_f32x4, c_f32x4),
3291
+ d_f32x4, d_f32x4);
3292
+ }
3293
+ }
3294
+ nk_f32_t sum = vaddvq_f32(sum_f32x4), sumsq = vaddvq_f32(sumsq_f32x4);
3295
+ for (; idx < count; ++idx) {
3296
+ nk_f32_t val;
3297
+ nk_e4m3_to_f32_serial((nk_e4m3_t const *)(data_ptr + idx * stride_elements), &val);
3298
+ sum += val, sumsq += val * val;
3299
+ }
3300
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
3301
+ }
3302
+
3303
+ NK_PUBLIC void nk_reduce_moments_e4m3_neon( //
3304
+ nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3305
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
3306
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
3307
+ int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
3308
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
3309
+ else if (!aligned) nk_reduce_moments_e4m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3310
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 16) {
3311
+ nk_size_t left_count = count / 2;
3312
+ nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
3313
+ nk_reduce_moments_e4m3_neon(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
3314
+ nk_reduce_moments_e4m3_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
3315
+ &right_sum, &right_sumsq);
3316
+ *sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
3317
+ }
3318
+ else if (stride_elements == 1) nk_reduce_moments_e4m3_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
3319
+ else if (stride_elements <= 4)
3320
+ nk_reduce_moments_e4m3_neon_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
3321
+ else nk_reduce_moments_e4m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3322
+ }
3323
+
3324
+ NK_INTERNAL void nk_reduce_minmax_e4m3_neon_contiguous_( //
3325
+ nk_e4m3_t const *data_ptr, nk_size_t count, //
3326
+ nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
3327
+ nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
3328
+ uint8x16_t min_u8x16 = vdupq_n_u8(0xFF), max_u8x16 = vdupq_n_u8(0);
3329
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
3330
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
3331
+ nk_size_t idx = 0;
3332
+ for (; idx + 16 <= count; idx += 16) {
3333
+ uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
3334
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(raw_u8x16);
3335
+ uint8x16_t is_nan_u8x16 = vorrq_u8(vceqq_u8(comparable_u8x16, vdupq_n_u8(0x00)),
3336
+ vceqq_u8(comparable_u8x16, vdupq_n_u8(0xFF)));
3337
+ uint8x16_t data_min_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
3338
+ uint8x16_t data_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
3339
+ uint8x16_t less_u8x16 = vcltq_u8(data_min_u8x16, min_u8x16);
3340
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_max_u8x16, max_u8x16);
3341
+ min_u8x16 = vbslq_u8(less_u8x16, data_min_u8x16, min_u8x16);
3342
+ max_u8x16 = vbslq_u8(greater_u8x16, data_max_u8x16, max_u8x16);
3343
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
3344
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
3345
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
3346
+ }
3347
+ nk_size_t remaining = count - idx;
3348
+ if (remaining > 0) {
3349
+ nk_b128_vec_t tail_vec;
3350
+ nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
3351
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
3352
+ uint8x16_t is_nan_u8x16 = vorrq_u8(vceqq_u8(comparable_u8x16, vdupq_n_u8(0x00)),
3353
+ vceqq_u8(comparable_u8x16, vdupq_n_u8(0xFF)));
3354
+ uint8x16_t nan_min_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
3355
+ uint8x16_t nan_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
3356
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
3357
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
3358
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)remaining));
3359
+ uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, nan_min_u8x16, vdupq_n_u8(0xFF));
3360
+ uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, nan_max_u8x16, vdupq_n_u8(0));
3361
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
3362
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
3363
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
3364
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
3365
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
3366
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
3367
+ }
3368
+ nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
3369
+ // If min stayed at 0xFF, all values were NaN
3370
+ if (min_comparable == 0xFF) {
3371
+ *min_value_ptr = (nk_e4m3_t)NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX;
3372
+ *max_value_ptr = (nk_e4m3_t)NK_E4M3_MIN, *max_index_ptr = NK_SIZE_MAX;
3373
+ return;
3374
+ }
3375
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
3376
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
3377
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
3378
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
3379
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
3380
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
3381
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
3382
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
3383
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
3384
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
3385
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
3386
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
3387
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
3388
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
3389
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
3390
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
3391
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
3392
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
3393
+ *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
3394
+ *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
3395
+ }
3396
+
3397
+ NK_INTERNAL void nk_reduce_minmax_e4m3_neon_strided_( //
3398
+ nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
3399
+ nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
3400
+ nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
3401
+ uint8x16_t min_u8x16 = vdupq_n_u8(0xFF), max_u8x16 = vdupq_n_u8(0);
3402
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
3403
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
3404
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
3405
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
3406
+ nk_size_t idx = 0;
3407
+ uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
3408
+
3409
+ nk_reduce_minmax_e4m3_neon_cycle:
3410
+ if (stride_elements == 2 && idx + 16 <= count) {
3411
+ uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
3412
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
3413
+ uint8x16_t is_nan_u8x16 = vorrq_u8(vceqq_u8(comparable_u8x16, vdupq_n_u8(0x00)),
3414
+ vceqq_u8(comparable_u8x16, vdupq_n_u8(0xFF)));
3415
+ data_for_min_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
3416
+ data_for_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
3417
+ idx += 16;
3418
+ }
3419
+ else if (stride_elements == 3 && idx + 16 <= count) {
3420
+ uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
3421
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
3422
+ uint8x16_t is_nan_u8x16 = vorrq_u8(vceqq_u8(comparable_u8x16, vdupq_n_u8(0x00)),
3423
+ vceqq_u8(comparable_u8x16, vdupq_n_u8(0xFF)));
3424
+ data_for_min_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
3425
+ data_for_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
3426
+ idx += 16;
3427
+ }
3428
+ else if (stride_elements == 4 && idx + 16 <= count) {
3429
+ uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
3430
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
3431
+ uint8x16_t is_nan_u8x16 = vorrq_u8(vceqq_u8(comparable_u8x16, vdupq_n_u8(0x00)),
3432
+ vceqq_u8(comparable_u8x16, vdupq_n_u8(0xFF)));
3433
+ data_for_min_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
3434
+ data_for_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
3435
+ idx += 16;
3436
+ }
3437
+ else if (idx < count) {
3438
+ nk_b128_vec_t tail_vec;
3439
+ nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
3440
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
3441
+ uint8x16_t is_nan_u8x16 = vorrq_u8(vceqq_u8(comparable_u8x16, vdupq_n_u8(0x00)),
3442
+ vceqq_u8(comparable_u8x16, vdupq_n_u8(0xFF)));
3443
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)(count - idx)));
3444
+ uint8x16_t invalid_or_nan_u8x16 = vornq_u8(is_nan_u8x16, valid_u8x16);
3445
+ data_for_min_u8x16 = vbslq_u8(invalid_or_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
3446
+ data_for_max_u8x16 = vbslq_u8(invalid_or_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
3447
+ idx = count;
3448
+ }
3449
+ else {
3450
+ nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
3451
+ if (min_comparable == 0xFF) {
3452
+ *min_value_ptr = (nk_e4m3_t)NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX;
3453
+ *max_value_ptr = (nk_e4m3_t)NK_E4M3_MIN, *max_index_ptr = NK_SIZE_MAX;
3454
+ return;
3455
+ }
3456
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
3457
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
3458
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
3459
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
3460
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
3461
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
3462
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
3463
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
3464
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
3465
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
3466
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
3467
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
3468
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
3469
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
3470
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
3471
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
3472
+ *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
3473
+ *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
3474
+ return;
3475
+ }
3476
+
3477
+ // Shared update body
3478
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
3479
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
3480
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
3481
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
3482
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
3483
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
3484
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
3485
+ goto nk_reduce_minmax_e4m3_neon_cycle;
3486
+ }
3487
+
3488
+ NK_PUBLIC void nk_reduce_minmax_e4m3_neon( //
3489
+ nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3490
+ nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
3491
+ nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
3492
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
3493
+ int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
3494
+ if (count == 0)
3495
+ *min_value_ptr = NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E4M3_MIN,
3496
+ *max_index_ptr = NK_SIZE_MAX;
3497
+ else if (!aligned)
3498
+ nk_reduce_minmax_e4m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
3499
+ max_index_ptr);
3500
+ else if (count > (nk_size_t)256 * 16) {
3501
+ nk_size_t left_count = count / 2;
3502
+ nk_e4m3_t left_min_value, right_min_value, left_max_value, right_max_value;
3503
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
3504
+ nk_reduce_minmax_e4m3_neon(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index,
3505
+ &left_max_value, &left_max_index);
3506
+ nk_reduce_minmax_e4m3_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
3507
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
3508
+ if (left_min_index == NK_SIZE_MAX)
3509
+ *min_value_ptr = right_min_value,
3510
+ *min_index_ptr = right_min_index == NK_SIZE_MAX ? NK_SIZE_MAX : left_count + right_min_index;
3511
+ else if (right_min_index == NK_SIZE_MAX || nk_e4m3_order_serial(left_min_value, right_min_value) <= 0)
3512
+ *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
3513
+ else *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
3514
+ if (left_max_index == NK_SIZE_MAX)
3515
+ *max_value_ptr = right_max_value,
3516
+ *max_index_ptr = right_max_index == NK_SIZE_MAX ? NK_SIZE_MAX : left_count + right_max_index;
3517
+ else if (right_max_index == NK_SIZE_MAX || nk_e4m3_order_serial(right_max_value, left_max_value) <= 0)
3518
+ *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
3519
+ else *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
3520
+ }
3521
+ else if (stride_elements == 1)
3522
+ nk_reduce_minmax_e4m3_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
3523
+ max_index_ptr);
3524
+ else if (stride_elements <= 4)
3525
+ nk_reduce_minmax_e4m3_neon_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
3526
+ max_value_ptr, max_index_ptr);
3527
+ else
3528
+ nk_reduce_minmax_e4m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
3529
+ max_index_ptr);
3530
+ }
3531
+
3532
+ NK_INTERNAL void nk_reduce_moments_e5m2_neon_contiguous_( //
3533
+ nk_e5m2_t const *data_ptr, nk_size_t count, //
3534
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
3535
+ float32x4_t sum_f32x4 = vdupq_n_f32(0), sumsq_f32x4 = vdupq_n_f32(0);
3536
+ nk_size_t idx = 0;
3537
+ for (; idx + 8 <= count; idx += 8) {
3538
+ uint8x8_t raw_u8x8 = vld1_u8((nk_u8_t const *)(data_ptr + idx));
3539
+ float16x8_t half_f16x8 = nk_e5m2x8_to_f16x8_neon_(raw_u8x8);
3540
+ float32x4_t lo_f32x4 = vcvt_f32_f16(vget_low_f16(half_f16x8));
3541
+ float32x4_t hi_f32x4 = vcvt_high_f32_f16(half_f16x8);
3542
+ sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(lo_f32x4, hi_f32x4));
3543
+ sumsq_f32x4 = vfmaq_f32(vfmaq_f32(sumsq_f32x4, lo_f32x4, lo_f32x4), hi_f32x4, hi_f32x4);
3544
+ }
3545
+ nk_f32_t sum = vaddvq_f32(sum_f32x4), sumsq = vaddvq_f32(sumsq_f32x4);
3546
+ for (; idx < count; ++idx) {
3547
+ nk_f32_t value_f32;
3548
+ nk_e5m2_to_f32_serial(&data_ptr[idx], &value_f32);
3549
+ sum += value_f32, sumsq += value_f32 * value_f32;
3550
+ }
3551
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
3552
+ }
3553
+
3554
+ NK_INTERNAL void nk_reduce_moments_e5m2_neon_strided_( //
3555
+ nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
3556
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
3557
+ float32x4_t sum_f32x4 = vdupq_n_f32(0), sumsq_f32x4 = vdupq_n_f32(0);
3558
+ nk_size_t idx = 0;
3559
+ if (stride_elements == 2) {
3560
+ for (; idx + 8 <= count; idx += 8) {
3561
+ uint8x8x2_t loaded_u8x8x2 = vld2_u8((nk_u8_t const *)(data_ptr + idx * 2));
3562
+ float16x8_t half_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded_u8x8x2.val[0]);
3563
+ float32x4_t lo_f32x4 = vcvt_f32_f16(vget_low_f16(half_f16x8));
3564
+ float32x4_t hi_f32x4 = vcvt_high_f32_f16(half_f16x8);
3565
+ sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(lo_f32x4, hi_f32x4));
3566
+ sumsq_f32x4 = vfmaq_f32(vfmaq_f32(sumsq_f32x4, lo_f32x4, lo_f32x4), hi_f32x4, hi_f32x4);
3567
+ }
3568
+ }
3569
+ else if (stride_elements == 3) {
3570
+ for (; idx + 8 <= count; idx += 8) {
3571
+ uint8x8x3_t loaded_u8x8x3 = vld3_u8((nk_u8_t const *)(data_ptr + idx * 3));
3572
+ float16x8_t half_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded_u8x8x3.val[0]);
3573
+ float32x4_t lo_f32x4 = vcvt_f32_f16(vget_low_f16(half_f16x8));
3574
+ float32x4_t hi_f32x4 = vcvt_high_f32_f16(half_f16x8);
3575
+ sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(lo_f32x4, hi_f32x4));
3576
+ sumsq_f32x4 = vfmaq_f32(vfmaq_f32(sumsq_f32x4, lo_f32x4, lo_f32x4), hi_f32x4, hi_f32x4);
3577
+ }
3578
+ }
3579
+ else {
3580
+ for (; idx + 8 <= count; idx += 8) {
3581
+ uint8x8x4_t loaded_u8x8x4 = vld4_u8((nk_u8_t const *)(data_ptr + idx * 4));
3582
+ float16x8_t half_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded_u8x8x4.val[0]);
3583
+ float32x4_t lo_f32x4 = vcvt_f32_f16(vget_low_f16(half_f16x8));
3584
+ float32x4_t hi_f32x4 = vcvt_high_f32_f16(half_f16x8);
3585
+ sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(lo_f32x4, hi_f32x4));
3586
+ sumsq_f32x4 = vfmaq_f32(vfmaq_f32(sumsq_f32x4, lo_f32x4, lo_f32x4), hi_f32x4, hi_f32x4);
3587
+ }
3588
+ }
3589
+ nk_f32_t sum = vaddvq_f32(sum_f32x4), sumsq = vaddvq_f32(sumsq_f32x4);
3590
+ for (; idx < count; ++idx) {
3591
+ nk_f32_t val;
3592
+ nk_e5m2_to_f32_serial((nk_e5m2_t const *)(data_ptr + idx * stride_elements), &val);
3593
+ sum += val, sumsq += val * val;
3594
+ }
3595
+ *sum_ptr = sum, *sumsq_ptr = sumsq;
3596
+ }
3597
+
3598
+ NK_PUBLIC void nk_reduce_moments_e5m2_neon( //
3599
+ nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3600
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
3601
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
3602
+ int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
3603
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
3604
+ else if (!aligned) nk_reduce_moments_e5m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3605
+ else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
3606
+ nk_size_t left_count = count / 2;
3607
+ nk_f32_t left_sum, left_sumsq, right_sum, right_sumsq;
3608
+ nk_reduce_moments_e5m2_neon(data_ptr, left_count, stride_bytes, &left_sum, &left_sumsq);
3609
+ nk_reduce_moments_e5m2_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
3610
+ &right_sum, &right_sumsq);
3611
+ *sum_ptr = left_sum + right_sum, *sumsq_ptr = left_sumsq + right_sumsq;
3612
+ }
3613
+ else if (stride_elements == 1) nk_reduce_moments_e5m2_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
3614
+ else if (stride_elements <= 4)
3615
+ nk_reduce_moments_e5m2_neon_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
3616
+ else nk_reduce_moments_e5m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
3617
+ }
3618
+
3619
+ NK_INTERNAL void nk_reduce_minmax_e5m2_neon_contiguous_( //
3620
+ nk_e5m2_t const *data_ptr, nk_size_t count, //
3621
+ nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
3622
+ nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
3623
+ uint8x16_t min_u8x16 = vdupq_n_u8(0xFF), max_u8x16 = vdupq_n_u8(0);
3624
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
3625
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
3626
+ nk_size_t idx = 0;
3627
+ for (; idx + 16 <= count; idx += 16) {
3628
+ uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
3629
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(raw_u8x16);
3630
+ uint8x16_t is_nan_low_u8x16 = vcleq_u8(comparable_u8x16, vdupq_n_u8(0x02));
3631
+ uint8x16_t is_nan_high_u8x16 = vcgeq_u8(comparable_u8x16, vdupq_n_u8(0xFD));
3632
+ uint8x16_t is_nan_u8x16 = vorrq_u8(is_nan_low_u8x16, is_nan_high_u8x16);
3633
+ uint8x16_t data_min_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
3634
+ uint8x16_t data_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
3635
+ uint8x16_t less_u8x16 = vcltq_u8(data_min_u8x16, min_u8x16);
3636
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_max_u8x16, max_u8x16);
3637
+ min_u8x16 = vbslq_u8(less_u8x16, data_min_u8x16, min_u8x16);
3638
+ max_u8x16 = vbslq_u8(greater_u8x16, data_max_u8x16, max_u8x16);
3639
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
3640
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
3641
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
3642
+ }
3643
+ nk_size_t remaining = count - idx;
3644
+ if (remaining > 0) {
3645
+ nk_b128_vec_t tail_vec;
3646
+ nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
3647
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
3648
+ uint8x16_t is_nan_low_u8x16 = vcleq_u8(comparable_u8x16, vdupq_n_u8(0x02));
3649
+ uint8x16_t is_nan_high_u8x16 = vcgeq_u8(comparable_u8x16, vdupq_n_u8(0xFD));
3650
+ uint8x16_t is_nan_u8x16 = vorrq_u8(is_nan_low_u8x16, is_nan_high_u8x16);
3651
+ uint8x16_t nan_min_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
3652
+ uint8x16_t nan_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
3653
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
3654
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
3655
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)remaining));
3656
+ uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, nan_min_u8x16, vdupq_n_u8(0xFF));
3657
+ uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, nan_max_u8x16, vdupq_n_u8(0));
3658
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
3659
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
3660
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
3661
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
3662
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
3663
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
3664
+ }
3665
+ nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
3666
+ // If min stayed at 0xFF, all values were NaN
3667
+ if (min_comparable == 0xFF) {
3668
+ *min_value_ptr = (nk_e5m2_t)NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX;
3669
+ *max_value_ptr = (nk_e5m2_t)NK_E5M2_MIN, *max_index_ptr = NK_SIZE_MAX;
3670
+ return;
3671
+ }
3672
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
3673
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
3674
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
3675
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
3676
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
3677
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
3678
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
3679
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
3680
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
3681
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
3682
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
3683
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
3684
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
3685
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
3686
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
3687
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
3688
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
3689
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
3690
+ *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
3691
+ *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
3692
+ }
3693
+
3694
+ NK_INTERNAL void nk_reduce_minmax_e5m2_neon_strided_( //
3695
+ nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
3696
+ nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
3697
+ nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
3698
+ uint8x16_t min_u8x16 = vdupq_n_u8(0xFF), max_u8x16 = vdupq_n_u8(0);
3699
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
3700
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
3701
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
3702
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
3703
+ nk_size_t idx = 0;
3704
+ uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
3705
+
3706
+ nk_reduce_minmax_e5m2_neon_cycle:
3707
+ if (stride_elements == 2 && idx + 16 <= count) {
3708
+ uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
3709
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
3710
+ uint8x16_t is_nan_u8x16 = vorrq_u8(vcleq_u8(comparable_u8x16, vdupq_n_u8(0x02)),
3711
+ vcgeq_u8(comparable_u8x16, vdupq_n_u8(0xFD)));
3712
+ data_for_min_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
3713
+ data_for_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
3714
+ idx += 16;
3715
+ }
3716
+ else if (stride_elements == 3 && idx + 16 <= count) {
3717
+ uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
3718
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
3719
+ uint8x16_t is_nan_u8x16 = vorrq_u8(vcleq_u8(comparable_u8x16, vdupq_n_u8(0x02)),
3720
+ vcgeq_u8(comparable_u8x16, vdupq_n_u8(0xFD)));
3721
+ data_for_min_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
3722
+ data_for_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
3723
+ idx += 16;
3724
+ }
3725
+ else if (stride_elements == 4 && idx + 16 <= count) {
3726
+ uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
3727
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
3728
+ uint8x16_t is_nan_u8x16 = vorrq_u8(vcleq_u8(comparable_u8x16, vdupq_n_u8(0x02)),
3729
+ vcgeq_u8(comparable_u8x16, vdupq_n_u8(0xFD)));
3730
+ data_for_min_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
3731
+ data_for_max_u8x16 = vbslq_u8(is_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
3732
+ idx += 16;
3733
+ }
3734
+ else if (idx < count) {
3735
+ nk_b128_vec_t tail_vec;
3736
+ nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
3737
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
3738
+ uint8x16_t is_nan_u8x16 = vorrq_u8(vcleq_u8(comparable_u8x16, vdupq_n_u8(0x02)),
3739
+ vcgeq_u8(comparable_u8x16, vdupq_n_u8(0xFD)));
3740
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)(count - idx)));
3741
+ uint8x16_t invalid_or_nan_u8x16 = vornq_u8(is_nan_u8x16, valid_u8x16);
3742
+ data_for_min_u8x16 = vbslq_u8(invalid_or_nan_u8x16, vdupq_n_u8(0xFF), comparable_u8x16);
3743
+ data_for_max_u8x16 = vbslq_u8(invalid_or_nan_u8x16, vdupq_n_u8(0x00), comparable_u8x16);
3744
+ idx = count;
3745
+ }
3746
+ else {
3747
+ nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
3748
+ if (min_comparable == 0xFF) {
3749
+ *min_value_ptr = (nk_e5m2_t)NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX;
3750
+ *max_value_ptr = (nk_e5m2_t)NK_E5M2_MIN, *max_index_ptr = NK_SIZE_MAX;
3751
+ return;
3752
+ }
3753
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
3754
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
3755
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
3756
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
3757
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
3758
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
3759
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
3760
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
3761
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
3762
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
3763
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
3764
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
3765
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
3766
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
3767
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
3768
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
3769
+ *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
3770
+ *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
3771
+ return;
3772
+ }
3773
+
3774
+ // Shared update body
3775
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
3776
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
3777
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
3778
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
3779
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
3780
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
3781
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
3782
+ goto nk_reduce_minmax_e5m2_neon_cycle;
3783
+ }
3784
+
3785
+ NK_PUBLIC void nk_reduce_minmax_e5m2_neon( //
3786
+ nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
3787
+ nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
3788
+ nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
3789
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
3790
+ int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
3791
+ if (count == 0)
3792
+ *min_value_ptr = NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E5M2_MIN,
3793
+ *max_index_ptr = NK_SIZE_MAX;
3794
+ else if (!aligned)
3795
+ nk_reduce_minmax_e5m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
3796
+ max_index_ptr);
3797
+ else if (count > (nk_size_t)256 * 16) {
3798
+ nk_size_t left_count = count / 2;
3799
+ nk_e5m2_t left_min_value, right_min_value, left_max_value, right_max_value;
3800
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
3801
+ nk_reduce_minmax_e5m2_neon(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index,
3802
+ &left_max_value, &left_max_index);
3803
+ nk_reduce_minmax_e5m2_neon(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
3804
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
3805
+ if (left_min_index == NK_SIZE_MAX)
3806
+ *min_value_ptr = right_min_value,
3807
+ *min_index_ptr = right_min_index == NK_SIZE_MAX ? NK_SIZE_MAX : left_count + right_min_index;
3808
+ else if (right_min_index == NK_SIZE_MAX || nk_e5m2_order_serial(left_min_value, right_min_value) <= 0)
3809
+ *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
3810
+ else *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
3811
+ if (left_max_index == NK_SIZE_MAX)
3812
+ *max_value_ptr = right_max_value,
3813
+ *max_index_ptr = right_max_index == NK_SIZE_MAX ? NK_SIZE_MAX : left_count + right_max_index;
3814
+ else if (right_max_index == NK_SIZE_MAX || nk_e5m2_order_serial(right_max_value, left_max_value) <= 0)
3815
+ *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
3816
+ else *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
3817
+ }
3818
+ else if (stride_elements == 1)
3819
+ nk_reduce_minmax_e5m2_neon_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
3820
+ max_index_ptr);
3821
+ else if (stride_elements <= 4)
3822
+ nk_reduce_minmax_e5m2_neon_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
3823
+ max_value_ptr, max_index_ptr);
3824
+ else
3825
+ nk_reduce_minmax_e5m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
3826
+ max_index_ptr);
3827
+ }
3828
+
3829
+ #if defined(__clang__)
3830
+ #pragma clang attribute pop
3831
+ #elif defined(__GNUC__)
3832
+ #pragma GCC pop_options
3833
+ #endif
3834
+
3835
+ #if defined(__cplusplus)
3836
+ } // extern "C"
3837
+ #endif
3838
+
3839
+ #endif // NK_TARGET_NEON
3840
+ #endif // NK_TARGET_ARM_
3841
+ #endif // NK_REDUCE_NEON_H