numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,165 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for NEON BF16.
3
+ * @file include/numkong/spatial/neonbfdot.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * @section spatial_neonbfdot_instructions ARM NEON BF16 Instructions (ARMv8.6-BF16)
10
+ *
11
+ * Intrinsic Instruction Latency Throughput
12
+ * A76 M4+/V1+/Oryon
13
+ * vbfdotq_f32 BFDOT (V.4S, V.8H, V.8H) 3cy 2/cy 4/cy
14
+ * vcvt_f32_bf16 BFCVTN (V.4H, V.4S) 3cy 2/cy 4/cy
15
+ * vld1q_bf16 LD1 (V.8H) 4cy 2/cy 3/cy
16
+ * vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
17
+ * vfmaq_f64 FMLA (V.2D, V.2D, V.2D) 4cy 2/cy 4/cy
18
+ * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
19
+ * vaddvq_f64 FADDP (V.2D) 3cy 1/cy 2/cy
20
+ *
21
+ * The ARMv8.6-BF16 extension provides BFDOT for accelerated dot products on BF16 data, useful for
22
+ * angular distance (cosine similarity) computations. BF16's larger exponent range (matching FP32)
23
+ * prevents overflow during norm accumulation compared to FP16.
24
+ *
25
+ * For L2 distance, inputs are converted to F32 for subtraction, then accumulated in F64 for
26
+ * numerical stability. Angular distance leverages BFDOT directly since it only requires dot
27
+ * products, not element-wise differences.
28
+ */
29
+ #ifndef NK_SPATIAL_NEONBFDOT_H
30
+ #define NK_SPATIAL_NEONBFDOT_H
31
+
32
+ #if NK_TARGET_ARM_
33
+ #if NK_TARGET_NEONBFDOT
34
+
35
+ #include "numkong/types.h"
36
+ #include "numkong/cast/serial.h" // `nk_partial_load_b16x4_serial_`
37
+ #include "numkong/reduce/neon.h" // `nk_partial_load_b16x8_serial_`
38
+ #include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
39
+
40
+ #if defined(__cplusplus)
41
+ extern "C" {
42
+ #endif
43
+
44
+ #if defined(__clang__)
45
+ #pragma clang attribute push(__attribute__((target("arch=armv8.6-a+simd+bf16"))), apply_to = function)
46
+ #elif defined(__GNUC__)
47
+ #pragma GCC push_options
48
+ #pragma GCC target("arch=armv8.6-a+simd+bf16")
49
+ #endif
50
+
51
+ NK_PUBLIC void nk_angular_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
52
+
53
+ // Similar to `nk_angular_i8_neon`, we can use the `BFMMLA` instruction through
54
+ // the `vbfmmlaq_f32` intrinsic to compute matrix products and later drop 1/4 of values.
55
+ // The only difference is that `zip` isn't provided for `bf16` and we need to reinterpret back
56
+ // and forth before zipping. Same as with integers, on modern Arm CPUs, this "smart"
57
+ // approach is actually slower by around 25%.
58
+ //
59
+ // float32x4_t products_low_vec = vdupq_n_f32(0.0f);
60
+ // float32x4_t products_high_vec = vdupq_n_f32(0.0f);
61
+ // for (; i + 8 <= n; i += 8) {
62
+ // bfloat16x8_t a_vec = vld1q_bf16((nk_bf16_for_arm_simd_t const*)a + i);
63
+ // bfloat16x8_t b_vec = vld1q_bf16((nk_bf16_for_arm_simd_t const*)b + i);
64
+ // int16x8_t a_vec_s16 = vreinterpretq_s16_bf16(a_vec);
65
+ // int16x8_t b_vec_s16 = vreinterpretq_s16_bf16(b_vec);
66
+ // int16x8x2_t y_w_vecs_s16 = vzipq_s16(a_vec_s16, b_vec_s16);
67
+ // bfloat16x8_t y_vec = vreinterpretq_bf16_s16(y_w_vecs_s16.val[0]);
68
+ // bfloat16x8_t w_vec = vreinterpretq_bf16_s16(y_w_vecs_s16.val[1]);
69
+ // bfloat16x4_t a_low = vget_low_bf16(a_vec);
70
+ // bfloat16x4_t b_low = vget_low_bf16(b_vec);
71
+ // bfloat16x4_t a_high = vget_high_bf16(a_vec);
72
+ // bfloat16x4_t b_high = vget_high_bf16(b_vec);
73
+ // bfloat16x8_t x_vec = vcombine_bf16(a_low, b_low);
74
+ // bfloat16x8_t v_vec = vcombine_bf16(a_high, b_high);
75
+ // products_low_vec = vbfmmlaq_f32(products_low_vec, x_vec, y_vec);
76
+ // products_high_vec = vbfmmlaq_f32(products_high_vec, v_vec, w_vec);
77
+ // }
78
+ // float32x4_t products_vec = vaddq_f32(products_high_vec, products_low_vec);
79
+ // nk_f32_t a2 = products_vec[0], ab = products_vec[1], b2 = products_vec[3];
80
+ //
81
+ // Another way of accomplishing the same thing is to process the odd and even elements separately,
82
+ // using special `vbfmlaltq_f32` and `vbfmlalbq_f32` intrinsics:
83
+ //
84
+ // ab_high_vec = vbfmlaltq_f32(ab_high_vec, a_vec, b_vec);
85
+ // ab_low_vec = vbfmlalbq_f32(ab_low_vec, a_vec, b_vec);
86
+ // a2_high_vec = vbfmlaltq_f32(a2_high_vec, a_vec, a_vec);
87
+ // a2_low_vec = vbfmlalbq_f32(a2_low_vec, a_vec, a_vec);
88
+ // b2_high_vec = vbfmlaltq_f32(b2_high_vec, b_vec, b_vec);
89
+ // b2_low_vec = vbfmlalbq_f32(b2_low_vec, b_vec, b_vec);
90
+ //
91
+
92
+ float32x4_t dot_product_f32x4 = vdupq_n_f32(0);
93
+ float32x4_t a_norm_sq_f32x4 = vdupq_n_f32(0);
94
+ float32x4_t b_norm_sq_f32x4 = vdupq_n_f32(0);
95
+ bfloat16x8_t a_bf16x8, b_bf16x8;
96
+
97
+ nk_angular_bf16_neonbfdot_cycle:
98
+ if (n < 8) {
99
+ nk_b128_vec_t a_vec, b_vec;
100
+ nk_partial_load_b16x8_serial_(a, &a_vec, n);
101
+ nk_partial_load_b16x8_serial_(b, &b_vec, n);
102
+ a_bf16x8 = vreinterpretq_bf16_u16(a_vec.u16x8);
103
+ b_bf16x8 = vreinterpretq_bf16_u16(b_vec.u16x8);
104
+ n = 0;
105
+ }
106
+ else {
107
+ a_bf16x8 = vld1q_bf16((nk_bf16_for_arm_simd_t const *)a);
108
+ b_bf16x8 = vld1q_bf16((nk_bf16_for_arm_simd_t const *)b);
109
+ n -= 8, a += 8, b += 8;
110
+ }
111
+ dot_product_f32x4 = vbfdotq_f32(dot_product_f32x4, a_bf16x8, b_bf16x8);
112
+ a_norm_sq_f32x4 = vbfdotq_f32(a_norm_sq_f32x4, a_bf16x8, a_bf16x8);
113
+ b_norm_sq_f32x4 = vbfdotq_f32(b_norm_sq_f32x4, b_bf16x8, b_bf16x8);
114
+ if (n) goto nk_angular_bf16_neonbfdot_cycle;
115
+
116
+ nk_f32_t dot_product_f32 = vaddvq_f32(dot_product_f32x4);
117
+ nk_f32_t a_norm_sq_f32 = vaddvq_f32(a_norm_sq_f32x4);
118
+ nk_f32_t b_norm_sq_f32 = vaddvq_f32(b_norm_sq_f32x4);
119
+ *result = nk_angular_normalize_f32_neon_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
120
+ }
121
+
122
+ NK_PUBLIC void nk_sqeuclidean_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
123
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
124
+ bfloat16x4_t a_bf16x4, b_bf16x4;
125
+
126
+ nk_sqeuclidean_bf16_neonbfdot_cycle:
127
+ if (n < 4) {
128
+ nk_b64_vec_t a_tail, b_tail;
129
+ nk_partial_load_b16x4_serial_(a, &a_tail, n);
130
+ nk_partial_load_b16x4_serial_(b, &b_tail, n);
131
+ a_bf16x4 = vreinterpret_bf16_u16(a_tail.u16x4);
132
+ b_bf16x4 = vreinterpret_bf16_u16(b_tail.u16x4);
133
+ n = 0;
134
+ }
135
+ else {
136
+ a_bf16x4 = vld1_bf16((nk_bf16_for_arm_simd_t const *)a);
137
+ b_bf16x4 = vld1_bf16((nk_bf16_for_arm_simd_t const *)b);
138
+ n -= 4, a += 4, b += 4;
139
+ }
140
+ float32x4_t a_f32x4 = vcvt_f32_bf16(a_bf16x4);
141
+ float32x4_t b_f32x4 = vcvt_f32_bf16(b_bf16x4);
142
+ float32x4_t diff_f32x4 = vsubq_f32(a_f32x4, b_f32x4);
143
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_f32x4, diff_f32x4);
144
+ if (n) goto nk_sqeuclidean_bf16_neonbfdot_cycle;
145
+
146
+ *result = vaddvq_f32(sum_f32x4);
147
+ }
148
+ NK_PUBLIC void nk_euclidean_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
149
+ nk_sqeuclidean_bf16_neonbfdot(a, b, n, result);
150
+ *result = nk_f32_sqrt_neon(*result);
151
+ }
152
+
153
+ #if defined(__clang__)
154
+ #pragma clang attribute pop
155
+ #elif defined(__GNUC__)
156
+ #pragma GCC pop_options
157
+ #endif
158
+
159
+ #if defined(__cplusplus)
160
+ } // extern "C"
161
+ #endif
162
+
163
+ #endif // NK_TARGET_NEONBFDOT
164
+ #endif // NK_TARGET_ARM_
165
+ #endif // NK_SPATIAL_NEONBFDOT_H
@@ -0,0 +1,118 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for NEON FP16.
3
+ * @file include/numkong/spatial/neonhalf.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * @section spatial_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
10
+ *
11
+ * Intrinsic Instruction Latency Throughput
12
+ * A76 M4+/V1+/Oryon
13
+ * vfmaq_f16 FMLA (V.8H, V.8H, V.8H) 4cy 2/cy 4/cy
14
+ * vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy 4/cy
15
+ * vld1q_f16 LD1 (V.8H) 4cy 2/cy 3/cy
16
+ * vsubq_f16 FSUB (V.8H, V.8H, V.8H) 2cy 2/cy 4/cy
17
+ * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
18
+ *
19
+ * The ARMv8.2-FP16 extension enables native half-precision arithmetic, doubling the element count
20
+ * per vector register (8x F16 vs 4x F32). For spatial distance computations like L2 and angular
21
+ * distance, this halves memory bandwidth requirements.
22
+ *
23
+ * Inputs are widened from F16 to F32 for accumulation via FCVTL to preserve numerical precision
24
+ * during the squared difference summation. The subtraction and FMA operations use F32 precision
25
+ * in the accumulator to avoid catastrophic cancellation in distance computations.
26
+ */
27
+ #ifndef NK_SPATIAL_NEONHALF_H
28
+ #define NK_SPATIAL_NEONHALF_H
29
+
30
+ #if NK_TARGET_ARM_
31
+ #if NK_TARGET_NEONHALF
32
+
33
+ #include "numkong/types.h"
34
+ #include "numkong/cast/serial.h" // `nk_partial_load_b16x4_serial_`
35
+ #include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
36
+
37
+ #if defined(__cplusplus)
38
+ extern "C" {
39
+ #endif
40
+
41
+ #if defined(__clang__)
42
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
43
+ #elif defined(__GNUC__)
44
+ #pragma GCC push_options
45
+ #pragma GCC target("arch=armv8.2-a+simd+fp16")
46
+ #endif
47
+
48
+ NK_PUBLIC void nk_sqeuclidean_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
49
+ float32x4_t a_f32x4, b_f32x4;
50
+ float32x4_t distance_sq_f32x4 = vdupq_n_f32(0);
51
+
52
+ nk_sqeuclidean_f16_neonhalf_cycle:
53
+ if (n < 4) {
54
+ nk_b64_vec_t a_vec, b_vec;
55
+ nk_partial_load_b16x4_serial_(a, &a_vec, n);
56
+ nk_partial_load_b16x4_serial_(b, &b_vec, n);
57
+ a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
58
+ b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
59
+ n = 0;
60
+ }
61
+ else {
62
+ a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a));
63
+ b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b));
64
+ n -= 4, a += 4, b += 4;
65
+ }
66
+ float32x4_t diff_f32x4 = vsubq_f32(a_f32x4, b_f32x4);
67
+ distance_sq_f32x4 = vfmaq_f32(distance_sq_f32x4, diff_f32x4, diff_f32x4);
68
+ if (n) goto nk_sqeuclidean_f16_neonhalf_cycle;
69
+
70
+ *result = vaddvq_f32(distance_sq_f32x4);
71
+ }
72
+ NK_PUBLIC void nk_euclidean_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
73
+ nk_sqeuclidean_f16_neonhalf(a, b, n, result);
74
+ *result = nk_f32_sqrt_neon(*result);
75
+ }
76
+
77
+ NK_PUBLIC void nk_angular_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
78
+ float32x4_t dot_product_f32x4 = vdupq_n_f32(0), a_norm_sq_f32x4 = vdupq_n_f32(0), b_norm_sq_f32x4 = vdupq_n_f32(0);
79
+ float32x4_t a_f32x4, b_f32x4;
80
+
81
+ nk_angular_f16_neonhalf_cycle:
82
+ if (n < 4) {
83
+ nk_b64_vec_t a_vec, b_vec;
84
+ nk_partial_load_b16x4_serial_(a, &a_vec, n);
85
+ nk_partial_load_b16x4_serial_(b, &b_vec, n);
86
+ a_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(a_vec.u16x4));
87
+ b_f32x4 = vcvt_f32_f16(vreinterpret_f16_u16(b_vec.u16x4));
88
+ n = 0;
89
+ }
90
+ else {
91
+ a_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)a));
92
+ b_f32x4 = vcvt_f32_f16(vld1_f16((nk_f16_for_arm_simd_t const *)b));
93
+ n -= 4, a += 4, b += 4;
94
+ }
95
+ dot_product_f32x4 = vfmaq_f32(dot_product_f32x4, a_f32x4, b_f32x4);
96
+ a_norm_sq_f32x4 = vfmaq_f32(a_norm_sq_f32x4, a_f32x4, a_f32x4);
97
+ b_norm_sq_f32x4 = vfmaq_f32(b_norm_sq_f32x4, b_f32x4, b_f32x4);
98
+ if (n) goto nk_angular_f16_neonhalf_cycle;
99
+
100
+ nk_f32_t dot_product_f32 = vaddvq_f32(dot_product_f32x4);
101
+ nk_f32_t a_norm_sq_f32 = vaddvq_f32(a_norm_sq_f32x4);
102
+ nk_f32_t b_norm_sq_f32 = vaddvq_f32(b_norm_sq_f32x4);
103
+ *result = nk_angular_normalize_f32_neon_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
104
+ }
105
+
106
+ #if defined(__clang__)
107
+ #pragma clang attribute pop
108
+ #elif defined(__GNUC__)
109
+ #pragma GCC pop_options
110
+ #endif
111
+
112
+ #if defined(__cplusplus)
113
+ } // extern "C"
114
+ #endif
115
+
116
+ #endif // NK_TARGET_NEONHALF
117
+ #endif // NK_TARGET_ARM_
118
+ #endif // NK_SPATIAL_NEONHALF_H
@@ -0,0 +1,261 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for NEON SDOT.
3
+ * @file include/numkong/spatial/neonsdot.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * @section spatial_neonsdot_instructions ARM NEON SDOT/UDOT Instructions (ARMv8.4-DotProd)
10
+ *
11
+ * Intrinsic Instruction Latency Throughput
12
+ * A76 M4+/V1+/Oryon
13
+ * vdotq_s32 SDOT (V.4S, V.16B, V.16B) 3cy 2/cy 4/cy
14
+ * vdotq_u32 UDOT (V.4S, V.16B, V.16B) 3cy 2/cy 4/cy
15
+ * vabdq_s8 SABD (V.16B, V.16B, V.16B) 2cy 2/cy 4/cy
16
+ * vabdq_u8 UABD (V.16B, V.16B, V.16B) 2cy 2/cy 4/cy
17
+ * vld1q_s8 LD1 (V.16B) 4cy 2/cy 3/cy
18
+ * vld1q_u8 LD1 (V.16B) 4cy 2/cy 3/cy
19
+ * vaddvq_s32 ADDV (V.4S) 4cy 1/cy 2/cy
20
+ * vaddvq_u32 ADDV (V.4S) 4cy 1/cy 2/cy
21
+ *
22
+ * The ARMv8.4-DotProd extension provides SDOT/UDOT for int8 dot products and SABD/UABD for
23
+ * absolute differences, enabling L2 and angular distance on quantized embeddings.
24
+ * For L2 distance, SABD computes |a-b| per byte, then UDOT squares and accumulates.
25
+ *
26
+ * Angular distance uses SDOT/UDOT directly for dot product and norm computations. This enables
27
+ * similarity search on int8-quantized embeddings, achieving 4x memory reduction vs FP32
28
+ * while maintaining reasonable precision for nearest-neighbor search applications.
29
+ */
30
+ #ifndef NK_SPATIAL_NEONSDOT_H
31
+ #define NK_SPATIAL_NEONSDOT_H
32
+
33
+ #if NK_TARGET_ARM_
34
+ #if NK_TARGET_NEONSDOT
35
+
36
+ #include "numkong/types.h"
37
+ #include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
38
+
39
+ #if defined(__cplusplus)
40
+ extern "C" {
41
+ #endif
42
+
43
+ #if defined(__clang__)
44
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+dotprod"))), apply_to = function)
45
+ #elif defined(__GNUC__)
46
+ #pragma GCC push_options
47
+ #pragma GCC target("arch=armv8.2-a+dotprod")
48
+ #endif
49
+
50
+ NK_PUBLIC void nk_sqeuclidean_i8_neonsdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_u32_t *result) {
51
+
52
+ // The naive approach is to upcast 8-bit signed integers into 16-bit signed integers
53
+ // for subtraction, then multiply within 16-bit integers and accumulate the results
54
+ // into 32-bit integers. This approach is slow on modern Arm CPUs. On Graviton 4,
55
+ // that approach results in 17 GB/s of throughput, compared to 39 GB/s for `i8`
56
+ // dot-products.
57
+ //
58
+ // Luckily we can use the `vabdq_s8` which technically returns `i8` values, but it's a
59
+ // matter of reinterpret-casting! That approach boosts us to 33 GB/s of throughput.
60
+ uint32x4_t distance_sq_u32x4 = vdupq_n_u32(0);
61
+ nk_size_t i = 0;
62
+ for (; i + 16 <= n; i += 16) {
63
+ int8x16_t a_i8x16 = vld1q_s8(a + i);
64
+ int8x16_t b_i8x16 = vld1q_s8(b + i);
65
+ uint8x16_t diff_u8x16 = vreinterpretq_u8_s8(vabdq_s8(a_i8x16, b_i8x16));
66
+ distance_sq_u32x4 = vdotq_u32(distance_sq_u32x4, diff_u8x16, diff_u8x16);
67
+ }
68
+ nk_u32_t distance_sq_u32 = vaddvq_u32(distance_sq_u32x4);
69
+ for (; i < n; ++i) {
70
+ nk_i32_t diff_i32 = (nk_i32_t)a[i] - b[i];
71
+ distance_sq_u32 += (nk_u32_t)(diff_i32 * diff_i32);
72
+ }
73
+ *result = distance_sq_u32;
74
+ }
75
+ NK_PUBLIC void nk_euclidean_i8_neonsdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
76
+ nk_u32_t distance_sq_u32;
77
+ nk_sqeuclidean_i8_neonsdot(a, b, n, &distance_sq_u32);
78
+ *result = nk_f32_sqrt_neon((nk_f32_t)distance_sq_u32);
79
+ }
80
+
81
+ NK_PUBLIC void nk_angular_i8_neonsdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
82
+
83
+ nk_size_t i = 0;
84
+
85
+ // Variant 1.
86
+ // If the 128-bit `vdot_s32` intrinsic is unavailable, we can use the 64-bit `vdot_s32`.
87
+ //
88
+ // int32x4_t ab_vec = vdupq_n_s32(0);
89
+ // int32x4_t a2_vec = vdupq_n_s32(0);
90
+ // int32x4_t b2_vec = vdupq_n_s32(0);
91
+ // for (nk_size_t i = 0; i != n; i += 8) {
92
+ // int16x8_t a_vec = vmovl_s8(vld1_s8(a + i));
93
+ // int16x8_t b_vec = vmovl_s8(vld1_s8(b + i));
94
+ // int16x8_t ab_part_vec = vmulq_s16(a_vec, b_vec);
95
+ // int16x8_t a2_part_vec = vmulq_s16(a_vec, a_vec);
96
+ // int16x8_t b2_part_vec = vmulq_s16(b_vec, b_vec);
97
+ // ab_vec = vaddq_s32(ab_vec, vaddq_s32(vmovl_s16(vget_high_s16(ab_part_vec)), //
98
+ // vmovl_s16(vget_low_s16(ab_part_vec))));
99
+ // a2_vec = vaddq_s32(a2_vec, vaddq_s32(vmovl_s16(vget_high_s16(a2_part_vec)), //
100
+ // vmovl_s16(vget_low_s16(a2_part_vec))));
101
+ // b2_vec = vaddq_s32(b2_vec, vaddq_s32(vmovl_s16(vget_high_s16(b2_part_vec)), //
102
+ // vmovl_s16(vget_low_s16(b2_part_vec))));
103
+ // }
104
+ //
105
+ // Variant 2.
106
+ // With the 128-bit `vdotq_s32` intrinsic, we can use the following code:
107
+ //
108
+ // for (; i + 16 <= n; i += 16) {
109
+ // int8x16_t a_vec = vld1q_s8(a + i);
110
+ // int8x16_t b_vec = vld1q_s8(b + i);
111
+ // ab_vec = vdotq_s32(ab_vec, a_vec, b_vec);
112
+ // a2_vec = vdotq_s32(a2_vec, a_vec, a_vec);
113
+ // b2_vec = vdotq_s32(b2_vec, b_vec, b_vec);
114
+ // }
115
+ //
116
+ // Variant 3.
117
+ // To use MMLA instructions, we need to reorganize the contents of the vectors.
118
+ // On input we have `a_vec` and `b_vec`:
119
+ //
120
+ // a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15]
121
+ // b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7], b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15]
122
+ //
123
+ // We will be multiplying matrices of size 2x8 and 8x2. So we need to perform a few shuffles:
124
+ //
125
+ // X =
126
+ // a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7],
127
+ // b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]
128
+ // Y =
129
+ // a[0], b[0],
130
+ // a[1], b[1],
131
+ // a[2], b[2],
132
+ // a[3], b[3],
133
+ // a[4], b[4],
134
+ // a[5], b[5],
135
+ // a[6], b[6],
136
+ // a[7], b[7]
137
+ //
138
+ // V =
139
+ // a[8], a[9], a[10], a[11], a[12], a[13], a[14], a[15],
140
+ // b[8], b[9], b[10], b[11], b[12], b[13], b[14], b[15]
141
+ // W =
142
+ // a[8], b[8],
143
+ // a[9], b[9],
144
+ // a[10], b[10],
145
+ // a[11], b[11],
146
+ // a[12], b[12],
147
+ // a[13], b[13],
148
+ // a[14], b[14],
149
+ // a[15], b[15]
150
+ //
151
+ // Performing matrix multiplications we can aggregate into a matrix `products_low_vec` and `products_high_vec`:
152
+ //
153
+ // X * X, X * Y V * W, V * V
154
+ // Y * X, Y * Y W * W, W * V
155
+ //
156
+ // Of those values we need only 3/4, as the (X * Y) and (Y * X) are the same.
157
+ //
158
+ // int32x4_t products_low_vec = vdupq_n_s32(0), products_high_vec = vdupq_n_s32(0);
159
+ // int8x16_t a_low_b_low_vec, a_high_b_high_vec;
160
+ // for (; i + 16 <= n; i += 16) {
161
+ // int8x16_t a_vec = vld1q_s8(a + i);
162
+ // int8x16_t b_vec = vld1q_s8(b + i);
163
+ // int8x16x2_t y_w_vecs = vzipq_s8(a_vec, b_vec);
164
+ // int8x16_t x_vec = vcombine_s8(vget_low_s8(a_vec), vget_low_s8(b_vec));
165
+ // int8x16_t v_vec = vcombine_s8(vget_high_s8(a_vec), vget_high_s8(b_vec));
166
+ // products_low_vec = vmmlaq_s32(products_low_vec, x_vec, y_w_vecs.val[0]);
167
+ // products_high_vec = vmmlaq_s32(products_high_vec, v_vec, y_w_vecs.val[1]);
168
+ // }
169
+ // int32x4_t products_vec = vaddq_s32(products_high_vec, products_low_vec);
170
+ // nk_i32_t a2 = products_vec[0];
171
+ // nk_i32_t ab = products_vec[1];
172
+ // nk_i32_t b2 = products_vec[3];
173
+ //
174
+ // That solution is elegant, but it requires the additional `+i8mm` extension and is currently slower,
175
+ // at least on AWS Graviton 3.
176
+ int32x4_t dot_product_i32x4 = vdupq_n_s32(0);
177
+ int32x4_t a_norm_sq_i32x4 = vdupq_n_s32(0);
178
+ int32x4_t b_norm_sq_i32x4 = vdupq_n_s32(0);
179
+ for (; i + 16 <= n; i += 16) {
180
+ int8x16_t a_i8x16 = vld1q_s8(a + i);
181
+ int8x16_t b_i8x16 = vld1q_s8(b + i);
182
+ dot_product_i32x4 = vdotq_s32(dot_product_i32x4, a_i8x16, b_i8x16);
183
+ a_norm_sq_i32x4 = vdotq_s32(a_norm_sq_i32x4, a_i8x16, a_i8x16);
184
+ b_norm_sq_i32x4 = vdotq_s32(b_norm_sq_i32x4, b_i8x16, b_i8x16);
185
+ }
186
+ nk_i32_t dot_product_i32 = vaddvq_s32(dot_product_i32x4);
187
+ nk_i32_t a_norm_sq_i32 = vaddvq_s32(a_norm_sq_i32x4);
188
+ nk_i32_t b_norm_sq_i32 = vaddvq_s32(b_norm_sq_i32x4);
189
+
190
+ // Take care of the tail:
191
+ for (; i < n; ++i) {
192
+ nk_i32_t a_element_i32 = a[i], b_element_i32 = b[i];
193
+ dot_product_i32 += a_element_i32 * b_element_i32;
194
+ a_norm_sq_i32 += a_element_i32 * a_element_i32;
195
+ b_norm_sq_i32 += b_element_i32 * b_element_i32;
196
+ }
197
+
198
+ *result = nk_angular_normalize_f32_neon_(dot_product_i32, a_norm_sq_i32, b_norm_sq_i32);
199
+ }
200
+
201
+ NK_PUBLIC void nk_sqeuclidean_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_u32_t *result) {
202
+ uint32x4_t distance_sq_u32x4 = vdupq_n_u32(0);
203
+ nk_size_t i = 0;
204
+ for (; i + 16 <= n; i += 16) {
205
+ uint8x16_t a_u8x16 = vld1q_u8(a + i);
206
+ uint8x16_t b_u8x16 = vld1q_u8(b + i);
207
+ uint8x16_t diff_u8x16 = vabdq_u8(a_u8x16, b_u8x16);
208
+ distance_sq_u32x4 = vdotq_u32(distance_sq_u32x4, diff_u8x16, diff_u8x16);
209
+ }
210
+ nk_u32_t distance_sq_u32 = vaddvq_u32(distance_sq_u32x4);
211
+ for (; i < n; ++i) {
212
+ nk_i32_t diff_i32 = (nk_i32_t)a[i] - b[i];
213
+ distance_sq_u32 += (nk_u32_t)(diff_i32 * diff_i32);
214
+ }
215
+ *result = distance_sq_u32;
216
+ }
217
+ NK_PUBLIC void nk_euclidean_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
218
+ nk_u32_t d2;
219
+ nk_sqeuclidean_u8_neonsdot(a, b, n, &d2);
220
+ *result = nk_f32_sqrt_neon((nk_f32_t)d2);
221
+ }
222
+
223
+ NK_PUBLIC void nk_angular_u8_neonsdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
224
+
225
+ nk_size_t i = 0;
226
+ uint32x4_t ab_vec = vdupq_n_u32(0);
227
+ uint32x4_t a2_vec = vdupq_n_u32(0);
228
+ uint32x4_t b2_vec = vdupq_n_u32(0);
229
+ for (; i + 16 <= n; i += 16) {
230
+ uint8x16_t a_vec = vld1q_u8(a + i);
231
+ uint8x16_t b_vec = vld1q_u8(b + i);
232
+ ab_vec = vdotq_u32(ab_vec, a_vec, b_vec);
233
+ a2_vec = vdotq_u32(a2_vec, a_vec, a_vec);
234
+ b2_vec = vdotq_u32(b2_vec, b_vec, b_vec);
235
+ }
236
+ nk_u32_t ab = vaddvq_u32(ab_vec);
237
+ nk_u32_t a2 = vaddvq_u32(a2_vec);
238
+ nk_u32_t b2 = vaddvq_u32(b2_vec);
239
+
240
+ // Take care of the tail:
241
+ for (; i < n; ++i) {
242
+ nk_u32_t ai = a[i], bi = b[i];
243
+ ab += ai * bi, a2 += ai * ai, b2 += bi * bi;
244
+ }
245
+
246
+ *result = nk_angular_normalize_f32_neon_(ab, a2, b2);
247
+ }
248
+
249
+ #if defined(__clang__)
250
+ #pragma clang attribute pop
251
+ #elif defined(__GNUC__)
252
+ #pragma GCC pop_options
253
+ #endif
254
+
255
+ #if defined(__cplusplus)
256
+ } // extern "C"
257
+ #endif
258
+
259
+ #endif // NK_TARGET_NEONSDOT
260
+ #endif // NK_TARGET_ARM_
261
+ #endif // NK_SPATIAL_NEONSDOT_H