numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -82,17 +82,17 @@
82
82
  *
83
83
  * The SIMD kernels are dominated by FMA, permutes, and gathers:
84
84
  *
85
- * Intrinsic Instruction Notes
86
- * _mm256_fmadd_ps/pd VFMADD* FMA on FP ports (Haswell/Skylake: ports 0/1)
87
- * _mm256_i32gather_ps VGATHERDPS High-latency; memory-bound
88
- * _mm512_permutex2var_ps/pd VPERMT2* Shuffle-heavy; can bottleneck on shuffle ports
89
- * _mm512_reduce_add_ps/pd (sequence) Implemented via shuffles + adds
85
+ * Intrinsic Instruction Notes
86
+ * _mm256_fmadd_ps/pd VFMADD* FMA on FP ports (Haswell/Skylake: ports 0/1)
87
+ * _mm256_i32gather_ps VGATHERDPS High-latency; memory-bound
88
+ * _mm512_permutex2var_ps/pd VPERMT2* Shuffle-heavy; can bottleneck on shuffle ports
89
+ * _mm512_reduce_add_ps/pd (sequence) Implemented via shuffles + adds
90
90
  *
91
91
  * Gather-heavy tails are intentionally isolated to keep the steady-state loop on contiguous loads.
92
92
  *
93
93
  * @section references References
94
94
  *
95
- * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/
95
+ * - x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
96
96
  * - Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/
97
97
  *
98
98
  */
@@ -245,6 +245,25 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
245
245
  /** @copydoc nk_umeyama_f64 */
246
246
  NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
247
247
  nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result);
248
+
249
+ /** @copydoc nk_rmsd_f16 */
250
+ NK_PUBLIC void nk_rmsd_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
251
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
252
+ /** @copydoc nk_kabsch_f16 */
253
+ NK_PUBLIC void nk_kabsch_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
254
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
255
+ /** @copydoc nk_umeyama_f16 */
256
+ NK_PUBLIC void nk_umeyama_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
257
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
258
+ /** @copydoc nk_rmsd_bf16 */
259
+ NK_PUBLIC void nk_rmsd_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
260
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
261
+ /** @copydoc nk_kabsch_bf16 */
262
+ NK_PUBLIC void nk_kabsch_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
263
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
264
+ /** @copydoc nk_umeyama_bf16 */
265
+ NK_PUBLIC void nk_umeyama_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
266
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
248
267
  #endif // NK_TARGET_SKYLAKE
249
268
 
250
269
  /* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer.
@@ -313,21 +332,16 @@ NK_PUBLIC void nk_kabsch_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_
313
332
  /** @copydoc nk_umeyama_f64 */
314
333
  NK_PUBLIC void nk_umeyama_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
315
334
  nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result);
316
- #endif // NK_TARGET_NEON
317
-
318
- /* SIMD-powered backends for Arm NEON FP16 CPUs.
319
- */
320
- #if NK_TARGET_NEONHALF
321
335
  /** @copydoc nk_rmsd_f16 */
322
- NK_PUBLIC void nk_rmsd_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
323
- nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
336
+ NK_PUBLIC void nk_rmsd_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
337
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
324
338
  /** @copydoc nk_kabsch_f16 */
325
- NK_PUBLIC void nk_kabsch_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
326
- nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
339
+ NK_PUBLIC void nk_kabsch_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
340
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
327
341
  /** @copydoc nk_umeyama_f16 */
328
- NK_PUBLIC void nk_umeyama_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
329
- nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
330
- #endif // NK_TARGET_NEONHALF
342
+ NK_PUBLIC void nk_umeyama_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
343
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
344
+ #endif // NK_TARGET_NEON
331
345
 
332
346
  /* SIMD-powered backends for Arm NEON BF16 CPUs.
333
347
  */
@@ -406,22 +420,10 @@ NK_PUBLIC void nk_umeyama_f64_v128relaxed(nk_f64_t const *a, nk_f64_t const *b,
406
420
  #endif // NK_TARGET_V128RELAXED
407
421
 
408
422
  /**
409
- * @brief Returns the output dtype for RMSD.
410
- */
411
- NK_INTERNAL nk_dtype_t nk_rmsd_output_dtype(nk_dtype_t dtype) {
412
- switch (dtype) {
413
- case nk_f64_k: return nk_f64_k;
414
- case nk_f32_k: return nk_f64_k;
415
- case nk_f16_k: return nk_f32_k;
416
- case nk_bf16_k: return nk_f32_k;
417
- default: return nk_dtype_unknown_k;
418
- }
419
- }
420
-
421
- /**
422
- * @brief Returns the output dtype for Kabsch alignment.
423
+ * @brief Returns the metric output dtype for mesh alignment operations.
424
+ * Matches the C++ `mesh_metric_t` alias in types.hpp.
423
425
  */
424
- NK_INTERNAL nk_dtype_t nk_kabsch_output_dtype(nk_dtype_t dtype) {
426
+ NK_INTERNAL nk_dtype_t nk_mesh_metric_dtype(nk_dtype_t dtype) {
425
427
  switch (dtype) {
426
428
  case nk_f64_k: return nk_f64_k;
427
429
  case nk_f32_k: return nk_f64_k;
@@ -432,12 +434,13 @@ NK_INTERNAL nk_dtype_t nk_kabsch_output_dtype(nk_dtype_t dtype) {
432
434
  }
433
435
 
434
436
  /**
435
- * @brief Returns the output dtype for Umeyama alignment.
437
+ * @brief Returns the transform output dtype for mesh alignment operations.
438
+ * Matches the C++ `mesh_transform_t` alias in types.hpp.
436
439
  */
437
- NK_INTERNAL nk_dtype_t nk_umeyama_output_dtype(nk_dtype_t dtype) {
440
+ NK_INTERNAL nk_dtype_t nk_mesh_transform_dtype(nk_dtype_t dtype) {
438
441
  switch (dtype) {
439
442
  case nk_f64_k: return nk_f64_k;
440
- case nk_f32_k: return nk_f64_k;
443
+ case nk_f32_k: return nk_f32_k;
441
444
  case nk_f16_k: return nk_f32_k;
442
445
  case nk_bf16_k: return nk_f32_k;
443
446
  default: return nk_dtype_unknown_k;
@@ -450,7 +453,6 @@ NK_INTERNAL nk_dtype_t nk_umeyama_output_dtype(nk_dtype_t dtype) {
450
453
 
451
454
  #include "numkong/mesh/serial.h"
452
455
  #include "numkong/mesh/neon.h"
453
- #include "numkong/mesh/neonhalf.h"
454
456
  #include "numkong/mesh/neonbfdot.h"
455
457
  #include "numkong/mesh/haswell.h"
456
458
  #include "numkong/mesh/skylake.h"
@@ -499,10 +501,12 @@ NK_PUBLIC void nk_rmsd_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk
499
501
 
500
502
  NK_PUBLIC void nk_rmsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
501
503
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
502
- #if NK_TARGET_HASWELL
504
+ #if NK_TARGET_SKYLAKE
505
+ nk_rmsd_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
506
+ #elif NK_TARGET_HASWELL
503
507
  nk_rmsd_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
504
- #elif NK_TARGET_NEONHALF
505
- nk_rmsd_f16_neonhalf(a, b, n, a_centroid, b_centroid, rotation, scale, result);
508
+ #elif NK_TARGET_NEON
509
+ nk_rmsd_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
506
510
  #elif NK_TARGET_RVV
507
511
  nk_rmsd_f16_rvv(a, b, n, a_centroid, b_centroid, rotation, scale, result);
508
512
  #else
@@ -512,7 +516,9 @@ NK_PUBLIC void nk_rmsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk
512
516
 
513
517
  NK_PUBLIC void nk_rmsd_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
514
518
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
515
- #if NK_TARGET_HASWELL
519
+ #if NK_TARGET_SKYLAKE
520
+ nk_rmsd_bf16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
521
+ #elif NK_TARGET_HASWELL
516
522
  nk_rmsd_bf16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
517
523
  #elif NK_TARGET_NEONBFDOT
518
524
  nk_rmsd_bf16_neonbfdot(a, b, n, a_centroid, b_centroid, rotation, scale, result);
@@ -559,10 +565,12 @@ NK_PUBLIC void nk_kabsch_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
559
565
 
560
566
  NK_PUBLIC void nk_kabsch_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
561
567
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
562
- #if NK_TARGET_HASWELL
568
+ #if NK_TARGET_SKYLAKE
569
+ nk_kabsch_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
570
+ #elif NK_TARGET_HASWELL
563
571
  nk_kabsch_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
564
- #elif NK_TARGET_NEONHALF
565
- nk_kabsch_f16_neonhalf(a, b, n, a_centroid, b_centroid, rotation, scale, result);
572
+ #elif NK_TARGET_NEON
573
+ nk_kabsch_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
566
574
  #elif NK_TARGET_RVV
567
575
  nk_kabsch_f16_rvv(a, b, n, a_centroid, b_centroid, rotation, scale, result);
568
576
  #else
@@ -572,7 +580,9 @@ NK_PUBLIC void nk_kabsch_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
572
580
 
573
581
  NK_PUBLIC void nk_kabsch_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
574
582
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
575
- #if NK_TARGET_HASWELL
583
+ #if NK_TARGET_SKYLAKE
584
+ nk_kabsch_bf16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
585
+ #elif NK_TARGET_HASWELL
576
586
  nk_kabsch_bf16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
577
587
  #elif NK_TARGET_NEONBFDOT
578
588
  nk_kabsch_bf16_neonbfdot(a, b, n, a_centroid, b_centroid, rotation, scale, result);
@@ -619,10 +629,12 @@ NK_PUBLIC void nk_umeyama_f32(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
619
629
 
620
630
  NK_PUBLIC void nk_umeyama_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
621
631
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
622
- #if NK_TARGET_HASWELL
632
+ #if NK_TARGET_SKYLAKE
633
+ nk_umeyama_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
634
+ #elif NK_TARGET_HASWELL
623
635
  nk_umeyama_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
624
- #elif NK_TARGET_NEONHALF
625
- nk_umeyama_f16_neonhalf(a, b, n, a_centroid, b_centroid, rotation, scale, result);
636
+ #elif NK_TARGET_NEON
637
+ nk_umeyama_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
626
638
  #elif NK_TARGET_RVV
627
639
  nk_umeyama_f16_rvv(a, b, n, a_centroid, b_centroid, rotation, scale, result);
628
640
  #else
@@ -632,7 +644,9 @@ NK_PUBLIC void nk_umeyama_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
632
644
 
633
645
  NK_PUBLIC void nk_umeyama_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
634
646
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
635
- #if NK_TARGET_HASWELL
647
+ #if NK_TARGET_SKYLAKE
648
+ nk_umeyama_bf16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
649
+ #elif NK_TARGET_HASWELL
636
650
  nk_umeyama_bf16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
637
651
  #elif NK_TARGET_NEONBFDOT
638
652
  nk_umeyama_bf16_neonbfdot(a, b, n, a_centroid, b_centroid, rotation, scale, result);
@@ -17,7 +17,7 @@
17
17
 
18
18
  namespace ashvardanian::numkong {
19
19
 
20
- #pragma region - SVD Helpers for Scalar Fallbacks
20
+ #pragma region SVD Helpers for Scalar Fallbacks
21
21
 
22
22
  /** @brief 3x3 matrix determinant. */
23
23
  template <typename scalar_type_>
@@ -313,9 +313,9 @@ void svd3x3_(scalar_type_ const *a, scalar_type_ *svd_u, scalar_type_ *svd_s, sc
313
313
  svd_s[8] = s3_sq.sqrt();
314
314
  }
315
315
 
316
- #pragma endregion - SVD Helpers for Scalar Fallbacks
316
+ #pragma endregion SVD Helpers for Scalar Fallbacks
317
317
 
318
- #pragma region - Mesh Alignment Kernels
318
+ #pragma region Mesh Alignment Kernels
319
319
 
320
320
  /**
321
321
  * @brief Root Mean Square Deviation between two 3D point clouds (no alignment)
@@ -755,7 +755,7 @@ void umeyama(in_type_ const *a, in_type_ const *b, std::size_t n, transform_type
755
755
  }
756
756
  }
757
757
 
758
- #pragma endregion - Mesh Alignment Kernels
758
+ #pragma endregion Mesh Alignment Kernels
759
759
 
760
760
  } // namespace ashvardanian::numkong
761
761
 
@@ -62,9 +62,9 @@ NK_PUBLIC nk_dtype_t nk_kernel_output_dtype(nk_kernel_kind_t kind, nk_dtype_t in
62
62
  case nk_kernel_vincenty_k: return nk_vincenty_output_dtype(input);
63
63
  case nk_kernel_kld_k:
64
64
  case nk_kernel_jsd_k: return nk_probability_output_dtype(input);
65
- case nk_kernel_rmsd_k: return nk_rmsd_output_dtype(input);
66
- case nk_kernel_kabsch_k: return nk_kabsch_output_dtype(input);
67
- case nk_kernel_umeyama_k: return nk_umeyama_output_dtype(input);
65
+ case nk_kernel_rmsd_k:
66
+ case nk_kernel_kabsch_k:
67
+ case nk_kernel_umeyama_k: return nk_mesh_metric_dtype(input);
68
68
  case nk_kernel_sparse_dot_k: return nk_sparse_dot_output_dtype(input);
69
69
  case nk_kernel_maxsim_packed_k: return nk_maxsim_output_dtype(input);
70
70
  default: return nk_dtype_unknown_k;
@@ -37,6 +37,7 @@
37
37
  #define NK_NUMKONG_HPP
38
38
 
39
39
  #include "numkong/random.hpp"
40
+ #include "numkong/cast.hpp"
40
41
  #include "numkong/dot.hpp"
41
42
  #include "numkong/spatial.hpp"
42
43
  #include "numkong/spatials.hpp"
@@ -5,17 +5,21 @@ These are used in variational inference, topic modeling, and distribution compar
5
5
 
6
6
  Kullback-Leibler divergence from $P$ to $Q$:
7
7
 
8
- ```math
8
+ $$
9
9
  \text{KLD}(P \| Q) = \sum_{i=0}^{n-1} P(i) \log_2 \frac{P(i)}{Q(i)}
10
- ```
10
+ $$
11
11
 
12
12
  Jensen-Shannon distance is the square root of the symmetrized KLD through a mixture:
13
13
 
14
- $$\text{JSD}(P, Q) = \frac{1}{2} \text{KLD}(P \| M) + \frac{1}{2} \text{KLD}(Q \| M)$$
14
+ $$
15
+ \text{JSD}(P, Q) = \frac{1}{2} \text{KLD}(P \| M) + \frac{1}{2} \text{KLD}(Q \| M)
16
+ $$
15
17
 
16
18
  where $M = \frac{P + Q}{2}$, yielding the distance:
17
19
 
18
- $$d_{JS}(P, Q) = \sqrt{\text{JSD}(P, Q)}$$
20
+ $$
21
+ d_{JS}(P, Q) = \sqrt{\text{JSD}(P, Q)}
22
+ $$
19
23
 
20
24
  Unlike the raw divergence, $d_{JS}$ is a true metric satisfying the triangle inequality.
21
25
 
@@ -35,9 +39,9 @@ def jsd(p: np.ndarray, q: np.ndarray) -> float:
35
39
 
36
40
  ## Use Cases
37
41
 
38
- __Kullback-Leibler divergence__ is the workhorse of variational inference (ELBO objective), knowledge distillation between neural networks, information gain in decision trees, and measuring fit between a model and observed data.
42
+ __Kullback-Leibler divergence__ is widely used in variational inference (ELBO objective), knowledge distillation between neural networks, information gain in decision trees, and measuring fit between a model and observed data.
39
43
 
40
- __Jensen-Shannon distance__ sees primary use in microbiome community comparison (enterotyping), where its metric property enables clustering with standard algorithms. It also appears in distribution drift detection, topic model evaluation, and as the theoretical foundation of the original GAN objective — though in practice GAN training uses proxy losses rather than computing JSD directly.
44
+ __Jensen-Shannon distance__ is commonly used in microbiome community comparison (enterotyping), where its metric property enables clustering with standard algorithms. It also appears in distribution drift detection, topic model evaluation, and as the theoretical foundation of the original GAN objective — though in practice GAN training uses proxy losses rather than computing JSD directly.
41
45
 
42
46
  ## Input & Output Types
43
47
 
@@ -149,25 +153,25 @@ Measured with Wasmtime v42 (Cranelift backend).
149
153
  | `nk_kld_f16_serial` | 0.118 gb/s, 1.04K ulp | 0.127 gb/s, 4.53K ulp | 0.111 gb/s, 18.3K ulp |
150
154
  | `nk_jsd_f16_serial` | 0.0748 gb/s, 1.4 ulp | 0.0681 gb/s, 2.6 ulp | 0.0857 gb/s, 9.7 ulp |
151
155
 
152
- ### Apple M4
156
+ ### Apple M5
153
157
 
154
158
  #### Native
155
159
 
156
160
  | Kernel | 256 | 1024 | 4096 |
157
161
  | :-------------------- | -----------------------: | -----------------------: | -----------------------: |
158
162
  | __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
159
- | `nk_kld_f64_serial` | 2.21 gb/s, 5.6K ulp | 2.22 gb/s, 25K ulp | 2.18 gb/s, 99K ulp |
160
- | `nk_jsd_f64_serial` | 1.40 gb/s, 0.4 ulp | 1.45 gb/s, 0.4 ulp | 1.45 gb/s, 0.5 ulp |
163
+ | `nk_kld_f64_serial` | 3.22 gb/s, 5.6K ulp | 3.36 gb/s, 25K ulp | 3.32 gb/s, 99K ulp |
164
+ | `nk_jsd_f64_serial` | 2.06 gb/s, 0.4 ulp | 2.17 gb/s, 0.4 ulp | 2.17 gb/s, 0.5 ulp |
161
165
  | __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
162
- | `nk_kld_f32_serial` | 6.29 gb/s, 1.0K ulp | 6.35 gb/s, 4.5K ulp | 6.22 gb/s, 18K ulp |
163
- | `nk_jsd_f32_serial` | 1.21 gb/s, 0.4 ulp | 1.20 gb/s, 0.4 ulp | 1.20 gb/s, 4.6 ulp |
164
- | `nk_kld_f32_neon` | 14.5 gb/s, 1.0K ulp | 14.4 gb/s, 4.5K ulp | 12.8 gb/s, 18K ulp |
165
- | `nk_jsd_f32_neon` | 6.81 gb/s, 15 ulp | 7.04 gb/s, 14 ulp | 6.78 gb/s, 9.9 ulp |
166
+ | `nk_kld_f32_serial` | 9.26 gb/s, 1.0K ulp | 8.73 gb/s, 4.5K ulp | 9.10 gb/s, 18K ulp |
167
+ | `nk_jsd_f32_serial` | 2.08 gb/s, 0.4 ulp | 2.16 gb/s, 0.4 ulp | 2.13 gb/s, 4.6 ulp |
168
+ | `nk_kld_f32_neon` | 19.0 gb/s, 1.0K ulp | 17.4 gb/s, 4.5K ulp | 18.1 gb/s, 18K ulp |
169
+ | `nk_jsd_f32_neon` | 9.75 gb/s, 15 ulp | 9.32 gb/s, 14 ulp | 9.62 gb/s, 9.9 ulp |
166
170
  | __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
167
- | `nk_kld_bf16_serial` | 3.16 gb/s, 1.0K ulp | 2.96 gb/s, 4.5K ulp | 3.16 gb/s, 18K ulp |
168
- | `nk_jsd_bf16_serial` | 0.611 gb/s, 1.4 ulp | 0.595 gb/s, 2.9 ulp | 0.613 gb/s, 9.7 ulp |
171
+ | `nk_kld_bf16_serial` | 4.58 gb/s, 1.0K ulp | 4.47 gb/s, 4.5K ulp | 4.65 gb/s, 18K ulp |
172
+ | `nk_jsd_bf16_serial` | 1.08 gb/s, 1.4 ulp | 1.07 gb/s, 2.9 ulp | 1.09 gb/s, 9.7 ulp |
169
173
  | __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
170
- | `nk_kld_f16_serial` | 3.15 gb/s, 1.0K ulp | 3.14 gb/s, 4.5K ulp | 2.81 gb/s, 18K ulp |
171
- | `nk_jsd_f16_serial` | 0.610 gb/s, 1.4 ulp | 0.611 gb/s, 2.7 ulp | 0.602 gb/s, 8.7 ulp |
172
- | `nk_kld_f16_neonhalf` | 6.78 gb/s, 1.0K ulp | 6.72 gb/s, 4.5K ulp | 6.09 gb/s, 18K ulp |
173
- | `nk_jsd_f16_neonhalf` | 3.42 gb/s, 15 ulp | 3.40 gb/s, 14 ulp | 3.14 gb/s, 9.9 ulp |
174
+ | `nk_kld_f16_serial` | 4.63 gb/s, 1.0K ulp | 4.45 gb/s, 4.5K ulp | 4.55 gb/s, 18K ulp |
175
+ | `nk_jsd_f16_serial` | 1.03 gb/s, 1.4 ulp | 0.962 gb/s, 2.7 ulp | 0.976 gb/s, 8.7 ulp |
176
+ | `nk_kld_f16_neonhalf` | 10.2 gb/s, 1.0K ulp | 9.67 gb/s, 4.5K ulp | 9.99 gb/s, 18K ulp |
177
+ | `nk_jsd_f16_neonhalf` | 5.00 gb/s, 15 ulp | 4.79 gb/s, 14 ulp | 4.94 gb/s, 9.9 ulp |
@@ -57,8 +57,8 @@ NK_PUBLIC float32x4_t nk_log2_f32x4_neon_(float32x4_t x) {
57
57
  NK_PUBLIC void nk_kld_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
58
58
  nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
59
59
  float32x4_t epsilon_f32x4 = vdupq_n_f32(epsilon);
60
- float64x2_t sum_lower_f64x2 = vdupq_n_f64(0.0);
61
- float64x2_t sum_upper_f64x2 = vdupq_n_f64(0.0);
60
+ float64x2_t sum_low_f64x2 = vdupq_n_f64(0.0);
61
+ float64x2_t sum_high_f64x2 = vdupq_n_f64(0.0);
62
62
  float32x4_t a_f32x4, b_f32x4;
63
63
 
64
64
  nk_kld_f32_neon_cycle:
@@ -79,20 +79,20 @@ nk_kld_f32_neon_cycle:
79
79
  float32x4_t ratio_f32x4 = vdivq_f32(vaddq_f32(a_f32x4, epsilon_f32x4), vaddq_f32(b_f32x4, epsilon_f32x4));
80
80
  float32x4_t log_ratio_f32x4 = nk_log2_f32x4_neon_(ratio_f32x4);
81
81
  float32x4_t contribution_f32x4 = vmulq_f32(a_f32x4, log_ratio_f32x4);
82
- sum_lower_f64x2 = vaddq_f64(sum_lower_f64x2, vcvt_f64_f32(vget_low_f32(contribution_f32x4)));
83
- sum_upper_f64x2 = vaddq_f64(sum_upper_f64x2, vcvt_f64_f32(vget_high_f32(contribution_f32x4)));
82
+ sum_low_f64x2 = vaddq_f64(sum_low_f64x2, vcvt_f64_f32(vget_low_f32(contribution_f32x4)));
83
+ sum_high_f64x2 = vaddq_f64(sum_high_f64x2, vcvt_high_f64_f32(contribution_f32x4));
84
84
  if (n != 0) goto nk_kld_f32_neon_cycle;
85
85
 
86
86
  nk_f64_t log2_normalizer = 0.6931471805599453;
87
- nk_f64_t sum = vaddvq_f64(vaddq_f64(sum_lower_f64x2, sum_upper_f64x2)) * log2_normalizer;
87
+ nk_f64_t sum = vaddvq_f64(vaddq_f64(sum_low_f64x2, sum_high_f64x2)) * log2_normalizer;
88
88
  *result = sum;
89
89
  }
90
90
 
91
91
  NK_PUBLIC void nk_jsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
92
92
  nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
93
93
  float32x4_t epsilon_f32x4 = vdupq_n_f32(epsilon);
94
- float64x2_t sum_lower_f64x2 = vdupq_n_f64(0.0);
95
- float64x2_t sum_upper_f64x2 = vdupq_n_f64(0.0);
94
+ float64x2_t sum_low_f64x2 = vdupq_n_f64(0.0);
95
+ float64x2_t sum_high_f64x2 = vdupq_n_f64(0.0);
96
96
  float32x4_t a_f32x4, b_f32x4;
97
97
 
98
98
  nk_jsd_f32_neon_cycle:
@@ -118,12 +118,12 @@ nk_jsd_f32_neon_cycle:
118
118
  float32x4_t contribution_a_f32x4 = vmulq_f32(a_f32x4, log_ratio_a_f32x4);
119
119
  float32x4_t contribution_b_f32x4 = vmulq_f32(b_f32x4, log_ratio_b_f32x4);
120
120
  float32x4_t contribution_f32x4 = vaddq_f32(contribution_a_f32x4, contribution_b_f32x4);
121
- sum_lower_f64x2 = vaddq_f64(sum_lower_f64x2, vcvt_f64_f32(vget_low_f32(contribution_f32x4)));
122
- sum_upper_f64x2 = vaddq_f64(sum_upper_f64x2, vcvt_f64_f32(vget_high_f32(contribution_f32x4)));
121
+ sum_low_f64x2 = vaddq_f64(sum_low_f64x2, vcvt_f64_f32(vget_low_f32(contribution_f32x4)));
122
+ sum_high_f64x2 = vaddq_f64(sum_high_f64x2, vcvt_high_f64_f32(contribution_f32x4));
123
123
  if (n != 0) goto nk_jsd_f32_neon_cycle;
124
124
 
125
125
  nk_f64_t log2_normalizer = 0.6931471805599453;
126
- nk_f64_t sum = vaddvq_f64(vaddq_f64(sum_lower_f64x2, sum_upper_f64x2)) * log2_normalizer / 2.0;
126
+ nk_f64_t sum = vaddvq_f64(vaddq_f64(sum_low_f64x2, sum_high_f64x2)) * log2_normalizer / 2.0;
127
127
  *result = sum > 0 ? nk_f64_sqrt_neon(sum) : 0;
128
128
  }
129
129
 
@@ -134,76 +134,106 @@ nk_jsd_f32_neon_cycle:
134
134
  #endif
135
135
  #endif // NK_TARGET_NEON
136
136
 
137
- #if NK_TARGET_NEONHALF
137
+ #if NK_TARGET_NEON
138
138
  #if defined(__clang__)
139
- #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
139
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd"))), apply_to = function)
140
140
  #elif defined(__GNUC__)
141
141
  #pragma GCC push_options
142
- #pragma GCC target("arch=armv8.2-a+simd+fp16")
142
+ #pragma GCC target("arch=armv8.2-a+simd")
143
143
  #endif
144
144
 
145
- NK_PUBLIC void nk_kld_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
145
+ NK_PUBLIC void nk_kld_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
146
146
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
147
147
  nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
148
148
  float32x4_t epsilon_f32x4 = vdupq_n_f32(epsilon);
149
- float32x4_t a_f32x4, b_f32x4;
149
+ float32x4_t a_low_f32x4, a_high_f32x4, b_low_f32x4, b_high_f32x4;
150
150
 
151
- nk_kld_f16_neonhalf_cycle:
152
- if (n < 4) {
153
- nk_b64_vec_t a_vec, b_vec;
154
- nk_partial_load_b16x4_serial_(a, &a_vec, n);
155
- nk_partial_load_b16x4_serial_(b, &b_vec, n);
156
- a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
157
- b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
151
+ nk_kld_f16_neon_cycle:
152
+ if (n < 8) {
153
+ nk_b128_vec_t a_vec, b_vec;
154
+ nk_partial_load_b16x8_serial_(a, &a_vec, n);
155
+ nk_partial_load_b16x8_serial_(b, &b_vec, n);
156
+ float16x8_t a_f16x8 = vreinterpretq_f16_u16(a_vec.u16x8);
157
+ float16x8_t b_f16x8 = vreinterpretq_f16_u16(b_vec.u16x8);
158
+ a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
159
+ a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
160
+ b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
161
+ b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
158
162
  n = 0;
159
163
  }
160
164
  else {
161
- a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a));
162
- b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b));
163
- n -= 4, a += 4, b += 4;
165
+ float16x8_t a_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)a));
166
+ float16x8_t b_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)b));
167
+ a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
168
+ a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
169
+ b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
170
+ b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
171
+ n -= 8, a += 8, b += 8;
164
172
  }
165
173
 
166
- float32x4_t ratio_f32x4 = vdivq_f32(vaddq_f32(a_f32x4, epsilon_f32x4), vaddq_f32(b_f32x4, epsilon_f32x4));
167
- float32x4_t log_ratio_f32x4 = nk_log2_f32x4_neon_(ratio_f32x4);
168
- float32x4_t contribution_f32x4 = vmulq_f32(a_f32x4, log_ratio_f32x4);
169
- sum_f32x4 = vaddq_f32(sum_f32x4, contribution_f32x4);
170
- if (n) goto nk_kld_f16_neonhalf_cycle;
174
+ float32x4_t ratio_low_f32x4 = vdivq_f32(vaddq_f32(a_low_f32x4, epsilon_f32x4),
175
+ vaddq_f32(b_low_f32x4, epsilon_f32x4));
176
+ float32x4_t ratio_high_f32x4 = vdivq_f32(vaddq_f32(a_high_f32x4, epsilon_f32x4),
177
+ vaddq_f32(b_high_f32x4, epsilon_f32x4));
178
+ float32x4_t log_ratio_low_f32x4 = nk_log2_f32x4_neon_(ratio_low_f32x4);
179
+ float32x4_t log_ratio_high_f32x4 = nk_log2_f32x4_neon_(ratio_high_f32x4);
180
+ sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, log_ratio_low_f32x4);
181
+ sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, log_ratio_high_f32x4);
182
+ if (n) goto nk_kld_f16_neon_cycle;
171
183
 
172
184
  nk_f32_t log2_normalizer = 0.693147181f;
173
185
  nk_f32_t sum = vaddvq_f32(sum_f32x4) * log2_normalizer;
174
186
  *result = sum;
175
187
  }
176
188
 
177
- NK_PUBLIC void nk_jsd_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
189
+ NK_PUBLIC void nk_jsd_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
178
190
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
179
191
  nk_f32_t epsilon = NK_F32_DIVISION_EPSILON;
180
192
  float32x4_t epsilon_f32x4 = vdupq_n_f32(epsilon);
181
- float32x4_t a_f32x4, b_f32x4;
193
+ float32x4_t a_low_f32x4, a_high_f32x4, b_low_f32x4, b_high_f32x4;
182
194
 
183
- nk_jsd_f16_neonhalf_cycle:
184
- if (n < 4) {
185
- nk_b64_vec_t a_vec, b_vec;
186
- nk_partial_load_b16x4_serial_(a, &a_vec, n);
187
- nk_partial_load_b16x4_serial_(b, &b_vec, n);
188
- a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
189
- b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
195
+ nk_jsd_f16_neon_cycle:
196
+ if (n < 8) {
197
+ nk_b128_vec_t a_vec, b_vec;
198
+ nk_partial_load_b16x8_serial_(a, &a_vec, n);
199
+ nk_partial_load_b16x8_serial_(b, &b_vec, n);
200
+ float16x8_t a_f16x8 = vreinterpretq_f16_u16(a_vec.u16x8);
201
+ float16x8_t b_f16x8 = vreinterpretq_f16_u16(b_vec.u16x8);
202
+ a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
203
+ a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
204
+ b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
205
+ b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
190
206
  n = 0;
191
207
  }
192
208
  else {
193
- a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a));
194
- b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b));
195
- n -= 4, a += 4, b += 4;
209
+ float16x8_t a_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)a));
210
+ float16x8_t b_f16x8 = vreinterpretq_f16_u16(vld1q_u16((nk_u16_t const *)b));
211
+ a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
212
+ a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
213
+ b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
214
+ b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
215
+ n -= 8, a += 8, b += 8;
196
216
  }
197
217
 
198
- float32x4_t mean_f32x4 = vmulq_n_f32(vaddq_f32(a_f32x4, b_f32x4), 0.5f);
199
- float32x4_t ratio_a_f32x4 = vdivq_f32(vaddq_f32(a_f32x4, epsilon_f32x4), vaddq_f32(mean_f32x4, epsilon_f32x4));
200
- float32x4_t ratio_b_f32x4 = vdivq_f32(vaddq_f32(b_f32x4, epsilon_f32x4), vaddq_f32(mean_f32x4, epsilon_f32x4));
201
- float32x4_t log_ratio_a_f32x4 = nk_log2_f32x4_neon_(ratio_a_f32x4);
202
- float32x4_t log_ratio_b_f32x4 = nk_log2_f32x4_neon_(ratio_b_f32x4);
203
- float32x4_t contribution_a_f32x4 = vmulq_f32(a_f32x4, log_ratio_a_f32x4);
204
- float32x4_t contribution_b_f32x4 = vmulq_f32(b_f32x4, log_ratio_b_f32x4);
205
- sum_f32x4 = vaddq_f32(sum_f32x4, vaddq_f32(contribution_a_f32x4, contribution_b_f32x4));
206
- if (n) goto nk_jsd_f16_neonhalf_cycle;
218
+ float32x4_t mean_low_f32x4 = vmulq_n_f32(vaddq_f32(a_low_f32x4, b_low_f32x4), 0.5f);
219
+ float32x4_t mean_high_f32x4 = vmulq_n_f32(vaddq_f32(a_high_f32x4, b_high_f32x4), 0.5f);
220
+ float32x4_t ratio_a_low_f32x4 = vdivq_f32(vaddq_f32(a_low_f32x4, epsilon_f32x4),
221
+ vaddq_f32(mean_low_f32x4, epsilon_f32x4));
222
+ float32x4_t ratio_a_high_f32x4 = vdivq_f32(vaddq_f32(a_high_f32x4, epsilon_f32x4),
223
+ vaddq_f32(mean_high_f32x4, epsilon_f32x4));
224
+ float32x4_t ratio_b_low_f32x4 = vdivq_f32(vaddq_f32(b_low_f32x4, epsilon_f32x4),
225
+ vaddq_f32(mean_low_f32x4, epsilon_f32x4));
226
+ float32x4_t ratio_b_high_f32x4 = vdivq_f32(vaddq_f32(b_high_f32x4, epsilon_f32x4),
227
+ vaddq_f32(mean_high_f32x4, epsilon_f32x4));
228
+ float32x4_t log_ratio_a_low_f32x4 = nk_log2_f32x4_neon_(ratio_a_low_f32x4);
229
+ float32x4_t log_ratio_a_high_f32x4 = nk_log2_f32x4_neon_(ratio_a_high_f32x4);
230
+ float32x4_t log_ratio_b_low_f32x4 = nk_log2_f32x4_neon_(ratio_b_low_f32x4);
231
+ float32x4_t log_ratio_b_high_f32x4 = nk_log2_f32x4_neon_(ratio_b_high_f32x4);
232
+ sum_f32x4 = vfmaq_f32(sum_f32x4, a_low_f32x4, log_ratio_a_low_f32x4);
233
+ sum_f32x4 = vfmaq_f32(sum_f32x4, a_high_f32x4, log_ratio_a_high_f32x4);
234
+ sum_f32x4 = vfmaq_f32(sum_f32x4, b_low_f32x4, log_ratio_b_low_f32x4);
235
+ sum_f32x4 = vfmaq_f32(sum_f32x4, b_high_f32x4, log_ratio_b_high_f32x4);
236
+ if (n) goto nk_jsd_f16_neon_cycle;
207
237
 
208
238
  nk_f32_t log2_normalizer = 0.693147181f;
209
239
  nk_f32_t sum = vaddvq_f32(sum_f32x4) * log2_normalizer / 2;
@@ -215,7 +245,7 @@ nk_jsd_f16_neonhalf_cycle:
215
245
  #elif defined(__GNUC__)
216
246
  #pragma GCC pop_options
217
247
  #endif
218
- #endif // NK_TARGET_NEONHALF
248
+ #endif // NK_TARGET_NEON
219
249
 
220
250
  #if defined(__cplusplus)
221
251
  } // extern "C"