numkong 7.0.0 → 7.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +239 -122
  2. package/binding.gyp +25 -491
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -34,8 +34,8 @@ extern "C" {
34
34
  NK_PUBLIC void nk_sqeuclidean_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
35
35
  nk_f32_t *result) {
36
36
  // Per-lane accumulator — deferred horizontal reduction
37
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
38
- vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
37
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
38
+ vfloat32m2_t sum_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
39
39
 
40
40
  for (nk_size_t vector_length; count_scalars > 0;
41
41
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
@@ -54,7 +54,7 @@ NK_PUBLIC void nk_sqeuclidean_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t co
54
54
 
55
55
  // Single horizontal reduction after the loop
56
56
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
57
- *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, vlmax));
57
+ *result = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(sum_f32m2, zero_f32m1, max_vector_length));
58
58
  }
59
59
 
60
60
  NK_PUBLIC void nk_euclidean_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
@@ -66,10 +66,10 @@ NK_PUBLIC void nk_euclidean_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t cons
66
66
  NK_PUBLIC void nk_angular_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const *b_scalars, nk_size_t count_scalars,
67
67
  nk_f32_t *result) {
68
68
  // Per-lane accumulators — deferred horizontal reduction
69
- nk_size_t vlmax = __riscv_vsetvlmax_e32m2();
70
- vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
71
- vfloat32m2_t a_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
72
- vfloat32m2_t b_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, vlmax);
69
+ nk_size_t max_vector_length = __riscv_vsetvlmax_e32m2();
70
+ vfloat32m2_t dot_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
71
+ vfloat32m2_t a_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
72
+ vfloat32m2_t b_sq_f32m2 = __riscv_vfmv_v_f_f32m2(0.0f, max_vector_length);
73
73
 
74
74
  for (nk_size_t vector_length; count_scalars > 0;
75
75
  count_scalars -= vector_length, a_scalars += vector_length, b_scalars += vector_length) {
@@ -89,9 +89,12 @@ NK_PUBLIC void nk_angular_f16_rvvhalf(nk_f16_t const *a_scalars, nk_f16_t const
89
89
 
90
90
  // Single horizontal reduction after the loop
91
91
  vfloat32m1_t zero_f32m1 = __riscv_vfmv_v_f_f32m1(0.0f, 1);
92
- nk_f32_t dot = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, vlmax));
93
- nk_f32_t a_sq = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(a_sq_f32m2, zero_f32m1, vlmax));
94
- nk_f32_t b_sq = __riscv_vfmv_f_s_f32m1_f32(__riscv_vfredusum_vs_f32m2_f32m1(b_sq_f32m2, zero_f32m1, vlmax));
92
+ nk_f32_t dot = __riscv_vfmv_f_s_f32m1_f32(
93
+ __riscv_vfredusum_vs_f32m2_f32m1(dot_f32m2, zero_f32m1, max_vector_length));
94
+ nk_f32_t a_sq = __riscv_vfmv_f_s_f32m1_f32(
95
+ __riscv_vfredusum_vs_f32m2_f32m1(a_sq_f32m2, zero_f32m1, max_vector_length));
96
+ nk_f32_t b_sq = __riscv_vfmv_f_s_f32m1_f32(
97
+ __riscv_vfredusum_vs_f32m2_f32m1(b_sq_f32m2, zero_f32m1, max_vector_length));
95
98
 
96
99
  // Normalize: 1 − dot / sqrt(‖a‖² × ‖b‖²)
97
100
  if (a_sq == 0.0f && b_sq == 0.0f) { *result = 0.0f; }
@@ -40,11 +40,11 @@ extern "C" {
40
40
  #define nk_define_sqeuclidean_(input_type, accumulator_type, output_type, load_and_convert) \
41
41
  NK_PUBLIC void nk_sqeuclidean_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
42
42
  nk_size_t n, nk_##output_type##_t *result) { \
43
- nk_##accumulator_type##_t sum = 0, compensation = 0, a_element, b_element; \
43
+ nk_##accumulator_type##_t sum = 0, compensation = 0, a_value, b_value; \
44
44
  for (nk_size_t i = 0; i != n; ++i) { \
45
- load_and_convert(a + i, &a_element); \
46
- load_and_convert(b + i, &b_element); \
47
- nk_##accumulator_type##_t diff = a_element - b_element; \
45
+ load_and_convert(a + i, &a_value); \
46
+ load_and_convert(b + i, &b_value); \
47
+ nk_##accumulator_type##_t diff = a_value - b_value; \
48
48
  nk_##accumulator_type##_t term = diff * diff, t = sum + term; \
49
49
  compensation += (nk_##accumulator_type##_abs_(sum) >= nk_##accumulator_type##_abs_(term)) \
50
50
  ? ((sum - t) + term) \
@@ -74,14 +74,14 @@ extern "C" {
74
74
  #define nk_define_angular_(input_type, accumulator_type, output_type, load_and_convert, compute_rsqrt) \
75
75
  NK_PUBLIC void nk_angular_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
76
76
  nk_size_t n, nk_##output_type##_t *result) { \
77
- nk_##accumulator_type##_t dot_sum = 0, a_sum = 0, b_sum = 0, a_element, b_element; \
77
+ nk_##accumulator_type##_t dot_sum = 0, a_sum = 0, b_sum = 0, a_value, b_value; \
78
78
  nk_##accumulator_type##_t compensation_dot = 0, compensation_a = 0, compensation_b = 0; \
79
79
  for (nk_size_t i = 0; i != n; ++i) { \
80
- load_and_convert(a + i, &a_element); \
81
- load_and_convert(b + i, &b_element); \
82
- nk_##accumulator_type##_t term_dot = a_element * b_element, t_dot = dot_sum + term_dot; \
83
- nk_##accumulator_type##_t term_a = a_element * a_element, t_a = a_sum + term_a; \
84
- nk_##accumulator_type##_t term_b = b_element * b_element, t_b = b_sum + term_b; \
80
+ load_and_convert(a + i, &a_value); \
81
+ load_and_convert(b + i, &b_value); \
82
+ nk_##accumulator_type##_t term_dot = a_value * b_value, t_dot = dot_sum + term_dot; \
83
+ nk_##accumulator_type##_t term_a = a_value * a_value, t_a = a_sum + term_a; \
84
+ nk_##accumulator_type##_t term_b = b_value * b_value, t_b = b_sum + term_b; \
85
85
  compensation_dot += (nk_##accumulator_type##_abs_(dot_sum) >= nk_##accumulator_type##_abs_(term_dot)) \
86
86
  ? ((dot_sum - t_dot) + term_dot) \
87
87
  : ((term_dot - t_dot) + dot_sum); \
@@ -101,8 +101,9 @@ extern "C" {
101
101
  if (a_norm_sq == 0 && b_norm_sq == 0) { *result = 0; } \
102
102
  else if (dot_product == 0) { *result = 1; } \
103
103
  else { \
104
- nk_##output_type##_t unclipped_distance = 1 - dot_product * compute_rsqrt(a_norm_sq) * \
105
- compute_rsqrt(b_norm_sq); \
104
+ nk_##output_type##_t unclipped_distance = (nk_##output_type##_t)( \
105
+ 1 - (nk_##output_type##_t)dot_product * compute_rsqrt((nk_##output_type##_t)a_norm_sq) * \
106
+ compute_rsqrt((nk_##output_type##_t)b_norm_sq)); \
106
107
  *result = unclipped_distance > 0 ? unclipped_distance : 0; \
107
108
  } \
108
109
  }
@@ -8,12 +8,12 @@
8
8
  *
9
9
  * @section spatial_sierra_instructions AVXVNNIINT8 Instructions Performance
10
10
  *
11
- * Intrinsic Instruction Sierra Forest
12
- * _mm256_dpbssds_epi32 VPDPBSSDS (YMM, YMM, YMM) 4cy @ p05
13
- * _mm256_dpbssd_epi32 VPDPBSSD (YMM, YMM, YMM) 4cy @ p05
14
- * _mm256_dpbuud_epi32 VPDPBUUD (YMM, YMM, YMM) 4cy @ p05
15
- * _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5cy @ p0
16
- * _mm_sqrt_ss VSQRTSS (XMM, XMM, XMM) 12cy @ p0
11
+ * Intrinsic Instruction Sierra Forest
12
+ * _mm256_dpbssds_epi32 VPDPBSSDS (YMM, YMM, YMM) 4cy @ p05
13
+ * _mm256_dpbssd_epi32 VPDPBSSD (YMM, YMM, YMM) 4cy @ p05
14
+ * _mm256_dpbuud_epi32 VPDPBUUD (YMM, YMM, YMM) 4cy @ p05
15
+ * _mm_rsqrt_ps VRSQRTPS (XMM, XMM) 5cy @ p0
16
+ * _mm_sqrt_ss VSQRTSS (XMM, XMM, XMM) 12cy @ p0
17
17
  *
18
18
  * Sierra Forest (AVXVNNIINT8) provides native signed x signed and unsigned x unsigned
19
19
  * dot products, eliminating the need for algebraic corrections required on Alder Lake.
@@ -67,7 +67,8 @@ NK_PUBLIC void nk_angular_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_
67
67
  b_norm_sq_i32 += b_element_i32 * b_element_i32;
68
68
  }
69
69
 
70
- *result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
70
+ *result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
71
+ (nk_f32_t)b_norm_sq_i32);
71
72
  }
72
73
 
73
74
  NK_PUBLIC void nk_sqeuclidean_i8_sierra(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
@@ -132,7 +133,8 @@ NK_PUBLIC void nk_angular_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_
132
133
  b_norm_sq_i32 += b_element_i32 * b_element_i32;
133
134
  }
134
135
 
135
- *result = nk_angular_normalize_f32_haswell_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
136
+ *result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
137
+ (nk_f32_t)b_norm_sq_i32);
136
138
  }
137
139
 
138
140
  NK_PUBLIC void nk_sqeuclidean_u8_sierra(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
@@ -177,15 +179,15 @@ NK_PUBLIC void nk_angular_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t cons
177
179
  // Every e2m3 value × 16 is an exact integer in [-120, +120].
178
180
  // DPBSSD(signed, signed) eliminates the need for unsigned conversion tricks.
179
181
  //
180
- __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
181
- 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
182
- __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
183
- 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
182
+ __m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
183
+ 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
184
+ __m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
185
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
184
186
  __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
185
187
  __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
186
188
  __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
187
189
  __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
188
- __m256i dot_i32x8 = _mm256_setzero_si256();
190
+ __m256i ab_i32x8 = _mm256_setzero_si256();
189
191
  __m256i a_norm_i32x8 = _mm256_setzero_si256();
190
192
  __m256i b_norm_i32x8 = _mm256_setzero_si256();
191
193
  __m256i a_e2m3_u8x32, b_e2m3_u8x32;
@@ -207,35 +209,39 @@ nk_angular_e2m3_sierra_cycle:
207
209
 
208
210
  // Decode a: extract magnitude, dual-VPSHUFB LUT, apply sign
209
211
  __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
210
- __m256i a_shuffle_idx = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
211
- __m256i a_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
212
- __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_idx),
213
- _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_idx), a_upper_sel);
212
+ __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
213
+ __m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
214
+ half_select_u8x32);
215
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
216
+ _mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
217
+ a_high_select_u8x32);
214
218
  __m256i a_negate = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
215
219
  __m256i a_signed_i8x32 = _mm256_blendv_epi8(a_unsigned_u8x32,
216
220
  _mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate);
217
221
 
218
222
  // Decode b: same LUT decode + sign
219
223
  __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
220
- __m256i b_shuffle_idx = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
221
- __m256i b_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
222
- __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_idx),
223
- _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_idx), b_upper_sel);
224
+ __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
225
+ __m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
226
+ half_select_u8x32);
227
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
228
+ _mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
229
+ b_high_select_u8x32);
224
230
  __m256i b_negate = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
225
231
  __m256i b_signed_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32,
226
232
  _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate);
227
233
 
228
234
  // VPDPBSSD: signed × signed → i32
229
- dot_i32x8 = _mm256_dpbssd_epi32(dot_i32x8, a_signed_i8x32, b_signed_i8x32);
235
+ ab_i32x8 = _mm256_dpbssd_epi32(ab_i32x8, a_signed_i8x32, b_signed_i8x32);
230
236
  a_norm_i32x8 = _mm256_dpbssd_epi32(a_norm_i32x8, a_signed_i8x32, a_signed_i8x32);
231
237
  b_norm_i32x8 = _mm256_dpbssd_epi32(b_norm_i32x8, b_signed_i8x32, b_signed_i8x32);
232
238
 
233
239
  if (count_scalars) goto nk_angular_e2m3_sierra_cycle;
234
240
 
235
- nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
241
+ nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(ab_i32x8);
236
242
  nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
237
243
  nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
238
- *result = nk_angular_normalize_f32_haswell_(dot_i32, a_norm_i32, b_norm_i32);
244
+ *result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_i32, (nk_f32_t)a_norm_i32, (nk_f32_t)b_norm_i32);
239
245
  }
240
246
 
241
247
  NK_PUBLIC void nk_sqeuclidean_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t const *b_scalars,
@@ -243,15 +249,15 @@ NK_PUBLIC void nk_sqeuclidean_e2m3_sierra(nk_e2m3_t const *a_scalars, nk_e2m3_t
243
249
  // Squared Euclidean distance for e2m3 using norm decomposition + VPDPBSSD.
244
250
  // ||a-b||^2 = ||a||^2 + ||b||^2 - 2*dot(a,b)
245
251
  //
246
- __m256i const lut_lower_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28,
247
- 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
248
- __m256i const lut_upper_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
249
- 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
252
+ __m256i const lut_low_u8x32 = _mm256_set_epi8(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 30, 28, 26,
253
+ 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
254
+ __m256i const lut_high_u8x32 = _mm256_set_epi8(120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32,
255
+ 120, 112, 104, 96, 88, 80, 72, 64, 60, 56, 52, 48, 44, 40, 36, 32);
250
256
  __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
251
257
  __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
252
258
  __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
253
259
  __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
254
- __m256i dot_i32x8 = _mm256_setzero_si256();
260
+ __m256i ab_i32x8 = _mm256_setzero_si256();
255
261
  __m256i a_norm_i32x8 = _mm256_setzero_si256();
256
262
  __m256i b_norm_i32x8 = _mm256_setzero_si256();
257
263
  __m256i a_e2m3_u8x32, b_e2m3_u8x32;
@@ -273,31 +279,35 @@ nk_sqeuclidean_e2m3_sierra_cycle:
273
279
 
274
280
  // Decode a
275
281
  __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e2m3_u8x32, magnitude_mask_u8x32);
276
- __m256i a_shuffle_idx = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
277
- __m256i a_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
278
- __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, a_shuffle_idx),
279
- _mm256_shuffle_epi8(lut_upper_u8x32, a_shuffle_idx), a_upper_sel);
282
+ __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
283
+ __m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
284
+ half_select_u8x32);
285
+ __m256i a_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, a_shuffle_index_u8x32),
286
+ _mm256_shuffle_epi8(lut_high_u8x32, a_shuffle_index_u8x32),
287
+ a_high_select_u8x32);
280
288
  __m256i a_negate = _mm256_cmpeq_epi8(_mm256_and_si256(a_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
281
289
  __m256i a_signed_i8x32 = _mm256_blendv_epi8(a_unsigned_u8x32,
282
290
  _mm256_sub_epi8(_mm256_setzero_si256(), a_unsigned_u8x32), a_negate);
283
291
 
284
292
  // Decode b
285
293
  __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e2m3_u8x32, magnitude_mask_u8x32);
286
- __m256i b_shuffle_idx = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
287
- __m256i b_upper_sel = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32), half_select_u8x32);
288
- __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_lower_u8x32, b_shuffle_idx),
289
- _mm256_shuffle_epi8(lut_upper_u8x32, b_shuffle_idx), b_upper_sel);
294
+ __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
295
+ __m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
296
+ half_select_u8x32);
297
+ __m256i b_unsigned_u8x32 = _mm256_blendv_epi8(_mm256_shuffle_epi8(lut_low_u8x32, b_shuffle_index_u8x32),
298
+ _mm256_shuffle_epi8(lut_high_u8x32, b_shuffle_index_u8x32),
299
+ b_high_select_u8x32);
290
300
  __m256i b_negate = _mm256_cmpeq_epi8(_mm256_and_si256(b_e2m3_u8x32, sign_mask_u8x32), sign_mask_u8x32);
291
301
  __m256i b_signed_i8x32 = _mm256_blendv_epi8(b_unsigned_u8x32,
292
302
  _mm256_sub_epi8(_mm256_setzero_si256(), b_unsigned_u8x32), b_negate);
293
303
 
294
- dot_i32x8 = _mm256_dpbssd_epi32(dot_i32x8, a_signed_i8x32, b_signed_i8x32);
304
+ ab_i32x8 = _mm256_dpbssd_epi32(ab_i32x8, a_signed_i8x32, b_signed_i8x32);
295
305
  a_norm_i32x8 = _mm256_dpbssd_epi32(a_norm_i32x8, a_signed_i8x32, a_signed_i8x32);
296
306
  b_norm_i32x8 = _mm256_dpbssd_epi32(b_norm_i32x8, b_signed_i8x32, b_signed_i8x32);
297
307
 
298
308
  if (count_scalars) goto nk_sqeuclidean_e2m3_sierra_cycle;
299
309
 
300
- nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(dot_i32x8);
310
+ nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(ab_i32x8);
301
311
  nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
302
312
  nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
303
313
  *result = (nk_f32_t)(a_norm_i32 + b_norm_i32 - 2 * dot_i32) / 256.0f;
@@ -308,6 +318,189 @@ NK_PUBLIC void nk_euclidean_e2m3_sierra(nk_e2m3_t const *a, nk_e2m3_t const *b,
308
318
  *result = nk_f32_sqrt_haswell(*result);
309
319
  }
310
320
 
321
+ NK_PUBLIC void nk_sqeuclidean_e3m2_sierra(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars,
322
+ nk_size_t count_scalars, nk_f32_t *result) {
323
+ // E3M2 squared Euclidean distance via direct difference squaring.
324
+ __m256i const lut_low_byte_first_u8x32 = _mm256_set_epi8( //
325
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
326
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
327
+ __m256i const lut_low_byte_second_u8x32 = _mm256_set_epi8( //
328
+ (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
329
+ (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
330
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
331
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
332
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
333
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
334
+ __m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
335
+ __m256i const ones_u8x32 = _mm256_set1_epi8(1);
336
+ __m256i const ones_i16x16 = _mm256_set1_epi16(1);
337
+ __m256i sum_i32x8 = _mm256_setzero_si256();
338
+ __m256i a_e3m2_u8x32, b_e3m2_u8x32;
339
+
340
+ nk_sqeuclidean_e3m2_sierra_cycle:
341
+ if (count_scalars < 32) {
342
+ nk_b256_vec_t a_vec, b_vec;
343
+ nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
344
+ nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
345
+ a_e3m2_u8x32 = a_vec.ymm;
346
+ b_e3m2_u8x32 = b_vec.ymm;
347
+ count_scalars = 0;
348
+ }
349
+ else {
350
+ a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
351
+ b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
352
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
353
+ }
354
+
355
+ // Decode both to unsigned i16 via dual-VPSHUFB + interleave
356
+ __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e3m2_u8x32, magnitude_mask_u8x32);
357
+ __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
358
+ __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
359
+ __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
360
+ __m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
361
+ half_select_u8x32);
362
+ __m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
363
+ half_select_u8x32);
364
+ __m256i a_low_bytes_u8x32 = _mm256_blendv_epi8(
365
+ _mm256_shuffle_epi8(lut_low_byte_first_u8x32, a_shuffle_index_u8x32),
366
+ _mm256_shuffle_epi8(lut_low_byte_second_u8x32, a_shuffle_index_u8x32), a_high_select_u8x32);
367
+ __m256i b_low_bytes_u8x32 = _mm256_blendv_epi8(
368
+ _mm256_shuffle_epi8(lut_low_byte_first_u8x32, b_shuffle_index_u8x32),
369
+ _mm256_shuffle_epi8(lut_low_byte_second_u8x32, b_shuffle_index_u8x32), b_high_select_u8x32);
370
+ __m256i a_high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32),
371
+ ones_u8x32);
372
+ __m256i b_high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32),
373
+ ones_u8x32);
374
+
375
+ // Interleave to i16 and apply signs
376
+ __m256i a_low_i16x16 = _mm256_unpacklo_epi8(a_low_bytes_u8x32, a_high_bytes_u8x32);
377
+ __m256i a_high_i16x16 = _mm256_unpackhi_epi8(a_low_bytes_u8x32, a_high_bytes_u8x32);
378
+ __m256i b_low_i16x16 = _mm256_unpacklo_epi8(b_low_bytes_u8x32, b_high_bytes_u8x32);
379
+ __m256i b_high_i16x16 = _mm256_unpackhi_epi8(b_low_bytes_u8x32, b_high_bytes_u8x32);
380
+
381
+ __m256i a_negative_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_e3m2_u8x32, sign_mask_u8x32), sign_mask_u8x32);
382
+ __m256i b_negative_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_e3m2_u8x32, sign_mask_u8x32), sign_mask_u8x32);
383
+ __m256i a_sign_low_i16x16 = _mm256_or_si256(_mm256_unpacklo_epi8(a_negative_mask_u8x32, a_negative_mask_u8x32),
384
+ ones_i16x16);
385
+ __m256i a_sign_high_i16x16 = _mm256_or_si256(_mm256_unpackhi_epi8(a_negative_mask_u8x32, a_negative_mask_u8x32),
386
+ ones_i16x16);
387
+ __m256i b_sign_low_i16x16 = _mm256_or_si256(_mm256_unpacklo_epi8(b_negative_mask_u8x32, b_negative_mask_u8x32),
388
+ ones_i16x16);
389
+ __m256i b_sign_high_i16x16 = _mm256_or_si256(_mm256_unpackhi_epi8(b_negative_mask_u8x32, b_negative_mask_u8x32),
390
+ ones_i16x16);
391
+ __m256i a_signed_low_i16x16 = _mm256_sign_epi16(a_low_i16x16, a_sign_low_i16x16);
392
+ __m256i a_signed_high_i16x16 = _mm256_sign_epi16(a_high_i16x16, a_sign_high_i16x16);
393
+ __m256i b_signed_low_i16x16 = _mm256_sign_epi16(b_low_i16x16, b_sign_low_i16x16);
394
+ __m256i b_signed_high_i16x16 = _mm256_sign_epi16(b_high_i16x16, b_sign_high_i16x16);
395
+
396
+ // Direct difference squaring: (a-b)² via VPMADDWD
397
+ __m256i diff_low_i16x16 = _mm256_sub_epi16(a_signed_low_i16x16, b_signed_low_i16x16);
398
+ __m256i diff_high_i16x16 = _mm256_sub_epi16(a_signed_high_i16x16, b_signed_high_i16x16);
399
+ sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(diff_low_i16x16, diff_low_i16x16));
400
+ sum_i32x8 = _mm256_add_epi32(sum_i32x8, _mm256_madd_epi16(diff_high_i16x16, diff_high_i16x16));
401
+
402
+ if (count_scalars) goto nk_sqeuclidean_e3m2_sierra_cycle;
403
+ *result = (nk_f32_t)nk_reduce_add_i32x8_haswell_(sum_i32x8) / 256.0f;
404
+ }
405
+
406
+ NK_PUBLIC void nk_euclidean_e3m2_sierra(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
407
+ nk_sqeuclidean_e3m2_sierra(a, b, n, result);
408
+ *result = nk_f32_sqrt_haswell(*result);
409
+ }
410
+
411
+ NK_PUBLIC void nk_angular_e3m2_sierra(nk_e3m2_t const *a_scalars, nk_e3m2_t const *b_scalars, nk_size_t count_scalars,
412
+ nk_f32_t *result) {
413
+ // E3M2 angular distance via VPMADDWD integer MAC.
414
+ __m256i const lut_low_byte_first_u8x32 = _mm256_set_epi8( //
415
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0, //
416
+ 28, 24, 20, 16, 14, 12, 10, 8, 7, 6, 5, 4, 3, 2, 1, 0);
417
+ __m256i const lut_low_byte_second_u8x32 = _mm256_set_epi8( //
418
+ (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32, //
419
+ (char)192, (char)128, 64, 0, (char)224, (char)192, (char)160, (char)128, 112, 96, 80, 64, 56, 48, 40, 32);
420
+ __m256i const nibble_mask_u8x32 = _mm256_set1_epi8(0x0F);
421
+ __m256i const magnitude_mask_u8x32 = _mm256_set1_epi8(0x1F);
422
+ __m256i const half_select_u8x32 = _mm256_set1_epi8(0x10);
423
+ __m256i const sign_mask_u8x32 = _mm256_set1_epi8(0x20);
424
+ __m256i const high_threshold_u8x32 = _mm256_set1_epi8(27);
425
+ __m256i const ones_u8x32 = _mm256_set1_epi8(1);
426
+ __m256i const ones_i16x16 = _mm256_set1_epi16(1);
427
+ __m256i ab_i32x8 = _mm256_setzero_si256();
428
+ __m256i a_norm_i32x8 = _mm256_setzero_si256();
429
+ __m256i b_norm_i32x8 = _mm256_setzero_si256();
430
+ __m256i a_e3m2_u8x32, b_e3m2_u8x32;
431
+
432
+ nk_angular_e3m2_sierra_cycle:
433
+ if (count_scalars < 32) {
434
+ nk_b256_vec_t a_vec, b_vec;
435
+ nk_partial_load_b8x32_serial_(a_scalars, &a_vec, count_scalars);
436
+ nk_partial_load_b8x32_serial_(b_scalars, &b_vec, count_scalars);
437
+ a_e3m2_u8x32 = a_vec.ymm;
438
+ b_e3m2_u8x32 = b_vec.ymm;
439
+ count_scalars = 0;
440
+ }
441
+ else {
442
+ a_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)a_scalars);
443
+ b_e3m2_u8x32 = _mm256_loadu_si256((__m256i const *)b_scalars);
444
+ a_scalars += 32, b_scalars += 32, count_scalars -= 32;
445
+ }
446
+
447
+ // Decode both to unsigned i16
448
+ __m256i a_magnitude_u8x32 = _mm256_and_si256(a_e3m2_u8x32, magnitude_mask_u8x32);
449
+ __m256i b_magnitude_u8x32 = _mm256_and_si256(b_e3m2_u8x32, magnitude_mask_u8x32);
450
+ __m256i a_shuffle_index_u8x32 = _mm256_and_si256(a_magnitude_u8x32, nibble_mask_u8x32);
451
+ __m256i b_shuffle_index_u8x32 = _mm256_and_si256(b_magnitude_u8x32, nibble_mask_u8x32);
452
+ __m256i a_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_magnitude_u8x32, half_select_u8x32),
453
+ half_select_u8x32);
454
+ __m256i b_high_select_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_magnitude_u8x32, half_select_u8x32),
455
+ half_select_u8x32);
456
+ __m256i a_low_bytes_u8x32 = _mm256_blendv_epi8(
457
+ _mm256_shuffle_epi8(lut_low_byte_first_u8x32, a_shuffle_index_u8x32),
458
+ _mm256_shuffle_epi8(lut_low_byte_second_u8x32, a_shuffle_index_u8x32), a_high_select_u8x32);
459
+ __m256i b_low_bytes_u8x32 = _mm256_blendv_epi8(
460
+ _mm256_shuffle_epi8(lut_low_byte_first_u8x32, b_shuffle_index_u8x32),
461
+ _mm256_shuffle_epi8(lut_low_byte_second_u8x32, b_shuffle_index_u8x32), b_high_select_u8x32);
462
+ __m256i a_high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(a_magnitude_u8x32, high_threshold_u8x32),
463
+ ones_u8x32);
464
+ __m256i b_high_bytes_u8x32 = _mm256_and_si256(_mm256_cmpgt_epi8(b_magnitude_u8x32, high_threshold_u8x32),
465
+ ones_u8x32);
466
+ __m256i a_low_i16x16 = _mm256_unpacklo_epi8(a_low_bytes_u8x32, a_high_bytes_u8x32);
467
+ __m256i a_high_i16x16 = _mm256_unpackhi_epi8(a_low_bytes_u8x32, a_high_bytes_u8x32);
468
+ __m256i b_low_i16x16 = _mm256_unpacklo_epi8(b_low_bytes_u8x32, b_high_bytes_u8x32);
469
+ __m256i b_high_i16x16 = _mm256_unpackhi_epi8(b_low_bytes_u8x32, b_high_bytes_u8x32);
470
+
471
+ // Apply signs individually
472
+ __m256i a_negative_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(a_e3m2_u8x32, sign_mask_u8x32), sign_mask_u8x32);
473
+ __m256i b_negative_mask_u8x32 = _mm256_cmpeq_epi8(_mm256_and_si256(b_e3m2_u8x32, sign_mask_u8x32), sign_mask_u8x32);
474
+ __m256i a_sign_low_i16x16 = _mm256_or_si256(_mm256_unpacklo_epi8(a_negative_mask_u8x32, a_negative_mask_u8x32),
475
+ ones_i16x16);
476
+ __m256i a_sign_high_i16x16 = _mm256_or_si256(_mm256_unpackhi_epi8(a_negative_mask_u8x32, a_negative_mask_u8x32),
477
+ ones_i16x16);
478
+ __m256i b_sign_low_i16x16 = _mm256_or_si256(_mm256_unpacklo_epi8(b_negative_mask_u8x32, b_negative_mask_u8x32),
479
+ ones_i16x16);
480
+ __m256i b_sign_high_i16x16 = _mm256_or_si256(_mm256_unpackhi_epi8(b_negative_mask_u8x32, b_negative_mask_u8x32),
481
+ ones_i16x16);
482
+ __m256i a_signed_low_i16x16 = _mm256_sign_epi16(a_low_i16x16, a_sign_low_i16x16);
483
+ __m256i a_signed_high_i16x16 = _mm256_sign_epi16(a_high_i16x16, a_sign_high_i16x16);
484
+ __m256i b_signed_low_i16x16 = _mm256_sign_epi16(b_low_i16x16, b_sign_low_i16x16);
485
+ __m256i b_signed_high_i16x16 = _mm256_sign_epi16(b_high_i16x16, b_sign_high_i16x16);
486
+
487
+ // dot(a,b) + a² + b² via VPMADDWD
488
+ ab_i32x8 = _mm256_add_epi32(ab_i32x8, _mm256_madd_epi16(a_signed_low_i16x16, b_signed_low_i16x16));
489
+ ab_i32x8 = _mm256_add_epi32(ab_i32x8, _mm256_madd_epi16(a_signed_high_i16x16, b_signed_high_i16x16));
490
+ a_norm_i32x8 = _mm256_add_epi32(a_norm_i32x8, _mm256_madd_epi16(a_low_i16x16, a_low_i16x16));
491
+ a_norm_i32x8 = _mm256_add_epi32(a_norm_i32x8, _mm256_madd_epi16(a_high_i16x16, a_high_i16x16));
492
+ b_norm_i32x8 = _mm256_add_epi32(b_norm_i32x8, _mm256_madd_epi16(b_low_i16x16, b_low_i16x16));
493
+ b_norm_i32x8 = _mm256_add_epi32(b_norm_i32x8, _mm256_madd_epi16(b_high_i16x16, b_high_i16x16));
494
+
495
+ if (count_scalars) goto nk_angular_e3m2_sierra_cycle;
496
+
497
+ nk_i32_t dot_i32 = nk_reduce_add_i32x8_haswell_(ab_i32x8);
498
+ nk_i32_t a_norm_i32 = nk_reduce_add_i32x8_haswell_(a_norm_i32x8);
499
+ nk_i32_t b_norm_i32 = nk_reduce_add_i32x8_haswell_(b_norm_i32x8);
500
+ *result = nk_angular_normalize_f32_haswell_((nk_f32_t)dot_i32 / 256.0f, (nk_f32_t)a_norm_i32 / 256.0f,
501
+ (nk_f32_t)b_norm_i32 / 256.0f);
502
+ }
503
+
311
504
  #if defined(__clang__)
312
505
  #pragma clang attribute pop
313
506
  #elif defined(__GNUC__)