numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -10,15 +10,14 @@
10
10
  *
11
11
  * ARM NEON instructions for distance computations:
12
12
  *
13
- * Intrinsic Instruction Latency Throughput
14
- * A76 M4+/V1+/Oryon
15
- * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
16
- * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy 2/cy 4/cy
17
- * vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
18
- * vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
19
- * vrsqrteq_f32 FRSQRTE (V.4S, V.4S) 2cy 2/cy 2/cy
20
- * vsqrtq_f32 FSQRT (V.4S, V.4S) 9-12cy 0.25/cy 0.25/cy
21
- * vrecpeq_f32 FRECPE (V.4S, V.4S) 2cy 2/cy 2/cy
13
+ * Intrinsic Instruction A76 M5
14
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy @ 2p 3cy @ 4p
15
+ * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy @ 2p 3cy @ 4p
16
+ * vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
17
+ * vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
18
+ * vrsqrteq_f32 FRSQRTE (V.4S, V.4S) 2cy @ 2p 3cy @ 1p
19
+ * vsqrtq_f32 FSQRT (V.4S, V.4S) 12cy @ 1p 9cy @ 1p
20
+ * vrecpeq_f32 FRECPE (V.4S, V.4S) 2cy @ 2p 3cy @ 1p
22
21
  *
23
22
  * FRSQRTE provides ~8-bit precision; two Newton-Raphson iterations via vrsqrtsq_f32 achieve
24
23
  * ~23-bit precision, sufficient for f32. This is much faster than FSQRT (0.25/cy).
@@ -55,10 +54,10 @@ extern "C" {
55
54
  * Much faster than `vsqrtq_f32` (2 cy vs 9-12 cy latency, 2/cy vs 0.25/cy throughput).
56
55
  */
57
56
  NK_INTERNAL float32x4_t nk_rsqrt_f32x4_neon_(float32x4_t x) {
58
- float32x4_t rsqrt = vrsqrteq_f32(x);
59
- rsqrt = vmulq_f32(rsqrt, vrsqrtsq_f32(vmulq_f32(x, rsqrt), rsqrt));
60
- rsqrt = vmulq_f32(rsqrt, vrsqrtsq_f32(vmulq_f32(x, rsqrt), rsqrt));
61
- return rsqrt;
57
+ float32x4_t rsqrt_f32x4 = vrsqrteq_f32(x);
58
+ rsqrt_f32x4 = vmulq_f32(rsqrt_f32x4, vrsqrtsq_f32(vmulq_f32(x, rsqrt_f32x4), rsqrt_f32x4));
59
+ rsqrt_f32x4 = vmulq_f32(rsqrt_f32x4, vrsqrtsq_f32(vmulq_f32(x, rsqrt_f32x4), rsqrt_f32x4));
60
+ return rsqrt_f32x4;
62
61
  }
63
62
 
64
63
  /**
@@ -70,29 +69,29 @@ NK_INTERNAL float32x4_t nk_rsqrt_f32x4_neon_(float32x4_t x) {
70
69
  * prefer `vsqrtq_f64` instead.
71
70
  */
72
71
  NK_INTERNAL float64x2_t nk_rsqrt_f64x2_neon_(float64x2_t x) {
73
- float64x2_t rsqrt = vrsqrteq_f64(x);
74
- rsqrt = vmulq_f64(rsqrt, vrsqrtsq_f64(vmulq_f64(x, rsqrt), rsqrt));
75
- rsqrt = vmulq_f64(rsqrt, vrsqrtsq_f64(vmulq_f64(x, rsqrt), rsqrt));
76
- rsqrt = vmulq_f64(rsqrt, vrsqrtsq_f64(vmulq_f64(x, rsqrt), rsqrt));
77
- return rsqrt;
72
+ float64x2_t rsqrt_f64x2 = vrsqrteq_f64(x);
73
+ rsqrt_f64x2 = vmulq_f64(rsqrt_f64x2, vrsqrtsq_f64(vmulq_f64(x, rsqrt_f64x2), rsqrt_f64x2));
74
+ rsqrt_f64x2 = vmulq_f64(rsqrt_f64x2, vrsqrtsq_f64(vmulq_f64(x, rsqrt_f64x2), rsqrt_f64x2));
75
+ rsqrt_f64x2 = vmulq_f64(rsqrt_f64x2, vrsqrtsq_f64(vmulq_f64(x, rsqrt_f64x2), rsqrt_f64x2));
76
+ return rsqrt_f64x2;
78
77
  }
79
78
 
80
79
  NK_INTERNAL nk_f32_t nk_angular_normalize_f32_neon_(nk_f32_t ab, nk_f32_t a2, nk_f32_t b2) {
81
80
  if (a2 == 0 && b2 == 0) return 0;
82
81
  if (ab == 0) return 1;
83
82
  nk_f32_t squares_arr[2] = {a2, b2};
84
- float32x2_t squares = vld1_f32(squares_arr);
83
+ float32x2_t squares_f32x2 = vld1_f32(squares_arr);
85
84
  // Unlike x86, Arm NEON manuals don't explicitly mention the accuracy of their `rsqrt` approximation.
86
85
  // Third-party research suggests that it's less accurate than SSE instructions, having an error of 1.5×2⁻¹².
87
86
  // One or two rounds of Newton-Raphson refinement are recommended to improve the accuracy.
88
87
  // https://github.com/lighttransport/embree-aarch64/issues/24
89
88
  // https://github.com/lighttransport/embree-aarch64/blob/3f75f8cb4e553d13dced941b5fefd4c826835a6b/common/math/math.h#L137-L145
90
- float32x2_t rsqrts = vrsqrte_f32(squares);
89
+ float32x2_t rsqrts_f32x2 = vrsqrte_f32(squares_f32x2);
91
90
  // Perform two rounds of Newton-Raphson refinement:
92
91
  // https://en.wikipedia.org/wiki/Newton%27s_method
93
- rsqrts = vmul_f32(rsqrts, vrsqrts_f32(vmul_f32(squares, rsqrts), rsqrts));
94
- rsqrts = vmul_f32(rsqrts, vrsqrts_f32(vmul_f32(squares, rsqrts), rsqrts));
95
- vst1_f32(squares_arr, rsqrts);
92
+ rsqrts_f32x2 = vmul_f32(rsqrts_f32x2, vrsqrts_f32(vmul_f32(squares_f32x2, rsqrts_f32x2), rsqrts_f32x2));
93
+ rsqrts_f32x2 = vmul_f32(rsqrts_f32x2, vrsqrts_f32(vmul_f32(squares_f32x2, rsqrts_f32x2), rsqrts_f32x2));
94
+ vst1_f32(squares_arr, rsqrts_f32x2);
96
95
  nk_f32_t result = 1 - ab * squares_arr[0] * squares_arr[1];
97
96
  return result > 0 ? result : 0;
98
97
  }
@@ -101,25 +100,25 @@ NK_INTERNAL nk_f64_t nk_angular_normalize_f64_neon_(nk_f64_t ab, nk_f64_t a2, nk
101
100
  if (a2 == 0 && b2 == 0) return 0;
102
101
  if (ab == 0) return 1;
103
102
  nk_f64_t squares_arr[2] = {a2, b2};
104
- float64x2_t squares = vld1q_f64(squares_arr);
103
+ float64x2_t squares_f64x2 = vld1q_f64(squares_arr);
105
104
 
106
105
  // Unlike x86, Arm NEON manuals don't explicitly mention the accuracy of their `rsqrt` approximation.
107
106
  // Third-party research suggests that it's less accurate than SSE instructions, having an error of 1.5×2⁻¹².
108
107
  // One or two rounds of Newton-Raphson refinement are recommended to improve the accuracy.
109
108
  // https://github.com/lighttransport/embree-aarch64/issues/24
110
109
  // https://github.com/lighttransport/embree-aarch64/blob/3f75f8cb4e553d13dced941b5fefd4c826835a6b/common/math/math.h#L137-L145
111
- float64x2_t rsqrts_f64x2 = vrsqrteq_f64(squares);
110
+ float64x2_t rsqrts_f64x2 = vrsqrteq_f64(squares_f64x2);
112
111
  // Perform three rounds of Newton-Raphson refinement for f64 precision (~48 bits):
113
112
  // https://en.wikipedia.org/wiki/Newton%27s_method
114
- rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(squares, rsqrts_f64x2), rsqrts_f64x2));
115
- rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(squares, rsqrts_f64x2), rsqrts_f64x2));
116
- rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(squares, rsqrts_f64x2), rsqrts_f64x2));
113
+ rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(squares_f64x2, rsqrts_f64x2), rsqrts_f64x2));
114
+ rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(squares_f64x2, rsqrts_f64x2), rsqrts_f64x2));
115
+ rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(squares_f64x2, rsqrts_f64x2), rsqrts_f64x2));
117
116
  vst1q_f64(squares_arr, rsqrts_f64x2);
118
117
  nk_f64_t result = 1 - ab * squares_arr[0] * squares_arr[1];
119
118
  return result > 0 ? result : 0;
120
119
  }
121
120
 
122
- #pragma region - Traditional Floats
121
+ #pragma region F32 and F64 Floats
123
122
 
124
123
  NK_PUBLIC void nk_sqeuclidean_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
125
124
  // Accumulate in f64 for numerical stability (2 f32s per iteration, avoids slow vget_low/high)
@@ -243,8 +242,8 @@ nk_angular_f64_neon_cycle:
243
242
  nk_dot_stable_sum_f64x2_neon_(ab_sum_f64x2, ab_compensation_f64x2), vaddvq_f64(a2_f64x2), vaddvq_f64(b2_f64x2));
244
243
  }
245
244
 
246
- #pragma endregion - Traditional Floats
247
- #pragma region - Smaller Floats
245
+ #pragma endregion F32 and F64 Floats
246
+ #pragma region F16 and BF16 Floats
248
247
 
249
248
  NK_PUBLIC void nk_sqeuclidean_bf16_neon(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
250
249
  uint16x8_t a_u16x8, b_u16x8;
@@ -264,9 +263,9 @@ nk_sqeuclidean_bf16_neon_cycle:
264
263
  a += 8, b += 8, n -= 8;
265
264
  }
266
265
  float32x4_t a_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a_u16x8), 16));
267
- float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(a_u16x8), 16));
266
+ float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(a_u16x8, 16));
268
267
  float32x4_t b_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b_u16x8), 16));
269
- float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b_u16x8), 16));
268
+ float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(b_u16x8, 16));
270
269
  float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
271
270
  float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
272
271
  sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
@@ -300,9 +299,9 @@ nk_angular_bf16_neon_cycle:
300
299
  a += 8, b += 8, n -= 8;
301
300
  }
302
301
  float32x4_t a_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a_u16x8), 16));
303
- float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(a_u16x8), 16));
302
+ float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(a_u16x8, 16));
304
303
  float32x4_t b_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b_u16x8), 16));
305
- float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b_u16x8), 16));
304
+ float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_high_n_u16(b_u16x8, 16));
306
305
  ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
307
306
  ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
308
307
  a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
@@ -316,6 +315,80 @@ nk_angular_bf16_neon_cycle:
316
315
  *result = nk_angular_normalize_f32_neon_(ab, a2, b2);
317
316
  }
318
317
 
318
+ NK_PUBLIC void nk_sqeuclidean_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
319
+ uint16x8_t a_u16x8, b_u16x8;
320
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
321
+ nk_sqeuclidean_f16_neon_cycle:
322
+ if (n < 8) {
323
+ nk_b128_vec_t a_vec, b_vec;
324
+ nk_partial_load_b16x8_serial_(a, &a_vec, n);
325
+ nk_partial_load_b16x8_serial_(b, &b_vec, n);
326
+ a_u16x8 = a_vec.u16x8;
327
+ b_u16x8 = b_vec.u16x8;
328
+ n = 0;
329
+ }
330
+ else {
331
+ a_u16x8 = vld1q_u16((nk_u16_t const *)a);
332
+ b_u16x8 = vld1q_u16((nk_u16_t const *)b);
333
+ a += 8, b += 8, n -= 8;
334
+ }
335
+ float16x8_t a_f16x8 = vreinterpretq_f16_u16(a_u16x8);
336
+ float16x8_t b_f16x8 = vreinterpretq_f16_u16(b_u16x8);
337
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
338
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
339
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
340
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
341
+ float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
342
+ float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
343
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
344
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_high_f32x4, diff_high_f32x4);
345
+ if (n) goto nk_sqeuclidean_f16_neon_cycle;
346
+ *result = vaddvq_f32(sum_f32x4);
347
+ }
348
+
349
+ NK_PUBLIC void nk_euclidean_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
350
+ nk_sqeuclidean_f16_neon(a, b, n, result);
351
+ *result = nk_f32_sqrt_neon(*result);
352
+ }
353
+
354
+ NK_PUBLIC void nk_angular_f16_neon(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
355
+ uint16x8_t a_u16x8, b_u16x8;
356
+ float32x4_t ab_f32x4 = vdupq_n_f32(0);
357
+ float32x4_t a2_f32x4 = vdupq_n_f32(0);
358
+ float32x4_t b2_f32x4 = vdupq_n_f32(0);
359
+ nk_angular_f16_neon_cycle:
360
+ if (n < 8) {
361
+ nk_b128_vec_t a_vec, b_vec;
362
+ nk_partial_load_b16x8_serial_(a, &a_vec, n);
363
+ nk_partial_load_b16x8_serial_(b, &b_vec, n);
364
+ a_u16x8 = a_vec.u16x8;
365
+ b_u16x8 = b_vec.u16x8;
366
+ n = 0;
367
+ }
368
+ else {
369
+ a_u16x8 = vld1q_u16((nk_u16_t const *)a);
370
+ b_u16x8 = vld1q_u16((nk_u16_t const *)b);
371
+ a += 8, b += 8, n -= 8;
372
+ }
373
+ float16x8_t a_f16x8 = vreinterpretq_f16_u16(a_u16x8);
374
+ float16x8_t b_f16x8 = vreinterpretq_f16_u16(b_u16x8);
375
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
376
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
377
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
378
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
379
+ ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
380
+ ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
381
+ a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
382
+ a2_f32x4 = vfmaq_f32(a2_f32x4, a_high_f32x4, a_high_f32x4);
383
+ b2_f32x4 = vfmaq_f32(b2_f32x4, b_low_f32x4, b_low_f32x4);
384
+ b2_f32x4 = vfmaq_f32(b2_f32x4, b_high_f32x4, b_high_f32x4);
385
+ if (n) goto nk_angular_f16_neon_cycle;
386
+ nk_f32_t ab = vaddvq_f32(ab_f32x4);
387
+ nk_f32_t a2 = vaddvq_f32(a2_f32x4);
388
+ nk_f32_t b2 = vaddvq_f32(b2_f32x4);
389
+ *result = nk_angular_normalize_f32_neon_(ab, a2, b2);
390
+ }
391
+
319
392
  NK_PUBLIC void nk_sqeuclidean_e2m3_neon(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
320
393
  float16x8_t a_f16x8, b_f16x8;
321
394
  float32x4_t sum_f32x4 = vdupq_n_f32(0);
@@ -334,9 +407,9 @@ nk_sqeuclidean_e2m3_neon_cycle:
334
407
  a += 8, b += 8, n -= 8;
335
408
  }
336
409
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
337
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
410
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
338
411
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
339
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
412
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
340
413
  float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
341
414
  float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
342
415
  sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
@@ -370,9 +443,9 @@ nk_angular_e2m3_neon_cycle:
370
443
  a += 8, b += 8, n -= 8;
371
444
  }
372
445
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
373
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
446
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
374
447
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
375
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
448
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
376
449
  ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
377
450
  ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
378
451
  a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
@@ -404,9 +477,9 @@ nk_sqeuclidean_e3m2_neon_cycle:
404
477
  a += 8, b += 8, n -= 8;
405
478
  }
406
479
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
407
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
480
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
408
481
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
409
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
482
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
410
483
  float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
411
484
  float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
412
485
  sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
@@ -440,9 +513,9 @@ nk_angular_e3m2_neon_cycle:
440
513
  a += 8, b += 8, n -= 8;
441
514
  }
442
515
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
443
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
516
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
444
517
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
445
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
518
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
446
519
  ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
447
520
  ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
448
521
  a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
@@ -474,9 +547,9 @@ nk_sqeuclidean_e4m3_neon_cycle:
474
547
  a += 8, b += 8, n -= 8;
475
548
  }
476
549
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
477
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
550
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
478
551
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
479
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
552
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
480
553
  float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
481
554
  float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
482
555
  sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
@@ -510,9 +583,9 @@ nk_angular_e4m3_neon_cycle:
510
583
  a += 8, b += 8, n -= 8;
511
584
  }
512
585
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
513
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
586
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
514
587
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
515
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
588
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
516
589
  ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
517
590
  ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
518
591
  a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
@@ -544,9 +617,9 @@ nk_sqeuclidean_e5m2_neon_cycle:
544
617
  a += 8, b += 8, n -= 8;
545
618
  }
546
619
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
547
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
620
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
548
621
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
549
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
622
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
550
623
  float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
551
624
  float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
552
625
  sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
@@ -580,9 +653,9 @@ nk_angular_e5m2_neon_cycle:
580
653
  a += 8, b += 8, n -= 8;
581
654
  }
582
655
  float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
583
- float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
656
+ float32x4_t a_high_f32x4 = vcvt_high_f32_f16(a_f16x8);
584
657
  float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
585
- float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
658
+ float32x4_t b_high_f32x4 = vcvt_high_f32_f16(b_f16x8);
586
659
  ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
587
660
  ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
588
661
  a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
@@ -767,7 +840,7 @@ NK_INTERNAL void nk_euclidean_through_u32_from_dot_neon_(nk_b128_vec_t dots, nk_
767
840
  } // extern "C"
768
841
  #endif
769
842
 
770
- #pragma endregion - Smaller Floats
843
+ #pragma endregion F16 and BF16 Floats
771
844
  #endif // NK_TARGET_NEON
772
845
  #endif // NK_TARGET_ARM_
773
846
  #endif // NK_SPATIAL_NEON_H
@@ -8,15 +8,14 @@
8
8
  *
9
9
  * @section spatial_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
10
10
  *
11
- * Intrinsic Instruction Latency Throughput
12
- * A76 M4+/V1+/Oryon
13
- * vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy 2/cy 4/cy
14
- * vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy 2/cy 4/cy
15
- * vld1q_bf16 LD1 (V.8H) 4cy 2/cy 3/cy
16
- * vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
17
- * vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy 2/cy 4/cy
18
- * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
19
- * vaddvq_f64 FADDP (V.2D) 3cy 1/cy 2/cy
11
+ * Intrinsic Instruction A76 M5
12
+ * vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy @ 2p 2cy @ 1p
13
+ * vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy @ 2p 3cy @ 4p
14
+ * vld1q_bf16 LD1 (V.8H) 4cy @ 2p 4cy @ 3p
15
+ * vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy @ 2p 2cy @ 4p
16
+ * vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy @ 2p 3cy @ 4p
17
+ * vaddvq_f32 FADDP+FADDP (V.4S) 5cy @ 1p 8cy @ 1p
18
+ * vaddvq_f64 FADDP (V.2D) 3cy @ 1p 3cy @ 2p
20
19
  *
21
20
  * The ARMv8.6-BF16 extension provides BFDOT for accelerated dot products on BF16 data, useful for
22
21
  * angular distance (cosine similarity) computations. BF16's larger exponent range (matching FP32)
@@ -0,0 +1,258 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for NEON FP8DOT4.
3
+ * @file include/numkong/spatial/neonfp8.h
4
+ * @author Ash Vardanian
5
+ * @date March 23, 2026
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * For L2 distance, we use the identity: (a−b)² = a² + b² − 2 × a × b,
10
+ * computing all three terms via FP8DOT4 without FP8 subtraction.
11
+ * Angular distance uses three DOT4 accumulators (a·b, ‖a‖², ‖b‖²) in parallel.
12
+ */
13
+ #ifndef NK_SPATIAL_NEONFP8_H
14
+ #define NK_SPATIAL_NEONFP8_H
15
+
16
+ #if NK_TARGET_ARM_
17
+ #if NK_TARGET_NEONFP8
18
+
19
+ #include "numkong/types.h"
20
+ #include "numkong/dot/neonfp8.h" // `nk_e2m3x16_to_e4m3x16_neonfp8_`, `nk_e3m2x16_to_e5m2x16_neonfp8_`
21
+ #include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`, `nk_angular_normalize_f32_neon_`
22
+
23
+ #if defined(__cplusplus)
24
+ extern "C" {
25
+ #endif
26
+
27
+ #if defined(__clang__)
28
+ #pragma clang attribute push(__attribute__((target("arch=armv8-a+simd+fp8dot4"))), apply_to = function)
29
+ #elif defined(__GNUC__)
30
+ #pragma GCC push_options
31
+ #pragma GCC target("arch=armv8-a+simd+fp8dot4")
32
+ #endif
33
+
34
+ NK_PUBLIC void nk_sqeuclidean_e4m3_neonfp8(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
35
+ mfloat8x16_t a_mf8x16, b_mf8x16;
36
+ float32x4_t a2_f32x4 = vdupq_n_f32(0), ab_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
37
+ nk_sqeuclidean_e4m3_neonfp8_cycle:
38
+ if (n < 16) {
39
+ nk_b128_vec_t a_vec, b_vec;
40
+ nk_partial_load_b8x16_serial_(a, &a_vec, n);
41
+ nk_partial_load_b8x16_serial_(b, &b_vec, n);
42
+ a_mf8x16 = vreinterpretq_mf8_u8(a_vec.u8x16);
43
+ b_mf8x16 = vreinterpretq_mf8_u8(b_vec.u8x16);
44
+ n = 0;
45
+ }
46
+ else {
47
+ a_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)a));
48
+ b_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)b));
49
+ a += 16, b += 16, n -= 16;
50
+ }
51
+ a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E4M3_);
52
+ ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
53
+ b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E4M3_);
54
+ if (n) goto nk_sqeuclidean_e4m3_neonfp8_cycle;
55
+ *result = vaddvq_f32(a2_f32x4) - 2 * vaddvq_f32(ab_f32x4) + vaddvq_f32(b2_f32x4);
56
+ }
57
+
58
+ NK_PUBLIC void nk_euclidean_e4m3_neonfp8(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
59
+ nk_sqeuclidean_e4m3_neonfp8(a, b, n, result);
60
+ *result = nk_f32_sqrt_neon(*result);
61
+ }
62
+
63
+ NK_PUBLIC void nk_angular_e4m3_neonfp8(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
64
+ mfloat8x16_t a_mf8x16, b_mf8x16;
65
+ float32x4_t ab_f32x4 = vdupq_n_f32(0), a2_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
66
+ nk_angular_e4m3_neonfp8_cycle:
67
+ if (n < 16) {
68
+ nk_b128_vec_t a_vec, b_vec;
69
+ nk_partial_load_b8x16_serial_(a, &a_vec, n);
70
+ nk_partial_load_b8x16_serial_(b, &b_vec, n);
71
+ a_mf8x16 = vreinterpretq_mf8_u8(a_vec.u8x16);
72
+ b_mf8x16 = vreinterpretq_mf8_u8(b_vec.u8x16);
73
+ n = 0;
74
+ }
75
+ else {
76
+ a_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)a));
77
+ b_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)b));
78
+ a += 16, b += 16, n -= 16;
79
+ }
80
+ ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
81
+ a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E4M3_);
82
+ b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E4M3_);
83
+ if (n) goto nk_angular_e4m3_neonfp8_cycle;
84
+ *result = nk_angular_normalize_f32_neon_(vaddvq_f32(ab_f32x4), vaddvq_f32(a2_f32x4), vaddvq_f32(b2_f32x4));
85
+ }
86
+
87
+ NK_PUBLIC void nk_sqeuclidean_e5m2_neonfp8(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
88
+ mfloat8x16_t a_mf8x16, b_mf8x16;
89
+ float32x4_t a2_f32x4 = vdupq_n_f32(0), ab_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
90
+ nk_sqeuclidean_e5m2_neonfp8_cycle:
91
+ if (n < 16) {
92
+ nk_b128_vec_t a_vec, b_vec;
93
+ nk_partial_load_b8x16_serial_(a, &a_vec, n);
94
+ nk_partial_load_b8x16_serial_(b, &b_vec, n);
95
+ a_mf8x16 = vreinterpretq_mf8_u8(a_vec.u8x16);
96
+ b_mf8x16 = vreinterpretq_mf8_u8(b_vec.u8x16);
97
+ n = 0;
98
+ }
99
+ else {
100
+ a_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)a));
101
+ b_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)b));
102
+ a += 16, b += 16, n -= 16;
103
+ }
104
+ a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E5M2_);
105
+ ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
106
+ b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E5M2_);
107
+ if (n) goto nk_sqeuclidean_e5m2_neonfp8_cycle;
108
+ *result = vaddvq_f32(a2_f32x4) - 2 * vaddvq_f32(ab_f32x4) + vaddvq_f32(b2_f32x4);
109
+ }
110
+
111
+ NK_PUBLIC void nk_euclidean_e5m2_neonfp8(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
112
+ nk_sqeuclidean_e5m2_neonfp8(a, b, n, result);
113
+ *result = nk_f32_sqrt_neon(*result);
114
+ }
115
+
116
+ NK_PUBLIC void nk_angular_e5m2_neonfp8(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
117
+ mfloat8x16_t a_mf8x16, b_mf8x16;
118
+ float32x4_t ab_f32x4 = vdupq_n_f32(0), a2_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
119
+ nk_angular_e5m2_neonfp8_cycle:
120
+ if (n < 16) {
121
+ nk_b128_vec_t a_vec, b_vec;
122
+ nk_partial_load_b8x16_serial_(a, &a_vec, n);
123
+ nk_partial_load_b8x16_serial_(b, &b_vec, n);
124
+ a_mf8x16 = vreinterpretq_mf8_u8(a_vec.u8x16);
125
+ b_mf8x16 = vreinterpretq_mf8_u8(b_vec.u8x16);
126
+ n = 0;
127
+ }
128
+ else {
129
+ a_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)a));
130
+ b_mf8x16 = vreinterpretq_mf8_u8(vld1q_u8((nk_u8_t const *)b));
131
+ a += 16, b += 16, n -= 16;
132
+ }
133
+ ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
134
+ a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E5M2_);
135
+ b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E5M2_);
136
+ if (n) goto nk_angular_e5m2_neonfp8_cycle;
137
+ *result = nk_angular_normalize_f32_neon_(vaddvq_f32(ab_f32x4), vaddvq_f32(a2_f32x4), vaddvq_f32(b2_f32x4));
138
+ }
139
+
140
+ NK_PUBLIC void nk_sqeuclidean_e2m3_neonfp8(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
141
+ mfloat8x16_t a_mf8x16, b_mf8x16;
142
+ float32x4_t a2_f32x4 = vdupq_n_f32(0), ab_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
143
+ nk_sqeuclidean_e2m3_neonfp8_cycle:
144
+ if (n < 16) {
145
+ nk_b128_vec_t a_vec, b_vec;
146
+ nk_partial_load_b8x16_serial_(a, &a_vec, n);
147
+ nk_partial_load_b8x16_serial_(b, &b_vec, n);
148
+ a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(a_vec.u8x16));
149
+ b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(b_vec.u8x16));
150
+ n = 0;
151
+ }
152
+ else {
153
+ a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(vld1q_u8((nk_u8_t const *)a)));
154
+ b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(vld1q_u8((nk_u8_t const *)b)));
155
+ a += 16, b += 16, n -= 16;
156
+ }
157
+ a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E4M3_);
158
+ ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
159
+ b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E4M3_);
160
+ if (n) goto nk_sqeuclidean_e2m3_neonfp8_cycle;
161
+ *result = vaddvq_f32(a2_f32x4) - 2 * vaddvq_f32(ab_f32x4) + vaddvq_f32(b2_f32x4);
162
+ }
163
+
164
+ NK_PUBLIC void nk_euclidean_e2m3_neonfp8(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
165
+ nk_sqeuclidean_e2m3_neonfp8(a, b, n, result);
166
+ *result = nk_f32_sqrt_neon(*result);
167
+ }
168
+
169
+ NK_PUBLIC void nk_angular_e2m3_neonfp8(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
170
+ mfloat8x16_t a_mf8x16, b_mf8x16;
171
+ float32x4_t ab_f32x4 = vdupq_n_f32(0), a2_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
172
+ nk_angular_e2m3_neonfp8_cycle:
173
+ if (n < 16) {
174
+ nk_b128_vec_t a_vec, b_vec;
175
+ nk_partial_load_b8x16_serial_(a, &a_vec, n);
176
+ nk_partial_load_b8x16_serial_(b, &b_vec, n);
177
+ a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(a_vec.u8x16));
178
+ b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(b_vec.u8x16));
179
+ n = 0;
180
+ }
181
+ else {
182
+ a_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(vld1q_u8((nk_u8_t const *)a)));
183
+ b_mf8x16 = vreinterpretq_mf8_u8(nk_e2m3x16_to_e4m3x16_neonfp8_(vld1q_u8((nk_u8_t const *)b)));
184
+ a += 16, b += 16, n -= 16;
185
+ }
186
+ ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E4M3_);
187
+ a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E4M3_);
188
+ b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E4M3_);
189
+ if (n) goto nk_angular_e2m3_neonfp8_cycle;
190
+ *result = nk_angular_normalize_f32_neon_(vaddvq_f32(ab_f32x4), vaddvq_f32(a2_f32x4), vaddvq_f32(b2_f32x4));
191
+ }
192
+
193
+ NK_PUBLIC void nk_sqeuclidean_e3m2_neonfp8(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
194
+ mfloat8x16_t a_mf8x16, b_mf8x16;
195
+ float32x4_t a2_f32x4 = vdupq_n_f32(0), ab_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
196
+ nk_sqeuclidean_e3m2_neonfp8_cycle:
197
+ if (n < 16) {
198
+ nk_b128_vec_t a_vec, b_vec;
199
+ nk_partial_load_b8x16_serial_(a, &a_vec, n);
200
+ nk_partial_load_b8x16_serial_(b, &b_vec, n);
201
+ a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(a_vec.u8x16));
202
+ b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(b_vec.u8x16));
203
+ n = 0;
204
+ }
205
+ else {
206
+ a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(vld1q_u8((nk_u8_t const *)a)));
207
+ b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(vld1q_u8((nk_u8_t const *)b)));
208
+ a += 16, b += 16, n -= 16;
209
+ }
210
+ a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E5M2_);
211
+ ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
212
+ b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E5M2_);
213
+ if (n) goto nk_sqeuclidean_e3m2_neonfp8_cycle;
214
+ *result = vaddvq_f32(a2_f32x4) - 2 * vaddvq_f32(ab_f32x4) + vaddvq_f32(b2_f32x4);
215
+ }
216
+
217
+ NK_PUBLIC void nk_euclidean_e3m2_neonfp8(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
218
+ nk_sqeuclidean_e3m2_neonfp8(a, b, n, result);
219
+ *result = nk_f32_sqrt_neon(*result);
220
+ }
221
+
222
+ NK_PUBLIC void nk_angular_e3m2_neonfp8(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
223
+ mfloat8x16_t a_mf8x16, b_mf8x16;
224
+ float32x4_t ab_f32x4 = vdupq_n_f32(0), a2_f32x4 = vdupq_n_f32(0), b2_f32x4 = vdupq_n_f32(0);
225
+ nk_angular_e3m2_neonfp8_cycle:
226
+ if (n < 16) {
227
+ nk_b128_vec_t a_vec, b_vec;
228
+ nk_partial_load_b8x16_serial_(a, &a_vec, n);
229
+ nk_partial_load_b8x16_serial_(b, &b_vec, n);
230
+ a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(a_vec.u8x16));
231
+ b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(b_vec.u8x16));
232
+ n = 0;
233
+ }
234
+ else {
235
+ a_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(vld1q_u8((nk_u8_t const *)a)));
236
+ b_mf8x16 = vreinterpretq_mf8_u8(nk_e3m2x16_to_e5m2x16_neonfp8_(vld1q_u8((nk_u8_t const *)b)));
237
+ a += 16, b += 16, n -= 16;
238
+ }
239
+ ab_f32x4 = vdotq_f32_mf8_fpm(ab_f32x4, a_mf8x16, b_mf8x16, NK_FPM_E5M2_);
240
+ a2_f32x4 = vdotq_f32_mf8_fpm(a2_f32x4, a_mf8x16, a_mf8x16, NK_FPM_E5M2_);
241
+ b2_f32x4 = vdotq_f32_mf8_fpm(b2_f32x4, b_mf8x16, b_mf8x16, NK_FPM_E5M2_);
242
+ if (n) goto nk_angular_e3m2_neonfp8_cycle;
243
+ *result = nk_angular_normalize_f32_neon_(vaddvq_f32(ab_f32x4), vaddvq_f32(a2_f32x4), vaddvq_f32(b2_f32x4));
244
+ }
245
+
246
+ #if defined(__clang__)
247
+ #pragma clang attribute pop
248
+ #elif defined(__GNUC__)
249
+ #pragma GCC pop_options
250
+ #endif
251
+
252
+ #if defined(__cplusplus)
253
+ } // extern "C"
254
+ #endif
255
+
256
+ #endif // NK_TARGET_NEONFP8
257
+ #endif // NK_TARGET_ARM_
258
+ #endif // NK_SPATIAL_NEONFP8_H