numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -1,5 +1,5 @@
1
1
  /**
2
- * @brief ARMv8.6-BF16 implementations for the redesigned reduction API (moments + minmax).
2
+ * @brief ARMv8.6-BF16 implementations for the redesigned reduction API.
3
3
  * @file include/numkong/reduce/neonbfdot.h
4
4
  * @author Ash Vardanian
5
5
  * @date February 13, 2026
@@ -67,24 +67,24 @@ NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_strided_( //
67
67
  nk_size_t idx = 0;
68
68
 
69
69
  if (stride_elements == 2) {
70
- for (; idx + 8 <= count; idx += 8) {
71
- uint16x8x2_t loaded_u16x8x2 = vld2q_u16((uint16_t const *)(data_ptr + idx * 2));
70
+ for (; idx + 8 < count; idx += 8) {
71
+ uint16x8x2_t loaded_u16x8x2 = vld2q_u16((nk_u16_t const *)(data_ptr + idx * 2));
72
72
  bfloat16x8_t data_bf16x8 = vreinterpretq_bf16_u16(loaded_u16x8x2.val[0]);
73
73
  sum_f32x4 = vbfdotq_f32(sum_f32x4, data_bf16x8, ones_bf16x8);
74
74
  sumsq_f32x4 = vbfdotq_f32(sumsq_f32x4, data_bf16x8, data_bf16x8);
75
75
  }
76
76
  }
77
77
  else if (stride_elements == 3) {
78
- for (; idx + 8 <= count; idx += 8) {
79
- uint16x8x3_t loaded_u16x8x3 = vld3q_u16((uint16_t const *)(data_ptr + idx * 3));
78
+ for (; idx + 8 < count; idx += 8) {
79
+ uint16x8x3_t loaded_u16x8x3 = vld3q_u16((nk_u16_t const *)(data_ptr + idx * 3));
80
80
  bfloat16x8_t data_bf16x8 = vreinterpretq_bf16_u16(loaded_u16x8x3.val[0]);
81
81
  sum_f32x4 = vbfdotq_f32(sum_f32x4, data_bf16x8, ones_bf16x8);
82
82
  sumsq_f32x4 = vbfdotq_f32(sumsq_f32x4, data_bf16x8, data_bf16x8);
83
83
  }
84
84
  }
85
85
  else if (stride_elements == 4) {
86
- for (; idx + 8 <= count; idx += 8) {
87
- uint16x8x4_t loaded_u16x8x4 = vld4q_u16((uint16_t const *)(data_ptr + idx * 4));
86
+ for (; idx + 8 < count; idx += 8) {
87
+ uint16x8x4_t loaded_u16x8x4 = vld4q_u16((nk_u16_t const *)(data_ptr + idx * 4));
88
88
  bfloat16x8_t data_bf16x8 = vreinterpretq_bf16_u16(loaded_u16x8x4.val[0]);
89
89
  sum_f32x4 = vbfdotq_f32(sum_f32x4, data_bf16x8, ones_bf16x8);
90
90
  sumsq_f32x4 = vbfdotq_f32(sumsq_f32x4, data_bf16x8, data_bf16x8);
@@ -127,217 +127,6 @@ NK_PUBLIC void nk_reduce_moments_bf16_neonbfdot( //
127
127
  else nk_reduce_moments_bf16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
128
128
  }
129
129
 
130
- /** @brief Convert 8 raw bf16 sign-magnitude u16 to order-preserving comparable i16.
131
- * Positive bf16 values (sign=0) are left as-is: they already sort correctly as i16.
132
- * Negative bf16 values (sign=1) have their magnitude bits flipped (XOR 0x7FFF)
133
- * so that more-negative values map to more-negative i16 values. */
134
- NK_INTERNAL int16x8_t nk_bf16x8_to_comparable_i16x8_neon_(uint16x8_t raw_u16x8) {
135
- int16x8_t raw_i16x8 = vreinterpretq_s16_u16(raw_u16x8);
136
- uint16x8_t is_negative_u16x8 = vtstq_u16(raw_u16x8, vdupq_n_u16(0x8000));
137
- int16x8_t flipped_i16x8 = veorq_s16(raw_i16x8, vdupq_n_s16(0x7FFF));
138
- return vbslq_s16(is_negative_u16x8, flipped_i16x8, raw_i16x8);
139
- }
140
-
141
- /** @brief Convert a comparable i16 value back to raw bf16 u16 bits.
142
- * Reverses the transformation from nk_bf16x8_to_comparable_i16x8_neon_. */
143
- NK_INTERNAL nk_u16_t nk_comparable_i16_to_bf16_raw_(nk_i16_t comparable) {
144
- if (comparable < 0) return (nk_u16_t)(comparable ^ 0x7FFF);
145
- return (nk_u16_t)comparable;
146
- }
147
-
148
- NK_INTERNAL void nk_reduce_minmax_bf16_neonbfdot_contiguous_( //
149
- nk_bf16_t const *data_ptr, nk_size_t count, //
150
- nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
151
- nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
152
- int16x8_t min_i16x8 = vdupq_n_s16(NK_I16_MAX), max_i16x8 = vdupq_n_s16(NK_I16_MIN);
153
- uint16x8_t min_iter_u16x8 = vdupq_n_u16(0), max_iter_u16x8 = vdupq_n_u16(0);
154
- uint16x8_t iter_u16x8 = vdupq_n_u16(0), one_u16x8 = vdupq_n_u16(1);
155
- nk_size_t idx = 0;
156
- for (; idx + 8 <= count; idx += 8) {
157
- uint16x8_t raw_u16x8 = vld1q_u16((uint16_t const *)(data_ptr + idx));
158
- int16x8_t comparable_i16x8 = nk_bf16x8_to_comparable_i16x8_neon_(raw_u16x8);
159
- uint16x8_t less_u16x8 = vcltq_s16(comparable_i16x8, min_i16x8);
160
- uint16x8_t greater_u16x8 = vcgtq_s16(comparable_i16x8, max_i16x8);
161
- min_i16x8 = vbslq_s16(less_u16x8, comparable_i16x8, min_i16x8);
162
- max_i16x8 = vbslq_s16(greater_u16x8, comparable_i16x8, max_i16x8);
163
- min_iter_u16x8 = vbslq_u16(less_u16x8, iter_u16x8, min_iter_u16x8);
164
- max_iter_u16x8 = vbslq_u16(greater_u16x8, iter_u16x8, max_iter_u16x8);
165
- iter_u16x8 = vaddq_u16(iter_u16x8, one_u16x8);
166
- }
167
- // Handle tail with partial load and identity masking
168
- nk_size_t remaining = count - idx;
169
- if (remaining > 0) {
170
- nk_b128_vec_t tail_vec;
171
- nk_partial_load_b16x8_serial_(data_ptr + idx, &tail_vec, remaining);
172
- int16x8_t comparable_i16x8 = nk_bf16x8_to_comparable_i16x8_neon_(tail_vec.u16x8);
173
- uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
174
- vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
175
- uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((uint16_t)remaining));
176
- int16x8_t data_for_min_i16x8 = vbslq_s16(valid_u16x8, comparable_i16x8, vdupq_n_s16(NK_I16_MAX));
177
- int16x8_t data_for_max_i16x8 = vbslq_s16(valid_u16x8, comparable_i16x8, vdupq_n_s16(NK_I16_MIN));
178
- uint16x8_t less_u16x8 = vcltq_s16(data_for_min_i16x8, min_i16x8);
179
- uint16x8_t greater_u16x8 = vcgtq_s16(data_for_max_i16x8, max_i16x8);
180
- min_i16x8 = vbslq_s16(less_u16x8, data_for_min_i16x8, min_i16x8);
181
- max_i16x8 = vbslq_s16(greater_u16x8, data_for_max_i16x8, max_i16x8);
182
- min_iter_u16x8 = vbslq_u16(less_u16x8, iter_u16x8, min_iter_u16x8);
183
- max_iter_u16x8 = vbslq_u16(greater_u16x8, iter_u16x8, max_iter_u16x8);
184
- }
185
- // Horizontal reduction
186
- nk_i16_t min_comparable = vminvq_s16(min_i16x8), max_comparable = vmaxvq_s16(max_i16x8);
187
- // All-NaN early return: both sentinels unchanged means no valid data was found
188
- if (min_comparable == NK_I16_MAX && max_comparable == NK_I16_MIN) {
189
- *(nk_u16_t *)min_value_ptr = nk_comparable_i16_to_bf16_raw_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
190
- *(nk_u16_t *)max_value_ptr = nk_comparable_i16_to_bf16_raw_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
191
- return;
192
- }
193
- uint16x8_t min_value_match_u16x8 = vceqq_s16(min_i16x8, vdupq_n_s16(min_comparable));
194
- uint16x8_t masked_min_iter_u16x8 = vbslq_u16(min_value_match_u16x8, min_iter_u16x8, vdupq_n_u16(0xFFFF));
195
- nk_u16_t earliest_min_cycle = vminvq_u16(masked_min_iter_u16x8);
196
- uint16x8_t max_value_match_u16x8 = vceqq_s16(max_i16x8, vdupq_n_s16(max_comparable));
197
- uint16x8_t masked_max_iter_u16x8 = vbslq_u16(max_value_match_u16x8, max_iter_u16x8, vdupq_n_u16(0xFFFF));
198
- nk_u16_t earliest_max_cycle = vminvq_u16(masked_max_iter_u16x8);
199
- uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
200
- vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
201
- uint16x8_t min_cycle_match_u16x8 = vceqq_u16(min_iter_u16x8, vdupq_n_u16(earliest_min_cycle));
202
- uint16x8_t min_both_match_u16x8 = vandq_u16(min_value_match_u16x8, min_cycle_match_u16x8);
203
- uint16x8_t min_masked_lanes_u16x8 = vbslq_u16(min_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
204
- nk_u16_t min_lane_offset = vminvq_u16(min_masked_lanes_u16x8);
205
- nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 8 + (nk_size_t)min_lane_offset;
206
- uint16x8_t max_cycle_match_u16x8 = vceqq_u16(max_iter_u16x8, vdupq_n_u16(earliest_max_cycle));
207
- uint16x8_t max_both_match_u16x8 = vandq_u16(max_value_match_u16x8, max_cycle_match_u16x8);
208
- uint16x8_t max_masked_lanes_u16x8 = vbslq_u16(max_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
209
- nk_u16_t max_lane_offset = vminvq_u16(max_masked_lanes_u16x8);
210
- nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 8 + (nk_size_t)max_lane_offset;
211
- // Convert comparable back to bf16 raw bits
212
- nk_u16_t min_raw = nk_comparable_i16_to_bf16_raw_(min_comparable);
213
- nk_u16_t max_raw = nk_comparable_i16_to_bf16_raw_(max_comparable);
214
- *(nk_u16_t *)min_value_ptr = min_raw, *min_index_ptr = min_idx;
215
- *(nk_u16_t *)max_value_ptr = max_raw, *max_index_ptr = max_idx;
216
- }
217
-
218
- NK_INTERNAL void nk_reduce_minmax_bf16_neonbfdot_strided_( //
219
- nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
220
- nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
221
- nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
222
- int16x8_t min_i16x8 = vdupq_n_s16(NK_I16_MAX), max_i16x8 = vdupq_n_s16(NK_I16_MIN);
223
- uint16x8_t min_iter_u16x8 = vdupq_n_u16(0), max_iter_u16x8 = vdupq_n_u16(0);
224
- uint16x8_t iter_u16x8 = vdupq_n_u16(0), one_u16x8 = vdupq_n_u16(1);
225
- uint16x8_t lane_indices_u16x8 = vcombine_u16(vreinterpret_u16_u64(vcreate_u64(0x0003000200010000ULL)),
226
- vreinterpret_u16_u64(vcreate_u64(0x0007000600050004ULL)));
227
- nk_size_t idx = 0;
228
- int16x8_t data_for_min_i16x8, data_for_max_i16x8;
229
-
230
- nk_reduce_minmax_bf16_neonbfdot_cycle:
231
- if (stride_elements == 2 && idx + 8 <= count) {
232
- uint16x8x2_t loaded = vld2q_u16((uint16_t const *)(data_ptr + idx * 2));
233
- int16x8_t comparable_i16x8 = nk_bf16x8_to_comparable_i16x8_neon_(loaded.val[0]);
234
- data_for_min_i16x8 = comparable_i16x8;
235
- data_for_max_i16x8 = comparable_i16x8;
236
- idx += 8;
237
- }
238
- else if (stride_elements == 3 && idx + 8 <= count) {
239
- uint16x8x3_t loaded = vld3q_u16((uint16_t const *)(data_ptr + idx * 3));
240
- int16x8_t comparable_i16x8 = nk_bf16x8_to_comparable_i16x8_neon_(loaded.val[0]);
241
- data_for_min_i16x8 = comparable_i16x8;
242
- data_for_max_i16x8 = comparable_i16x8;
243
- idx += 8;
244
- }
245
- else if (stride_elements == 4 && idx + 8 <= count) {
246
- uint16x8x4_t loaded = vld4q_u16((uint16_t const *)(data_ptr + idx * 4));
247
- int16x8_t comparable_i16x8 = nk_bf16x8_to_comparable_i16x8_neon_(loaded.val[0]);
248
- data_for_min_i16x8 = comparable_i16x8;
249
- data_for_max_i16x8 = comparable_i16x8;
250
- idx += 8;
251
- }
252
- else if (idx < count) {
253
- nk_b128_vec_t tail_vec;
254
- nk_strided_load_b16x8_serial_(data_ptr + idx * stride_elements, stride_elements, &tail_vec, count - idx);
255
- int16x8_t comparable_i16x8 = nk_bf16x8_to_comparable_i16x8_neon_(tail_vec.u16x8);
256
- uint16x8_t valid_u16x8 = vcltq_u16(lane_indices_u16x8, vdupq_n_u16((uint16_t)(count - idx)));
257
- data_for_min_i16x8 = vbslq_s16(valid_u16x8, comparable_i16x8, vdupq_n_s16(NK_I16_MAX));
258
- data_for_max_i16x8 = vbslq_s16(valid_u16x8, comparable_i16x8, vdupq_n_s16(NK_I16_MIN));
259
- idx = count;
260
- }
261
- else {
262
- nk_i16_t min_comparable = vminvq_s16(min_i16x8), max_comparable = vmaxvq_s16(max_i16x8);
263
- if (min_comparable == NK_I16_MAX && max_comparable == NK_I16_MIN) {
264
- *(nk_u16_t *)min_value_ptr = nk_comparable_i16_to_bf16_raw_(min_comparable), *min_index_ptr = NK_SIZE_MAX;
265
- *(nk_u16_t *)max_value_ptr = nk_comparable_i16_to_bf16_raw_(max_comparable), *max_index_ptr = NK_SIZE_MAX;
266
- return;
267
- }
268
- uint16x8_t min_value_match_u16x8 = vceqq_s16(min_i16x8, vdupq_n_s16(min_comparable));
269
- uint16x8_t masked_min_iter_u16x8 = vbslq_u16(min_value_match_u16x8, min_iter_u16x8, vdupq_n_u16(0xFFFF));
270
- nk_u16_t earliest_min_cycle = vminvq_u16(masked_min_iter_u16x8);
271
- uint16x8_t max_value_match_u16x8 = vceqq_s16(max_i16x8, vdupq_n_s16(max_comparable));
272
- uint16x8_t masked_max_iter_u16x8 = vbslq_u16(max_value_match_u16x8, max_iter_u16x8, vdupq_n_u16(0xFFFF));
273
- nk_u16_t earliest_max_cycle = vminvq_u16(masked_max_iter_u16x8);
274
- uint16x8_t min_cycle_match_u16x8 = vceqq_u16(min_iter_u16x8, vdupq_n_u16(earliest_min_cycle));
275
- uint16x8_t min_both_match_u16x8 = vandq_u16(min_value_match_u16x8, min_cycle_match_u16x8);
276
- uint16x8_t min_masked_lanes_u16x8 = vbslq_u16(min_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
277
- nk_u16_t min_lane_offset = vminvq_u16(min_masked_lanes_u16x8);
278
- nk_size_t min_idx = (nk_size_t)earliest_min_cycle * 8 + (nk_size_t)min_lane_offset;
279
- uint16x8_t max_cycle_match_u16x8 = vceqq_u16(max_iter_u16x8, vdupq_n_u16(earliest_max_cycle));
280
- uint16x8_t max_both_match_u16x8 = vandq_u16(max_value_match_u16x8, max_cycle_match_u16x8);
281
- uint16x8_t max_masked_lanes_u16x8 = vbslq_u16(max_both_match_u16x8, lane_indices_u16x8, vdupq_n_u16(0xFFFF));
282
- nk_u16_t max_lane_offset = vminvq_u16(max_masked_lanes_u16x8);
283
- nk_size_t max_idx = (nk_size_t)earliest_max_cycle * 8 + (nk_size_t)max_lane_offset;
284
- nk_u16_t min_raw = nk_comparable_i16_to_bf16_raw_(min_comparable);
285
- nk_u16_t max_raw = nk_comparable_i16_to_bf16_raw_(max_comparable);
286
- *(nk_u16_t *)min_value_ptr = min_raw, *min_index_ptr = min_idx;
287
- *(nk_u16_t *)max_value_ptr = max_raw, *max_index_ptr = max_idx;
288
- return;
289
- }
290
-
291
- // Shared update body
292
- uint16x8_t less_u16x8 = vcltq_s16(data_for_min_i16x8, min_i16x8);
293
- uint16x8_t greater_u16x8 = vcgtq_s16(data_for_max_i16x8, max_i16x8);
294
- min_i16x8 = vbslq_s16(less_u16x8, data_for_min_i16x8, min_i16x8);
295
- max_i16x8 = vbslq_s16(greater_u16x8, data_for_max_i16x8, max_i16x8);
296
- min_iter_u16x8 = vbslq_u16(less_u16x8, iter_u16x8, min_iter_u16x8);
297
- max_iter_u16x8 = vbslq_u16(greater_u16x8, iter_u16x8, max_iter_u16x8);
298
- iter_u16x8 = vaddq_u16(iter_u16x8, one_u16x8);
299
- goto nk_reduce_minmax_bf16_neonbfdot_cycle;
300
- }
301
-
302
- NK_PUBLIC void nk_reduce_minmax_bf16_neonbfdot( //
303
- nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
304
- nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
305
- nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
306
- nk_size_t stride_elements = stride_bytes / sizeof(nk_bf16_t);
307
- int aligned = (stride_bytes % sizeof(nk_bf16_t) == 0);
308
- if (count == 0) {
309
- *(nk_u16_t *)min_value_ptr = NK_BF16_MAX, *min_index_ptr = NK_SIZE_MAX;
310
- *(nk_u16_t *)max_value_ptr = NK_BF16_MIN, *max_index_ptr = NK_SIZE_MAX;
311
- }
312
- else if (!aligned)
313
- nk_reduce_minmax_bf16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
314
- max_index_ptr);
315
- else if (count > (nk_size_t)(NK_U16_MAX + 1) * 8) {
316
- nk_size_t left_count = count / 2;
317
- nk_bf16_t left_min_value, right_min_value, left_max_value, right_max_value;
318
- nk_size_t left_min_index, right_min_index, left_max_index, right_max_index;
319
- nk_reduce_minmax_bf16_neonbfdot(data_ptr, left_count, stride_bytes, &left_min_value, &left_min_index,
320
- &left_max_value, &left_max_index);
321
- nk_reduce_minmax_bf16_neonbfdot(data_ptr + left_count * stride_elements, count - left_count, stride_bytes,
322
- &right_min_value, &right_min_index, &right_max_value, &right_max_index);
323
- if (nk_bf16_order_serial(right_min_value, left_min_value) < 0)
324
- *min_value_ptr = right_min_value, *min_index_ptr = left_count + right_min_index;
325
- else *min_value_ptr = left_min_value, *min_index_ptr = left_min_index;
326
- if (nk_bf16_order_serial(right_max_value, left_max_value) > 0)
327
- *max_value_ptr = right_max_value, *max_index_ptr = left_count + right_max_index;
328
- else *max_value_ptr = left_max_value, *max_index_ptr = left_max_index;
329
- }
330
- else if (stride_elements == 1)
331
- nk_reduce_minmax_bf16_neonbfdot_contiguous_(data_ptr, count, min_value_ptr, min_index_ptr, max_value_ptr,
332
- max_index_ptr);
333
- else if (stride_elements <= 4)
334
- nk_reduce_minmax_bf16_neonbfdot_strided_(data_ptr, count, stride_elements, min_value_ptr, min_index_ptr,
335
- max_value_ptr, max_index_ptr);
336
- else
337
- nk_reduce_minmax_bf16_serial(data_ptr, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
338
- max_index_ptr);
339
- }
340
-
341
130
  #if defined(__clang__)
342
131
  #pragma clang attribute pop
343
132
  #elif defined(__GNUC__)