numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -1,5 +1,5 @@
1
1
  /**
2
- * @brief ARMv8.4-FHM implementations for the redesigned reduction API (moments + minmax).
2
+ * @brief ARMv8.4-FHM implementations for the redesigned reduction API.
3
3
  * @file include/numkong/reduce/neonfhm.h
4
4
  * @author Ash Vardanian
5
5
  * @date February 13, 2026
@@ -38,7 +38,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_contiguous_( //
38
38
  nk_size_t idx = 0;
39
39
 
40
40
  for (; idx + 8 <= count; idx += 8) {
41
- uint8x8_t data_u8x8 = vld1_u8((uint8_t const *)(data_ptr + idx));
41
+ uint8x8_t data_u8x8 = vld1_u8((nk_u8_t const *)(data_ptr + idx));
42
42
  float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(data_u8x8);
43
43
  sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
44
44
  sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
@@ -71,7 +71,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_strided_( //
71
71
  nk_size_t idx = 0;
72
72
 
73
73
  if (stride_elements == 2) {
74
- for (; idx + 8 <= count; idx += 8) {
74
+ for (; idx + 8 < count; idx += 8) {
75
75
  uint8x8x2_t loaded = vld2_u8((nk_u8_t const *)(data_ptr + idx * 2));
76
76
  float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(loaded.val[0]);
77
77
  sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
@@ -81,7 +81,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_strided_( //
81
81
  }
82
82
  }
83
83
  else if (stride_elements == 3) {
84
- for (; idx + 8 <= count; idx += 8) {
84
+ for (; idx + 8 < count; idx += 8) {
85
85
  uint8x8x3_t loaded = vld3_u8((nk_u8_t const *)(data_ptr + idx * 3));
86
86
  float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(loaded.val[0]);
87
87
  sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
@@ -91,7 +91,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_strided_( //
91
91
  }
92
92
  }
93
93
  else if (stride_elements == 4) {
94
- for (; idx + 8 <= count; idx += 8) {
94
+ for (; idx + 8 < count; idx += 8) {
95
95
  uint8x8x4_t loaded = vld4_u8((nk_u8_t const *)(data_ptr + idx * 4));
96
96
  float16x8_t data_f16x8 = nk_e4m3x8_to_f16x8_neon_(loaded.val[0]);
97
97
  sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
@@ -163,7 +163,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_contiguous_( //
163
163
  nk_size_t idx = 0;
164
164
 
165
165
  for (; idx + 8 <= count; idx += 8) {
166
- uint8x8_t data_u8x8 = vld1_u8((uint8_t const *)(data_ptr + idx));
166
+ uint8x8_t data_u8x8 = vld1_u8((nk_u8_t const *)(data_ptr + idx));
167
167
  float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(data_u8x8);
168
168
  sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
169
169
  sum_f32x4 = vfmlalq_high_f16(sum_f32x4, data_f16x8, ones_f16x8);
@@ -196,7 +196,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_strided_( //
196
196
  nk_size_t idx = 0;
197
197
 
198
198
  if (stride_elements == 2) {
199
- for (; idx + 8 <= count; idx += 8) {
199
+ for (; idx + 8 < count; idx += 8) {
200
200
  uint8x8x2_t loaded = vld2_u8((nk_u8_t const *)(data_ptr + idx * 2));
201
201
  float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded.val[0]);
202
202
  sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
@@ -206,7 +206,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_strided_( //
206
206
  }
207
207
  }
208
208
  else if (stride_elements == 3) {
209
- for (; idx + 8 <= count; idx += 8) {
209
+ for (; idx + 8 < count; idx += 8) {
210
210
  uint8x8x3_t loaded = vld3_u8((nk_u8_t const *)(data_ptr + idx * 3));
211
211
  float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded.val[0]);
212
212
  sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
@@ -216,7 +216,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_strided_( //
216
216
  }
217
217
  }
218
218
  else if (stride_elements == 4) {
219
- for (; idx + 8 <= count; idx += 8) {
219
+ for (; idx + 8 < count; idx += 8) {
220
220
  uint8x8x4_t loaded = vld4_u8((nk_u8_t const *)(data_ptr + idx * 4));
221
221
  float16x8_t data_f16x8 = nk_e5m2x8_to_f16x8_neon_(loaded.val[0]);
222
222
  sum_f32x4 = vfmlalq_low_f16(sum_f32x4, data_f16x8, ones_f16x8);
@@ -278,378 +278,6 @@ NK_PUBLIC void nk_reduce_moments_e5m2_neonfhm( //
278
278
  else nk_reduce_moments_e5m2_neonfhm_strided_(data_ptr, count, stride_elements, sum_ptr, sumsq_ptr);
279
279
  }
280
280
 
281
- NK_INTERNAL void nk_reduce_minmax_e4m3_neonfhm_contiguous_( //
282
- nk_e4m3_t const *data_ptr, nk_size_t count, //
283
- nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
284
- nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
285
- uint8x16_t min_u8x16 = vdupq_n_u8(0xFF);
286
- uint8x16_t max_u8x16 = vdupq_n_u8(0x00);
287
- uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
288
- uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
289
- nk_size_t idx = 0;
290
- for (; idx + 16 <= count; idx += 16) {
291
- uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
292
- uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(raw_u8x16);
293
- uint8x16_t less_u8x16 = vcltq_u8(comparable_u8x16, min_u8x16);
294
- uint8x16_t greater_u8x16 = vcgtq_u8(comparable_u8x16, max_u8x16);
295
- min_u8x16 = vbslq_u8(less_u8x16, comparable_u8x16, min_u8x16);
296
- max_u8x16 = vbslq_u8(greater_u8x16, comparable_u8x16, max_u8x16);
297
- min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
298
- max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
299
- iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
300
- }
301
- nk_size_t remaining = count - idx;
302
- if (remaining > 0) {
303
- nk_b128_vec_t tail_vec;
304
- nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
305
- uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
306
- uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
307
- vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
308
- uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)remaining));
309
- uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
310
- uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0));
311
- uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
312
- uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
313
- min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
314
- max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
315
- min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
316
- max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
317
- }
318
- nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
319
- // All-NaN early return: both sentinels unchanged means no valid data was found
320
- if (min_comparable == 0xFF && max_comparable == 0x00) {
321
- *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
322
- *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
323
- return;
324
- }
325
- uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
326
- uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
327
- nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
328
- uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
329
- uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
330
- nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
331
- uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
332
- vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
333
- uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
334
- uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
335
- uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
336
- nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
337
- nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
338
- uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
339
- uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
340
- uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
341
- nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
342
- nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
343
- *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
344
- *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
345
- }
346
-
347
- NK_INTERNAL void nk_reduce_minmax_e4m3_neonfhm_strided_( //
348
- nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
349
- nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
350
- nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
351
- uint8x16_t min_u8x16 = vdupq_n_u8(0xFF), max_u8x16 = vdupq_n_u8(0);
352
- uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
353
- uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
354
- uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
355
- vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
356
- nk_size_t idx = 0;
357
- uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
358
-
359
- nk_reduce_minmax_e4m3_neonfhm_cycle:
360
- if (stride_elements == 2 && idx + 16 <= count) {
361
- uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
362
- uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
363
- data_for_min_u8x16 = comparable_u8x16;
364
- data_for_max_u8x16 = comparable_u8x16;
365
- idx += 16;
366
- }
367
- else if (stride_elements == 3 && idx + 16 <= count) {
368
- uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
369
- uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
370
- data_for_min_u8x16 = comparable_u8x16;
371
- data_for_max_u8x16 = comparable_u8x16;
372
- idx += 16;
373
- }
374
- else if (stride_elements == 4 && idx + 16 <= count) {
375
- uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
376
- uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
377
- data_for_min_u8x16 = comparable_u8x16;
378
- data_for_max_u8x16 = comparable_u8x16;
379
- idx += 16;
380
- }
381
- else if (idx < count) {
382
- nk_b128_vec_t tail_vec;
383
- nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
384
- uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
385
- uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)(count - idx)));
386
- data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
387
- data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0x00));
388
- idx = count;
389
- }
390
- else {
391
- nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
392
- if (min_comparable == 0xFF && max_comparable == 0x00) {
393
- *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
394
- *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
395
- return;
396
- }
397
- uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
398
- uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
399
- nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
400
- uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
401
- uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
402
- nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
403
- uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
404
- uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
405
- uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
406
- nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
407
- nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
408
- uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
409
- uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
410
- uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
411
- nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
412
- nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
413
- *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
414
- *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
415
- return;
416
- }
417
-
418
- // Shared update body
419
- uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
420
- uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
421
- min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
422
- max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
423
- min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
424
- max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
425
- iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
426
- goto nk_reduce_minmax_e4m3_neonfhm_cycle;
427
- }
428
-
429
- NK_PUBLIC void nk_reduce_minmax_e4m3_neonfhm( //
430
- nk_e4m3_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
431
- nk_e4m3_t *min_value_ptr, nk_size_t *min_index_ptr, //
432
- nk_e4m3_t *max_value_ptr, nk_size_t *max_index_ptr) {
433
- nk_size_t stride_elements = stride_bytes / sizeof(nk_e4m3_t);
434
- int aligned = (stride_bytes % sizeof(nk_e4m3_t) == 0);
435
- if (count == 0)
436
- *min_value_ptr = NK_E4M3_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E4M3_MIN,
437
- *max_index_ptr = NK_SIZE_MAX;
438
- else if (!aligned)
439
- nk_reduce_minmax_e4m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
440
- max_index_ptr);
441
- else if (count > (nk_size_t)256 * 16) {
442
- nk_size_t left_count = count / 2;
443
- nk_e4m3_t left_min_value, right_min_value, left_max_value, right_max_value;
444
- nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
445
- nk_reduce_minmax_e4m3_neonfhm(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index,
446
- &left_max_value, &left_max_index);
447
- nk_reduce_minmax_e4m3_neonfhm(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
448
- &right_min_value, &right_min_index, &right_max_value, &right_max_index);
449
- if (nk_e4m3_order_serial(right_min_value, left_min_value) < 0)
450
- *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
451
- else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
452
- if (nk_e4m3_order_serial(right_max_value, left_max_value) > 0)
453
- *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
454
- else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
455
- }
456
- else if (stride_elements == 1)
457
- nk_reduce_minmax_e4m3_neonfhm_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
458
- max_index_ptr);
459
- else if (stride_elements <= 4)
460
- nk_reduce_minmax_e4m3_neonfhm_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
461
- max_value_ptr, max_index_ptr);
462
- else
463
- nk_reduce_minmax_e4m3_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
464
- max_index_ptr);
465
- }
466
-
467
- NK_INTERNAL void nk_reduce_minmax_e5m2_neonfhm_contiguous_( //
468
- nk_e5m2_t const *data_ptr, nk_size_t count, //
469
- nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
470
- nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
471
- uint8x16_t min_u8x16 = vdupq_n_u8(0xFF);
472
- uint8x16_t max_u8x16 = vdupq_n_u8(0x00);
473
- uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
474
- uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
475
- nk_size_t idx = 0;
476
- for (; idx + 16 <= count; idx += 16) {
477
- uint8x16_t raw_u8x16 = vld1q_u8((nk_u8_t const *)(data_ptr + idx));
478
- uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(raw_u8x16);
479
- uint8x16_t less_u8x16 = vcltq_u8(comparable_u8x16, min_u8x16);
480
- uint8x16_t greater_u8x16 = vcgtq_u8(comparable_u8x16, max_u8x16);
481
- min_u8x16 = vbslq_u8(less_u8x16, comparable_u8x16, min_u8x16);
482
- max_u8x16 = vbslq_u8(greater_u8x16, comparable_u8x16, max_u8x16);
483
- min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
484
- max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
485
- iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
486
- }
487
- nk_size_t remaining = count - idx;
488
- if (remaining > 0) {
489
- nk_b128_vec_t tail_vec;
490
- nk_partial_load_b8x16_serial_(data_ptr + idx, &tail_vec, remaining);
491
- uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
492
- uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
493
- vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
494
- uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)remaining));
495
- uint8x16_t data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
496
- uint8x16_t data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0));
497
- uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
498
- uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
499
- min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
500
- max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
501
- min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
502
- max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
503
- }
504
- nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
505
- // All-NaN early return: both sentinels unchanged means no valid data was found
506
- if (min_comparable == 0xFF && max_comparable == 0x00) {
507
- *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
508
- *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
509
- return;
510
- }
511
- uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
512
- uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
513
- nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
514
- uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
515
- uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
516
- nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
517
- uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
518
- vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
519
- uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
520
- uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
521
- uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
522
- nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
523
- nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
524
- uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
525
- uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
526
- uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
527
- nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
528
- nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
529
- *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
530
- *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
531
- }
532
-
533
- NK_INTERNAL void nk_reduce_minmax_e5m2_neonfhm_strided_( //
534
- nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
535
- nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
536
- nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
537
- uint8x16_t min_u8x16 = vdupq_n_u8(0xFF), max_u8x16 = vdupq_n_u8(0);
538
- uint8x16_t min_iter_u8x16 = vdupq_n_u8(0), max_iter_u8x16 = vdupq_n_u8(0);
539
- uint8x16_t iter_u8x16 = vdupq_n_u8(0), one_u8x16 = vdupq_n_u8(1);
540
- uint8x16_t lane_indices_u8x16 = vcombine_u8(vreinterpret_u8_u64(vcreate_u64(0x0706050403020100ULL)),
541
- vreinterpret_u8_u64(vcreate_u64(0x0F0E0D0C0B0A0908ULL)));
542
- nk_size_t idx = 0;
543
- uint8x16_t data_for_min_u8x16, data_for_max_u8x16;
544
-
545
- nk_reduce_minmax_e5m2_neonfhm_cycle:
546
- if (stride_elements == 2 && idx + 16 <= count) {
547
- uint8x16x2_t loaded = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
548
- uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
549
- data_for_min_u8x16 = comparable_u8x16;
550
- data_for_max_u8x16 = comparable_u8x16;
551
- idx += 16;
552
- }
553
- else if (stride_elements == 3 && idx + 16 <= count) {
554
- uint8x16x3_t loaded = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
555
- uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
556
- data_for_min_u8x16 = comparable_u8x16;
557
- data_for_max_u8x16 = comparable_u8x16;
558
- idx += 16;
559
- }
560
- else if (stride_elements == 4 && idx + 16 <= count) {
561
- uint8x16x4_t loaded = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
562
- uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(loaded.val[0]);
563
- data_for_min_u8x16 = comparable_u8x16;
564
- data_for_max_u8x16 = comparable_u8x16;
565
- idx += 16;
566
- }
567
- else if (idx < count) {
568
- nk_b128_vec_t tail_vec;
569
- nk_strided_load_b8x16_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
570
- uint8x16_t comparable_u8x16 = nk_fp8x16_to_comparable_neon_(tail_vec.u8x16);
571
- uint8x16_t valid_u8x16 = vcltq_u8(lane_indices_u8x16, vdupq_n_u8((uint8_t)(count - idx)));
572
- data_for_min_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0xFF));
573
- data_for_max_u8x16 = vbslq_u8(valid_u8x16, comparable_u8x16, vdupq_n_u8(0x00));
574
- idx = count;
575
- }
576
- else {
577
- nk_u8_t min_comparable = vminvq_u8(min_u8x16), max_comparable = vmaxvq_u8(max_u8x16);
578
- if (min_comparable == 0xFF && max_comparable == 0x00) {
579
- *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
580
- *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
581
- return;
582
- }
583
- uint8x16_t min_value_match_u8x16 = vceqq_u8(min_u8x16, vdupq_n_u8(min_comparable));
584
- uint8x16_t masked_min_iter_u8x16 = vbslq_u8(min_value_match_u8x16, min_iter_u8x16, vdupq_n_u8(0xFF));
585
- nk_u8_t earliest_min_cycle = vminvq_u8(masked_min_iter_u8x16);
586
- uint8x16_t max_value_match_u8x16 = vceqq_u8(max_u8x16, vdupq_n_u8(max_comparable));
587
- uint8x16_t masked_max_iter_u8x16 = vbslq_u8(max_value_match_u8x16, max_iter_u8x16, vdupq_n_u8(0xFF));
588
- nk_u8_t earliest_max_cycle = vminvq_u8(masked_max_iter_u8x16);
589
- uint8x16_t min_cycle_match_u8x16 = vceqq_u8(min_iter_u8x16, vdupq_n_u8(earliest_min_cycle));
590
- uint8x16_t min_both_match_u8x16 = vandq_u8(min_value_match_u8x16, min_cycle_match_u8x16);
591
- uint8x16_t min_masked_lanes_u8x16 = vbslq_u8(min_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
592
- nk_u8_t min_lane_offset = vminvq_u8(min_masked_lanes_u8x16);
593
- nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 16 + (nk_size_t)min_lane_offset;
594
- uint8x16_t max_cycle_match_u8x16 = vceqq_u8(max_iter_u8x16, vdupq_n_u8(earliest_max_cycle));
595
- uint8x16_t max_both_match_u8x16 = vandq_u8(max_value_match_u8x16, max_cycle_match_u8x16);
596
- uint8x16_t max_masked_lanes_u8x16 = vbslq_u8(max_both_match_u8x16, lane_indices_u8x16, vdupq_n_u8(0xFF));
597
- nk_u8_t max_lane_offset = vminvq_u8(max_masked_lanes_u8x16);
598
- nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 16 + (nk_size_t)max_lane_offset;
599
- *min_value_ptr = nk_comparable_to_fp8_(min_comparable), *min_index_ptr = min_idx;
600
- *max_value_ptr = nk_comparable_to_fp8_(max_comparable), *max_index_ptr = max_idx;
601
- return;
602
- }
603
-
604
- // Shared update body
605
- uint8x16_t less_u8x16 = vcltq_u8(data_for_min_u8x16, min_u8x16);
606
- uint8x16_t greater_u8x16 = vcgtq_u8(data_for_max_u8x16, max_u8x16);
607
- min_u8x16 = vbslq_u8(less_u8x16, data_for_min_u8x16, min_u8x16);
608
- max_u8x16 = vbslq_u8(greater_u8x16, data_for_max_u8x16, max_u8x16);
609
- min_iter_u8x16 = vbslq_u8(less_u8x16, iter_u8x16, min_iter_u8x16);
610
- max_iter_u8x16 = vbslq_u8(greater_u8x16, iter_u8x16, max_iter_u8x16);
611
- iter_u8x16 = vaddq_u8(iter_u8x16, one_u8x16);
612
- goto nk_reduce_minmax_e5m2_neonfhm_cycle;
613
- }
614
-
615
- NK_PUBLIC void nk_reduce_minmax_e5m2_neonfhm( //
616
- nk_e5m2_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
617
- nk_e5m2_t *min_value_ptr, nk_size_t *min_index_ptr, //
618
- nk_e5m2_t *max_value_ptr, nk_size_t *max_index_ptr) {
619
- nk_size_t stride_elements = stride_bytes / sizeof(nk_e5m2_t);
620
- int aligned = (stride_bytes % sizeof(nk_e5m2_t) == 0);
621
- if (count == 0)
622
- *min_value_ptr = NK_E5M2_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_E5M2_MIN,
623
- *max_index_ptr = NK_SIZE_MAX;
624
- else if (!aligned)
625
- nk_reduce_minmax_e5m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
626
- max_index_ptr);
627
- else if (count > (nk_size_t)256 * 16) {
628
- nk_size_t left_count = count / 2;
629
- nk_e5m2_t left_min_value, right_min_value, left_max_value, right_max_value;
630
- nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
631
- nk_reduce_minmax_e5m2_neonfhm(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index,
632
- &left_max_value, &left_max_index);
633
- nk_reduce_minmax_e5m2_neonfhm(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
634
- &right_min_value, &right_min_index, &right_max_value, &right_max_index);
635
- if (nk_e5m2_order_serial(right_min_value, left_min_value) < 0)
636
- *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
637
- else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
638
- if (nk_e5m2_order_serial(right_max_value, left_max_value) > 0)
639
- *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
640
- else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
641
- }
642
- else if (stride_elements == 1)
643
- nk_reduce_minmax_e5m2_neonfhm_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
644
- max_index_ptr);
645
- else if (stride_elements <= 4)
646
- nk_reduce_minmax_e5m2_neonfhm_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
647
- max_value_ptr, max_index_ptr);
648
- else
649
- nk_reduce_minmax_e5m2_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
650
- max_index_ptr);
651
- }
652
-
653
281
  #if defined(__clang__)
654
282
  #pragma clang attribute pop
655
283
  #elif defined(__GNUC__)
@@ -60,7 +60,7 @@ NK_INTERNAL void nk_reduce_moments_i8_neonsdot_strided_( //
60
60
  int32x4_t sumsq_i32x4 = vdupq_n_s32(0);
61
61
  nk_size_t idx = 0;
62
62
  if (stride_elements == 2) {
63
- for (; idx + 16 <= count; idx += 16) {
63
+ for (; idx + 16 < count; idx += 16) {
64
64
  int8x16x2_t loaded = vld2q_s8(data_ptr + idx * 2);
65
65
  int8x16_t data_i8x16 = loaded.val[0];
66
66
  sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
@@ -68,7 +68,7 @@ NK_INTERNAL void nk_reduce_moments_i8_neonsdot_strided_( //
68
68
  }
69
69
  }
70
70
  else if (stride_elements == 3) {
71
- for (; idx + 16 <= count; idx += 16) {
71
+ for (; idx + 16 < count; idx += 16) {
72
72
  int8x16x3_t loaded = vld3q_s8(data_ptr + idx * 3);
73
73
  int8x16_t data_i8x16 = loaded.val[0];
74
74
  sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
@@ -76,7 +76,7 @@ NK_INTERNAL void nk_reduce_moments_i8_neonsdot_strided_( //
76
76
  }
77
77
  }
78
78
  else if (stride_elements == 4) {
79
- for (; idx + 16 <= count; idx += 16) {
79
+ for (; idx + 16 < count; idx += 16) {
80
80
  int8x16x4_t loaded = vld4q_s8(data_ptr + idx * 4);
81
81
  int8x16_t data_i8x16 = loaded.val[0];
82
82
  sum_i32x4 = vdotq_s32(sum_i32x4, data_i8x16, ones_i8x16);
@@ -151,7 +151,7 @@ NK_INTERNAL void nk_reduce_moments_u8_neonsdot_strided_( //
151
151
  uint32x4_t sumsq_u32x4 = vdupq_n_u32(0);
152
152
  nk_size_t idx = 0;
153
153
  if (stride_elements == 2) {
154
- for (; idx + 16 <= count; idx += 16) {
154
+ for (; idx + 16 < count; idx += 16) {
155
155
  uint8x16x2_t loaded = vld2q_u8(data_ptr + idx * 2);
156
156
  uint8x16_t data_u8x16 = loaded.val[0];
157
157
  sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
@@ -159,7 +159,7 @@ NK_INTERNAL void nk_reduce_moments_u8_neonsdot_strided_( //
159
159
  }
160
160
  }
161
161
  else if (stride_elements == 3) {
162
- for (; idx + 16 <= count; idx += 16) {
162
+ for (; idx + 16 < count; idx += 16) {
163
163
  uint8x16x3_t loaded = vld3q_u8(data_ptr + idx * 3);
164
164
  uint8x16_t data_u8x16 = loaded.val[0];
165
165
  sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
@@ -167,7 +167,7 @@ NK_INTERNAL void nk_reduce_moments_u8_neonsdot_strided_( //
167
167
  }
168
168
  }
169
169
  else if (stride_elements == 4) {
170
- for (; idx + 16 <= count; idx += 16) {
170
+ for (; idx + 16 < count; idx += 16) {
171
171
  uint8x16x4_t loaded = vld4q_u8(data_ptr + idx * 4);
172
172
  uint8x16_t data_u8x16 = loaded.val[0];
173
173
  sum_u32x4 = vdotq_u32(sum_u32x4, data_u8x16, ones_u8x16);
@@ -268,7 +268,7 @@ NK_INTERNAL void nk_reduce_moments_e2m3_neonsdot_strided_( //
268
268
  int32x4_t sumsq_i32x4 = vdupq_n_s32(0);
269
269
  nk_size_t idx = 0;
270
270
  if (stride_elements == 2) {
271
- for (; idx + 16 <= count; idx += 16) {
271
+ for (; idx + 16 < count; idx += 16) {
272
272
  uint8x16x2_t loaded_u8x16x2 = vld2q_u8((nk_u8_t const *)(data_ptr + idx * 2));
273
273
  uint8x16_t raw_u8x16 = loaded_u8x16x2.val[0];
274
274
  uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
@@ -282,7 +282,7 @@ NK_INTERNAL void nk_reduce_moments_e2m3_neonsdot_strided_( //
282
282
  }
283
283
  }
284
284
  else if (stride_elements == 3) {
285
- for (; idx + 16 <= count; idx += 16) {
285
+ for (; idx + 16 < count; idx += 16) {
286
286
  uint8x16x3_t loaded_u8x16x3 = vld3q_u8((nk_u8_t const *)(data_ptr + idx * 3));
287
287
  uint8x16_t raw_u8x16 = loaded_u8x16x3.val[0];
288
288
  uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));
@@ -296,7 +296,7 @@ NK_INTERNAL void nk_reduce_moments_e2m3_neonsdot_strided_( //
296
296
  }
297
297
  }
298
298
  else if (stride_elements == 4) {
299
- for (; idx + 16 <= count; idx += 16) {
299
+ for (; idx + 16 < count; idx += 16) {
300
300
  uint8x16x4_t loaded_u8x16x4 = vld4q_u8((nk_u8_t const *)(data_ptr + idx * 4));
301
301
  uint8x16_t raw_u8x16 = loaded_u8x16x4.val[0];
302
302
  uint8x16_t magnitude_u8x16 = vandq_u8(raw_u8x16, vdupq_n_u8(0x1F));