numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -620,8 +620,8 @@ NK_INTERNAL void nk_reduce_minmax_f16_v128relaxed_contiguous_( //
620
620
  if (val > max_value_f32) max_value_f32 = val, max_idx = idx;
621
621
  }
622
622
  if (min_value_f32 == NK_F32_MAX && max_value_f32 == NK_F32_MIN) {
623
- *min_value_ptr = nk_f16_from_u16_(NK_F16_MAX), *min_index_ptr = NK_SIZE_MAX,
624
- *max_value_ptr = nk_f16_from_u16_(NK_F16_MIN), *max_index_ptr = NK_SIZE_MAX;
623
+ *min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
624
+ *max_index_ptr = NK_SIZE_MAX;
625
625
  return;
626
626
  }
627
627
  *min_value_ptr = data[min_idx], *min_index_ptr = min_idx;
@@ -635,8 +635,8 @@ NK_PUBLIC void nk_reduce_minmax_f16_v128relaxed( //
635
635
  nk_size_t stride_elements = stride_bytes / sizeof(nk_f16_t);
636
636
  int aligned = (stride_bytes % sizeof(nk_f16_t) == 0);
637
637
  if (count == 0)
638
- *min_value_ptr = nk_f16_from_u16_(NK_F16_MAX), *min_index_ptr = NK_SIZE_MAX,
639
- *max_value_ptr = nk_f16_from_u16_(NK_F16_MIN), *max_index_ptr = NK_SIZE_MAX;
638
+ *min_value_ptr = NK_F16_MAX, *min_index_ptr = NK_SIZE_MAX, *max_value_ptr = NK_F16_MIN,
639
+ *max_index_ptr = NK_SIZE_MAX;
640
640
  else if (!aligned)
641
641
  nk_reduce_minmax_f16_serial(data, count, stride_bytes, min_value_ptr, min_index_ptr, max_value_ptr,
642
642
  max_index_ptr);
@@ -856,8 +856,8 @@ NK_PUBLIC void nk_reduce_moments_u16_v128relaxed( //
856
856
  NK_INTERNAL void nk_reduce_moments_i32_v128relaxed_contiguous_( //
857
857
  nk_i32_t const *data, nk_size_t count, //
858
858
  nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
859
- v128_t sum_lower_u64x2 = wasm_i64x2_splat(0);
860
- v128_t sum_upper_i64x2 = wasm_i64x2_splat(0);
859
+ v128_t sum_low_u64x2 = wasm_i64x2_splat(0);
860
+ v128_t sum_high_i64x2 = wasm_i64x2_splat(0);
861
861
  v128_t sumsq_u64x2 = wasm_i64x2_splat(0);
862
862
  v128_t sumsq_overflow_u64x2 = wasm_i64x2_splat(0);
863
863
  v128_t sign_bit_i64x2 = wasm_i64x2_splat((nk_i64_t)0x8000000000000000LL);
@@ -865,21 +865,21 @@ NK_INTERNAL void nk_reduce_moments_i32_v128relaxed_contiguous_( //
865
865
  for (; idx + 4 <= count; idx += 4) {
866
866
  v128_t data_i32x4 = wasm_v128_load(data + idx);
867
867
  v128_t data_low_i64x2 = wasm_i64x2_extend_low_i32x4(data_i32x4);
868
- v128_t before_u64x2 = sum_lower_u64x2;
869
- sum_lower_u64x2 = wasm_i64x2_add(sum_lower_u64x2, data_low_i64x2);
870
- v128_t result_biased_i64x2 = wasm_v128_xor(sum_lower_u64x2, sign_bit_i64x2);
868
+ v128_t before_u64x2 = sum_low_u64x2;
869
+ sum_low_u64x2 = wasm_i64x2_add(sum_low_u64x2, data_low_i64x2);
870
+ v128_t result_biased_i64x2 = wasm_v128_xor(sum_low_u64x2, sign_bit_i64x2);
871
871
  v128_t before_biased_i64x2 = wasm_v128_xor(before_u64x2, sign_bit_i64x2);
872
872
  v128_t carry_u64x2 = wasm_i64x2_gt(before_biased_i64x2, result_biased_i64x2);
873
- sum_upper_i64x2 = wasm_i64x2_sub(sum_upper_i64x2, carry_u64x2);
874
- sum_upper_i64x2 = wasm_i64x2_add(sum_upper_i64x2, wasm_i64x2_shr(data_low_i64x2, 63));
873
+ sum_high_i64x2 = wasm_i64x2_sub(sum_high_i64x2, carry_u64x2);
874
+ sum_high_i64x2 = wasm_i64x2_add(sum_high_i64x2, wasm_i64x2_shr(data_low_i64x2, 63));
875
875
  v128_t data_high_i64x2 = wasm_i64x2_extend_high_i32x4(data_i32x4);
876
- before_u64x2 = sum_lower_u64x2;
877
- sum_lower_u64x2 = wasm_i64x2_add(sum_lower_u64x2, data_high_i64x2);
878
- result_biased_i64x2 = wasm_v128_xor(sum_lower_u64x2, sign_bit_i64x2);
876
+ before_u64x2 = sum_low_u64x2;
877
+ sum_low_u64x2 = wasm_i64x2_add(sum_low_u64x2, data_high_i64x2);
878
+ result_biased_i64x2 = wasm_v128_xor(sum_low_u64x2, sign_bit_i64x2);
879
879
  before_biased_i64x2 = wasm_v128_xor(before_u64x2, sign_bit_i64x2);
880
880
  carry_u64x2 = wasm_i64x2_gt(before_biased_i64x2, result_biased_i64x2);
881
- sum_upper_i64x2 = wasm_i64x2_sub(sum_upper_i64x2, carry_u64x2);
882
- sum_upper_i64x2 = wasm_i64x2_add(sum_upper_i64x2, wasm_i64x2_shr(data_high_i64x2, 63));
881
+ sum_high_i64x2 = wasm_i64x2_sub(sum_high_i64x2, carry_u64x2);
882
+ sum_high_i64x2 = wasm_i64x2_add(sum_high_i64x2, wasm_i64x2_shr(data_high_i64x2, 63));
883
883
  v128_t sq_low_i64x2 = wasm_i64x2_extmul_low_i32x4(data_i32x4, data_i32x4);
884
884
  v128_t sq_high_i64x2 = wasm_i64x2_extmul_high_i32x4(data_i32x4, data_i32x4);
885
885
  v128_t sq_before_u64x2 = sumsq_u64x2;
@@ -897,26 +897,26 @@ NK_INTERNAL void nk_reduce_moments_i32_v128relaxed_contiguous_( //
897
897
  wasm_i64x2_extract_lane(sumsq_overflow_u64x2, 1));
898
898
  nk_u64_t sumsq = sumsq_overflow ? NK_U64_MAX : nk_reduce_sadd_u64x2_v128relaxed_(sumsq_u64x2);
899
899
  nk_b128_vec_t lower_vec, upper_vec;
900
- lower_vec.v128 = sum_lower_u64x2;
901
- upper_vec.v128 = sum_upper_i64x2;
902
- nk_u64_t sum_lower = 0;
903
- nk_i64_t sum_upper = 0;
904
- nk_u64_t sum_before = sum_lower;
905
- sum_lower += lower_vec.u64s[0], sum_upper += (sum_lower < sum_before) + upper_vec.i64s[0];
906
- sum_before = sum_lower;
907
- sum_lower += lower_vec.u64s[1], sum_upper += (sum_lower < sum_before) + upper_vec.i64s[1];
900
+ lower_vec.v128 = sum_low_u64x2;
901
+ upper_vec.v128 = sum_high_i64x2;
902
+ nk_u64_t sum_low = 0;
903
+ nk_i64_t sum_high = 0;
904
+ nk_u64_t sum_before = sum_low;
905
+ sum_low += lower_vec.u64s[0], sum_high += (sum_low < sum_before) + upper_vec.i64s[0];
906
+ sum_before = sum_low;
907
+ sum_low += lower_vec.u64s[1], sum_high += (sum_low < sum_before) + upper_vec.i64s[1];
908
908
  for (; idx < count; ++idx) {
909
909
  nk_i64_t val = (nk_i64_t)data[idx];
910
- sum_before = sum_lower;
911
- sum_lower += (nk_u64_t)val;
912
- if (sum_lower < sum_before) sum_upper++;
913
- sum_upper += (val >> 63);
910
+ sum_before = sum_low;
911
+ sum_low += (nk_u64_t)val;
912
+ if (sum_low < sum_before) sum_high++;
913
+ sum_high += (val >> 63);
914
914
  nk_u64_t product = (nk_u64_t)(val * val);
915
915
  sumsq = nk_u64_saturating_add_serial(sumsq, product);
916
916
  }
917
- nk_i64_t sum_lower_signed = (nk_i64_t)sum_lower;
918
- if (sum_upper == (sum_lower_signed >> 63)) *sum_ptr = sum_lower_signed;
919
- else if (sum_upper >= 0) *sum_ptr = NK_I64_MAX;
917
+ nk_i64_t sum_low_signed = (nk_i64_t)sum_low;
918
+ if (sum_high == (sum_low_signed >> 63)) *sum_ptr = sum_low_signed;
919
+ else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
920
920
  else *sum_ptr = NK_I64_MIN;
921
921
  *sumsq_ptr = sumsq;
922
922
  }
@@ -981,8 +981,8 @@ NK_PUBLIC void nk_reduce_moments_u32_v128relaxed( //
981
981
  NK_INTERNAL void nk_reduce_moments_i64_v128relaxed_contiguous_( //
982
982
  nk_i64_t const *data, nk_size_t count, //
983
983
  nk_i64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
984
- v128_t sum_lower_u64x2 = wasm_i64x2_splat(0);
985
- v128_t sum_upper_i64x2 = wasm_i64x2_splat(0);
984
+ v128_t sum_low_u64x2 = wasm_i64x2_splat(0);
985
+ v128_t sum_high_i64x2 = wasm_i64x2_splat(0);
986
986
  v128_t sumsq_u64x2 = wasm_i64x2_splat(0);
987
987
  v128_t sumsq_overflow_u64x2 = wasm_i64x2_splat(0);
988
988
  v128_t sign_bit_i64x2 = wasm_i64x2_splat((nk_i64_t)0x8000000000000000LL);
@@ -995,36 +995,36 @@ NK_INTERNAL void nk_reduce_moments_i64_v128relaxed_contiguous_( //
995
995
  sumsq_overflow_u64x2 = wasm_v128_or(
996
996
  sumsq_overflow_u64x2,
997
997
  wasm_i64x2_gt(wasm_v128_xor(sq_before_u64x2, sign_bit_i64x2), wasm_v128_xor(sumsq_u64x2, sign_bit_i64x2)));
998
- v128_t before_u64x2 = sum_lower_u64x2;
999
- sum_lower_u64x2 = wasm_i64x2_add(sum_lower_u64x2, data_i64x2);
998
+ v128_t before_u64x2 = sum_low_u64x2;
999
+ sum_low_u64x2 = wasm_i64x2_add(sum_low_u64x2, data_i64x2);
1000
1000
  v128_t carry_u64x2 = wasm_i64x2_gt(wasm_v128_xor(before_u64x2, sign_bit_i64x2),
1001
- wasm_v128_xor(sum_lower_u64x2, sign_bit_i64x2));
1002
- sum_upper_i64x2 = wasm_i64x2_sub(sum_upper_i64x2, carry_u64x2);
1003
- sum_upper_i64x2 = wasm_i64x2_add(sum_upper_i64x2, wasm_i64x2_shr(data_i64x2, 63));
1001
+ wasm_v128_xor(sum_low_u64x2, sign_bit_i64x2));
1002
+ sum_high_i64x2 = wasm_i64x2_sub(sum_high_i64x2, carry_u64x2);
1003
+ sum_high_i64x2 = wasm_i64x2_add(sum_high_i64x2, wasm_i64x2_shr(data_i64x2, 63));
1004
1004
  }
1005
1005
  int sumsq_overflow = (int)(wasm_i64x2_extract_lane(sumsq_overflow_u64x2, 0) |
1006
1006
  wasm_i64x2_extract_lane(sumsq_overflow_u64x2, 1));
1007
1007
  nk_u64_t sumsq = sumsq_overflow ? NK_U64_MAX : nk_reduce_sadd_u64x2_v128relaxed_(sumsq_u64x2);
1008
- nk_u64_t sum_lower = (nk_u64_t)wasm_i64x2_extract_lane(sum_lower_u64x2, 0);
1009
- nk_i64_t sum_upper = wasm_i64x2_extract_lane(sum_upper_i64x2, 0);
1008
+ nk_u64_t sum_low = (nk_u64_t)wasm_i64x2_extract_lane(sum_low_u64x2, 0);
1009
+ nk_i64_t sum_high = wasm_i64x2_extract_lane(sum_high_i64x2, 0);
1010
1010
  {
1011
- nk_u64_t sum_before = sum_lower;
1012
- sum_lower += (nk_u64_t)wasm_i64x2_extract_lane(sum_lower_u64x2, 1);
1013
- if (sum_lower < sum_before) sum_upper++;
1014
- sum_upper += wasm_i64x2_extract_lane(sum_upper_i64x2, 1);
1011
+ nk_u64_t sum_before = sum_low;
1012
+ sum_low += (nk_u64_t)wasm_i64x2_extract_lane(sum_low_u64x2, 1);
1013
+ if (sum_low < sum_before) sum_high++;
1014
+ sum_high += wasm_i64x2_extract_lane(sum_high_i64x2, 1);
1015
1015
  }
1016
1016
  for (; idx < count; ++idx) {
1017
1017
  nk_i64_t val = data[idx];
1018
1018
  nk_u64_t unsigned_product = (nk_u64_t)nk_i64_saturating_mul_serial(val, val);
1019
1019
  sumsq = nk_u64_saturating_add_serial(sumsq, unsigned_product);
1020
- nk_u64_t sum_before = sum_lower;
1021
- sum_lower += (nk_u64_t)val;
1022
- if (sum_lower < sum_before) sum_upper++;
1023
- sum_upper += (val >> 63);
1024
- }
1025
- nk_i64_t sum_lower_signed = (nk_i64_t)sum_lower;
1026
- if (sum_upper == (sum_lower_signed >> 63)) *sum_ptr = sum_lower_signed;
1027
- else if (sum_upper >= 0) *sum_ptr = NK_I64_MAX;
1020
+ nk_u64_t sum_before = sum_low;
1021
+ sum_low += (nk_u64_t)val;
1022
+ if (sum_low < sum_before) sum_high++;
1023
+ sum_high += (val >> 63);
1024
+ }
1025
+ nk_i64_t sum_low_signed = (nk_i64_t)sum_low;
1026
+ if (sum_high == (sum_low_signed >> 63)) *sum_ptr = sum_low_signed;
1027
+ else if (sum_high >= 0) *sum_ptr = NK_I64_MAX;
1028
1028
  else *sum_ptr = NK_I64_MIN;
1029
1029
  *sumsq_ptr = sumsq;
1030
1030
  }
@@ -446,19 +446,13 @@ NK_PUBLIC void nk_reduce_minmax_e4m3_neon(nk_e4m3_t const *, nk_size_t, nk_size_
446
446
  /** @copydoc nk_reduce_minmax_f64 */
447
447
  NK_PUBLIC void nk_reduce_minmax_e5m2_neon(nk_e5m2_t const *, nk_size_t, nk_size_t, nk_e5m2_t *, nk_size_t *,
448
448
  nk_e5m2_t *, nk_size_t *);
449
- #endif // NK_TARGET_NEON
450
-
451
- #if NK_TARGET_NEONHALF
452
449
  /** @copydoc nk_reduce_moments_f64 */
453
- NK_PUBLIC void nk_reduce_moments_f16_neonhalf(nk_f16_t const *, nk_size_t, nk_size_t, nk_f32_t *, nk_f32_t *);
454
- #endif // NK_TARGET_NEONHALF
450
+ NK_PUBLIC void nk_reduce_moments_f16_neon(nk_f16_t const *, nk_size_t, nk_size_t, nk_f32_t *, nk_f32_t *);
451
+ #endif // NK_TARGET_NEON
455
452
 
456
453
  #if NK_TARGET_NEONBFDOT
457
454
  /** @copydoc nk_reduce_moments_f64 */
458
455
  NK_PUBLIC void nk_reduce_moments_bf16_neonbfdot(nk_bf16_t const *, nk_size_t, nk_size_t, nk_f32_t *, nk_f32_t *);
459
- /** @copydoc nk_reduce_minmax_f64 */
460
- NK_PUBLIC void nk_reduce_minmax_bf16_neonbfdot(nk_bf16_t const *, nk_size_t, nk_size_t, nk_bf16_t *, nk_size_t *,
461
- nk_bf16_t *, nk_size_t *);
462
456
  #endif // NK_TARGET_NEONBFDOT
463
457
 
464
458
  #if NK_TARGET_NEONSDOT
@@ -475,12 +469,6 @@ NK_PUBLIC void nk_reduce_moments_e2m3_neonsdot(nk_e2m3_t const *, nk_size_t, nk_
475
469
  NK_PUBLIC void nk_reduce_moments_e4m3_neonfhm(nk_e4m3_t const *, nk_size_t, nk_size_t, nk_f32_t *, nk_f32_t *);
476
470
  /** @copydoc nk_reduce_moments_f64 */
477
471
  NK_PUBLIC void nk_reduce_moments_e5m2_neonfhm(nk_e5m2_t const *, nk_size_t, nk_size_t, nk_f32_t *, nk_f32_t *);
478
- /** @copydoc nk_reduce_minmax_f64 */
479
- NK_PUBLIC void nk_reduce_minmax_e4m3_neonfhm(nk_e4m3_t const *, nk_size_t, nk_size_t, nk_e4m3_t *, nk_size_t *,
480
- nk_e4m3_t *, nk_size_t *);
481
- /** @copydoc nk_reduce_minmax_f64 */
482
- NK_PUBLIC void nk_reduce_minmax_e5m2_neonfhm(nk_e5m2_t const *, nk_size_t, nk_size_t, nk_e5m2_t *, nk_size_t *,
483
- nk_e5m2_t *, nk_size_t *);
484
472
  #endif // NK_TARGET_NEONFHM
485
473
 
486
474
  #if NK_TARGET_HASWELL
@@ -950,7 +938,6 @@ NK_INTERNAL nk_dtype_t nk_reduce_minmax_value_dtype(nk_dtype_t dtype) {
950
938
 
951
939
  #include "numkong/reduce/serial.h"
952
940
  #include "numkong/reduce/neon.h"
953
- #include "numkong/reduce/neonhalf.h"
954
941
  #include "numkong/reduce/neonbfdot.h"
955
942
  #include "numkong/reduce/neonsdot.h"
956
943
  #include "numkong/reduce/neonfhm.h"
@@ -1324,8 +1311,8 @@ NK_PUBLIC void nk_reduce_moments_f16(nk_f16_t const *d, nk_size_t n, nk_size_t s
1324
1311
  nk_reduce_moments_f16_skylake(d, n, s, sum, sumsq);
1325
1312
  #elif NK_TARGET_HASWELL
1326
1313
  nk_reduce_moments_f16_haswell(d, n, s, sum, sumsq);
1327
- #elif NK_TARGET_NEONHALF
1328
- nk_reduce_moments_f16_neonhalf(d, n, s, sum, sumsq);
1314
+ #elif NK_TARGET_NEON
1315
+ nk_reduce_moments_f16_neon(d, n, s, sum, sumsq);
1329
1316
  #elif NK_TARGET_RVV
1330
1317
  nk_reduce_moments_f16_rvv(d, n, s, sum, sumsq);
1331
1318
  #elif NK_TARGET_V128RELAXED
@@ -1376,8 +1363,6 @@ NK_PUBLIC void nk_reduce_minmax_bf16(nk_bf16_t const *d, nk_size_t n, nk_size_t
1376
1363
  nk_reduce_minmax_bf16_skylake(d, n, s, mn, mi, mx, xi);
1377
1364
  #elif NK_TARGET_HASWELL
1378
1365
  nk_reduce_minmax_bf16_haswell(d, n, s, mn, mi, mx, xi);
1379
- #elif NK_TARGET_NEONBFDOT
1380
- nk_reduce_minmax_bf16_neonbfdot(d, n, s, mn, mi, mx, xi);
1381
1366
  #elif NK_TARGET_RVV
1382
1367
  nk_reduce_minmax_bf16_rvv(d, n, s, mn, mi, mx, xi);
1383
1368
  #elif NK_TARGET_V128RELAXED
@@ -1413,8 +1398,6 @@ NK_PUBLIC void nk_reduce_minmax_e4m3(nk_e4m3_t const *d, nk_size_t n, nk_size_t
1413
1398
  nk_reduce_minmax_e4m3_skylake(d, n, s, mn, mi, mx, xi);
1414
1399
  #elif NK_TARGET_HASWELL
1415
1400
  nk_reduce_minmax_e4m3_haswell(d, n, s, mn, mi, mx, xi);
1416
- #elif NK_TARGET_NEONFHM
1417
- nk_reduce_minmax_e4m3_neonfhm(d, n, s, mn, mi, mx, xi);
1418
1401
  #elif NK_TARGET_NEON
1419
1402
  nk_reduce_minmax_e4m3_neon(d, n, s, mn, mi, mx, xi);
1420
1403
  #elif NK_TARGET_RVV
@@ -1452,8 +1435,6 @@ NK_PUBLIC void nk_reduce_minmax_e5m2(nk_e5m2_t const *d, nk_size_t n, nk_size_t
1452
1435
  nk_reduce_minmax_e5m2_skylake(d, n, s, mn, mi, mx, xi);
1453
1436
  #elif NK_TARGET_HASWELL
1454
1437
  nk_reduce_minmax_e5m2_haswell(d, n, s, mn, mi, mx, xi);
1455
- #elif NK_TARGET_NEONFHM
1456
- nk_reduce_minmax_e5m2_neonfhm(d, n, s, mn, mi, mx, xi);
1457
1438
  #elif NK_TARGET_NEON
1458
1439
  nk_reduce_minmax_e5m2_neon(d, n, s, mn, mi, mx, xi);
1459
1440
  #elif NK_TARGET_RVV
@@ -192,13 +192,95 @@ void reduce_minmax(in_type_ const *data, std::size_t count, std::size_t stride_b
192
192
  if (max_index) *max_index = static_cast<std::size_t>(max_offset);
193
193
  }
194
194
 
195
+ /** @brief Compute sum and sum-of-squares over a vector view. */
196
+ template <numeric_dtype in_type_, numeric_dtype sum_type_ = typename in_type_::reduce_moments_sum_t,
197
+ numeric_dtype sumsq_type_ = typename in_type_::reduce_moments_sumsq_t,
198
+ allow_simd_t allow_simd_ = prefer_simd_k>
199
+ void reduce_moments(vector_view<in_type_> input, sum_type_ *sum, sumsq_type_ *sumsq) noexcept {
200
+ reduce_moments<in_type_, sum_type_, sumsq_type_, allow_simd_>(
201
+ input.data(), input.size(), static_cast<std::size_t>(input.stride_bytes()), sum, sumsq);
202
+ }
203
+
204
+ /** @brief Find minimum and maximum elements with their indices over a vector view. */
205
+ template <numeric_dtype in_type_, numeric_dtype minmax_type_ = typename in_type_::reduce_minmax_value_t,
206
+ allow_simd_t allow_simd_ = prefer_simd_k>
207
+ void reduce_minmax(vector_view<in_type_> input, minmax_type_ *min_value, std::size_t *min_index,
208
+ minmax_type_ *max_value, std::size_t *max_index) noexcept {
209
+ reduce_minmax<in_type_, minmax_type_, allow_simd_>(input.data(), input.size(),
210
+ static_cast<std::size_t>(input.stride_bytes()), min_value,
211
+ min_index, max_value, max_index);
212
+ }
213
+
195
214
  } // namespace ashvardanian::numkong
196
215
 
197
216
  #include "numkong/tensor.hpp"
198
217
 
199
218
  namespace ashvardanian::numkong {
200
219
 
201
- #pragma region - Tensor Reduction Helpers
220
+ #pragma region Tensor Reduction Helpers
221
+
222
+ /** @brief Result of detecting how many trailing dimensions form a single arithmetic progression. */
223
+ struct uniform_stride_tail_result_t_ {
224
+ std::size_t tail_dims; ///< Number of collapsible trailing dimensions.
225
+ std::size_t element_count; ///< Product of collapsed extents.
226
+ std::size_t stride_bytes; ///< Absolute stride of the innermost collapsed dimension.
227
+ };
228
+
229
+ /** @brief Detect trailing dimensions where stride[i] == stride[i+1] * extent[i+1].
230
+ * When this holds, the tail is a single strided sequence and can be passed to a SIMD
231
+ * kernel in one call with (element_count, stride_bytes). */
232
+ template <typename value_type_, std::size_t max_rank_>
233
+ uniform_stride_tail_result_t_ uniform_stride_tail_(tensor_view<value_type_, max_rank_> input) noexcept {
234
+ if constexpr (dimensions_per_value<value_type_>() > 1) return {0, 0, 0};
235
+ auto rank = input.rank();
236
+ if (rank == 0) return {0, 1, sizeof(value_type_)};
237
+ std::size_t tail = 1;
238
+ auto innermost_stride = input.stride_bytes(rank - 1);
239
+ auto expected_stride = innermost_stride;
240
+ for (std::size_t i = rank - 1; i > 0; --i) {
241
+ expected_stride *= static_cast<std::ptrdiff_t>(input.extent(i));
242
+ if (input.stride_bytes(i - 1) != expected_stride) break;
243
+ ++tail;
244
+ }
245
+ std::size_t count = 1;
246
+ for (std::size_t i = rank - tail; i < rank; ++i) count *= input.extent(i);
247
+ return {tail, count, static_cast<std::size_t>(innermost_stride < 0 ? -innermost_stride : innermost_stride)};
248
+ }
249
+
250
+ /** @brief Collapse the trailing `tail.tail_dims` dimensions into one, preserving outer dims and strides. */
251
+ template <typename value_type_, std::size_t max_rank_>
252
+ tensor_view<value_type_, max_rank_> collapse_uniform_tail_(tensor_view<value_type_, max_rank_> input,
253
+ uniform_stride_tail_result_t_ const &tail) noexcept {
254
+ shape_storage_<max_rank_> s;
255
+ s.rank = input.rank() - tail.tail_dims + 1;
256
+ for (std::size_t i = 0; i + tail.tail_dims < input.rank(); ++i) {
257
+ s.extents[i] = input.extent(i);
258
+ s.strides[i] = input.stride_bytes(i);
259
+ }
260
+ s.extents[s.rank - 1] = tail.element_count;
261
+ s.strides[s.rank - 1] = input.stride_bytes(input.rank() - 1);
262
+ return {input.byte_data(), s};
263
+ }
264
+
265
+ /** @brief Normalize a fully-collapsed tail for SIMD kernel consumption, handling negative strides. */
266
+ template <typename value_type_, std::size_t max_rank_>
267
+ normalized_rank1_lane_<value_type_, max_rank_> normalize_rank1_lane_from_tail_(
268
+ tensor_view<value_type_, max_rank_> input, uniform_stride_tail_result_t_ const &tail) noexcept {
269
+ normalized_rank1_lane_<value_type_, max_rank_> lane;
270
+ lane.count = tail.element_count;
271
+ lane.stride_bytes = tail.stride_bytes;
272
+ auto innermost_stride = input.stride_bytes(input.rank() - 1);
273
+ if (innermost_stride >= 0) {
274
+ lane.data = input.data();
275
+ lane.reversed = false;
276
+ }
277
+ else {
278
+ lane.data = reinterpret_cast<value_type_ const *>(
279
+ input.byte_data() + static_cast<std::ptrdiff_t>(lane.count - 1) * innermost_stride);
280
+ lane.reversed = true;
281
+ }
282
+ return lane;
283
+ }
202
284
 
203
285
  template <numeric_dtype value_type_, std::size_t max_rank_>
204
286
  bool reduce_rank1_moments_(tensor_view<value_type_, max_rank_> input, typename value_type_::reduce_moments_sum_t &sum,
@@ -391,9 +473,9 @@ bool reduce_minmax_axis_packed_(tensor_view<value_type_, max_rank_> input, std::
391
473
  return true;
392
474
  }
393
475
 
394
- #pragma endregion - Tensor Reduction Helpers
476
+ #pragma endregion Tensor Reduction Helpers
395
477
 
396
- #pragma region - Scalar Reductions
478
+ #pragma region Scalar Reductions
397
479
 
398
480
  /** @brief Compute Σxᵢ and Σxᵢ² in a single pass. Returns zeroed result for empty tensors. */
399
481
  template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
@@ -403,11 +485,14 @@ moments_result<typename value_type_::reduce_moments_sum_t, typename value_type_:
403
485
  using sumsq_t = typename value_type_::reduce_moments_sumsq_t;
404
486
  moments_result<sum_t, sumsq_t> result {};
405
487
  if (input.empty() || input.numel() == 0 || !tensor_layout_supported_(input)) return result;
406
- if (input.is_contiguous()) {
407
- numkong::reduce_moments<value_type_>(input.data(), input.numel(), sizeof(value_type_), &result.sum,
408
- &result.sumsq);
488
+ auto tail = uniform_stride_tail_(input);
489
+ if (tail.tail_dims == input.rank()) {
490
+ auto lane = normalize_rank1_lane_from_tail_<value_type_, max_rank_>(input, tail);
491
+ numkong::reduce_moments<value_type_>(lane.data, lane.count, lane.stride_bytes, &result.sum, &result.sumsq);
409
492
  return result;
410
493
  }
494
+ if (tail.tail_dims >= 2) return moments<value_type_, max_rank_>(collapse_uniform_tail_(input, tail));
495
+ // Sub-byte rank-1 fallback: uniform_stride_tail_ returns {0,0,0} for packed types.
411
496
  if (input.rank() == 1) {
412
497
  reduce_rank1_moments_(input, result.sum, result.sumsq);
413
498
  return result;
@@ -426,11 +511,19 @@ minmax_result<typename value_type_::reduce_minmax_value_t> minmax(tensor_view<va
426
511
  using minmax_t = typename value_type_::reduce_minmax_value_t;
427
512
  minmax_result<minmax_t> result {};
428
513
  if (input.empty() || input.numel() == 0 || !tensor_layout_supported_(input)) return result;
429
- if (input.is_contiguous()) {
430
- numkong::reduce_minmax<value_type_>(input.data(), input.numel(), sizeof(value_type_), &result.min_value,
514
+ auto tail = uniform_stride_tail_(input);
515
+ if (tail.tail_dims == input.rank()) {
516
+ auto lane = normalize_rank1_lane_from_tail_<value_type_, max_rank_>(input, tail);
517
+ numkong::reduce_minmax<value_type_>(lane.data, lane.count, lane.stride_bytes, &result.min_value,
431
518
  &result.min_index, &result.max_value, &result.max_index);
519
+ if (lane.reversed) {
520
+ result.min_index = tail.element_count - 1 - result.min_index;
521
+ result.max_index = tail.element_count - 1 - result.max_index;
522
+ }
432
523
  return result;
433
524
  }
525
+ if (tail.tail_dims >= 2) return minmax<value_type_, max_rank_>(collapse_uniform_tail_(input, tail));
526
+ // Sub-byte rank-1 fallback.
434
527
  if (input.rank() == 1) {
435
528
  reduce_rank1_minmax_(input, result);
436
529
  return result;
@@ -484,9 +577,61 @@ std::size_t argmax(tensor_view<value_type_, max_rank_> input) noexcept {
484
577
  return minmax(input).max_index;
485
578
  }
486
579
 
487
- #pragma endregion - Scalar Reductions
580
+ /** @brief Compute Σxᵢ and Σxᵢ² over a vector view. */
581
+ template <numeric_dtype value_type_>
582
+ moments_result<typename value_type_::reduce_moments_sum_t, typename value_type_::reduce_moments_sumsq_t> moments(
583
+ vector_view<value_type_> input) noexcept {
584
+ using sum_t = typename value_type_::reduce_moments_sum_t;
585
+ using sumsq_t = typename value_type_::reduce_moments_sumsq_t;
586
+ moments_result<sum_t, sumsq_t> result {};
587
+ if (input.size() == 0) return result;
588
+ reduce_moments<value_type_>(input, &result.sum, &result.sumsq);
589
+ return result;
590
+ }
591
+
592
+ /** @brief Find min and max values with their indices over a vector view. */
593
+ template <numeric_dtype value_type_>
594
+ minmax_result<typename value_type_::reduce_minmax_value_t> minmax(vector_view<value_type_> input) noexcept {
595
+ using minmax_t = typename value_type_::reduce_minmax_value_t;
596
+ minmax_result<minmax_t> result {};
597
+ if (input.size() == 0) return result;
598
+ reduce_minmax<value_type_>(input, &result.min_value, &result.min_index, &result.max_value, &result.max_index);
599
+ return result;
600
+ }
601
+
602
+ /** @brief Σ of all elements in a vector view. */
603
+ template <numeric_dtype value_type_>
604
+ typename value_type_::reduce_moments_sum_t sum(vector_view<value_type_> input) noexcept {
605
+ return moments(input).sum;
606
+ }
607
+
608
+ /** @brief Find the minimum element value in a vector view. */
609
+ template <numeric_dtype value_type_>
610
+ typename value_type_::reduce_minmax_value_t min(vector_view<value_type_> input) noexcept {
611
+ return minmax(input).min_value;
612
+ }
613
+
614
+ /** @brief Find the maximum element value in a vector view. */
615
+ template <numeric_dtype value_type_>
616
+ typename value_type_::reduce_minmax_value_t max(vector_view<value_type_> input) noexcept {
617
+ return minmax(input).max_value;
618
+ }
619
+
620
+ /** @brief Index of the minimum element in a vector view. */
621
+ template <numeric_dtype value_type_>
622
+ std::size_t argmin(vector_view<value_type_> input) noexcept {
623
+ return minmax(input).min_index;
624
+ }
625
+
626
+ /** @brief Index of the maximum element in a vector view. */
627
+ template <numeric_dtype value_type_>
628
+ std::size_t argmax(vector_view<value_type_> input) noexcept {
629
+ return minmax(input).max_index;
630
+ }
631
+
632
+ #pragma endregion Scalar Reductions
488
633
 
489
- #pragma region - Axis Reductions
634
+ #pragma region Axis Reductions
490
635
 
491
636
  /** @brief Σ along a single axis. Returns empty tensor on failure. */
492
637
  template <numeric_dtype value_type_, std::size_t max_rank_ = 8,
@@ -626,7 +771,7 @@ tensor<typename value_type_::reduce_minmax_value_t, allocator_type_, max_rank_>
626
771
  return try_minmax<value_type_, max_rank_, allocator_type_>(input, axis, keep_dims).max_value;
627
772
  }
628
773
 
629
- #pragma endregion - Axis Reductions
774
+ #pragma endregion Axis Reductions
630
775
 
631
776
  } // namespace ashvardanian::numkong
632
777
 
@@ -6,21 +6,21 @@ Ordering functions (`nk_f16_order`, `nk_bf16_order`, `nk_e4m3_order`) convert fl
6
6
 
7
7
  Reciprocal square root:
8
8
 
9
- ```math
9
+ $$
10
10
  \text{rsqrt}(x) = \frac{1}{\sqrt{x}}
11
- ```
11
+ $$
12
12
 
13
13
  Fused multiply-add:
14
14
 
15
- ```math
15
+ $$
16
16
  \text{fma}(a, b, c) = a \cdot b + c
17
- ```
17
+ $$
18
18
 
19
19
  Saturating addition:
20
20
 
21
- ```math
21
+ $$
22
22
  \text{sat\_add}(a, b) = \text{clamp}(a + b, \text{T\_MIN}, \text{T\_MAX})
23
- ```
23
+ $$
24
24
 
25
25
  Reformulating as Python pseudocode:
26
26
 
@@ -8,13 +8,13 @@
8
8
  *
9
9
  * @section scalars_haswell_instructions Key AVX2/FMA Scalar Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput Ports
12
- * _mm_sqrt_ps VSQRTPS (XMM, XMM) 11cy 7cy p0
13
- * _mm_sqrt_pd VSQRTPD (XMM, XMM) 16cy 12cy p0
14
- * _mm_fmadd_ss VFMADD (XMM, XMM, XMM) 5cy 0.5/cy p01
15
- * _mm_fmadd_sd VFMADD (XMM, XMM, XMM) 5cy 0.5/cy p01
16
- * _mm_cvtps_ph VCVTPS2PH (XMM, XMM, I8) 4cy 1/cy p01+p5
17
- * _mm_cvtph_ps VCVTPH2PS (XMM, XMM) 5cy 1/cy p01
11
+ * Intrinsic Instruction Haswell Genoa
12
+ * _mm_sqrt_ps VSQRTPS (XMM, XMM) 11cy @ p0 15cy @ p01
13
+ * _mm_sqrt_pd VSQRTPD (XMM, XMM) 16cy @ p0 15cy @ p01
14
+ * _mm_fmadd_ss VFMADD (XMM, XMM, XMM) 5cy @ p01 4cy @ p01
15
+ * _mm_fmadd_sd VFMADD (XMM, XMM, XMM) 5cy @ p01 4cy @ p01
16
+ * _mm_cvtps_ph VCVTPS2PH (XMM, XMM, I8) 5cy @ p01 4cy @ p12+p23
17
+ * _mm_cvtph_ps VCVTPH2PS (XMM, XMM) 5cy @ p01 4cy @ p12+p23
18
18
  */
19
19
  #ifndef NK_SCALAR_HASWELL_H
20
20
  #define NK_SCALAR_HASWELL_H
@@ -52,23 +52,32 @@ NK_PUBLIC nk_f64_t nk_f64_fma_haswell(nk_f64_t a, nk_f64_t b, nk_f64_t c) {
52
52
  return _mm_cvtsd_f64(_mm_fmadd_sd(_mm_set_sd(a), _mm_set_sd(b), _mm_set_sd(c)));
53
53
  }
54
54
  NK_PUBLIC nk_f16_t nk_f16_sqrt_haswell(nk_f16_t x) {
55
- __m128 x_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(x));
56
- return (nk_f16_t)_mm_cvtsi128_si32(_mm_cvtps_ph(_mm_sqrt_ps(x_f32x4), _MM_FROUND_TO_NEAREST_INT));
55
+ nk_fui16_t x_fui, out_fui;
56
+ x_fui.f = x;
57
+ __m128 x_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(x_fui.u));
58
+ out_fui.u = (nk_u16_t)_mm_cvtsi128_si32(_mm_cvtps_ph(_mm_sqrt_ps(x_f32x4), _MM_FROUND_TO_NEAREST_INT));
59
+ return out_fui.f;
57
60
  }
58
61
  NK_PUBLIC nk_f16_t nk_f16_rsqrt_haswell(nk_f16_t x) {
59
- __m128 x_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(x));
62
+ nk_fui16_t x_fui, out_fui;
63
+ x_fui.f = x;
64
+ __m128 x_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(x_fui.u));
60
65
  __m128 estimate_f32x4 = _mm_rsqrt_ss(x_f32x4);
61
66
  __m128 refinement_f32x4 = _mm_mul_ss(_mm_mul_ss(x_f32x4, estimate_f32x4), estimate_f32x4);
62
67
  refinement_f32x4 = _mm_sub_ss(_mm_set_ss(3.0f), refinement_f32x4);
63
68
  estimate_f32x4 = _mm_mul_ss(_mm_mul_ss(_mm_set_ss(0.5f), estimate_f32x4), refinement_f32x4);
64
- return (nk_f16_t)_mm_cvtsi128_si32(_mm_cvtps_ph(estimate_f32x4, _MM_FROUND_TO_NEAREST_INT));
69
+ out_fui.u = (nk_u16_t)_mm_cvtsi128_si32(_mm_cvtps_ph(estimate_f32x4, _MM_FROUND_TO_NEAREST_INT));
70
+ return out_fui.f;
65
71
  }
66
72
  NK_PUBLIC nk_f16_t nk_f16_fma_haswell(nk_f16_t a, nk_f16_t b, nk_f16_t c) {
67
- __m128 a_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(a));
68
- __m128 b_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(b));
69
- __m128 c_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(c));
70
- return (nk_f16_t)_mm_cvtsi128_si32(
73
+ nk_fui16_t a_fui, b_fui, c_fui, out_fui;
74
+ a_fui.f = a, b_fui.f = b, c_fui.f = c;
75
+ __m128 a_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(a_fui.u));
76
+ __m128 b_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(b_fui.u));
77
+ __m128 c_f32x4 = _mm_cvtph_ps(_mm_cvtsi32_si128(c_fui.u));
78
+ out_fui.u = (nk_u16_t)_mm_cvtsi128_si32(
71
79
  _mm_cvtps_ph(_mm_fmadd_ss(a_f32x4, b_f32x4, c_f32x4), _MM_FROUND_TO_NEAREST_INT));
80
+ return out_fui.f;
72
81
  }
73
82
  NK_PUBLIC nk_u8_t nk_u8_saturating_add_haswell(nk_u8_t a, nk_u8_t b) {
74
83
  return (nk_u8_t)_mm_cvtsi128_si32(_mm_adds_epu8(_mm_cvtsi32_si128(a), _mm_cvtsi32_si128(b)));
@@ -89,8 +98,8 @@ NK_PUBLIC nk_u64_t nk_u64_saturating_mul_haswell(nk_u64_t a, nk_u64_t b) {
89
98
  }
90
99
  NK_PUBLIC nk_i64_t nk_i64_saturating_mul_haswell(nk_i64_t a, nk_i64_t b) {
91
100
  int sign = (a < 0) ^ (b < 0);
92
- nk_u64_t abs_a = a < 0 ? -(nk_u64_t)a : (nk_u64_t)a;
93
- nk_u64_t abs_b = b < 0 ? -(nk_u64_t)b : (nk_u64_t)b;
101
+ nk_u64_t abs_a = a < 0 ? (0u - (nk_u64_t)a) : (nk_u64_t)a;
102
+ nk_u64_t abs_b = b < 0 ? (0u - (nk_u64_t)b) : (nk_u64_t)b;
94
103
  unsigned long long high;
95
104
  unsigned long long low = _mulx_u64(abs_a, abs_b, &high);
96
105
  if (high || (sign && low > 9223372036854775808ull) || (!sign && low > 9223372036854775807ull))
@@ -0,0 +1,74 @@
1
+ /**
2
+ * @brief SIMD-accelerated Scalar Math Helpers for LoongArch LASX.
3
+ * @file include/numkong/scalar/loongsonasx.h
4
+ * @author Ash Vardanian
5
+ * @date March 23, 2026
6
+ *
7
+ * @sa include/numkong/scalar.h
8
+ *
9
+ * LASX provides `xvfrsqrt` (full-precision reciprocal sqrt) and `xvfsqrt`
10
+ * (full-precision sqrt). No Newton-Raphson refinement needed.
11
+ * Full-precision sqrt uses the hardware `xvfsqrt` instruction.
12
+ * Broadcast via `xvreplgr2vr`, extract via `xvpickve2gr` — no memory round-trips.
13
+ */
14
+ #ifndef NK_SCALAR_LOONGSONASX_H
15
+ #define NK_SCALAR_LOONGSONASX_H
16
+
17
+ #if NK_TARGET_LOONGARCH_
18
+ #if NK_TARGET_LOONGSONASX
19
+
20
+ #include "numkong/types.h"
21
+
22
+ #if defined(__cplusplus)
23
+ extern "C" {
24
+ #endif
25
+
26
+ /** @brief Broadcast f32 scalar into all 4 lanes of a 128-bit register (GCC/Clang portable). */
27
+ NK_INTERNAL __m128 nk_xvreplgr2vr_s_128_(float x) {
28
+ nk_fui32_t c;
29
+ c.f = x;
30
+ return (__m128)__lsx_vreplgr2vr_w((int)c.u);
31
+ }
32
+
33
+ /** @brief Broadcast f32 scalar into all 8 lanes of a 256-bit register (GCC/Clang portable). */
34
+ NK_INTERNAL __m256 nk_xvfreplgr2vr_s_(float x) {
35
+ nk_fui32_t c;
36
+ c.f = x;
37
+ return (__m256)__lasx_xvreplgr2vr_w((int)c.u);
38
+ }
39
+
40
+ /** @brief Broadcast f64 scalar into all 4 lanes of a 256-bit register (GCC/Clang portable). */
41
+ NK_INTERNAL __m256d nk_xvfreplgr2vr_d_(double x) {
42
+ nk_fui64_t c;
43
+ c.f = x;
44
+ return (__m256d)__lasx_xvreplgr2vr_d((long long)c.u);
45
+ }
46
+
47
+ NK_PUBLIC nk_f32_t nk_f32_rsqrt_loongsonasx(nk_f32_t x) {
48
+ // xvfrsqrt.s is full precision — no Newton-Raphson needed
49
+ __m256 x_f32x8 = nk_xvfreplgr2vr_s_(x);
50
+ __m256 result_f32x8 = __lasx_xvfrsqrt_s(x_f32x8);
51
+ nk_fui32_t c;
52
+ c.u = (nk_u32_t)__lasx_xvpickve2gr_w((__m256i)result_f32x8, 0);
53
+ return c.f;
54
+ }
55
+
56
+ NK_PUBLIC nk_f32_t nk_f32_sqrt_loongsonasx(nk_f32_t x) { return x > 0 ? x * nk_f32_rsqrt_loongsonasx(x) : 0; }
57
+
58
+ NK_PUBLIC nk_f64_t nk_f64_sqrt_loongsonasx(nk_f64_t x) {
59
+ __m256d x_f64x4 = nk_xvfreplgr2vr_d_(x);
60
+ __m256d result_f64x4 = __lasx_xvfsqrt_d(x_f64x4);
61
+ nk_fui64_t c;
62
+ c.u = (nk_u64_t)__lasx_xvpickve2gr_du((__m256i)result_f64x4, 0);
63
+ return c.f;
64
+ }
65
+
66
+ NK_PUBLIC nk_f64_t nk_f64_rsqrt_loongsonasx(nk_f64_t x) { return 1.0 / nk_f64_sqrt_loongsonasx(x); }
67
+
68
+ #if defined(__cplusplus)
69
+ } // extern "C"
70
+ #endif
71
+
72
+ #endif // NK_TARGET_LOONGSONASX
73
+ #endif // NK_TARGET_LOONGARCH_
74
+ #endif // NK_SCALAR_LOONGSONASX_H