numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,16 +8,16 @@
8
8
  *
9
9
  * @section dot_svehalf_instructions ARM SVE+FP16 Instructions
10
10
  *
11
- * Intrinsic Instruction Latency Throughput
12
- * svld1_f16 LD1H (Z.H, P/Z, [Xn]) 4-6cy 2/cy
13
- * svld2_f16 LD2H (Z.H, P/Z, [Xn]) 6-8cy 1/cy
14
- * svmla_f16_x FMLA (Z.H, P/M, Z.H, Z.H) 4cy 2/cy
15
- * svmls_f16_x FMLS (Z.H, P/M, Z.H, Z.H) 4cy 2/cy
16
- * svaddv_f16 FADDV (H, P, Z.H) 6cy 1/cy
17
- * svdup_f16 DUP (Z.H, #imm) 1cy 2/cy
18
- * svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy 1/cy
19
- * svptrue_b16 PTRUE (P.H, pattern) 1cy 2/cy
20
- * svcnth CNTH (Xd) 1cy 2/cy
11
+ * Intrinsic Instruction V1
12
+ * svld1_f16 LD1H (Z.H, P/Z, [Xn]) 4-6cy @ 2p
13
+ * svld2_f16 LD2H (Z.H, P/Z, [Xn]) 6-8cy @ 1p
14
+ * svmla_f16_x FMLA (Z.H, P/M, Z.H, Z.H) 4cy @ 2p
15
+ * svmls_f16_x FMLS (Z.H, P/M, Z.H, Z.H) 4cy @ 2p
16
+ * svaddv_f16 FADDV (H, P, Z.H) 6cy @ 1p
17
+ * svdup_f16 DUP (Z.H, #imm) 1cy @ 2p
18
+ * svwhilelt_b16 WHILELT (P.H, Xn, Xm) 2cy @ 1p
19
+ * svptrue_b16 PTRUE (P.H, pattern) 1cy @ 2p
20
+ * svcnth CNTH (Xd) 1cy @ 2p
21
21
  *
22
22
  * SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
23
23
  * and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
@@ -51,13 +51,21 @@ NK_PUBLIC void nk_dot_f16_svehalf(nk_f16_t const *a_scalars, nk_f16_t const *b_s
51
51
  nk_size_t idx_scalars = 0;
52
52
  svfloat32_t ab_f32x = svdup_f32(0);
53
53
  do {
54
- svbool_t predicate_f32x = svwhilelt_b32_u64(idx_scalars, count_scalars);
55
- svfloat16_t a_f16x = svld1_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(a_scalars) + idx_scalars);
56
- svfloat16_t b_f16x = svld1_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(b_scalars) + idx_scalars);
57
- svfloat32_t a_f32x = svcvt_f32_f16_x(predicate_f32x, a_f16x);
58
- svfloat32_t b_f32x = svcvt_f32_f16_x(predicate_f32x, b_f16x);
59
- ab_f32x = svmla_f32_x(predicate_f32x, ab_f32x, a_f32x, b_f32x);
60
- idx_scalars += svcntw();
54
+ svbool_t predicate_b16x = svwhilelt_b16_u64(idx_scalars, count_scalars);
55
+ svfloat16_t a_f16x = svld1_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(a_scalars) + idx_scalars);
56
+ svfloat16_t b_f16x = svld1_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(b_scalars) + idx_scalars);
57
+ nk_size_t remaining = count_scalars - idx_scalars < svcnth() ? count_scalars - idx_scalars : svcnth();
58
+
59
+ // svcvt_f32_f16_x widens only even-indexed f16 elements; svext by 1 shifts odd into even.
60
+ svbool_t pred_even_b32x = svwhilelt_b32_u64(0u, (remaining + 1) / 2);
61
+ ab_f32x = svmla_f32_m(pred_even_b32x, ab_f32x, svcvt_f32_f16_x(pred_even_b32x, a_f16x),
62
+ svcvt_f32_f16_x(pred_even_b32x, b_f16x));
63
+
64
+ svbool_t pred_odd_b32x = svwhilelt_b32_u64(0u, remaining / 2);
65
+ ab_f32x = svmla_f32_m(pred_odd_b32x, ab_f32x, svcvt_f32_f16_x(pred_odd_b32x, svext_f16(a_f16x, a_f16x, 1)),
66
+ svcvt_f32_f16_x(pred_odd_b32x, svext_f16(b_f16x, b_f16x, 1)));
67
+
68
+ idx_scalars += svcnth();
61
69
  } while (idx_scalars < count_scalars);
62
70
  *result = svaddv_f32(svptrue_b32(), ab_f32x);
63
71
  }
@@ -68,18 +76,36 @@ NK_PUBLIC void nk_dot_f16c_svehalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b_
68
76
  svfloat32_t ab_real_f32x = svdup_f32(0);
69
77
  svfloat32_t ab_imag_f32x = svdup_f32(0);
70
78
  do {
71
- svbool_t predicate_f32x = svwhilelt_b32_u64(idx_scalars, count_pairs);
72
- svfloat16x2_t a_f16x2 = svld2_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(a_pairs) + idx_scalars * 2);
73
- svfloat16x2_t b_f16x2 = svld2_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(b_pairs) + idx_scalars * 2);
74
- svfloat32_t a_real_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(a_f16x2, 0));
75
- svfloat32_t a_imag_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(a_f16x2, 1));
76
- svfloat32_t b_real_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(b_f16x2, 0));
77
- svfloat32_t b_imag_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(b_f16x2, 1));
78
- ab_real_f32x = svmla_f32_x(predicate_f32x, ab_real_f32x, a_real_f32x, b_real_f32x);
79
- ab_real_f32x = svmls_f32_x(predicate_f32x, ab_real_f32x, a_imag_f32x, b_imag_f32x);
80
- ab_imag_f32x = svmla_f32_x(predicate_f32x, ab_imag_f32x, a_real_f32x, b_imag_f32x);
81
- ab_imag_f32x = svmla_f32_x(predicate_f32x, ab_imag_f32x, a_imag_f32x, b_real_f32x);
82
- idx_scalars += svcntw();
79
+ svbool_t predicate_b16x = svwhilelt_b16_u64(idx_scalars, count_pairs);
80
+ svfloat16x2_t a_f16x2x = svld2_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(a_pairs) + idx_scalars * 2);
81
+ svfloat16x2_t b_f16x2x = svld2_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(b_pairs) + idx_scalars * 2);
82
+ svfloat16_t ar_f16x = svget2_f16(a_f16x2x, 0), ai_f16x = svget2_f16(a_f16x2x, 1);
83
+ svfloat16_t br_f16x = svget2_f16(b_f16x2x, 0), bi_f16x = svget2_f16(b_f16x2x, 1);
84
+ nk_size_t remaining = count_pairs - idx_scalars < svcnth() ? count_pairs - idx_scalars : svcnth();
85
+
86
+ // Even-indexed elements of each deinterleaved component
87
+ svbool_t pred_even_b32x = svwhilelt_b32_u64(0u, (remaining + 1) / 2);
88
+ svfloat32_t ar_even_f32x = svcvt_f32_f16_x(pred_even_b32x, ar_f16x);
89
+ svfloat32_t ai_even_f32x = svcvt_f32_f16_x(pred_even_b32x, ai_f16x);
90
+ svfloat32_t br_even_f32x = svcvt_f32_f16_x(pred_even_b32x, br_f16x);
91
+ svfloat32_t bi_even_f32x = svcvt_f32_f16_x(pred_even_b32x, bi_f16x);
92
+ ab_real_f32x = svmla_f32_m(pred_even_b32x, ab_real_f32x, ar_even_f32x, br_even_f32x);
93
+ ab_real_f32x = svmls_f32_m(pred_even_b32x, ab_real_f32x, ai_even_f32x, bi_even_f32x);
94
+ ab_imag_f32x = svmla_f32_m(pred_even_b32x, ab_imag_f32x, ar_even_f32x, bi_even_f32x);
95
+ ab_imag_f32x = svmla_f32_m(pred_even_b32x, ab_imag_f32x, ai_even_f32x, br_even_f32x);
96
+
97
+ // Odd-indexed elements via svext shift-by-1
98
+ svbool_t pred_odd_b32x = svwhilelt_b32_u64(0u, remaining / 2);
99
+ svfloat32_t ar_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(ar_f16x, ar_f16x, 1));
100
+ svfloat32_t ai_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(ai_f16x, ai_f16x, 1));
101
+ svfloat32_t br_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(br_f16x, br_f16x, 1));
102
+ svfloat32_t bi_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(bi_f16x, bi_f16x, 1));
103
+ ab_real_f32x = svmla_f32_m(pred_odd_b32x, ab_real_f32x, ar_odd_f32x, br_odd_f32x);
104
+ ab_real_f32x = svmls_f32_m(pred_odd_b32x, ab_real_f32x, ai_odd_f32x, bi_odd_f32x);
105
+ ab_imag_f32x = svmla_f32_m(pred_odd_b32x, ab_imag_f32x, ar_odd_f32x, bi_odd_f32x);
106
+ ab_imag_f32x = svmla_f32_m(pred_odd_b32x, ab_imag_f32x, ai_odd_f32x, br_odd_f32x);
107
+
108
+ idx_scalars += svcnth();
83
109
  } while (idx_scalars < count_pairs);
84
110
  results->real = svaddv_f32(svptrue_b32(), ab_real_f32x);
85
111
  results->imag = svaddv_f32(svptrue_b32(), ab_imag_f32x);
@@ -91,18 +117,36 @@ NK_PUBLIC void nk_vdot_f16c_svehalf(nk_f16c_t const *a_pairs, nk_f16c_t const *b
91
117
  svfloat32_t ab_real_f32x = svdup_f32(0);
92
118
  svfloat32_t ab_imag_f32x = svdup_f32(0);
93
119
  do {
94
- svbool_t predicate_f32x = svwhilelt_b32_u64(idx_scalars, count_pairs);
95
- svfloat16x2_t a_f16x2 = svld2_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(a_pairs) + idx_scalars * 2);
96
- svfloat16x2_t b_f16x2 = svld2_f16(predicate_f32x, (nk_f16_for_arm_simd_t const *)(b_pairs) + idx_scalars * 2);
97
- svfloat32_t a_real_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(a_f16x2, 0));
98
- svfloat32_t a_imag_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(a_f16x2, 1));
99
- svfloat32_t b_real_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(b_f16x2, 0));
100
- svfloat32_t b_imag_f32x = svcvt_f32_f16_x(predicate_f32x, svget2_f16(b_f16x2, 1));
101
- ab_real_f32x = svmla_f32_x(predicate_f32x, ab_real_f32x, a_real_f32x, b_real_f32x);
102
- ab_real_f32x = svmla_f32_x(predicate_f32x, ab_real_f32x, a_imag_f32x, b_imag_f32x);
103
- ab_imag_f32x = svmla_f32_x(predicate_f32x, ab_imag_f32x, a_real_f32x, b_imag_f32x);
104
- ab_imag_f32x = svmls_f32_x(predicate_f32x, ab_imag_f32x, a_imag_f32x, b_real_f32x);
105
- idx_scalars += svcntw();
120
+ svbool_t predicate_b16x = svwhilelt_b16_u64(idx_scalars, count_pairs);
121
+ svfloat16x2_t a_f16x2x = svld2_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(a_pairs) + idx_scalars * 2);
122
+ svfloat16x2_t b_f16x2x = svld2_f16(predicate_b16x, (nk_f16_for_arm_simd_t const *)(b_pairs) + idx_scalars * 2);
123
+ svfloat16_t ar_f16x = svget2_f16(a_f16x2x, 0), ai_f16x = svget2_f16(a_f16x2x, 1);
124
+ svfloat16_t br_f16x = svget2_f16(b_f16x2x, 0), bi_f16x = svget2_f16(b_f16x2x, 1);
125
+ nk_size_t remaining = count_pairs - idx_scalars < svcnth() ? count_pairs - idx_scalars : svcnth();
126
+
127
+ // Even-indexed elements
128
+ svbool_t pred_even_b32x = svwhilelt_b32_u64(0u, (remaining + 1) / 2);
129
+ svfloat32_t ar_even_f32x = svcvt_f32_f16_x(pred_even_b32x, ar_f16x);
130
+ svfloat32_t ai_even_f32x = svcvt_f32_f16_x(pred_even_b32x, ai_f16x);
131
+ svfloat32_t br_even_f32x = svcvt_f32_f16_x(pred_even_b32x, br_f16x);
132
+ svfloat32_t bi_even_f32x = svcvt_f32_f16_x(pred_even_b32x, bi_f16x);
133
+ ab_real_f32x = svmla_f32_m(pred_even_b32x, ab_real_f32x, ar_even_f32x, br_even_f32x);
134
+ ab_real_f32x = svmla_f32_m(pred_even_b32x, ab_real_f32x, ai_even_f32x, bi_even_f32x);
135
+ ab_imag_f32x = svmla_f32_m(pred_even_b32x, ab_imag_f32x, ar_even_f32x, bi_even_f32x);
136
+ ab_imag_f32x = svmls_f32_m(pred_even_b32x, ab_imag_f32x, ai_even_f32x, br_even_f32x);
137
+
138
+ // Odd-indexed elements via svext shift-by-1
139
+ svbool_t pred_odd_b32x = svwhilelt_b32_u64(0u, remaining / 2);
140
+ svfloat32_t ar_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(ar_f16x, ar_f16x, 1));
141
+ svfloat32_t ai_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(ai_f16x, ai_f16x, 1));
142
+ svfloat32_t br_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(br_f16x, br_f16x, 1));
143
+ svfloat32_t bi_odd_f32x = svcvt_f32_f16_x(pred_odd_b32x, svext_f16(bi_f16x, bi_f16x, 1));
144
+ ab_real_f32x = svmla_f32_m(pred_odd_b32x, ab_real_f32x, ar_odd_f32x, br_odd_f32x);
145
+ ab_real_f32x = svmla_f32_m(pred_odd_b32x, ab_real_f32x, ai_odd_f32x, bi_odd_f32x);
146
+ ab_imag_f32x = svmla_f32_m(pred_odd_b32x, ab_imag_f32x, ar_odd_f32x, bi_odd_f32x);
147
+ ab_imag_f32x = svmls_f32_m(pred_odd_b32x, ab_imag_f32x, ai_odd_f32x, br_odd_f32x);
148
+
149
+ idx_scalars += svcnth();
106
150
  } while (idx_scalars < count_pairs);
107
151
  results->real = svaddv_f32(svptrue_b32(), ab_real_f32x);
108
152
  results->imag = svaddv_f32(svptrue_b32(), ab_imag_f32x);
@@ -0,0 +1,89 @@
1
+ /**
2
+ * @brief SIMD-accelerated Dot Products for SVE SDOT.
3
+ * @file include/numkong/dot/svesdot.h
4
+ * @author Ash Vardanian
5
+ * @date April 3, 2026
6
+ *
7
+ * @sa include/numkong/dot.h
8
+ *
9
+ * @section dot_svesdot_instructions ARM SVE+DotProd Instructions
10
+ *
11
+ * Intrinsic Instruction V1
12
+ * svld1_s8 LD1B (Z.B, P/Z, [Xn]) 4-6cy @ 2p
13
+ * svld1_u8 LD1B (Z.B, P/Z, [Xn]) 4-6cy @ 2p
14
+ * svdot_s32 SDOT (Z.S, Z.B, Z.B) 3cy @ 2p
15
+ * svdot_u32 UDOT (Z.S, Z.B, Z.B) 3cy @ 2p
16
+ * svaddv_s32 SADDV (D, P, Z.S) 6cy @ 1p
17
+ * svaddv_u32 UADDV (D, P, Z.S) 6cy @ 1p
18
+ * svdup_s32 DUP (Z.S, #imm) 1cy @ 2p
19
+ * svwhilelt_b8 WHILELT (P.B, Xn, Xm) 2cy @ 1p
20
+ * svcntb CNTB (Xd) 1cy @ 2p
21
+ *
22
+ * SVE vector widths vary across implementations: Graviton3 uses 256-bit, while Graviton4/5
23
+ * and Apple M4+ use 128-bit. Code using svcntb() adapts automatically, but wider vectors
24
+ * process more elements per iteration with identical latencies.
25
+ *
26
+ * The SDOT/UDOT instructions fuse four int8 multiplications with int32 accumulation per lane,
27
+ * providing the same 4-way dot product as NEON SDOT but with scalable vector widths.
28
+ * On 256-bit SVE, this processes 32 int8 elements per instruction vs NEON's fixed 16.
29
+ */
30
+ #ifndef NK_DOT_SVESDOT_H
31
+ #define NK_DOT_SVESDOT_H
32
+
33
+ #if NK_TARGET_ARM_
34
+ #if NK_TARGET_SVESDOT
35
+
36
+ #include "numkong/types.h"
37
+
38
+ #if defined(__cplusplus)
39
+ extern "C" {
40
+ #endif
41
+
42
+ #if defined(__clang__)
43
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+sve+dotprod"))), apply_to = function)
44
+ #elif defined(__GNUC__)
45
+ #pragma GCC push_options
46
+ #pragma GCC target("arch=armv8.2-a+sve+dotprod")
47
+ #endif
48
+
49
+ NK_PUBLIC void nk_dot_i8_svesdot(nk_i8_t const *a_scalars, nk_i8_t const *b_scalars, nk_size_t count_scalars,
50
+ nk_i32_t *result) {
51
+ nk_size_t idx_scalars = 0;
52
+ svint32_t sum_i32x = svdup_s32(0);
53
+ do {
54
+ svbool_t predicate_b8x = svwhilelt_b8_u64(idx_scalars, count_scalars);
55
+ svint8_t a_i8x = svld1_s8(predicate_b8x, a_scalars + idx_scalars);
56
+ svint8_t b_i8x = svld1_s8(predicate_b8x, b_scalars + idx_scalars);
57
+ sum_i32x = svdot_s32(sum_i32x, a_i8x, b_i8x);
58
+ idx_scalars += svcntb();
59
+ } while (idx_scalars < count_scalars);
60
+ *result = (nk_i32_t)svaddv_s32(svptrue_b32(), sum_i32x);
61
+ }
62
+
63
+ NK_PUBLIC void nk_dot_u8_svesdot(nk_u8_t const *a_scalars, nk_u8_t const *b_scalars, nk_size_t count_scalars,
64
+ nk_u32_t *result) {
65
+ nk_size_t idx_scalars = 0;
66
+ svuint32_t sum_u32x = svdup_u32(0);
67
+ do {
68
+ svbool_t predicate_b8x = svwhilelt_b8_u64(idx_scalars, count_scalars);
69
+ svuint8_t a_u8x = svld1_u8(predicate_b8x, a_scalars + idx_scalars);
70
+ svuint8_t b_u8x = svld1_u8(predicate_b8x, b_scalars + idx_scalars);
71
+ sum_u32x = svdot_u32(sum_u32x, a_u8x, b_u8x);
72
+ idx_scalars += svcntb();
73
+ } while (idx_scalars < count_scalars);
74
+ *result = (nk_u32_t)svaddv_u32(svptrue_b32(), sum_u32x);
75
+ }
76
+
77
+ #if defined(__clang__)
78
+ #pragma clang attribute pop
79
+ #elif defined(__GNUC__)
80
+ #pragma GCC pop_options
81
+ #endif
82
+
83
+ #if defined(__cplusplus)
84
+ } // extern "C"
85
+ #endif
86
+
87
+ #endif // NK_TARGET_SVESDOT
88
+ #endif // NK_TARGET_ARM_
89
+ #endif // NK_DOT_SVESDOT_H
@@ -73,8 +73,8 @@ nk_dot_f32_v128relaxed_cycle:
73
73
  nk_load_b64_serial_(b_scalars, &b_f32_vec);
74
74
  a_scalars += 2, b_scalars += 2, count_scalars -= 2;
75
75
  }
76
- v128_t a_f32x2 = wasm_v128_load64_zero(&a_f32_vec.u64);
77
- v128_t b_f32x2 = wasm_v128_load64_zero(&b_f32_vec.u64);
76
+ v128_t a_f32x2 = wasm_i64x2_splat(a_f32_vec.u64);
77
+ v128_t b_f32x2 = wasm_i64x2_splat(b_f32_vec.u64);
78
78
  v128_t a_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
79
79
  v128_t b_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
80
80
  sum_f64x2 = wasm_f64x2_relaxed_madd(a_f64x2, b_f64x2, sum_f64x2);
@@ -110,24 +110,28 @@ nk_dot_f16_v128relaxed_cycle:
110
110
 
111
111
  NK_PUBLIC void nk_dot_bf16_v128relaxed(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
112
112
  v128_t sum_f32x4 = wasm_f32x4_splat(0.0f);
113
+ v128_t mask_high_u32x4 = wasm_i32x4_splat((int)0xFFFF0000);
113
114
  nk_bf16_t const *a_scalars = a, *b_scalars = b;
114
115
  nk_size_t count_scalars = n;
115
- nk_b64_vec_t a_bf16_vec, b_bf16_vec;
116
+ nk_b128_vec_t a_bf16_vec, b_bf16_vec;
116
117
 
117
118
  nk_dot_bf16_v128relaxed_cycle:
118
- if (count_scalars < 4) {
119
- nk_partial_load_b16x4_serial_(a_scalars, &a_bf16_vec, count_scalars);
120
- nk_partial_load_b16x4_serial_(b_scalars, &b_bf16_vec, count_scalars);
119
+ if (count_scalars < 8) {
120
+ nk_partial_load_b16x8_serial_(a_scalars, &a_bf16_vec, count_scalars);
121
+ nk_partial_load_b16x8_serial_(b_scalars, &b_bf16_vec, count_scalars);
121
122
  count_scalars = 0;
122
123
  }
123
124
  else {
124
- nk_load_b64_serial_(a_scalars, &a_bf16_vec);
125
- nk_load_b64_serial_(b_scalars, &b_bf16_vec);
126
- a_scalars += 4, b_scalars += 4, count_scalars -= 4;
125
+ nk_load_b128_v128relaxed_(a_scalars, &a_bf16_vec);
126
+ nk_load_b128_v128relaxed_(b_scalars, &b_bf16_vec);
127
+ a_scalars += 8, b_scalars += 8, count_scalars -= 8;
127
128
  }
128
- nk_b128_vec_t a_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(a_bf16_vec);
129
- nk_b128_vec_t b_f32_vec = nk_bf16x4_to_f32x4_v128relaxed_(b_bf16_vec);
130
- sum_f32x4 = wasm_f32x4_relaxed_madd(a_f32_vec.v128, b_f32_vec.v128, sum_f32x4);
129
+ v128_t a_even_f32x4 = wasm_i32x4_shl(a_bf16_vec.v128, 16);
130
+ v128_t b_even_f32x4 = wasm_i32x4_shl(b_bf16_vec.v128, 16);
131
+ sum_f32x4 = wasm_f32x4_relaxed_madd(a_even_f32x4, b_even_f32x4, sum_f32x4);
132
+ v128_t a_odd_f32x4 = wasm_v128_and(a_bf16_vec.v128, mask_high_u32x4);
133
+ v128_t b_odd_f32x4 = wasm_v128_and(b_bf16_vec.v128, mask_high_u32x4);
134
+ sum_f32x4 = wasm_f32x4_relaxed_madd(a_odd_f32x4, b_odd_f32x4, sum_f32x4);
131
135
  if (count_scalars) goto nk_dot_bf16_v128relaxed_cycle;
132
136
 
133
137
  *result = nk_reduce_add_f32x4_v128relaxed_(sum_f32x4);
@@ -274,8 +278,8 @@ NK_PUBLIC void nk_dot_e2m3_v128relaxed(nk_e2m3_t const *a_scalars, nk_e2m3_t con
274
278
  // Result = i32_dot / 256.0f (exact, no rounding error).
275
279
  //
276
280
  // 32-entry LUT split into two 16-entry halves for wasm_i8x16_relaxed_swizzle (indexes 0-15).
277
- v128_t lut_lower_u8x16 = wasm_i8x16_const(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30);
278
- v128_t lut_upper_u8x16 = wasm_i8x16_const(32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120);
281
+ v128_t lut_low_u8x16 = wasm_i8x16_const(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30);
282
+ v128_t lut_high_u8x16 = wasm_i8x16_const(32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120);
279
283
  v128_t magnitude_mask_u8x16 = wasm_u8x16_splat(0x1F);
280
284
  v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
281
285
  v128_t half_select_u8x16 = wasm_u8x16_splat(0x10);
@@ -304,17 +308,17 @@ nk_dot_e2m3_v128relaxed_cycle:
304
308
 
305
309
  // Dual swizzle + bitselect for 32-entry LUT (a)
306
310
  v128_t a_shuffle_index_u8x16 = wasm_v128_and(a_magnitude_u8x16, nibble_mask_u8x16);
307
- v128_t a_lower_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lower_u8x16, a_shuffle_index_u8x16);
308
- v128_t a_upper_u8x16 = wasm_i8x16_relaxed_swizzle(lut_upper_u8x16, a_shuffle_index_u8x16);
309
- v128_t a_upper_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
310
- v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(a_upper_u8x16, a_lower_u8x16, a_upper_select_u8x16);
311
+ v128_t a_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_u8x16, a_shuffle_index_u8x16);
312
+ v128_t a_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_high_u8x16, a_shuffle_index_u8x16);
313
+ v128_t a_high_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
314
+ v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(a_high_u8x16, a_low_u8x16, a_high_select_u8x16);
311
315
 
312
316
  // Dual swizzle + bitselect for 32-entry LUT (b)
313
317
  v128_t b_shuffle_index_u8x16 = wasm_v128_and(b_magnitude_u8x16, nibble_mask_u8x16);
314
- v128_t b_lower_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lower_u8x16, b_shuffle_index_u8x16);
315
- v128_t b_upper_u8x16 = wasm_i8x16_relaxed_swizzle(lut_upper_u8x16, b_shuffle_index_u8x16);
316
- v128_t b_upper_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
317
- v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(b_upper_u8x16, b_lower_u8x16, b_upper_select_u8x16);
318
+ v128_t b_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_u8x16, b_shuffle_index_u8x16);
319
+ v128_t b_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_high_u8x16, b_shuffle_index_u8x16);
320
+ v128_t b_high_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
321
+ v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(b_high_u8x16, b_low_u8x16, b_high_select_u8x16);
318
322
 
319
323
  // Combined sign: (a ^ b) & 0x20 — nonzero means negative product
320
324
  // Apply sign to a (relaxed_dot wants i8 × u7: a_signed, b_unsigned)
@@ -343,12 +347,13 @@ NK_PUBLIC void nk_dot_e3m2_v128relaxed(nk_e3m2_t const *a_scalars, nk_e3m2_t con
343
347
  // Low-byte LUT entries (magnitude[i] & 0xFF):
344
348
  // [0,1,2,3,4,5,6,7,8,10,12,14,16,20,24,28] lower half
345
349
  // [32,40,48,56,64,80,96,112,128,160,192,224,0,64,128,192] upper half
346
- v128_t lut_lo_lower_u8x16 = wasm_i8x16_const(0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28);
347
- v128_t lut_lo_upper_u8x16 = wasm_u8x16_const(32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 0, 64, 128, 192);
350
+ v128_t lut_low_byte_first_u8x16 = wasm_i8x16_const(0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28);
351
+ v128_t lut_low_byte_second_u8x16 = wasm_u8x16_const(32, 40, 48, 56, 64, 80, 96, 112, 128, 160, 192, 224, 0, 64, 128,
352
+ 192);
348
353
  v128_t magnitude_mask_u8x16 = wasm_u8x16_splat(0x1F);
349
354
  v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
350
355
  v128_t half_select_u8x16 = wasm_u8x16_splat(0x10);
351
- v128_t hi_threshold_u8x16 = wasm_u8x16_splat(28);
356
+ v128_t high_threshold_u8x16 = wasm_u8x16_splat(28);
352
357
  v128_t sign_mask_u8x16 = wasm_u8x16_splat(0x20);
353
358
  v128_t sum_i32x4 = wasm_i32x4_splat(0);
354
359
  v128_t a_e3m2_u8x16, b_e3m2_u8x16;
@@ -374,32 +379,34 @@ nk_dot_e3m2_v128relaxed_cycle:
374
379
 
375
380
  // Dual swizzle + bitselect for 32-entry low-byte LUT (a)
376
381
  v128_t a_shuffle_index_u8x16 = wasm_v128_and(a_magnitude_u8x16, nibble_mask_u8x16);
377
- v128_t a_lower_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lo_lower_u8x16, a_shuffle_index_u8x16);
378
- v128_t a_upper_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lo_upper_u8x16, a_shuffle_index_u8x16);
379
- v128_t a_upper_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
380
- v128_t a_lo_bytes_u8x16 = wasm_i8x16_relaxed_laneselect(a_upper_u8x16, a_lower_u8x16, a_upper_select_u8x16);
382
+ v128_t a_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_byte_first_u8x16, a_shuffle_index_u8x16);
383
+ v128_t a_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_byte_second_u8x16, a_shuffle_index_u8x16);
384
+ v128_t a_high_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
385
+ v128_t a_low_byte_u8x16 = wasm_i8x16_relaxed_laneselect(a_high_u8x16, a_low_u8x16, a_high_select_u8x16);
381
386
 
382
387
  // High byte is 1 iff magnitude index >= 28 (values 256, 320, 384, 448), else 0
383
- v128_t a_hi_bytes_u8x16 = wasm_v128_and(wasm_u8x16_ge(a_magnitude_u8x16, hi_threshold_u8x16), wasm_u8x16_splat(1));
388
+ v128_t a_high_byte_u8x16 = wasm_v128_and(wasm_u8x16_ge(a_magnitude_u8x16, high_threshold_u8x16),
389
+ wasm_u8x16_splat(1));
384
390
 
385
391
  // Dual swizzle + bitselect for 32-entry low-byte LUT (b)
386
392
  v128_t b_shuffle_index_u8x16 = wasm_v128_and(b_magnitude_u8x16, nibble_mask_u8x16);
387
- v128_t b_lower_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lo_lower_u8x16, b_shuffle_index_u8x16);
388
- v128_t b_upper_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lo_upper_u8x16, b_shuffle_index_u8x16);
389
- v128_t b_upper_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
390
- v128_t b_lo_bytes_u8x16 = wasm_i8x16_relaxed_laneselect(b_upper_u8x16, b_lower_u8x16, b_upper_select_u8x16);
393
+ v128_t b_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_byte_first_u8x16, b_shuffle_index_u8x16);
394
+ v128_t b_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_byte_second_u8x16, b_shuffle_index_u8x16);
395
+ v128_t b_high_select_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_magnitude_u8x16, half_select_u8x16), half_select_u8x16);
396
+ v128_t b_low_byte_u8x16 = wasm_i8x16_relaxed_laneselect(b_high_u8x16, b_low_u8x16, b_high_select_u8x16);
391
397
 
392
398
  // High byte is 1 iff magnitude index >= 28
393
- v128_t b_hi_bytes_u8x16 = wasm_v128_and(wasm_u8x16_ge(b_magnitude_u8x16, hi_threshold_u8x16), wasm_u8x16_splat(1));
399
+ v128_t b_high_byte_u8x16 = wasm_v128_and(wasm_u8x16_ge(b_magnitude_u8x16, high_threshold_u8x16),
400
+ wasm_u8x16_splat(1));
394
401
 
395
402
  // Combine low and high bytes into i16 via byte interleave shuffle (little-endian: low byte first)
396
- v128_t a_unsigned_low_i16x8 = wasm_i8x16_shuffle(a_lo_bytes_u8x16, a_hi_bytes_u8x16, 0, 16, 1, 17, 2, 18, 3, 19, 4,
403
+ v128_t a_unsigned_low_i16x8 = wasm_i8x16_shuffle(a_low_byte_u8x16, a_high_byte_u8x16, 0, 16, 1, 17, 2, 18, 3, 19, 4,
397
404
  20, 5, 21, 6, 22, 7, 23);
398
- v128_t a_unsigned_high_i16x8 = wasm_i8x16_shuffle(a_lo_bytes_u8x16, a_hi_bytes_u8x16, 8, 24, 9, 25, 10, 26, 11, 27,
405
+ v128_t a_unsigned_high_i16x8 = wasm_i8x16_shuffle(a_low_byte_u8x16, a_high_byte_u8x16, 8, 24, 9, 25, 10, 26, 11, 27,
399
406
  12, 28, 13, 29, 14, 30, 15, 31);
400
- v128_t b_unsigned_low_i16x8 = wasm_i8x16_shuffle(b_lo_bytes_u8x16, b_hi_bytes_u8x16, 0, 16, 1, 17, 2, 18, 3, 19, 4,
407
+ v128_t b_unsigned_low_i16x8 = wasm_i8x16_shuffle(b_low_byte_u8x16, b_high_byte_u8x16, 0, 16, 1, 17, 2, 18, 3, 19, 4,
401
408
  20, 5, 21, 6, 22, 7, 23);
402
- v128_t b_unsigned_high_i16x8 = wasm_i8x16_shuffle(b_lo_bytes_u8x16, b_hi_bytes_u8x16, 8, 24, 9, 25, 10, 26, 11, 27,
409
+ v128_t b_unsigned_high_i16x8 = wasm_i8x16_shuffle(b_low_byte_u8x16, b_high_byte_u8x16, 8, 24, 9, 25, 10, 26, 11, 27,
403
410
  12, 28, 13, 29, 14, 30, 15, 31);
404
411
 
405
412
  // Combined sign: XOR sign bits, negate only b (saves ~15 ops vs independent negation)
@@ -497,6 +504,33 @@ NK_INTERNAL void nk_dot_through_f32x4_finalize_v128relaxed_( //
497
504
  result->f32s[3] = nk_reduce_add_f32x4_v128relaxed_(state_d->sum_f32x4);
498
505
  }
499
506
 
507
+ typedef struct nk_dot_through_f32x4_state_v128relaxed_t_ nk_dot_bf16x8_state_v128relaxed_t;
508
+
509
+ NK_INTERNAL void nk_dot_bf16x8_init_v128relaxed(nk_dot_bf16x8_state_v128relaxed_t *state) {
510
+ nk_dot_through_f32x4_init_v128relaxed_(state);
511
+ }
512
+
513
+ NK_INTERNAL void nk_dot_bf16x8_update_v128relaxed(nk_dot_bf16x8_state_v128relaxed_t *state, nk_b128_vec_t a,
514
+ nk_b128_vec_t b, nk_size_t depth_offset,
515
+ nk_size_t active_dimensions) {
516
+ nk_unused_(depth_offset);
517
+ nk_unused_(active_dimensions);
518
+ v128_t mask_high_u32x4 = wasm_i32x4_splat((int)0xFFFF0000);
519
+ v128_t a_even_f32x4 = wasm_i32x4_shl(a.v128, 16);
520
+ v128_t b_even_f32x4 = wasm_i32x4_shl(b.v128, 16);
521
+ state->sum_f32x4 = wasm_f32x4_relaxed_madd(a_even_f32x4, b_even_f32x4, state->sum_f32x4);
522
+ v128_t a_odd_f32x4 = wasm_v128_and(a.v128, mask_high_u32x4);
523
+ v128_t b_odd_f32x4 = wasm_v128_and(b.v128, mask_high_u32x4);
524
+ state->sum_f32x4 = wasm_f32x4_relaxed_madd(a_odd_f32x4, b_odd_f32x4, state->sum_f32x4);
525
+ }
526
+
527
+ NK_INTERNAL void nk_dot_bf16x8_finalize_v128relaxed( //
528
+ nk_dot_bf16x8_state_v128relaxed_t const *state_a, nk_dot_bf16x8_state_v128relaxed_t const *state_b, //
529
+ nk_dot_bf16x8_state_v128relaxed_t const *state_c, nk_dot_bf16x8_state_v128relaxed_t const *state_d, //
530
+ nk_size_t total_dimensions, nk_b128_vec_t *result) {
531
+ nk_dot_through_f32x4_finalize_v128relaxed_(state_a, state_b, state_c, state_d, total_dimensions, result);
532
+ }
533
+
500
534
  typedef struct nk_dot_f32x2_state_v128relaxed_t {
501
535
  v128_t sum_f64x2;
502
536
  } nk_dot_f32x2_state_v128relaxed_t;
@@ -509,8 +543,8 @@ NK_INTERNAL void nk_dot_f32x2_update_v128relaxed(nk_dot_f32x2_state_v128relaxed_
509
543
  nk_b64_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
510
544
  nk_unused_(depth_offset);
511
545
  nk_unused_(active_dimensions);
512
- v128_t a_f32x2 = wasm_v128_load64_zero(&a.u64);
513
- v128_t b_f32x2 = wasm_v128_load64_zero(&b.u64);
546
+ v128_t a_f32x2 = wasm_i64x2_splat(a.u64);
547
+ v128_t b_f32x2 = wasm_i64x2_splat(b.u64);
514
548
  v128_t a_f64x2 = wasm_f64x2_promote_low_f32x4(a_f32x2);
515
549
  v128_t b_f64x2 = wasm_f64x2_promote_low_f32x4(b_f32x2);
516
550
  state->sum_f64x2 = wasm_f64x2_relaxed_madd(a_f64x2, b_f64x2, state->sum_f64x2);
@@ -603,12 +637,12 @@ NK_INTERNAL void nk_dot_i8x16_update_v128relaxed(nk_dot_i8x16_state_v128relaxed_
603
637
  nk_b128_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
604
638
  nk_unused_(depth_offset);
605
639
  nk_unused_(active_dimensions);
606
- // Bit-split: b = b_lo + (-128)·b_hi where b_lo = b & 0x7F ∈ [0,127], b_hi = b >> 7 ∈ {0,1}
607
- // So a·b = a·b_lo − 128·a·b_hi, both operands fit i7 for relaxed_dot
608
- v128_t b_lo_u8x16 = wasm_v128_and(b.v128, wasm_i8x16_splat(0x7F));
609
- v128_t b_hi_u8x16 = wasm_u8x16_shr(b.v128, 7);
610
- state->product_sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a.v128, b_lo_u8x16, state->product_sum_i32x4);
611
- state->negative_sum_a_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a.v128, b_hi_u8x16,
640
+ // Bit-split: b = b_low + (-128)·b_high where b_low = b & 0x7F ∈ [0,127], b_high = b >> 7 ∈ {0,1}
641
+ // So a·b = a·b_low − 128·a·b_high, both operands fit i7 for relaxed_dot
642
+ v128_t b_low_u8x16 = wasm_v128_and(b.v128, wasm_i8x16_splat(0x7F));
643
+ v128_t b_high_u8x16 = wasm_u8x16_shr(b.v128, 7);
644
+ state->product_sum_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a.v128, b_low_u8x16, state->product_sum_i32x4);
645
+ state->negative_sum_a_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a.v128, b_high_u8x16,
612
646
  state->negative_sum_a_i32x4);
613
647
  }
614
648
 
@@ -629,28 +663,29 @@ NK_INTERNAL void nk_dot_i8x16_finalize_v128relaxed(
629
663
  }
630
664
 
631
665
  typedef struct nk_dot_u8x16_state_v128relaxed_t {
632
- v128_t product_lo_i32x4; // relaxed_dot(a_signed, b_lo) accumulator
633
- v128_t product_hi_i32x4; // relaxed_dot(a_signed, b_hi) accumulator
666
+ v128_t product_low_i32x4; // relaxed_dot(a_signed, b_low) accumulator
667
+ v128_t product_high_i32x4; // relaxed_dot(a_signed, b_high) accumulator
634
668
  } nk_dot_u8x16_state_v128relaxed_t;
635
669
 
636
670
  NK_INTERNAL void nk_dot_u8x16_init_v128relaxed(nk_dot_u8x16_state_v128relaxed_t *state) {
637
- state->product_lo_i32x4 = wasm_i32x4_splat(0);
638
- state->product_hi_i32x4 = wasm_i32x4_splat(0);
671
+ state->product_low_i32x4 = wasm_i32x4_splat(0);
672
+ state->product_high_i32x4 = wasm_i32x4_splat(0);
639
673
  }
640
674
 
641
675
  NK_INTERNAL void nk_dot_u8x16_update_v128relaxed(nk_dot_u8x16_state_v128relaxed_t *state, nk_b128_vec_t a,
642
676
  nk_b128_vec_t b, nk_size_t depth_offset, nk_size_t active_dimensions) {
643
677
  nk_unused_(depth_offset);
644
678
  nk_unused_(active_dimensions);
645
- // Bit-split b: b = b_lo + 128·b_hi, with a_signed = a ^ 0x80 = a - 128 (reinterpret u8 as i8)
646
- // Σ a·b = Σ(a_signed+128)·(b_lo+128·b_hi) = relaxed_dot(a_signed,b_lo) + 128·relaxed_dot(a_signed,b_hi) + 128·Σb
679
+ // Bit-split b: b = b_low + 128·b_high, with a_signed = a ^ 0x80 = a - 128 (reinterpret u8 as i8)
680
+ // Σ a·b = Σ(a_signed+128)·(b_lo+128·b_high) = relaxed_dot(a_signed,b_low) + 128·relaxed_dot(a_signed,b_high) +
681
+ // 128·Σb
647
682
  v128_t a_signed_i8x16 = wasm_v128_xor(a.v128, wasm_i8x16_splat((signed char)0x80));
648
- v128_t b_lo_u8x16 = wasm_v128_and(b.v128, wasm_i8x16_splat(0x7F));
649
- v128_t b_hi_u8x16 = wasm_u8x16_shr(b.v128, 7);
650
- state->product_lo_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_signed_i8x16, b_lo_u8x16,
651
- state->product_lo_i32x4);
652
- state->product_hi_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_signed_i8x16, b_hi_u8x16,
653
- state->product_hi_i32x4);
683
+ v128_t b_low_u8x16 = wasm_v128_and(b.v128, wasm_i8x16_splat(0x7F));
684
+ v128_t b_high_u8x16 = wasm_u8x16_shr(b.v128, 7);
685
+ state->product_low_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_signed_i8x16, b_low_u8x16,
686
+ state->product_low_i32x4);
687
+ state->product_high_i32x4 = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(a_signed_i8x16, b_high_u8x16,
688
+ state->product_high_i32x4);
654
689
  }
655
690
 
656
691
  NK_INTERNAL void nk_dot_u8x16_finalize_v128relaxed( //
@@ -659,17 +694,17 @@ NK_INTERNAL void nk_dot_u8x16_finalize_v128relaxed(
659
694
  nk_size_t total_dimensions, nk_u32_t a_sum, nk_b128_vec_t b_sums, nk_b128_vec_t *result) {
660
695
  nk_unused_(a_sum);
661
696
  // Σ a·b = reduce(lo) + 128·reduce(hi) + 128·Σb
662
- result->u32s[0] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_a->product_lo_i32x4) +
663
- 128 * nk_reduce_add_i32x4_v128relaxed_(state_a->product_hi_i32x4) +
697
+ result->u32s[0] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_a->product_low_i32x4) +
698
+ 128 * nk_reduce_add_i32x4_v128relaxed_(state_a->product_high_i32x4) +
664
699
  128 * (nk_i32_t)b_sums.u32s[0]);
665
- result->u32s[1] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_b->product_lo_i32x4) +
666
- 128 * nk_reduce_add_i32x4_v128relaxed_(state_b->product_hi_i32x4) +
700
+ result->u32s[1] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_b->product_low_i32x4) +
701
+ 128 * nk_reduce_add_i32x4_v128relaxed_(state_b->product_high_i32x4) +
667
702
  128 * (nk_i32_t)b_sums.u32s[1]);
668
- result->u32s[2] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_c->product_lo_i32x4) +
669
- 128 * nk_reduce_add_i32x4_v128relaxed_(state_c->product_hi_i32x4) +
703
+ result->u32s[2] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_c->product_low_i32x4) +
704
+ 128 * nk_reduce_add_i32x4_v128relaxed_(state_c->product_high_i32x4) +
670
705
  128 * (nk_i32_t)b_sums.u32s[2]);
671
- result->u32s[3] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_d->product_lo_i32x4) +
672
- 128 * nk_reduce_add_i32x4_v128relaxed_(state_d->product_hi_i32x4) +
706
+ result->u32s[3] = (nk_u32_t)(nk_reduce_add_i32x4_v128relaxed_(state_d->product_low_i32x4) +
707
+ 128 * nk_reduce_add_i32x4_v128relaxed_(state_d->product_high_i32x4) +
673
708
  128 * (nk_i32_t)b_sums.u32s[3]);
674
709
  }
675
710
 
@@ -706,8 +741,8 @@ NK_INTERNAL void nk_dot_e2m3x16_update_v128relaxed(nk_dot_e2m3x16_state_v128rela
706
741
  nk_unused_(depth_offset);
707
742
  nk_unused_(active_dimensions);
708
743
  // Same LUT-based approach as 1:1 dot, accumulating into state
709
- v128_t lut_lower_u8x16 = wasm_i8x16_const(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30);
710
- v128_t lut_upper_u8x16 = wasm_i8x16_const(32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120);
744
+ v128_t lut_low_u8x16 = wasm_i8x16_const(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30);
745
+ v128_t lut_high_u8x16 = wasm_i8x16_const(32, 36, 40, 44, 48, 52, 56, 60, 64, 72, 80, 88, 96, 104, 112, 120);
711
746
  v128_t magnitude_mask_u8x16 = wasm_u8x16_splat(0x1F);
712
747
  v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
713
748
  v128_t half_select_u8x16 = wasm_u8x16_splat(0x10);
@@ -719,17 +754,17 @@ NK_INTERNAL void nk_dot_e2m3x16_update_v128relaxed(nk_dot_e2m3x16_state_v128rela
719
754
 
720
755
  // Dual swizzle LUT for a
721
756
  v128_t a_idx_u8x16 = wasm_v128_and(a_mag_u8x16, nibble_mask_u8x16);
722
- v128_t a_lo_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lower_u8x16, a_idx_u8x16);
723
- v128_t a_hi_u8x16 = wasm_i8x16_relaxed_swizzle(lut_upper_u8x16, a_idx_u8x16);
757
+ v128_t a_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_u8x16, a_idx_u8x16);
758
+ v128_t a_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_high_u8x16, a_idx_u8x16);
724
759
  v128_t a_sel_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_mag_u8x16, half_select_u8x16), half_select_u8x16);
725
- v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(a_hi_u8x16, a_lo_u8x16, a_sel_u8x16);
760
+ v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(a_high_u8x16, a_low_u8x16, a_sel_u8x16);
726
761
 
727
762
  // Dual swizzle LUT for b
728
763
  v128_t b_idx_u8x16 = wasm_v128_and(b_mag_u8x16, nibble_mask_u8x16);
729
- v128_t b_lo_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lower_u8x16, b_idx_u8x16);
730
- v128_t b_hi_u8x16 = wasm_i8x16_relaxed_swizzle(lut_upper_u8x16, b_idx_u8x16);
764
+ v128_t b_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_u8x16, b_idx_u8x16);
765
+ v128_t b_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_high_u8x16, b_idx_u8x16);
731
766
  v128_t b_sel_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_mag_u8x16, half_select_u8x16), half_select_u8x16);
732
- v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(b_hi_u8x16, b_lo_u8x16, b_sel_u8x16);
767
+ v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(b_high_u8x16, b_low_u8x16, b_sel_u8x16);
733
768
 
734
769
  // Combined sign → apply to a (relaxed_dot wants i8 × u7)
735
770
  v128_t sign_u8x16 = wasm_v128_and(wasm_v128_xor(a.v128, b.v128), sign_mask_u8x16);
@@ -770,8 +805,8 @@ NK_INTERNAL void nk_dot_e3m2x16_update_v128relaxed(nk_dot_e3m2x16_state_v128rela
770
805
  // ×4 scaled LUT — all values ≤ 112, fits u7 for relaxed_dot
771
806
  // Indices 0-11 rounded to nearest integer (max error ±0.5 in ×4 domain = ±0.125 in value)
772
807
  // Indices 12-31 exact
773
- v128_t lut_lower_u8x16 = wasm_i8x16_const(0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 6, 7);
774
- v128_t lut_upper_u8x16 = wasm_i8x16_const(8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, 64, 80, 96, 112);
808
+ v128_t lut_low_u8x16 = wasm_i8x16_const(0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 6, 7);
809
+ v128_t lut_high_u8x16 = wasm_i8x16_const(8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, 64, 80, 96, 112);
775
810
  v128_t magnitude_mask_u8x16 = wasm_u8x16_splat(0x1F);
776
811
  v128_t nibble_mask_u8x16 = wasm_u8x16_splat(0x0F);
777
812
  v128_t half_select_u8x16 = wasm_u8x16_splat(0x10);
@@ -782,17 +817,17 @@ NK_INTERNAL void nk_dot_e3m2x16_update_v128relaxed(nk_dot_e3m2x16_state_v128rela
782
817
 
783
818
  // Dual swizzle LUT for a
784
819
  v128_t a_idx_u8x16 = wasm_v128_and(a_mag_u8x16, nibble_mask_u8x16);
785
- v128_t a_lo_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lower_u8x16, a_idx_u8x16);
786
- v128_t a_hi_u8x16 = wasm_i8x16_relaxed_swizzle(lut_upper_u8x16, a_idx_u8x16);
820
+ v128_t a_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_u8x16, a_idx_u8x16);
821
+ v128_t a_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_high_u8x16, a_idx_u8x16);
787
822
  v128_t a_sel_u8x16 = wasm_i8x16_eq(wasm_v128_and(a_mag_u8x16, half_select_u8x16), half_select_u8x16);
788
- v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(a_hi_u8x16, a_lo_u8x16, a_sel_u8x16);
823
+ v128_t a_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(a_high_u8x16, a_low_u8x16, a_sel_u8x16);
789
824
 
790
825
  // Dual swizzle LUT for b
791
826
  v128_t b_idx_u8x16 = wasm_v128_and(b_mag_u8x16, nibble_mask_u8x16);
792
- v128_t b_lo_u8x16 = wasm_i8x16_relaxed_swizzle(lut_lower_u8x16, b_idx_u8x16);
793
- v128_t b_hi_u8x16 = wasm_i8x16_relaxed_swizzle(lut_upper_u8x16, b_idx_u8x16);
827
+ v128_t b_low_u8x16 = wasm_i8x16_relaxed_swizzle(lut_low_u8x16, b_idx_u8x16);
828
+ v128_t b_high_u8x16 = wasm_i8x16_relaxed_swizzle(lut_high_u8x16, b_idx_u8x16);
794
829
  v128_t b_sel_u8x16 = wasm_i8x16_eq(wasm_v128_and(b_mag_u8x16, half_select_u8x16), half_select_u8x16);
795
- v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(b_hi_u8x16, b_lo_u8x16, b_sel_u8x16);
830
+ v128_t b_unsigned_u8x16 = wasm_i8x16_relaxed_laneselect(b_high_u8x16, b_low_u8x16, b_sel_u8x16);
796
831
 
797
832
  // Combined sign → apply to a (relaxed_dot wants i8 × u7)
798
833
  v128_t sign_u8x16 = wasm_v128_and(wasm_v128_xor(a.v128, b.v128), sign_mask_u8x16);
@@ -1233,13 +1268,13 @@ NK_INTERNAL void nk_dot_u1x128_finalize_v128relaxed(
1233
1268
  v128_t a_u32x4 = state_a->dot_count_u32x4, b_u32x4 = state_b->dot_count_u32x4;
1234
1269
  v128_t c_u32x4 = state_c->dot_count_u32x4, d_u32x4 = state_d->dot_count_u32x4;
1235
1270
  // Step 1: interleave pairs
1236
- v128_t ab_lo_u32x4 = wasm_i32x4_shuffle(a_u32x4, b_u32x4, 0, 4, 1, 5); // a0 b0 a1 b1
1237
- v128_t ab_hi_u32x4 = wasm_i32x4_shuffle(a_u32x4, b_u32x4, 2, 6, 3, 7); // a2 b2 a3 b3
1238
- v128_t cd_lo_u32x4 = wasm_i32x4_shuffle(c_u32x4, d_u32x4, 0, 4, 1, 5); // c0 d0 c1 d1
1239
- v128_t cd_hi_u32x4 = wasm_i32x4_shuffle(c_u32x4, d_u32x4, 2, 6, 3, 7); // c2 d2 c3 d3
1271
+ v128_t ab_low_u32x4 = wasm_i32x4_shuffle(a_u32x4, b_u32x4, 0, 4, 1, 5); // a0 b0 a1 b1
1272
+ v128_t ab_high_u32x4 = wasm_i32x4_shuffle(a_u32x4, b_u32x4, 2, 6, 3, 7); // a2 b2 a3 b3
1273
+ v128_t cd_low_u32x4 = wasm_i32x4_shuffle(c_u32x4, d_u32x4, 0, 4, 1, 5); // c0 d0 c1 d1
1274
+ v128_t cd_high_u32x4 = wasm_i32x4_shuffle(c_u32x4, d_u32x4, 2, 6, 3, 7); // c2 d2 c3 d3
1240
1275
  // Step 2: pairwise add
1241
- v128_t sum_02_u32x4 = wasm_i32x4_add(ab_lo_u32x4, ab_hi_u32x4); // a02 b02 a13 b13
1242
- v128_t sum_13_u32x4 = wasm_i32x4_add(cd_lo_u32x4, cd_hi_u32x4); // c02 d02 c13 d13
1276
+ v128_t sum_02_u32x4 = wasm_i32x4_add(ab_low_u32x4, ab_high_u32x4); // a02 b02 a13 b13
1277
+ v128_t sum_13_u32x4 = wasm_i32x4_add(cd_low_u32x4, cd_high_u32x4); // c02 d02 c13 d13
1243
1278
  // Step 3: final interleave
1244
1279
  v128_t even_u32x4 = wasm_i32x4_shuffle(sum_02_u32x4, sum_13_u32x4, 0, 1, 4, 5);
1245
1280
  v128_t odd_u32x4 = wasm_i32x4_shuffle(sum_02_u32x4, sum_13_u32x4, 2, 3, 6, 7);