numkong 7.0.0 → 7.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (315) hide show
  1. package/README.md +197 -124
  2. package/binding.gyp +34 -484
  3. package/c/dispatch_bf16.c +59 -1
  4. package/c/dispatch_e2m3.c +41 -8
  5. package/c/dispatch_e3m2.c +49 -8
  6. package/c/dispatch_e4m3.c +51 -9
  7. package/c/dispatch_e5m2.c +45 -1
  8. package/c/dispatch_f16.c +79 -26
  9. package/c/dispatch_f16c.c +5 -5
  10. package/c/dispatch_f32.c +56 -0
  11. package/c/dispatch_f64.c +52 -0
  12. package/c/dispatch_i4.c +3 -0
  13. package/c/dispatch_i8.c +62 -3
  14. package/c/dispatch_other.c +18 -0
  15. package/c/dispatch_u1.c +54 -9
  16. package/c/dispatch_u4.c +3 -0
  17. package/c/dispatch_u8.c +64 -3
  18. package/c/numkong.c +3 -0
  19. package/include/README.md +79 -9
  20. package/include/numkong/attention/sapphireamx.h +278 -276
  21. package/include/numkong/attention/sme.h +983 -977
  22. package/include/numkong/attention.h +1 -1
  23. package/include/numkong/capabilities.h +289 -94
  24. package/include/numkong/cast/README.md +40 -40
  25. package/include/numkong/cast/diamond.h +64 -0
  26. package/include/numkong/cast/haswell.h +42 -194
  27. package/include/numkong/cast/icelake.h +42 -37
  28. package/include/numkong/cast/loongsonasx.h +252 -0
  29. package/include/numkong/cast/neon.h +216 -249
  30. package/include/numkong/cast/powervsx.h +449 -0
  31. package/include/numkong/cast/rvv.h +223 -274
  32. package/include/numkong/cast/sapphire.h +18 -18
  33. package/include/numkong/cast/serial.h +1018 -944
  34. package/include/numkong/cast/skylake.h +82 -23
  35. package/include/numkong/cast/v128relaxed.h +462 -105
  36. package/include/numkong/cast.h +24 -0
  37. package/include/numkong/cast.hpp +44 -0
  38. package/include/numkong/curved/README.md +17 -17
  39. package/include/numkong/curved/neon.h +131 -7
  40. package/include/numkong/curved/neonbfdot.h +6 -7
  41. package/include/numkong/curved/rvv.h +26 -26
  42. package/include/numkong/curved/smef64.h +186 -182
  43. package/include/numkong/curved.h +14 -18
  44. package/include/numkong/dot/README.md +154 -137
  45. package/include/numkong/dot/alder.h +43 -43
  46. package/include/numkong/dot/diamond.h +158 -0
  47. package/include/numkong/dot/genoa.h +4 -30
  48. package/include/numkong/dot/haswell.h +215 -180
  49. package/include/numkong/dot/icelake.h +190 -76
  50. package/include/numkong/dot/loongsonasx.h +671 -0
  51. package/include/numkong/dot/neon.h +124 -73
  52. package/include/numkong/dot/neonbfdot.h +11 -12
  53. package/include/numkong/dot/neonfhm.h +44 -46
  54. package/include/numkong/dot/neonfp8.h +323 -0
  55. package/include/numkong/dot/neonsdot.h +190 -76
  56. package/include/numkong/dot/powervsx.h +752 -0
  57. package/include/numkong/dot/rvv.h +92 -84
  58. package/include/numkong/dot/rvvbf16.h +12 -12
  59. package/include/numkong/dot/rvvhalf.h +12 -12
  60. package/include/numkong/dot/sapphire.h +4 -4
  61. package/include/numkong/dot/serial.h +66 -30
  62. package/include/numkong/dot/sierra.h +31 -31
  63. package/include/numkong/dot/skylake.h +142 -110
  64. package/include/numkong/dot/sve.h +217 -177
  65. package/include/numkong/dot/svebfdot.h +10 -10
  66. package/include/numkong/dot/svehalf.h +85 -41
  67. package/include/numkong/dot/svesdot.h +89 -0
  68. package/include/numkong/dot/v128relaxed.h +124 -89
  69. package/include/numkong/dot.h +114 -48
  70. package/include/numkong/dots/README.md +203 -203
  71. package/include/numkong/dots/alder.h +12 -9
  72. package/include/numkong/dots/diamond.h +86 -0
  73. package/include/numkong/dots/genoa.h +10 -4
  74. package/include/numkong/dots/haswell.h +63 -48
  75. package/include/numkong/dots/icelake.h +27 -18
  76. package/include/numkong/dots/loongsonasx.h +176 -0
  77. package/include/numkong/dots/neon.h +14 -11
  78. package/include/numkong/dots/neonbfdot.h +4 -3
  79. package/include/numkong/dots/neonfhm.h +11 -9
  80. package/include/numkong/dots/neonfp8.h +99 -0
  81. package/include/numkong/dots/neonsdot.h +48 -12
  82. package/include/numkong/dots/powervsx.h +194 -0
  83. package/include/numkong/dots/rvv.h +451 -344
  84. package/include/numkong/dots/sapphireamx.h +1028 -984
  85. package/include/numkong/dots/serial.h +213 -197
  86. package/include/numkong/dots/sierra.h +10 -7
  87. package/include/numkong/dots/skylake.h +47 -36
  88. package/include/numkong/dots/sme.h +2001 -2364
  89. package/include/numkong/dots/smebi32.h +175 -162
  90. package/include/numkong/dots/smef64.h +328 -323
  91. package/include/numkong/dots/v128relaxed.h +64 -41
  92. package/include/numkong/dots.h +573 -293
  93. package/include/numkong/dots.hpp +45 -43
  94. package/include/numkong/each/README.md +133 -137
  95. package/include/numkong/each/haswell.h +6 -6
  96. package/include/numkong/each/icelake.h +7 -7
  97. package/include/numkong/each/neon.h +76 -42
  98. package/include/numkong/each/neonbfdot.h +11 -12
  99. package/include/numkong/each/neonhalf.h +24 -116
  100. package/include/numkong/each/rvv.h +28 -28
  101. package/include/numkong/each/sapphire.h +27 -161
  102. package/include/numkong/each/serial.h +6 -6
  103. package/include/numkong/each/skylake.h +7 -7
  104. package/include/numkong/each/v128relaxed.h +562 -0
  105. package/include/numkong/each.h +148 -62
  106. package/include/numkong/each.hpp +2 -2
  107. package/include/numkong/geospatial/README.md +18 -18
  108. package/include/numkong/geospatial/haswell.h +365 -325
  109. package/include/numkong/geospatial/neon.h +350 -306
  110. package/include/numkong/geospatial/rvv.h +4 -4
  111. package/include/numkong/geospatial/skylake.h +376 -340
  112. package/include/numkong/geospatial/v128relaxed.h +366 -327
  113. package/include/numkong/geospatial.h +17 -17
  114. package/include/numkong/matrix.hpp +4 -4
  115. package/include/numkong/maxsim/README.md +14 -14
  116. package/include/numkong/maxsim/alder.h +6 -6
  117. package/include/numkong/maxsim/genoa.h +4 -4
  118. package/include/numkong/maxsim/haswell.h +6 -6
  119. package/include/numkong/maxsim/icelake.h +18 -18
  120. package/include/numkong/maxsim/neonsdot.h +21 -21
  121. package/include/numkong/maxsim/sapphireamx.h +14 -14
  122. package/include/numkong/maxsim/serial.h +6 -6
  123. package/include/numkong/maxsim/sme.h +221 -196
  124. package/include/numkong/maxsim/v128relaxed.h +6 -6
  125. package/include/numkong/mesh/README.md +62 -56
  126. package/include/numkong/mesh/haswell.h +339 -464
  127. package/include/numkong/mesh/neon.h +1100 -519
  128. package/include/numkong/mesh/neonbfdot.h +36 -68
  129. package/include/numkong/mesh/rvv.h +530 -435
  130. package/include/numkong/mesh/serial.h +75 -91
  131. package/include/numkong/mesh/skylake.h +1627 -302
  132. package/include/numkong/mesh/v128relaxed.h +443 -330
  133. package/include/numkong/mesh.h +63 -49
  134. package/include/numkong/mesh.hpp +4 -4
  135. package/include/numkong/numkong.h +3 -3
  136. package/include/numkong/numkong.hpp +1 -0
  137. package/include/numkong/probability/README.md +23 -19
  138. package/include/numkong/probability/neon.h +82 -52
  139. package/include/numkong/probability/rvv.h +28 -23
  140. package/include/numkong/probability/serial.h +51 -39
  141. package/include/numkong/probability.h +20 -23
  142. package/include/numkong/random.h +1 -1
  143. package/include/numkong/reduce/README.md +143 -138
  144. package/include/numkong/reduce/alder.h +81 -77
  145. package/include/numkong/reduce/haswell.h +222 -220
  146. package/include/numkong/reduce/neon.h +629 -519
  147. package/include/numkong/reduce/neonbfdot.h +7 -218
  148. package/include/numkong/reduce/neonfhm.h +9 -381
  149. package/include/numkong/reduce/neonsdot.h +9 -9
  150. package/include/numkong/reduce/rvv.h +928 -802
  151. package/include/numkong/reduce/serial.h +23 -27
  152. package/include/numkong/reduce/sierra.h +20 -20
  153. package/include/numkong/reduce/skylake.h +326 -324
  154. package/include/numkong/reduce/v128relaxed.h +52 -52
  155. package/include/numkong/reduce.h +4 -23
  156. package/include/numkong/reduce.hpp +156 -11
  157. package/include/numkong/scalar/README.md +6 -6
  158. package/include/numkong/scalar/haswell.h +26 -17
  159. package/include/numkong/scalar/loongsonasx.h +74 -0
  160. package/include/numkong/scalar/neon.h +9 -9
  161. package/include/numkong/scalar/powervsx.h +96 -0
  162. package/include/numkong/scalar/rvv.h +2 -2
  163. package/include/numkong/scalar/sapphire.h +21 -10
  164. package/include/numkong/scalar/serial.h +21 -21
  165. package/include/numkong/scalar.h +13 -0
  166. package/include/numkong/set/README.md +28 -28
  167. package/include/numkong/set/haswell.h +12 -12
  168. package/include/numkong/set/icelake.h +14 -14
  169. package/include/numkong/set/loongsonasx.h +181 -0
  170. package/include/numkong/set/neon.h +17 -18
  171. package/include/numkong/set/powervsx.h +326 -0
  172. package/include/numkong/set/rvv.h +4 -4
  173. package/include/numkong/set/serial.h +6 -6
  174. package/include/numkong/set/sve.h +60 -59
  175. package/include/numkong/set/v128relaxed.h +6 -6
  176. package/include/numkong/set.h +21 -7
  177. package/include/numkong/sets/README.md +26 -26
  178. package/include/numkong/sets/loongsonasx.h +52 -0
  179. package/include/numkong/sets/powervsx.h +65 -0
  180. package/include/numkong/sets/smebi32.h +395 -364
  181. package/include/numkong/sets.h +83 -40
  182. package/include/numkong/sparse/README.md +4 -4
  183. package/include/numkong/sparse/icelake.h +101 -101
  184. package/include/numkong/sparse/serial.h +1 -1
  185. package/include/numkong/sparse/sve2.h +137 -141
  186. package/include/numkong/sparse/turin.h +12 -12
  187. package/include/numkong/sparse.h +10 -10
  188. package/include/numkong/spatial/README.md +230 -226
  189. package/include/numkong/spatial/alder.h +113 -116
  190. package/include/numkong/spatial/diamond.h +240 -0
  191. package/include/numkong/spatial/genoa.h +0 -68
  192. package/include/numkong/spatial/haswell.h +74 -55
  193. package/include/numkong/spatial/icelake.h +539 -58
  194. package/include/numkong/spatial/loongsonasx.h +483 -0
  195. package/include/numkong/spatial/neon.h +125 -52
  196. package/include/numkong/spatial/neonbfdot.h +8 -9
  197. package/include/numkong/spatial/neonfp8.h +258 -0
  198. package/include/numkong/spatial/neonsdot.h +180 -12
  199. package/include/numkong/spatial/powervsx.h +738 -0
  200. package/include/numkong/spatial/rvv.h +146 -139
  201. package/include/numkong/spatial/rvvbf16.h +17 -12
  202. package/include/numkong/spatial/rvvhalf.h +13 -10
  203. package/include/numkong/spatial/serial.h +13 -12
  204. package/include/numkong/spatial/sierra.h +232 -39
  205. package/include/numkong/spatial/skylake.h +73 -74
  206. package/include/numkong/spatial/sve.h +93 -72
  207. package/include/numkong/spatial/svebfdot.h +29 -29
  208. package/include/numkong/spatial/svehalf.h +52 -26
  209. package/include/numkong/spatial/svesdot.h +142 -0
  210. package/include/numkong/spatial/v128relaxed.h +293 -41
  211. package/include/numkong/spatial.h +338 -82
  212. package/include/numkong/spatials/README.md +194 -194
  213. package/include/numkong/spatials/diamond.h +82 -0
  214. package/include/numkong/spatials/haswell.h +2 -2
  215. package/include/numkong/spatials/loongsonasx.h +153 -0
  216. package/include/numkong/spatials/neonfp8.h +111 -0
  217. package/include/numkong/spatials/neonsdot.h +34 -0
  218. package/include/numkong/spatials/powervsx.h +153 -0
  219. package/include/numkong/spatials/rvv.h +259 -243
  220. package/include/numkong/spatials/sapphireamx.h +173 -173
  221. package/include/numkong/spatials/serial.h +2 -2
  222. package/include/numkong/spatials/skylake.h +2 -2
  223. package/include/numkong/spatials/sme.h +590 -605
  224. package/include/numkong/spatials/smef64.h +139 -130
  225. package/include/numkong/spatials/v128relaxed.h +2 -2
  226. package/include/numkong/spatials.h +820 -500
  227. package/include/numkong/spatials.hpp +49 -48
  228. package/include/numkong/tensor.hpp +406 -17
  229. package/include/numkong/trigonometry/README.md +19 -19
  230. package/include/numkong/trigonometry/haswell.h +402 -401
  231. package/include/numkong/trigonometry/neon.h +386 -387
  232. package/include/numkong/trigonometry/rvv.h +52 -51
  233. package/include/numkong/trigonometry/serial.h +13 -13
  234. package/include/numkong/trigonometry/skylake.h +373 -369
  235. package/include/numkong/trigonometry/v128relaxed.h +375 -374
  236. package/include/numkong/trigonometry.h +13 -13
  237. package/include/numkong/trigonometry.hpp +2 -2
  238. package/include/numkong/types.h +287 -49
  239. package/include/numkong/types.hpp +436 -12
  240. package/include/numkong/vector.hpp +82 -14
  241. package/javascript/dist/cjs/numkong-wasm.js +6 -12
  242. package/javascript/dist/cjs/numkong.d.ts +7 -1
  243. package/javascript/dist/cjs/numkong.js +37 -11
  244. package/javascript/dist/cjs/types.d.ts +9 -0
  245. package/javascript/dist/cjs/types.js +96 -0
  246. package/javascript/dist/esm/numkong-browser.d.ts +14 -0
  247. package/javascript/dist/esm/numkong-browser.js +23 -0
  248. package/javascript/dist/esm/numkong-wasm.js +6 -12
  249. package/javascript/dist/esm/numkong.d.ts +7 -1
  250. package/javascript/dist/esm/numkong.js +37 -11
  251. package/javascript/dist/esm/types.d.ts +9 -0
  252. package/javascript/dist/esm/types.js +96 -0
  253. package/javascript/node-gyp-build.d.ts +4 -1
  254. package/javascript/numkong-browser.ts +40 -0
  255. package/javascript/numkong-wasm.ts +7 -13
  256. package/javascript/numkong.c +5 -26
  257. package/javascript/numkong.ts +36 -11
  258. package/javascript/tsconfig-base.json +1 -0
  259. package/javascript/tsconfig-cjs.json +6 -1
  260. package/javascript/types.ts +110 -0
  261. package/numkong.gypi +101 -0
  262. package/package.json +34 -13
  263. package/probes/arm_neon.c +8 -0
  264. package/probes/arm_neon_bfdot.c +9 -0
  265. package/probes/arm_neon_fhm.c +9 -0
  266. package/probes/arm_neon_half.c +8 -0
  267. package/probes/arm_neon_sdot.c +9 -0
  268. package/probes/arm_neonfp8.c +9 -0
  269. package/probes/arm_sme.c +16 -0
  270. package/probes/arm_sme2.c +16 -0
  271. package/probes/arm_sme2p1.c +16 -0
  272. package/probes/arm_sme_bf16.c +16 -0
  273. package/probes/arm_sme_bi32.c +16 -0
  274. package/probes/arm_sme_f64.c +16 -0
  275. package/probes/arm_sme_fa64.c +14 -0
  276. package/probes/arm_sme_half.c +16 -0
  277. package/probes/arm_sme_lut2.c +15 -0
  278. package/probes/arm_sve.c +18 -0
  279. package/probes/arm_sve2.c +20 -0
  280. package/probes/arm_sve2p1.c +18 -0
  281. package/probes/arm_sve_bfdot.c +20 -0
  282. package/probes/arm_sve_half.c +18 -0
  283. package/probes/arm_sve_sdot.c +21 -0
  284. package/probes/loongarch_lasx.c +12 -0
  285. package/probes/power_vsx.c +12 -0
  286. package/probes/probe.js +127 -0
  287. package/probes/riscv_rvv.c +14 -0
  288. package/probes/riscv_rvv_bb.c +15 -0
  289. package/probes/riscv_rvv_bf16.c +17 -0
  290. package/probes/riscv_rvv_half.c +14 -0
  291. package/probes/wasm_v128relaxed.c +11 -0
  292. package/probes/x86_alder.c +17 -0
  293. package/probes/x86_diamond.c +17 -0
  294. package/probes/x86_genoa.c +17 -0
  295. package/probes/x86_graniteamx.c +19 -0
  296. package/probes/x86_haswell.c +11 -0
  297. package/probes/x86_icelake.c +17 -0
  298. package/probes/x86_sapphire.c +16 -0
  299. package/probes/x86_sapphireamx.c +18 -0
  300. package/probes/x86_sierra.c +17 -0
  301. package/probes/x86_skylake.c +15 -0
  302. package/probes/x86_turin.c +17 -0
  303. package/wasm/numkong-emscripten.js +2 -0
  304. package/wasm/numkong.d.ts +14 -0
  305. package/wasm/numkong.js +1124 -0
  306. package/wasm/numkong.wasm +0 -0
  307. package/include/numkong/curved/neonhalf.h +0 -212
  308. package/include/numkong/dot/neonhalf.h +0 -198
  309. package/include/numkong/dots/neonhalf.h +0 -57
  310. package/include/numkong/mesh/neonhalf.h +0 -616
  311. package/include/numkong/reduce/neonhalf.h +0 -157
  312. package/include/numkong/spatial/neonhalf.h +0 -118
  313. package/include/numkong/spatial/sapphire.h +0 -343
  314. package/include/numkong/spatials/neonhalf.h +0 -58
  315. package/javascript/README.md +0 -246
@@ -8,16 +8,15 @@
8
8
  *
9
9
  * @section spatial_neonsdot_instructions ARM NEON SDOT/UDOT Instructions (ARMv8.4-DotProd)
10
10
  *
11
- * Intrinsic Instruction Latency Throughput
12
- * A76 M4+/V1+/Oryon
13
- * vdotq_s32 SDOT (V.4S, V.16B, V.16B) 3cy 2/cy 4/cy
14
- * vdotq_u32 UDOT (V.4S, V.16B, V.16B) 3cy 2/cy 4/cy
15
- * vabdq_s8 SABD (V.16B, V.16B, V.16B) 2cy 2/cy 4/cy
16
- * vabdq_u8 UABD (V.16B, V.16B, V.16B) 2cy 2/cy 4/cy
17
- * vld1q_s8 LD1 (V.16B) 4cy 2/cy 3/cy
18
- * vld1q_u8 LD1 (V.16B) 4cy 2/cy 3/cy
19
- * vaddvq_s32 ADDV (V.4S) 4cy 1/cy 2/cy
20
- * vaddvq_u32 ADDV (V.4S) 4cy 1/cy 2/cy
11
+ * Intrinsic Instruction A76 M5
12
+ * vdotq_s32 SDOT (V.4S, V.16B, V.16B) 3cy @ 2p 3cy @ 4p
13
+ * vdotq_u32 UDOT (V.4S, V.16B, V.16B) 3cy @ 2p 3cy @ 4p
14
+ * vabdq_s8 SABD (V.16B, V.16B, V.16B) 3cy @ 2p 3cy @ 2p
15
+ * vabdq_u8 UABD (V.16B, V.16B, V.16B) 3cy @ 2p 3cy @ 2p
16
+ * vld1q_s8 LD1 (V.16B) 4cy @ 2p 4cy @ 3p
17
+ * vld1q_u8 LD1 (V.16B) 4cy @ 2p 4cy @ 3p
18
+ * vaddvq_s32 ADDV (V.4S) 4cy @ 1p 5cy @ 1p
19
+ * vaddvq_u32 ADDV (V.4S) 4cy @ 1p 5cy @ 1p
21
20
  *
22
21
  * The ARMv8.4-DotProd extension provides SDOT/UDOT for int8 dot products and SABD/UABD for
23
22
  * absolute differences, enabling L2 and angular distance on quantized embeddings.
@@ -34,6 +33,7 @@
34
33
  #if NK_TARGET_NEONSDOT
35
34
 
36
35
  #include "numkong/types.h"
36
+ #include "numkong/cast/serial.h" // `nk_partial_load_b4x32_serial_`
37
37
  #include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
38
38
 
39
39
  #if defined(__cplusplus)
@@ -195,7 +195,8 @@ NK_PUBLIC void nk_angular_i8_neonsdot(nk_i8_t const *a, nk_i8_t const *b, nk_siz
195
195
  b_norm_sq_i32 += b_element_i32 * b_element_i32;
196
196
  }
197
197
 
198
- *result = nk_angular_normalize_f32_neon_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
198
+ *result = nk_angular_normalize_f32_neon_((nk_f32_t)dot_product_i32, (nk_f32_t)a_norm_sq_i32,
199
+ (nk_f32_t)b_norm_sq_i32);
199
200
  }
200
201
 
201
202
  NK_PUBLIC void nk_sqeuclidean_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
@@ -243,7 +244,174 @@ NK_PUBLIC void nk_angular_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_siz
243
244
  ab += ai * bi, a2 += ai * ai, b2 += bi * bi;
244
245
  }
245
246
 
246
- *result = nk_angular_normalize_f32_neon_(ab, a2, b2);
247
+ *result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
248
+ }
249
+
250
+ NK_PUBLIC void nk_sqeuclidean_i4_neonsdot(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_u32_t *result) {
251
+ n = nk_size_round_up_to_multiple_(n, 2);
252
+ nk_size_t n_bytes = n / 2;
253
+ uint32x4_t d2_u32x4 = vdupq_n_u32(0);
254
+ uint8x16_t a_u8x16, b_u8x16;
255
+
256
+ nk_sqeuclidean_i4_neonsdot_cycle:
257
+ if (n_bytes < 16) {
258
+ nk_b128_vec_t a_vec, b_vec;
259
+ nk_partial_load_b4x32_serial_(a, &a_vec, n_bytes * 2);
260
+ nk_partial_load_b4x32_serial_(b, &b_vec, n_bytes * 2);
261
+ a_u8x16 = a_vec.u8x16;
262
+ b_u8x16 = b_vec.u8x16;
263
+ n_bytes = 0;
264
+ }
265
+ else {
266
+ a_u8x16 = vld1q_u8((nk_u8_t const *)a);
267
+ b_u8x16 = vld1q_u8((nk_u8_t const *)b);
268
+ a += 16, b += 16, n_bytes -= 16;
269
+ }
270
+
271
+ // Sign-extend low nibbles, compute |a-b|, reinterpret as unsigned for UDOT squaring
272
+ int8x16_t a_low_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(a_u8x16), 4), 4);
273
+ int8x16_t b_low_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(b_u8x16), 4), 4);
274
+ int8x16_t a_high_i8x16 = vshrq_n_s8(vreinterpretq_s8_u8(a_u8x16), 4);
275
+ int8x16_t b_high_i8x16 = vshrq_n_s8(vreinterpretq_s8_u8(b_u8x16), 4);
276
+
277
+ uint8x16_t diff_low_u8x16 = vreinterpretq_u8_s8(vabdq_s8(a_low_i8x16, b_low_i8x16));
278
+ uint8x16_t diff_high_u8x16 = vreinterpretq_u8_s8(vabdq_s8(a_high_i8x16, b_high_i8x16));
279
+ d2_u32x4 = vdotq_u32(d2_u32x4, diff_low_u8x16, diff_low_u8x16);
280
+ d2_u32x4 = vdotq_u32(d2_u32x4, diff_high_u8x16, diff_high_u8x16);
281
+
282
+ if (n_bytes) goto nk_sqeuclidean_i4_neonsdot_cycle;
283
+ *result = vaddvq_u32(d2_u32x4);
284
+ }
285
+
286
+ NK_PUBLIC void nk_euclidean_i4_neonsdot(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result) {
287
+ nk_u32_t d2;
288
+ nk_sqeuclidean_i4_neonsdot(a, b, n, &d2);
289
+ *result = nk_f32_sqrt_neon((nk_f32_t)d2);
290
+ }
291
+
292
+ NK_PUBLIC void nk_angular_i4_neonsdot(nk_i4x2_t const *a, nk_i4x2_t const *b, nk_size_t n, nk_f32_t *result) {
293
+ n = nk_size_round_up_to_multiple_(n, 2);
294
+ nk_size_t n_bytes = n / 2;
295
+ int32x4_t ab_i32x4 = vdupq_n_s32(0);
296
+ int32x4_t a2_i32x4 = vdupq_n_s32(0);
297
+ int32x4_t b2_i32x4 = vdupq_n_s32(0);
298
+ uint8x16_t a_u8x16, b_u8x16;
299
+
300
+ nk_angular_i4_neonsdot_cycle:
301
+ if (n_bytes < 16) {
302
+ nk_b128_vec_t a_vec, b_vec;
303
+ nk_partial_load_b4x32_serial_(a, &a_vec, n_bytes * 2);
304
+ nk_partial_load_b4x32_serial_(b, &b_vec, n_bytes * 2);
305
+ a_u8x16 = a_vec.u8x16;
306
+ b_u8x16 = b_vec.u8x16;
307
+ n_bytes = 0;
308
+ }
309
+ else {
310
+ a_u8x16 = vld1q_u8((nk_u8_t const *)a);
311
+ b_u8x16 = vld1q_u8((nk_u8_t const *)b);
312
+ a += 16, b += 16, n_bytes -= 16;
313
+ }
314
+
315
+ int8x16_t a_low_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(a_u8x16), 4), 4);
316
+ int8x16_t b_low_i8x16 = vshrq_n_s8(vshlq_n_s8(vreinterpretq_s8_u8(b_u8x16), 4), 4);
317
+ int8x16_t a_high_i8x16 = vshrq_n_s8(vreinterpretq_s8_u8(a_u8x16), 4);
318
+ int8x16_t b_high_i8x16 = vshrq_n_s8(vreinterpretq_s8_u8(b_u8x16), 4);
319
+
320
+ ab_i32x4 = vdotq_s32(ab_i32x4, a_low_i8x16, b_low_i8x16);
321
+ ab_i32x4 = vdotq_s32(ab_i32x4, a_high_i8x16, b_high_i8x16);
322
+ a2_i32x4 = vdotq_s32(a2_i32x4, a_low_i8x16, a_low_i8x16);
323
+ a2_i32x4 = vdotq_s32(a2_i32x4, a_high_i8x16, a_high_i8x16);
324
+ b2_i32x4 = vdotq_s32(b2_i32x4, b_low_i8x16, b_low_i8x16);
325
+ b2_i32x4 = vdotq_s32(b2_i32x4, b_high_i8x16, b_high_i8x16);
326
+
327
+ if (n_bytes) goto nk_angular_i4_neonsdot_cycle;
328
+
329
+ *result = nk_angular_normalize_f32_neon_((nk_f32_t)vaddvq_s32(ab_i32x4), (nk_f32_t)vaddvq_s32(a2_i32x4),
330
+ (nk_f32_t)vaddvq_s32(b2_i32x4));
331
+ }
332
+
333
+ NK_PUBLIC void nk_sqeuclidean_u4_neonsdot(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_u32_t *result) {
334
+ n = nk_size_round_up_to_multiple_(n, 2);
335
+ nk_size_t n_bytes = n / 2;
336
+ uint8x16_t const nibble_mask_u8x16 = vdupq_n_u8(0x0F);
337
+ uint32x4_t d2_u32x4 = vdupq_n_u32(0);
338
+ uint8x16_t a_u8x16, b_u8x16;
339
+
340
+ nk_sqeuclidean_u4_neonsdot_cycle:
341
+ if (n_bytes < 16) {
342
+ nk_b128_vec_t a_vec, b_vec;
343
+ nk_partial_load_b4x32_serial_(a, &a_vec, n_bytes * 2);
344
+ nk_partial_load_b4x32_serial_(b, &b_vec, n_bytes * 2);
345
+ a_u8x16 = a_vec.u8x16;
346
+ b_u8x16 = b_vec.u8x16;
347
+ n_bytes = 0;
348
+ }
349
+ else {
350
+ a_u8x16 = vld1q_u8((nk_u8_t const *)a);
351
+ b_u8x16 = vld1q_u8((nk_u8_t const *)b);
352
+ a += 16, b += 16, n_bytes -= 16;
353
+ }
354
+
355
+ uint8x16_t a_low_u8x16 = vandq_u8(a_u8x16, nibble_mask_u8x16);
356
+ uint8x16_t a_high_u8x16 = vshrq_n_u8(a_u8x16, 4);
357
+ uint8x16_t b_low_u8x16 = vandq_u8(b_u8x16, nibble_mask_u8x16);
358
+ uint8x16_t b_high_u8x16 = vshrq_n_u8(b_u8x16, 4);
359
+
360
+ uint8x16_t diff_low_u8x16 = vabdq_u8(a_low_u8x16, b_low_u8x16);
361
+ uint8x16_t diff_high_u8x16 = vabdq_u8(a_high_u8x16, b_high_u8x16);
362
+ d2_u32x4 = vdotq_u32(d2_u32x4, diff_low_u8x16, diff_low_u8x16);
363
+ d2_u32x4 = vdotq_u32(d2_u32x4, diff_high_u8x16, diff_high_u8x16);
364
+
365
+ if (n_bytes) goto nk_sqeuclidean_u4_neonsdot_cycle;
366
+ *result = vaddvq_u32(d2_u32x4);
367
+ }
368
+
369
+ NK_PUBLIC void nk_euclidean_u4_neonsdot(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result) {
370
+ nk_u32_t d2;
371
+ nk_sqeuclidean_u4_neonsdot(a, b, n, &d2);
372
+ *result = nk_f32_sqrt_neon((nk_f32_t)d2);
373
+ }
374
+
375
+ NK_PUBLIC void nk_angular_u4_neonsdot(nk_u4x2_t const *a, nk_u4x2_t const *b, nk_size_t n, nk_f32_t *result) {
376
+ n = nk_size_round_up_to_multiple_(n, 2);
377
+ nk_size_t n_bytes = n / 2;
378
+ uint8x16_t const nibble_mask_u8x16 = vdupq_n_u8(0x0F);
379
+ uint32x4_t ab_u32x4 = vdupq_n_u32(0);
380
+ uint32x4_t a2_u32x4 = vdupq_n_u32(0);
381
+ uint32x4_t b2_u32x4 = vdupq_n_u32(0);
382
+ uint8x16_t a_u8x16, b_u8x16;
383
+
384
+ nk_angular_u4_neonsdot_cycle:
385
+ if (n_bytes < 16) {
386
+ nk_b128_vec_t a_vec, b_vec;
387
+ nk_partial_load_b4x32_serial_(a, &a_vec, n_bytes * 2);
388
+ nk_partial_load_b4x32_serial_(b, &b_vec, n_bytes * 2);
389
+ a_u8x16 = a_vec.u8x16;
390
+ b_u8x16 = b_vec.u8x16;
391
+ n_bytes = 0;
392
+ }
393
+ else {
394
+ a_u8x16 = vld1q_u8((nk_u8_t const *)a);
395
+ b_u8x16 = vld1q_u8((nk_u8_t const *)b);
396
+ a += 16, b += 16, n_bytes -= 16;
397
+ }
398
+
399
+ uint8x16_t a_low_u8x16 = vandq_u8(a_u8x16, nibble_mask_u8x16);
400
+ uint8x16_t a_high_u8x16 = vshrq_n_u8(a_u8x16, 4);
401
+ uint8x16_t b_low_u8x16 = vandq_u8(b_u8x16, nibble_mask_u8x16);
402
+ uint8x16_t b_high_u8x16 = vshrq_n_u8(b_u8x16, 4);
403
+
404
+ ab_u32x4 = vdotq_u32(ab_u32x4, a_low_u8x16, b_low_u8x16);
405
+ ab_u32x4 = vdotq_u32(ab_u32x4, a_high_u8x16, b_high_u8x16);
406
+ a2_u32x4 = vdotq_u32(a2_u32x4, a_low_u8x16, a_low_u8x16);
407
+ a2_u32x4 = vdotq_u32(a2_u32x4, a_high_u8x16, a_high_u8x16);
408
+ b2_u32x4 = vdotq_u32(b2_u32x4, b_low_u8x16, b_low_u8x16);
409
+ b2_u32x4 = vdotq_u32(b2_u32x4, b_high_u8x16, b_high_u8x16);
410
+
411
+ if (n_bytes) goto nk_angular_u4_neonsdot_cycle;
412
+
413
+ *result = nk_angular_normalize_f32_neon_((nk_f32_t)vaddvq_u32(ab_u32x4), (nk_f32_t)vaddvq_u32(a2_u32x4),
414
+ (nk_f32_t)vaddvq_u32(b2_u32x4));
247
415
  }
248
416
 
249
417
  #if defined(__clang__)