numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -140,22 +140,22 @@ NK_PUBLIC void nk_reduce_moments_u16_serial( //
140
140
  NK_PUBLIC void nk_reduce_moments_i32_serial( //
141
141
  nk_i32_t const *data, nk_size_t count, nk_size_t stride_bytes, //
142
142
  nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
143
- nk_u64_t sum_lower = 0;
144
- nk_i64_t sum_upper = 0;
143
+ nk_u64_t sum_low = 0;
144
+ nk_i64_t sum_high = 0;
145
145
  nk_u64_t sumsq = 0;
146
146
  unsigned char const *ptr = (unsigned char const *)data;
147
147
  for (nk_size_t i = 0; i < count; ++i, ptr += stride_bytes) {
148
148
  nk_i64_t val = (nk_i64_t)(*(nk_i32_t const *)ptr);
149
149
  nk_u64_t product = (nk_u64_t)(val * val);
150
- nk_u64_t sum_before = sum_lower;
151
- sum_lower += (nk_u64_t)val;
152
- if (sum_lower < sum_before) sum_upper++;
153
- sum_upper += (val >> 63);
150
+ nk_u64_t sum_before = sum_low;
151
+ sum_low += (nk_u64_t)val;
152
+ if (sum_low < sum_before) sum_high++;
153
+ sum_high += (val >> 63);
154
154
  sumsq = nk_u64_saturating_add_serial(sumsq, product);
155
155
  }
156
- nk_i64_t sum_lower_signed = (nk_i64_t)sum_lower;
157
- if (sum_upper == (sum_lower_signed >> 63)) *sum_ptr = sum_lower_signed;
158
- else if (sum_upper >= 0) *sum_ptr = NK_I64_MAX;
156
+ nk_i64_t sum_low_signed = (nk_i64_t)sum_low;
157
+ if (sum_high == (sum_low_signed >> 63)) *sum_ptr = sum_low_signed;
158
+ else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
159
159
  else *sum_ptr = NK_I64_MIN;
160
160
  *sumsq_ptr = sumsq;
161
161
  }
@@ -177,8 +177,8 @@ NK_PUBLIC void nk_reduce_moments_u32_serial( //
177
177
  NK_PUBLIC void nk_reduce_moments_i64_serial( //
178
178
  nk_i64_t const *data, nk_size_t count, nk_size_t stride_bytes, //
179
179
  nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
180
- nk_u64_t sum_lower = 0;
181
- nk_i64_t sum_upper = 0;
180
+ nk_u64_t sum_low = 0;
181
+ nk_i64_t sum_high = 0;
182
182
  nk_u64_t sumsq = 0;
183
183
  unsigned char const *ptr = (unsigned char const *)data;
184
184
  for (nk_size_t i = 0; i < count; ++i, ptr += stride_bytes) {
@@ -186,14 +186,14 @@ NK_PUBLIC void nk_reduce_moments_i64_serial( //
186
186
  nk_i64_t product = nk_i64_saturating_mul_serial(val, val);
187
187
  nk_u64_t unsigned_product = (nk_u64_t)product;
188
188
  sumsq = nk_u64_saturating_add_serial(sumsq, unsigned_product);
189
- nk_u64_t sum_before = sum_lower;
190
- sum_lower += (nk_u64_t)val;
191
- if (sum_lower < sum_before) sum_upper++;
192
- sum_upper += (val >> 63);
193
- }
194
- nk_i64_t sum_lower_signed = (nk_i64_t)sum_lower;
195
- if (sum_upper == (sum_lower_signed >> 63)) *sum_ptr = sum_lower_signed;
196
- else if (sum_upper >= 0) *sum_ptr = NK_I64_MAX;
189
+ nk_u64_t sum_before = sum_low;
190
+ sum_low += (nk_u64_t)val;
191
+ if (sum_low < sum_before) sum_high++;
192
+ sum_high += (val >> 63);
193
+ }
194
+ nk_i64_t sum_low_signed = (nk_i64_t)sum_low;
195
+ if (sum_high == (sum_low_signed >> 63)) *sum_ptr = sum_low_signed;
196
+ else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
197
197
  else *sum_ptr = NK_I64_MIN;
198
198
  *sumsq_ptr = sumsq;
199
199
  }
@@ -572,13 +572,11 @@ NK_PUBLIC void nk_reduce_minmax_f16_serial( //
572
572
  nk_f16_t *min_value_ptr, nk_size_t *min_index_ptr, //
573
573
  nk_f16_t *max_value_ptr, nk_size_t *max_index_ptr) {
574
574
  unsigned char const *ptr = (unsigned char const *)data;
575
- nk_f16_t min_value = nk_f16_from_u16_(NK_F16_MAX), max_value = nk_f16_from_u16_(NK_F16_MIN);
575
+ nk_f16_t min_value = NK_F16_MAX, max_value = NK_F16_MIN;
576
576
  nk_size_t min_idx = NK_SIZE_MAX, max_idx = NK_SIZE_MAX;
577
577
  for (nk_size_t i = 0; i < count; ++i, ptr += stride_bytes) {
578
578
  nk_f16_t raw_value = *(nk_f16_t const *)ptr;
579
- nk_fui16_t raw_fui;
580
- raw_fui.f = raw_value;
581
- if (nk_f16_is_nan_(raw_fui.u)) continue;
579
+ if (nk_f16_is_nan_(raw_value)) continue;
582
580
  if (min_idx == NK_SIZE_MAX || nk_f16_order_serial(raw_value, min_value) < 0) min_value = raw_value, min_idx = i;
583
581
  if (max_idx == NK_SIZE_MAX || nk_f16_order_serial(raw_value, max_value) > 0) max_value = raw_value, max_idx = i;
584
582
  }
@@ -591,13 +589,11 @@ NK_PUBLIC void nk_reduce_minmax_bf16_serial( //
591
589
  nk_bf16_t *min_value_ptr, nk_size_t *min_index_ptr, //
592
590
  nk_bf16_t *max_value_ptr, nk_size_t *max_index_ptr) {
593
591
  unsigned char const *ptr = (unsigned char const *)data;
594
- nk_bf16_t min_value = nk_bf16_from_u16_(NK_BF16_MAX), max_value = nk_bf16_from_u16_(NK_BF16_MIN);
592
+ nk_bf16_t min_value = NK_BF16_MAX, max_value = NK_BF16_MIN;
595
593
  nk_size_t min_idx = NK_SIZE_MAX, max_idx = NK_SIZE_MAX;
596
594
  for (nk_size_t i = 0; i < count; ++i, ptr += stride_bytes) {
597
595
  nk_bf16_t raw_value = *(nk_bf16_t const *)ptr;
598
- nk_fui16_t raw_fui;
599
- raw_fui.bf = raw_value;
600
- if (nk_bf16_is_nan_(raw_fui.u)) continue;
596
+ if (nk_bf16_is_nan_(raw_value)) continue;
601
597
  if (min_idx == NK_SIZE_MAX || nk_bf16_order_serial(raw_value, min_value) < 0)
602
598
  min_value = raw_value, min_idx = i;
603
599
  if (max_idx == NK_SIZE_MAX || nk_bf16_order_serial(raw_value, max_value) > 0)
@@ -7,8 +7,8 @@
7
7
  * @sa include/numkong/reduce.h
8
8
  *
9
9
  * Uses AVX-VNNI-INT8 (256-bit) for efficient widening dot-products on i8, u8, and e2m3:
10
- * - `_mm256_dpbssd_epi32`: i8 x i8 -> i32 signed dot product (AVXVNNIINT8)
11
- * - `_mm256_dpbuud_epi32`: u8 x u8 -> u32 unsigned dot product (AVXVNNIINT8)
10
+ * - `_mm256_dpbssd_epi32`: i8 × i8 i32 signed dot product (AVXVNNIINT8)
11
+ * - `_mm256_dpbuud_epi32`: u8 × u8 u32 unsigned dot product (AVXVNNIINT8)
12
12
  */
13
13
  #ifndef NK_REDUCE_SIERRA_H
14
14
  #define NK_REDUCE_SIERRA_H
@@ -68,7 +68,7 @@ NK_INTERNAL void nk_reduce_moments_i8_sierra_strided_( //
68
68
  nk_size_t idx_scalars = 0;
69
69
  nk_size_t total_scalars = count * stride_elements;
70
70
  nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
71
- for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
71
+ for (; idx_scalars + stride_elements + 31 <= total_scalars; idx_scalars += step) {
72
72
  __m256i data_i8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
73
73
  data_i8x32 = _mm256_and_si256(data_i8x32, stride_mask_i8x32);
74
74
  sum_i32x8 = _mm256_dpbssd_epi32(sum_i32x8, data_i8x32, ones_i8x32);
@@ -109,7 +109,7 @@ NK_PUBLIC void nk_reduce_moments_i8_sierra( //
109
109
  }
110
110
 
111
111
  /**
112
- * @section u8 moments via VPDPBUUD (unsigned u8 x u8 -> u32)
112
+ * @section u8 moments via VPDPBUUD (unsigned u8 × u8 u32)
113
113
  *
114
114
  * Sierra's `_mm256_dpbuud_epi32` provides native u8×u8→u32 dot product, replacing
115
115
  * Haswell's 8-instruction SAD+widen+MADD sequence with 3 instructions per 32 elements.
@@ -153,7 +153,7 @@ NK_INTERNAL void nk_reduce_moments_u8_sierra_strided_( //
153
153
  nk_size_t idx_scalars = 0;
154
154
  nk_size_t total_scalars = count * stride_elements;
155
155
  nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
156
- for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
156
+ for (; idx_scalars + stride_elements + 31 <= total_scalars; idx_scalars += step) {
157
157
  __m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
158
158
  data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
159
159
  sum_i32x8 = _mm256_dpbuud_epi32(sum_i32x8, data_u8x32, ones_u8x32);
@@ -203,10 +203,10 @@ NK_PUBLIC void nk_reduce_moments_u8_sierra( //
203
203
  NK_INTERNAL void nk_reduce_moments_e2m3_sierra_contiguous_( //
204
204
  nk_e2m3_t const *data, nk_size_t count, //
205
205
  nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
206
- __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
207
- 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
208
- __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
209
- 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
206
+ __m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
207
+ 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
208
+ __m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
209
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
210
210
  __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
211
211
  __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
212
212
  __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
@@ -221,8 +221,8 @@ NK_INTERNAL void nk_reduce_moments_e2m3_sierra_contiguous_( //
221
221
  __m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
222
222
  __m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
223
223
  half_select_u8x32);
224
- __m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, shuffle_idx_u8x32),
225
- _mm256_shuffle_epi8(lut_upper_u8x32, shuffle_idx_u8x32),
224
+ __m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, shuffle_idx_u8x32),
225
+ _mm256_shuffle_epi8(lut_high_u8x32, shuffle_idx_u8x32),
226
226
  upper_select_u8x32);
227
227
  __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
228
228
  __m256i signed_i8x32 = _mm256_blendv_epi8(
@@ -241,8 +241,8 @@ NK_INTERNAL void nk_reduce_moments_e2m3_sierra_contiguous_( //
241
241
  __m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
242
242
  __m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
243
243
  half_select_u8x32);
244
- __m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, shuffle_idx_u8x32),
245
- _mm256_shuffle_epi8(lut_upper_u8x32, shuffle_idx_u8x32),
244
+ __m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, shuffle_idx_u8x32),
245
+ _mm256_shuffle_epi8(lut_high_u8x32, shuffle_idx_u8x32),
246
246
  upper_select_u8x32);
247
247
  __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
248
248
  __m256i signed_i8x32 = _mm256_blendv_epi8(
@@ -258,10 +258,10 @@ NK_INTERNAL void nk_reduce_moments_e2m3_sierra_strided_( //
258
258
  nk_e2m3_t const *data, nk_size_t count, nk_size_t stride_elements, //
259
259
  nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
260
260
  __m256i stride_mask_u8x32 = nk_stride_blend_u1x32_(stride_elements);
261
- __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
262
- 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
263
- __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
264
- 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
261
+ __m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, //
262
+ 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
263
+ __m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
264
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
265
265
  __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
266
266
  __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
267
267
  __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
@@ -272,15 +272,15 @@ NK_INTERNAL void nk_reduce_moments_e2m3_sierra_strided_( //
272
272
  nk_size_t idx_scalars = 0;
273
273
  nk_size_t total_scalars = count * stride_elements;
274
274
  nk_size_t step = nk_size_round_up_to_multiple_(32, stride_elements);
275
- for (; idx_scalars + step <= total_scalars; idx_scalars += step) {
275
+ for (; idx_scalars + stride_elements + 31 <= total_scalars; idx_scalars += step) {
276
276
  __m256i data_u8x32 = _mm256_loadu_si256((__m256i const *)(data + idx_scalars));
277
277
  data_u8x32 = _mm256_and_si256(data_u8x32, stride_mask_u8x32);
278
278
  __m256i magnitude_u8x32 = _mm256_and_si256(data_u8x32, magnitude_mask_u8x32);
279
279
  __m256i shuffle_idx_u8x32 = _mm256_and_si256(magnitude_u8x32, nibble_mask_u8x32);
280
280
  __m256i upper_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(magnitude_u8x32, half_select_u8x32),
281
281
  half_select_u8x32);
282
- __m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, shuffle_idx_u8x32),
283
- _mm256_shuffle_epi8(lut_upper_u8x32, shuffle_idx_u8x32),
282
+ __m256i unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, shuffle_idx_u8x32),
283
+ _mm256_shuffle_epi8(lut_high_u8x32, shuffle_idx_u8x32),
284
284
  upper_select_u8x32);
285
285
  __m256i negate_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(data_u8x32, sign_mask_u8x32), sign_mask_u8x32);
286
286
  __m256i signed_i8x32 = _mm256_blendv_epi8(