numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,665 @@
1
+ /**
2
+ * @brief ARMv8.4-FHM implementations for the redesigned reduction API (moments + minmax).
3
+ * @file include/numkong/reduce/neonfhm.h
4
+ * @author Ash Vardanian
5
+ * @date February 13, 2026
6
+ *
7
+ * @sa include/numkong/reduce.h
8
+ */
9
+ #ifndef NK_REDUCE_NEONFHM_H
10
+ #define NK_REDUCE_NEONFHM_H
11
+
12
+ #if NK_TARGET_ARM_
13
+ #if NK_TARGET_NEONFHM
14
+
15
+ #include "numkong/types.h" // `nk_e4m3_t`
16
+ #include "numkong/cast/serial.h" // `nk_f16_to_f32_serial`
17
+ #include "numkong/cast/neon.h" // `nk_e4m3x8_to_f16x8_neon_`
18
+ #include "numkong/reduce/serial.h" // `nk_reduce_moments_e4m3_serial`
19
+
20
+ #if defined(__cplusplus)
21
+ extern "C" {
22
+ #endif
23
+
24
+ #if defined(__clang__)
25
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16+fp16fml"))), apply_to = function)
26
+ #elif defined(__GNUC__)
27
+ #pragma GCC push_options
28
+ #pragma GCC target("arch=armv8.2-a+simd+fp16+fp16fml")
29
+ #endif
30
+
31
+ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_contiguous_( //
32
+ nk_e4m3_t const *data_ptr, nk_size_t count, //
33
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
34
+
35
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
36
+ float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
37
+ float16x8_t ones_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0x3C00));
38
+ nk_size_t idx = 0;
39
+
40
+ for (; idx + 8 <= count; idx += 8) {
41
+ uint8x8_t data_u8x8 = vld1_u8((uint8_t const *)(data_ptr + idx));
42
+ float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(data_u8x8);
43
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
44
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
45
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
46
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
47
+ }
48
+
49
+ // Tail: partial load for remaining elements (< 8)
50
+ if (idx < count) {
51
+ nk_b64_vec_t tail_vec;
52
+ nk_partial_load_b8x8_serial_(data_ptr + idx, &tail_vec, count - idx);
53
+ float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(tail_vec.u8x8);
54
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
55
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
56
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
57
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
58
+ }
59
+
60
+ *sum_ptr = vaddvq_f32(sum_f32x4);
61
+ *sumsq_ptr = vaddvq_f32(sumsq_f32x4);
62
+ }
63
+
64
+ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_strided_( //
65
+ nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
66
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
67
+
68
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
69
+ float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
70
+ float16x8_t ones_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0x3C00));
71
+ nk_size_t idx = 0;
72
+
73
+ if (stride_elements == 2) {
74
+ for (; idx + 8 <= count; idx += 8) {
75
+ uint8x8x2_t loaded = vld2_u8((nk_u8_t const *)(data_ptr + idx * 2));
76
+ float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(loaded.val[0]);
77
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
78
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
79
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
80
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
81
+ }
82
+ }
83
+ else if (stride_elements == 3) {
84
+ for (; idx + 8 <= count; idx += 8) {
85
+ uint8x8x3_t loaded = vld3_u8((nk_u8_t const *)(data_ptr + idx * 3));
86
+ float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(loaded.val[0]);
87
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
88
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
89
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
90
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
91
+ }
92
+ }
93
+ else if (stride_elements == 4) {
94
+ for (; idx + 8 <= count; idx += 8) {
95
+ uint8x8x4_t loaded = vld4_u8((nk_u8_t const *)(data_ptr + idx * 4));
96
+ float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(loaded.val[0]);
97
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
98
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
99
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
100
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
101
+ }
102
+ }
103
+ else {
104
+ nk_e4m3_t const *ptr = data_ptr;
105
+ for (; idx + 8 <= count; idx += 8) {
106
+ nk_b64_vec_t data_vec = {0};
107
+ for (nk_size_t i = 0; i < 8; ++i) {
108
+ data_vec.u8s[i] = *ptr;
109
+ ptr += stride_elements;
110
+ }
111
+ float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(data_vec.u8x8);
112
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
113
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
114
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
115
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
116
+ }
117
+ }
118
+
119
+ if (idx < count) {
120
+ nk_b64_vec_t data_vec = {0};
121
+ nk_e4m3_t const *ptr = data_ptr + idx * stride_elements;
122
+ for (nk_size_t i = 0; idx + i < count; ++i) {
123
+ data_vec.u8s[i] = *ptr;
124
+ ptr += stride_elements;
125
+ }
126
+ float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(data_vec.u8x8);
127
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
128
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
129
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
130
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
131
+ }
132
+
133
+ *sum_ptr = vaddvq_f32(sum_f32x4);
134
+ *sumsq_ptr = vaddvq_f32(sumsq_f32x4);
135
+ }
136
+
137
+ NK_PUBLIC void nk_reduce_moments_e4m3_neonfhm( //
138
+ nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
139
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
140
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
141
+ int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
142
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
143
+ else if (!aligned) nk_reduce_moments_e4m3_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
144
+ else if (count > (nk_size_t)5000 * 8) {
145
+ nk_size_t left_count = count / 2;
146
+ nk_f32_t left_sum_value, left_sumsq_value, right_sum_value, right_sumsq_value;
147
+ nk_reduce_moments_e4m3_neonfhm(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
148
+ nk_reduce_moments_e4m3_neonfhm(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
149
+ &right_sum_value, &right_sumsq_value);
150
+ *sum_ptr = left_sum_value + right_sum_value, *sumsq_ptr = left_sumsq_value + right_sumsq_value;
151
+ }
152
+ else if (stride_elements == 1) nk_reduce_moments_e4m3_neonfhm_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
153
+ else nk_reduce_moments_e4m3_neonfhm_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
154
+ }
155
+
156
+ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_contiguous_( //
157
+ nk_e5m2_t const *data_ptr, nk_size_t count, //
158
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
159
+
160
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
161
+ float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
162
+ float16x8_t ones_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0x3C00));
163
+ nk_size_t idx = 0;
164
+
165
+ for (; idx + 8 <= count; idx += 8) {
166
+ uint8x8_t data_u8x8 = vld1_u8((uint8_t const *)(data_ptr + idx));
167
+ float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(data_u8x8);
168
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
169
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
170
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
171
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
172
+ }
173
+
174
+ // Tail: partial load for remaining elements (< 8)
175
+ if (idx < count) {
176
+ nk_b64_vec_t tail_vec;
177
+ nk_partial_load_b8x8_serial_(data_ptr + idx, &tail_vec, count - idx);
178
+ float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(tail_vec.u8x8);
179
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
180
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
181
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
182
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
183
+ }
184
+
185
+ *sum_ptr = vaddvq_f32(sum_f32x4);
186
+ *sumsq_ptr = vaddvq_f32(sumsq_f32x4);
187
+ }
188
+
189
+ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_strided_( //
190
+ nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
191
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
192
+
193
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
194
+ float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
195
+ float16x8_t ones_f16x8 = vreinterpretq_f16_u16(vdupq_n_u16(0x3C00));
196
+ nk_size_t idx = 0;
197
+
198
+ if (stride_elements == 2) {
199
+ for (; idx + 8 <= count; idx += 8) {
200
+ uint8x8x2_t loaded = vld2_u8((nk_u8_t const *)(data_ptr + idx * 2));
201
+ float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded.val[0]);
202
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
203
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
204
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
205
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
206
+ }
207
+ }
208
+ else if (stride_elements == 3) {
209
+ for (; idx + 8 <= count; idx += 8) {
210
+ uint8x8x3_t loaded = vld3_u8((nk_u8_t const *)(data_ptr + idx * 3));
211
+ float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded.val[0]);
212
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
213
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
214
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
215
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
216
+ }
217
+ }
218
+ else if (stride_elements == 4) {
219
+ for (; idx + 8 <= count; idx += 8) {
220
+ uint8x8x4_t loaded = vld4_u8((nk_u8_t const *)(data_ptr + idx * 4));
221
+ float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded.val[0]);
222
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
223
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
224
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
225
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
226
+ }
227
+ }
228
+ else {
229
+ nk_e5m2_t const *ptr = data_ptr;
230
+ for (; idx + 8 <= count; idx += 8) {
231
+ nk_b64_vec_t data_vec = {0};
232
+ for (nk_size_t i = 0; i < 8; ++i) {
233
+ data_vec.u8s[i] = *ptr;
234
+ ptr += stride_elements;
235
+ }
236
+ float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(data_vec.u8x8);
237
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
238
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
239
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
240
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
241
+ }
242
+ }
243
+
244
+ if (idx < count) {
245
+ nk_b64_vec_t data_vec = {0};
246
+ nk_e5m2_t const *ptr = data_ptr + idx * stride_elements;
247
+ for (nk_size_t i = 0; idx + i < count; ++i) {
248
+ data_vec.u8s[i] = *ptr;
249
+ ptr += stride_elements;
250
+ }
251
+ float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(data_vec.u8x8);
252
+ sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
253
+ sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
254
+ sumsq_f32x4 = vfmlalq_low_f16(sumsq_f32x4, data_f16x8, data_f16x8);
255
+ sumsq_f32x4 = vfmlalq_high_f16(sumsq_f32x4, data_f16x8, data_f16x8);
256
+ }
257
+
258
+ *sum_ptr = vaddvq_f32(sum_f32x4);
259
+ *sumsq_ptr = vaddvq_f32(sumsq_f32x4);
260
+ }
261
+
262
+ NK_PUBLIC void nk_reduce_moments_e5m2_neonfhm( //
263
+ nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
264
+ nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
265
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
266
+ int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
267
+ if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
268
+ else if (!aligned) nk_reduce_moments_e5m2_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
269
+ else if (count > (nk_size_t)5000 * 8) {
270
+ nk_size_t left_count = count / 2;
271
+ nk_f32_t left_sum_value, left_sumsq_value, right_sum_value, right_sumsq_value;
272
+ nk_reduce_moments_e5m2_neonfhm(data_ptr, left_count, stride_bytes, &left_sum_value, &left_sumsq_value);
273
+ nk_reduce_moments_e5m2_neonfhm(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
274
+ &right_sum_value, &right_sumsq_value);
275
+ *sum_ptr = left_sum_value + right_sum_value, *sumsq_ptr = left_sumsq_value + right_sumsq_value;
276
+ }
277
+ else if (stride_elements == 1) nk_reduce_moments_e5m2_neonfhm_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
278
+ else nk_reduce_moments_e5m2_neonfhm_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
279
+ }
280
+
281
+ NK_INTERNAL void nk_reduce_minmax_e4m3_neonfhm_contiguous_( //
282
+ nk_e4m3_t const *data_ptr, nk_size_t count, //
283
+ nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
284
+ nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
285
+ uint8x16_t min_u8x16 = vdupq_n_u8(0xFF);
286
+ uint8x16_t max_u8x16 = vdupq_n_u8(0x00);
287
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
288
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
289
+ nk_size_t idx = 0;
290
+ for (; idx + 16 <= count; idx += 16) {
291
+ uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
292
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(raw_u8x16);
293
+ uint8x16_t less_u8x16 = vcltq_u8(comparable_u8x16, min_u8x16);
294
+ uint8x16_t greater_u8x16 = vcgtq_u8(comparable_u8x16, max_u8x16);
295
+ min_u8x16 = vbslq_u8(less_u8x16, comparable_u8x16, min_u8x16);
296
+ max_u8x16 = vbslq_u8(greater_u8x16, comparable_u8x16, max_u8x16);
297
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
298
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
299
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
300
+ }
301
+ nk_size_t remaining = count - idx;
302
+ if (remaining > 0) {
303
+ nk_b128_vec_t tail_vec;
304
+ nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
305
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
306
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
307
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
308
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)remaining));
309
+ uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
310
+ uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0));
311
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
312
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
313
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
314
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
315
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
316
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
317
+ }
318
+ nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
319
+ // All-NaN early return: both sentinels unchanged means no valid data was found
320
+ if (min_comparable == 0xFF && max_comparable == 0x00) {
321
+ *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
322
+ *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
323
+ return;
324
+ }
325
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
326
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
327
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
328
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
329
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
330
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
331
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
332
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
333
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
334
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
335
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
336
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
337
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
338
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
339
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
340
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
341
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
342
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
343
+ *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
344
+ *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
345
+ }
346
+
347
+ NK_INTERNAL void nk_reduce_minmax_e4m3_neonfhm_strided_( //
348
+ nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
349
+ nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
350
+ nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
351
+ uint8x16_t min_u8x16 = vdupq_n_u8(0xFF), max_u8x16 = vdupq_n_u8(0);
352
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
353
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
354
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
355
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
356
+ nk_size_t idx = 0;
357
+ uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
358
+
359
+ nk_reduce_minmax_e4m3_neonfhm_cycle:
360
+ if (stride_elements == 2 && idx + 16 <= count) {
361
+ uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
362
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
363
+ data_for_min_u8x16 = comparable_u8x16;
364
+ data_for_max_u8x16 = comparable_u8x16;
365
+ idx += 16;
366
+ }
367
+ else if (stride_elements == 3 && idx + 16 <= count) {
368
+ uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
369
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
370
+ data_for_min_u8x16 = comparable_u8x16;
371
+ data_for_max_u8x16 = comparable_u8x16;
372
+ idx += 16;
373
+ }
374
+ else if (stride_elements == 4 && idx + 16 <= count) {
375
+ uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
376
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
377
+ data_for_min_u8x16 = comparable_u8x16;
378
+ data_for_max_u8x16 = comparable_u8x16;
379
+ idx += 16;
380
+ }
381
+ else if (idx < count) {
382
+ nk_b128_vec_t tail_vec;
383
+ nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
384
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
385
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)(count - idx)));
386
+ data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
387
+ data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0x00));
388
+ idx = count;
389
+ }
390
+ else {
391
+ nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
392
+ if (min_comparable == 0xFF && max_comparable == 0x00) {
393
+ *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
394
+ *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
395
+ return;
396
+ }
397
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
398
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
399
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
400
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
401
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
402
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
403
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
404
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
405
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
406
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
407
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
408
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
409
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
410
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
411
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
412
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
413
+ *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
414
+ *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
415
+ return;
416
+ }
417
+
418
+ // Shared update body
419
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
420
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
421
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
422
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
423
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
424
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
425
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
426
+ goto nk_reduce_minmax_e4m3_neonfhm_cycle;
427
+ }
428
+
429
+ NK_PUBLIC void nk_reduce_minmax_e4m3_neonfhm( //
430
+ nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
431
+ nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
432
+ nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
433
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
434
+ int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
435
+ if (count == 0)
436
+ *min_value_ptr = NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E4M3_MIN,
437
+ *max_index_ptr = NK_SIZE_MAX;
438
+ else if (!aligned)
439
+ nk_reduce_minmax_e4m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
440
+ max_index_ptr);
441
+ else if (count > (nk_size_t)256 * 16) {
442
+ nk_size_t left_count = count / 2;
443
+ nk_e4m3_t left_min_value, right_min_value, left_max_value, right_max_value;
444
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
445
+ nk_reduce_minmax_e4m3_neonfhm(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index,
446
+ &left_max_value, &left_max_index);
447
+ nk_reduce_minmax_e4m3_neonfhm(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
448
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
449
+ if (nk_e4m3_order_serial(right_min_value, left_min_value) < 0)
450
+ *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
451
+ else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
452
+ if (nk_e4m3_order_serial(right_max_value, left_max_value) > 0)
453
+ *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
454
+ else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
455
+ }
456
+ else if (stride_elements == 1)
457
+ nk_reduce_minmax_e4m3_neonfhm_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
458
+ max_index_ptr);
459
+ else if (stride_elements <= 4)
460
+ nk_reduce_minmax_e4m3_neonfhm_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
461
+ max_value_ptr, max_index_ptr);
462
+ else
463
+ nk_reduce_minmax_e4m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
464
+ max_index_ptr);
465
+ }
466
+
467
+ NK_INTERNAL void nk_reduce_minmax_e5m2_neonfhm_contiguous_( //
468
+ nk_e5m2_t const *data_ptr, nk_size_t count, //
469
+ nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
470
+ nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
471
+ uint8x16_t min_u8x16 = vdupq_n_u8(0xFF);
472
+ uint8x16_t max_u8x16 = vdupq_n_u8(0x00);
473
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
474
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
475
+ nk_size_t idx = 0;
476
+ for (; idx + 16 <= count; idx += 16) {
477
+ uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
478
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(raw_u8x16);
479
+ uint8x16_t less_u8x16 = vcltq_u8(comparable_u8x16, min_u8x16);
480
+ uint8x16_t greater_u8x16 = vcgtq_u8(comparable_u8x16, max_u8x16);
481
+ min_u8x16 = vbslq_u8(less_u8x16, comparable_u8x16, min_u8x16);
482
+ max_u8x16 = vbslq_u8(greater_u8x16, comparable_u8x16, max_u8x16);
483
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
484
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
485
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
486
+ }
487
+ nk_size_t remaining = count - idx;
488
+ if (remaining > 0) {
489
+ nk_b128_vec_t tail_vec;
490
+ nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
491
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
492
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
493
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
494
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)remaining));
495
+ uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
496
+ uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0));
497
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
498
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
499
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
500
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
501
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
502
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
503
+ }
504
+ nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
505
+ // All-NaN early return: both sentinels unchanged means no valid data was found
506
+ if (min_comparable == 0xFF && max_comparable == 0x00) {
507
+ *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
508
+ *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
509
+ return;
510
+ }
511
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
512
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
513
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
514
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
515
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
516
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
517
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
518
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
519
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
520
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
521
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
522
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
523
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
524
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
525
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
526
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
527
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
528
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
529
+ *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
530
+ *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
531
+ }
532
+
533
+ NK_INTERNAL void nk_reduce_minmax_e5m2_neonfhm_strided_( //
534
+ nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
535
+ nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
536
+ nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
537
+ uint8x16_t min_u8x16 = vdupq_n_u8(0xFF), max_u8x16 = vdupq_n_u8(0);
538
+ uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
539
+ uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
540
+ uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
541
+ vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
542
+ nk_size_t idx = 0;
543
+ uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
544
+
545
+ nk_reduce_minmax_e5m2_neonfhm_cycle:
546
+ if (stride_elements == 2 && idx + 16 <= count) {
547
+ uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
548
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
549
+ data_for_min_u8x16 = comparable_u8x16;
550
+ data_for_max_u8x16 = comparable_u8x16;
551
+ idx += 16;
552
+ }
553
+ else if (stride_elements == 3 && idx + 16 <= count) {
554
+ uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
555
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
556
+ data_for_min_u8x16 = comparable_u8x16;
557
+ data_for_max_u8x16 = comparable_u8x16;
558
+ idx += 16;
559
+ }
560
+ else if (stride_elements == 4 && idx + 16 <= count) {
561
+ uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
562
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
563
+ data_for_min_u8x16 = comparable_u8x16;
564
+ data_for_max_u8x16 = comparable_u8x16;
565
+ idx += 16;
566
+ }
567
+ else if (idx < count) {
568
+ nk_b128_vec_t tail_vec;
569
+ nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
570
+ uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
571
+ uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)(count - idx)));
572
+ data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
573
+ data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0x00));
574
+ idx = count;
575
+ }
576
+ else {
577
+ nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
578
+ if (min_comparable == 0xFF && max_comparable == 0x00) {
579
+ *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
580
+ *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
581
+ return;
582
+ }
583
+ uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
584
+ uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
585
+ nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
586
+ uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
587
+ uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
588
+ nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
589
+ uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
590
+ uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
591
+ uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
592
+ nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
593
+ nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
594
+ uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
595
+ uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
596
+ uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
597
+ nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
598
+ nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
599
+ *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
600
+ *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
601
+ return;
602
+ }
603
+
604
+ // Shared update body
605
+ uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
606
+ uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
607
+ min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
608
+ max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
609
+ min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
610
+ max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
611
+ iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
612
+ goto nk_reduce_minmax_e5m2_neonfhm_cycle;
613
+ }
614
+
615
+ NK_PUBLIC void nk_reduce_minmax_e5m2_neonfhm( //
616
+ nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
617
+ nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
618
+ nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
619
+ nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
620
+ int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
621
+ if (count == 0)
622
+ *min_value_ptr = NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E5M2_MIN,
623
+ *max_index_ptr = NK_SIZE_MAX;
624
+ else if (!aligned)
625
+ nk_reduce_minmax_e5m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
626
+ max_index_ptr);
627
+ else if (count > (nk_size_t)256 * 16) {
628
+ nk_size_t left_count = count / 2;
629
+ nk_e5m2_t left_min_value, right_min_value, left_max_value, right_max_value;
630
+ nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
631
+ nk_reduce_minmax_e5m2_neonfhm(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index,
632
+ &left_max_value, &left_max_index);
633
+ nk_reduce_minmax_e5m2_neonfhm(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
634
+ &right_min_value, &right_min_index, &right_max_value, &right_max_index);
635
+ if (nk_e5m2_order_serial(right_min_value, left_min_value) < 0)
636
+ *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
637
+ else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
638
+ if (nk_e5m2_order_serial(right_max_value, left_max_value) > 0)
639
+ *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
640
+ else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
641
+ }
642
+ else if (stride_elements == 1)
643
+ nk_reduce_minmax_e5m2_neonfhm_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
644
+ max_index_ptr);
645
+ else if (stride_elements <= 4)
646
+ nk_reduce_minmax_e5m2_neonfhm_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
647
+ max_value_ptr, max_index_ptr);
648
+ else
649
+ nk_reduce_minmax_e5m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
650
+ max_index_ptr);
651
+ }
652
+
653
+ #if defined(__clang__)
654
+ #pragma clang attribute pop
655
+ #elif defined(__GNUC__)
656
+ #pragma GCC pop_options
657
+ #endif
658
+
659
+ #if defined(__cplusplus)
660
+ } // extern "C"
661
+ #endif
662
+
663
+ #endif // NK_TARGET_NEONFHM
664
+ #endif // NK_TARGET_ARM_
665
+ #endif // NK_REDUCE_NEONFHM_H