numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,773 @@
1
+ /**
2
+ * @brief SIMD-accelerated Spatial Similarity Measures for NEON.
3
+ * @file include/numkong/spatial/neon.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/spatial.h
8
+ *
9
+ * @section spatial_neon_instructions Key NEON Spatial Instructions
10
+ *
11
+ * ARM NEON instructions for distance computations:
12
+ *
13
+ * Intrinsic Instruction Latency Throughput
14
+ * A76 M4+/V1+/Oryon
15
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
16
+ * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy 2/cy 4/cy
17
+ * vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
18
+ * vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
19
+ * vrsqrteq_f32 FRSQRTE (V.4S, V.4S) 2cy 2/cy 2/cy
20
+ * vsqrtq_f32 FSQRT (V.4S, V.4S) 9-12cy 0.25/cy 0.25/cy
21
+ * vrecpeq_f32 FRECPE (V.4S, V.4S) 2cy 2/cy 2/cy
22
+ *
23
+ * FRSQRTE provides ~8-bit precision; two Newton-Raphson iterations via vrsqrtsq_f32 achieve
24
+ * ~23-bit precision, sufficient for f32. This is much faster than FSQRT (0.25/cy).
25
+ *
26
+ * Distance computations (L2, angular) benefit from 2x throughput on 4-pipe cores (Apple M4+,
27
+ * Graviton3+, Oryon), but FSQRT remains slow on all cores. Use rsqrt+NR when precision allows.
28
+ */
29
+ #ifndef NK_SPATIAL_NEON_H
30
+ #define NK_SPATIAL_NEON_H
31
+
32
+ #if NK_TARGET_ARM_
33
+ #if NK_TARGET_NEON
34
+
35
+ #include "numkong/types.h"
36
+ #include "numkong/scalar/neon.h" // `nk_f32_sqrt_neon`
37
+ #include "numkong/dot/neon.h" // `nk_dot_stable_sum_f64x2_neon_`
38
+
39
+ #if defined(__cplusplus)
40
+ extern "C" {
41
+ #endif
42
+
43
+ #if defined(__clang__)
44
+ #pragma clang attribute push(__attribute__((target("arch=armv8-a+simd"))), apply_to = function)
45
+ #elif defined(__GNUC__)
46
+ #pragma GCC push_options
47
+ #pragma GCC target("arch=armv8-a+simd")
48
+ #endif
49
+
50
+ /**
51
+ * @brief Reciprocal square root of 4 floats with Newton-Raphson refinement.
52
+ *
53
+ * Uses `vrsqrteq_f32` (~8-bit initial estimate) followed by two Newton-Raphson iterations
54
+ * via `vrsqrtsq_f32`, achieving ~23-bit precision — sufficient for f32.
55
+ * Much faster than `vsqrtq_f32` (2 cy vs 9-12 cy latency, 2/cy vs 0.25/cy throughput).
56
+ */
57
+ NK_INTERNAL float32x4_t nk_rsqrt_f32x4_neon_(float32x4_t x) {
58
+ float32x4_t rsqrt = vrsqrteq_f32(x);
59
+ rsqrt = vmulq_f32(rsqrt, vrsqrtsq_f32(vmulq_f32(x, rsqrt), rsqrt));
60
+ rsqrt = vmulq_f32(rsqrt, vrsqrtsq_f32(vmulq_f32(x, rsqrt), rsqrt));
61
+ return rsqrt;
62
+ }
63
+
64
+ /**
65
+ * @brief Reciprocal square root of 2 doubles with Newton-Raphson refinement.
66
+ *
67
+ * Uses `vrsqrteq_f64` (~8-bit initial estimate) followed by three Newton-Raphson iterations
68
+ * via `vrsqrtsq_f64`, achieving ~48-bit precision — reasonable for f64 distance computations
69
+ * where the final result is often narrowed to f32. For full 52-bit mantissa fidelity,
70
+ * prefer `vsqrtq_f64` instead.
71
+ */
72
+ NK_INTERNAL float64x2_t nk_rsqrt_f64x2_neon_(float64x2_t x) {
73
+ float64x2_t rsqrt = vrsqrteq_f64(x);
74
+ rsqrt = vmulq_f64(rsqrt, vrsqrtsq_f64(vmulq_f64(x, rsqrt), rsqrt));
75
+ rsqrt = vmulq_f64(rsqrt, vrsqrtsq_f64(vmulq_f64(x, rsqrt), rsqrt));
76
+ rsqrt = vmulq_f64(rsqrt, vrsqrtsq_f64(vmulq_f64(x, rsqrt), rsqrt));
77
+ return rsqrt;
78
+ }
79
+
80
+ NK_INTERNAL nk_f32_t nk_angular_normalize_f32_neon_(nk_f32_t ab, nk_f32_t a2, nk_f32_t b2) {
81
+ if (a2 == 0 && b2 == 0) return 0;
82
+ if (ab == 0) return 1;
83
+ nk_f32_t squares_arr[2] = {a2, b2};
84
+ float32x2_t squares = vld1_f32(squares_arr);
85
+ // Unlike x86, Arm NEON manuals don't explicitly mention the accuracy of their `rsqrt` approximation.
86
+ // Third-party research suggests that it's less accurate than SSE instructions, having an error of 1.5×2⁻¹².
87
+ // One or two rounds of Newton-Raphson refinement are recommended to improve the accuracy.
88
+ // https://github.com/lighttransport/embree-aarch64/issues/24
89
+ // https://github.com/lighttransport/embree-aarch64/blob/3f75f8cb4e553d13dced941b5fefd4c826835a6b/common/math/math.h#L137-L145
90
+ float32x2_t rsqrts = vrsqrte_f32(squares);
91
+ // Perform two rounds of Newton-Raphson refinement:
92
+ // https://en.wikipedia.org/wiki/Newton%27s_method
93
+ rsqrts = vmul_f32(rsqrts, vrsqrts_f32(vmul_f32(squares, rsqrts), rsqrts));
94
+ rsqrts = vmul_f32(rsqrts, vrsqrts_f32(vmul_f32(squares, rsqrts), rsqrts));
95
+ vst1_f32(squares_arr, rsqrts);
96
+ nk_f32_t result = 1 - ab * squares_arr[0] * squares_arr[1];
97
+ return result > 0 ? result : 0;
98
+ }
99
+
100
+ NK_INTERNAL nk_f64_t nk_angular_normalize_f64_neon_(nk_f64_t ab, nk_f64_t a2, nk_f64_t b2) {
101
+ if (a2 == 0 && b2 == 0) return 0;
102
+ if (ab == 0) return 1;
103
+ nk_f64_t squares_arr[2] = {a2, b2};
104
+ float64x2_t squares = vld1q_f64(squares_arr);
105
+
106
+ // Unlike x86, Arm NEON manuals don't explicitly mention the accuracy of their `rsqrt` approximation.
107
+ // Third-party research suggests that it's less accurate than SSE instructions, having an error of 1.5×2⁻¹².
108
+ // One or two rounds of Newton-Raphson refinement are recommended to improve the accuracy.
109
+ // https://github.com/lighttransport/embree-aarch64/issues/24
110
+ // https://github.com/lighttransport/embree-aarch64/blob/3f75f8cb4e553d13dced941b5fefd4c826835a6b/common/math/math.h#L137-L145
111
+ float64x2_t rsqrts_f64x2 = vrsqrteq_f64(squares);
112
+ // Perform three rounds of Newton-Raphson refinement for f64 precision (~48 bits):
113
+ // https://en.wikipedia.org/wiki/Newton%27s_method
114
+ rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(squares, rsqrts_f64x2), rsqrts_f64x2));
115
+ rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(squares, rsqrts_f64x2), rsqrts_f64x2));
116
+ rsqrts_f64x2 = vmulq_f64(rsqrts_f64x2, vrsqrtsq_f64(vmulq_f64(squares, rsqrts_f64x2), rsqrts_f64x2));
117
+ vst1q_f64(squares_arr, rsqrts_f64x2);
118
+ nk_f64_t result = 1 - ab * squares_arr[0] * squares_arr[1];
119
+ return result > 0 ? result : 0;
120
+ }
121
+
122
+ #pragma region - Traditional Floats
123
+
124
+ NK_PUBLIC void nk_sqeuclidean_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
125
+ // Accumulate in f64 for numerical stability (2 f32s per iteration, avoids slow vget_low/high)
126
+ float64x2_t sum_f64x2 = vdupq_n_f64(0);
127
+ nk_size_t i = 0;
128
+ for (; i + 2 <= n; i += 2) {
129
+ float32x2_t a_f32x2 = vld1_f32(a + i);
130
+ float32x2_t b_f32x2 = vld1_f32(b + i);
131
+ float32x2_t diff_f32x2 = vsub_f32(a_f32x2, b_f32x2);
132
+ float64x2_t diff_f64x2 = vcvt_f64_f32(diff_f32x2);
133
+ sum_f64x2 = vfmaq_f64(sum_f64x2, diff_f64x2, diff_f64x2);
134
+ }
135
+ nk_f64_t sum_f64 = vaddvq_f64(sum_f64x2);
136
+ for (; i < n; ++i) {
137
+ nk_f64_t diff_f64 = (nk_f64_t)a[i] - (nk_f64_t)b[i];
138
+ sum_f64 += diff_f64 * diff_f64;
139
+ }
140
+ *result = sum_f64;
141
+ }
142
+
143
+ NK_PUBLIC void nk_euclidean_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
144
+ nk_sqeuclidean_f32_neon(a, b, n, result);
145
+ *result = nk_f64_sqrt_neon(*result);
146
+ }
147
+
148
+ NK_PUBLIC void nk_angular_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *result) {
149
+ // Accumulate in f64 for numerical stability (2 f32s per iteration, avoids slow vget_low/high)
150
+ float64x2_t ab_f64x2 = vdupq_n_f64(0);
151
+ float64x2_t a2_f64x2 = vdupq_n_f64(0);
152
+ float64x2_t b2_f64x2 = vdupq_n_f64(0);
153
+ nk_size_t i = 0;
154
+ for (; i + 2 <= n; i += 2) {
155
+ float32x2_t a_f32x2 = vld1_f32(a + i);
156
+ float32x2_t b_f32x2 = vld1_f32(b + i);
157
+ float64x2_t a_f64x2 = vcvt_f64_f32(a_f32x2);
158
+ float64x2_t b_f64x2 = vcvt_f64_f32(b_f32x2);
159
+ ab_f64x2 = vfmaq_f64(ab_f64x2, a_f64x2, b_f64x2);
160
+ a2_f64x2 = vfmaq_f64(a2_f64x2, a_f64x2, a_f64x2);
161
+ b2_f64x2 = vfmaq_f64(b2_f64x2, b_f64x2, b_f64x2);
162
+ }
163
+ nk_f64_t ab_f64 = vaddvq_f64(ab_f64x2);
164
+ nk_f64_t a2_f64 = vaddvq_f64(a2_f64x2);
165
+ nk_f64_t b2_f64 = vaddvq_f64(b2_f64x2);
166
+ for (; i < n; ++i) {
167
+ nk_f64_t ai = (nk_f64_t)a[i], bi = (nk_f64_t)b[i];
168
+ ab_f64 += ai * bi, a2_f64 += ai * ai, b2_f64 += bi * bi;
169
+ }
170
+ *result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
171
+ }
172
+
173
+ NK_PUBLIC void nk_sqeuclidean_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
174
+ float64x2_t sum_f64x2 = vdupq_n_f64(0);
175
+ float64x2_t a_f64x2, b_f64x2;
176
+
177
+ nk_sqeuclidean_f64_neon_cycle:
178
+ if (n < 2) {
179
+ nk_b128_vec_t a_tail, b_tail;
180
+ nk_partial_load_b64x2_serial_(a, &a_tail, n);
181
+ nk_partial_load_b64x2_serial_(b, &b_tail, n);
182
+ a_f64x2 = a_tail.f64x2;
183
+ b_f64x2 = b_tail.f64x2;
184
+ n = 0;
185
+ }
186
+ else {
187
+ a_f64x2 = vld1q_f64(a);
188
+ b_f64x2 = vld1q_f64(b);
189
+ a += 2, b += 2, n -= 2;
190
+ }
191
+ float64x2_t diff_f64x2 = vsubq_f64(a_f64x2, b_f64x2);
192
+ sum_f64x2 = vfmaq_f64(sum_f64x2, diff_f64x2, diff_f64x2);
193
+ if (n) goto nk_sqeuclidean_f64_neon_cycle;
194
+
195
+ *result = vaddvq_f64(sum_f64x2);
196
+ }
197
+
198
+ NK_PUBLIC void nk_euclidean_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
199
+ nk_sqeuclidean_f64_neon(a, b, n, result);
200
+ *result = nk_f64_sqrt_neon(*result);
201
+ }
202
+
203
+ NK_PUBLIC void nk_angular_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *result) {
204
+ // Dot2 (Ogita-Rump-Oishi) for cross-product ab (may have cancellation),
205
+ // simple FMA for self-products a2/b2 (all positive, no cancellation)
206
+ float64x2_t ab_sum_f64x2 = vdupq_n_f64(0);
207
+ float64x2_t ab_compensation_f64x2 = vdupq_n_f64(0);
208
+ float64x2_t a2_f64x2 = vdupq_n_f64(0);
209
+ float64x2_t b2_f64x2 = vdupq_n_f64(0);
210
+ float64x2_t a_f64x2, b_f64x2;
211
+
212
+ nk_angular_f64_neon_cycle:
213
+ if (n < 2) {
214
+ nk_b128_vec_t a_tail, b_tail;
215
+ nk_partial_load_b64x2_serial_(a, &a_tail, n);
216
+ nk_partial_load_b64x2_serial_(b, &b_tail, n);
217
+ a_f64x2 = a_tail.f64x2;
218
+ b_f64x2 = b_tail.f64x2;
219
+ n = 0;
220
+ }
221
+ else {
222
+ a_f64x2 = vld1q_f64(a);
223
+ b_f64x2 = vld1q_f64(b);
224
+ a += 2, b += 2, n -= 2;
225
+ }
226
+ // TwoProd for ab: product = a*b, error = fma(a,b,-product)
227
+ float64x2_t product_f64x2 = vmulq_f64(a_f64x2, b_f64x2);
228
+ float64x2_t product_error_f64x2 = vnegq_f64(vfmsq_f64(product_f64x2, a_f64x2, b_f64x2));
229
+ // TwoSum: (t, q) = TwoSum(sum, product)
230
+ float64x2_t tentative_sum_f64x2 = vaddq_f64(ab_sum_f64x2, product_f64x2);
231
+ float64x2_t virtual_addend_f64x2 = vsubq_f64(tentative_sum_f64x2, ab_sum_f64x2);
232
+ float64x2_t sum_error_f64x2 = vaddq_f64(
233
+ vsubq_f64(ab_sum_f64x2, vsubq_f64(tentative_sum_f64x2, virtual_addend_f64x2)),
234
+ vsubq_f64(product_f64x2, virtual_addend_f64x2));
235
+ ab_sum_f64x2 = tentative_sum_f64x2;
236
+ ab_compensation_f64x2 = vaddq_f64(ab_compensation_f64x2, vaddq_f64(sum_error_f64x2, product_error_f64x2));
237
+ // Simple FMA for self-products (no cancellation)
238
+ a2_f64x2 = vfmaq_f64(a2_f64x2, a_f64x2, a_f64x2);
239
+ b2_f64x2 = vfmaq_f64(b2_f64x2, b_f64x2, b_f64x2);
240
+ if (n) goto nk_angular_f64_neon_cycle;
241
+
242
+ *result = nk_angular_normalize_f64_neon_( //
243
+ nk_dot_stable_sum_f64x2_neon_(ab_sum_f64x2, ab_compensation_f64x2), vaddvq_f64(a2_f64x2), vaddvq_f64(b2_f64x2));
244
+ }
245
+
246
+ #pragma endregion - Traditional Floats
247
+ #pragma region - Smaller Floats
248
+
249
+ NK_PUBLIC void nk_sqeuclidean_bf16_neon(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
250
+ uint16x8_t a_u16x8, b_u16x8;
251
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
252
+ nk_sqeuclidean_bf16_neon_cycle:
253
+ if (n < 8) {
254
+ nk_b128_vec_t a_vec, b_vec;
255
+ nk_partial_load_b16x8_serial_(a, &a_vec, n);
256
+ nk_partial_load_b16x8_serial_(b, &b_vec, n);
257
+ a_u16x8 = a_vec.u16x8;
258
+ b_u16x8 = b_vec.u16x8;
259
+ n = 0;
260
+ }
261
+ else {
262
+ a_u16x8 = vld1q_u16((nk_u16_t const *)a);
263
+ b_u16x8 = vld1q_u16((nk_u16_t const *)b);
264
+ a += 8, b += 8, n -= 8;
265
+ }
266
+ float32x4_t a_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a_u16x8), 16));
267
+ float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(a_u16x8), 16));
268
+ float32x4_t b_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b_u16x8), 16));
269
+ float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b_u16x8), 16));
270
+ float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
271
+ float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
272
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
273
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_high_f32x4, diff_high_f32x4);
274
+ if (n) goto nk_sqeuclidean_bf16_neon_cycle;
275
+ *result = vaddvq_f32(sum_f32x4);
276
+ }
277
+
278
+ NK_PUBLIC void nk_euclidean_bf16_neon(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
279
+ nk_sqeuclidean_bf16_neon(a, b, n, result);
280
+ *result = nk_f32_sqrt_neon(*result);
281
+ }
282
+
283
+ NK_PUBLIC void nk_angular_bf16_neon(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
284
+ uint16x8_t a_u16x8, b_u16x8;
285
+ float32x4_t ab_f32x4 = vdupq_n_f32(0);
286
+ float32x4_t a2_f32x4 = vdupq_n_f32(0);
287
+ float32x4_t b2_f32x4 = vdupq_n_f32(0);
288
+ nk_angular_bf16_neon_cycle:
289
+ if (n < 8) {
290
+ nk_b128_vec_t a_vec, b_vec;
291
+ nk_partial_load_b16x8_serial_(a, &a_vec, n);
292
+ nk_partial_load_b16x8_serial_(b, &b_vec, n);
293
+ a_u16x8 = a_vec.u16x8;
294
+ b_u16x8 = b_vec.u16x8;
295
+ n = 0;
296
+ }
297
+ else {
298
+ a_u16x8 = vld1q_u16((nk_u16_t const *)a);
299
+ b_u16x8 = vld1q_u16((nk_u16_t const *)b);
300
+ a += 8, b += 8, n -= 8;
301
+ }
302
+ float32x4_t a_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a_u16x8), 16));
303
+ float32x4_t a_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(a_u16x8), 16));
304
+ float32x4_t b_low_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b_u16x8), 16));
305
+ float32x4_t b_high_f32x4 = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b_u16x8), 16));
306
+ ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
307
+ ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
308
+ a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
309
+ a2_f32x4 = vfmaq_f32(a2_f32x4, a_high_f32x4, a_high_f32x4);
310
+ b2_f32x4 = vfmaq_f32(b2_f32x4, b_low_f32x4, b_low_f32x4);
311
+ b2_f32x4 = vfmaq_f32(b2_f32x4, b_high_f32x4, b_high_f32x4);
312
+ if (n) goto nk_angular_bf16_neon_cycle;
313
+ nk_f32_t ab = vaddvq_f32(ab_f32x4);
314
+ nk_f32_t a2 = vaddvq_f32(a2_f32x4);
315
+ nk_f32_t b2 = vaddvq_f32(b2_f32x4);
316
+ *result = nk_angular_normalize_f32_neon_(ab, a2, b2);
317
+ }
318
+
319
+ NK_PUBLIC void nk_sqeuclidean_e2m3_neon(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
320
+ float16x8_t a_f16x8, b_f16x8;
321
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
322
+ nk_sqeuclidean_e2m3_neon_cycle:
323
+ if (n < 8) {
324
+ nk_b64_vec_t a_vec, b_vec;
325
+ nk_partial_load_b8x8_serial_(a, &a_vec, n);
326
+ nk_partial_load_b8x8_serial_(b, &b_vec, n);
327
+ a_f16x8 = nk_e2m3x8_to_f16x8_neon_(a_vec.u8x8);
328
+ b_f16x8 = nk_e2m3x8_to_f16x8_neon_(b_vec.u8x8);
329
+ n = 0;
330
+ }
331
+ else {
332
+ a_f16x8 = nk_e2m3x8_to_f16x8_neon_(vld1_u8(a));
333
+ b_f16x8 = nk_e2m3x8_to_f16x8_neon_(vld1_u8(b));
334
+ a += 8, b += 8, n -= 8;
335
+ }
336
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
337
+ float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
338
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
339
+ float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
340
+ float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
341
+ float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
342
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
343
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_high_f32x4, diff_high_f32x4);
344
+ if (n) goto nk_sqeuclidean_e2m3_neon_cycle;
345
+ *result = vaddvq_f32(sum_f32x4);
346
+ }
347
+
348
+ NK_PUBLIC void nk_euclidean_e2m3_neon(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
349
+ nk_sqeuclidean_e2m3_neon(a, b, n, result);
350
+ *result = nk_f32_sqrt_neon(*result);
351
+ }
352
+
353
+ NK_PUBLIC void nk_angular_e2m3_neon(nk_e2m3_t const *a, nk_e2m3_t const *b, nk_size_t n, nk_f32_t *result) {
354
+ float16x8_t a_f16x8, b_f16x8;
355
+ float32x4_t ab_f32x4 = vdupq_n_f32(0);
356
+ float32x4_t a2_f32x4 = vdupq_n_f32(0);
357
+ float32x4_t b2_f32x4 = vdupq_n_f32(0);
358
+ nk_angular_e2m3_neon_cycle:
359
+ if (n < 8) {
360
+ nk_b64_vec_t a_vec, b_vec;
361
+ nk_partial_load_b8x8_serial_(a, &a_vec, n);
362
+ nk_partial_load_b8x8_serial_(b, &b_vec, n);
363
+ a_f16x8 = nk_e2m3x8_to_f16x8_neon_(a_vec.u8x8);
364
+ b_f16x8 = nk_e2m3x8_to_f16x8_neon_(b_vec.u8x8);
365
+ n = 0;
366
+ }
367
+ else {
368
+ a_f16x8 = nk_e2m3x8_to_f16x8_neon_(vld1_u8(a));
369
+ b_f16x8 = nk_e2m3x8_to_f16x8_neon_(vld1_u8(b));
370
+ a += 8, b += 8, n -= 8;
371
+ }
372
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
373
+ float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
374
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
375
+ float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
376
+ ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
377
+ ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
378
+ a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
379
+ a2_f32x4 = vfmaq_f32(a2_f32x4, a_high_f32x4, a_high_f32x4);
380
+ b2_f32x4 = vfmaq_f32(b2_f32x4, b_low_f32x4, b_low_f32x4);
381
+ b2_f32x4 = vfmaq_f32(b2_f32x4, b_high_f32x4, b_high_f32x4);
382
+ if (n) goto nk_angular_e2m3_neon_cycle;
383
+ nk_f32_t ab = vaddvq_f32(ab_f32x4);
384
+ nk_f32_t a2 = vaddvq_f32(a2_f32x4);
385
+ nk_f32_t b2 = vaddvq_f32(b2_f32x4);
386
+ *result = nk_angular_normalize_f32_neon_(ab, a2, b2);
387
+ }
388
+
389
+ NK_PUBLIC void nk_sqeuclidean_e3m2_neon(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
390
+ float16x8_t a_f16x8, b_f16x8;
391
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
392
+ nk_sqeuclidean_e3m2_neon_cycle:
393
+ if (n < 8) {
394
+ nk_b64_vec_t a_vec, b_vec;
395
+ nk_partial_load_b8x8_serial_(a, &a_vec, n);
396
+ nk_partial_load_b8x8_serial_(b, &b_vec, n);
397
+ a_f16x8 = nk_e3m2x8_to_f16x8_neon_(a_vec.u8x8);
398
+ b_f16x8 = nk_e3m2x8_to_f16x8_neon_(b_vec.u8x8);
399
+ n = 0;
400
+ }
401
+ else {
402
+ a_f16x8 = nk_e3m2x8_to_f16x8_neon_(vld1_u8(a));
403
+ b_f16x8 = nk_e3m2x8_to_f16x8_neon_(vld1_u8(b));
404
+ a += 8, b += 8, n -= 8;
405
+ }
406
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
407
+ float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
408
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
409
+ float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
410
+ float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
411
+ float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
412
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
413
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_high_f32x4, diff_high_f32x4);
414
+ if (n) goto nk_sqeuclidean_e3m2_neon_cycle;
415
+ *result = vaddvq_f32(sum_f32x4);
416
+ }
417
+
418
+ NK_PUBLIC void nk_euclidean_e3m2_neon(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
419
+ nk_sqeuclidean_e3m2_neon(a, b, n, result);
420
+ *result = nk_f32_sqrt_neon(*result);
421
+ }
422
+
423
+ NK_PUBLIC void nk_angular_e3m2_neon(nk_e3m2_t const *a, nk_e3m2_t const *b, nk_size_t n, nk_f32_t *result) {
424
+ float16x8_t a_f16x8, b_f16x8;
425
+ float32x4_t ab_f32x4 = vdupq_n_f32(0);
426
+ float32x4_t a2_f32x4 = vdupq_n_f32(0);
427
+ float32x4_t b2_f32x4 = vdupq_n_f32(0);
428
+ nk_angular_e3m2_neon_cycle:
429
+ if (n < 8) {
430
+ nk_b64_vec_t a_vec, b_vec;
431
+ nk_partial_load_b8x8_serial_(a, &a_vec, n);
432
+ nk_partial_load_b8x8_serial_(b, &b_vec, n);
433
+ a_f16x8 = nk_e3m2x8_to_f16x8_neon_(a_vec.u8x8);
434
+ b_f16x8 = nk_e3m2x8_to_f16x8_neon_(b_vec.u8x8);
435
+ n = 0;
436
+ }
437
+ else {
438
+ a_f16x8 = nk_e3m2x8_to_f16x8_neon_(vld1_u8(a));
439
+ b_f16x8 = nk_e3m2x8_to_f16x8_neon_(vld1_u8(b));
440
+ a += 8, b += 8, n -= 8;
441
+ }
442
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
443
+ float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
444
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
445
+ float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
446
+ ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
447
+ ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
448
+ a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
449
+ a2_f32x4 = vfmaq_f32(a2_f32x4, a_high_f32x4, a_high_f32x4);
450
+ b2_f32x4 = vfmaq_f32(b2_f32x4, b_low_f32x4, b_low_f32x4);
451
+ b2_f32x4 = vfmaq_f32(b2_f32x4, b_high_f32x4, b_high_f32x4);
452
+ if (n) goto nk_angular_e3m2_neon_cycle;
453
+ nk_f32_t ab = vaddvq_f32(ab_f32x4);
454
+ nk_f32_t a2 = vaddvq_f32(a2_f32x4);
455
+ nk_f32_t b2 = vaddvq_f32(b2_f32x4);
456
+ *result = nk_angular_normalize_f32_neon_(ab, a2, b2);
457
+ }
458
+
459
+ NK_PUBLIC void nk_sqeuclidean_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
460
+ float16x8_t a_f16x8, b_f16x8;
461
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
462
+ nk_sqeuclidean_e4m3_neon_cycle:
463
+ if (n < 8) {
464
+ nk_b64_vec_t a_vec, b_vec;
465
+ nk_partial_load_b8x8_serial_(a, &a_vec, n);
466
+ nk_partial_load_b8x8_serial_(b, &b_vec, n);
467
+ a_f16x8 = nk_e4m3x8_to_f16x8_neon_(a_vec.u8x8);
468
+ b_f16x8 = nk_e4m3x8_to_f16x8_neon_(b_vec.u8x8);
469
+ n = 0;
470
+ }
471
+ else {
472
+ a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a));
473
+ b_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(b));
474
+ a += 8, b += 8, n -= 8;
475
+ }
476
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
477
+ float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
478
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
479
+ float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
480
+ float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
481
+ float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
482
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
483
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_high_f32x4, diff_high_f32x4);
484
+ if (n) goto nk_sqeuclidean_e4m3_neon_cycle;
485
+ *result = vaddvq_f32(sum_f32x4);
486
+ }
487
+
488
+ NK_PUBLIC void nk_euclidean_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
489
+ nk_sqeuclidean_e4m3_neon(a, b, n, result);
490
+ *result = nk_f32_sqrt_neon(*result);
491
+ }
492
+
493
+ NK_PUBLIC void nk_angular_e4m3_neon(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
494
+ float16x8_t a_f16x8, b_f16x8;
495
+ float32x4_t ab_f32x4 = vdupq_n_f32(0);
496
+ float32x4_t a2_f32x4 = vdupq_n_f32(0);
497
+ float32x4_t b2_f32x4 = vdupq_n_f32(0);
498
+ nk_angular_e4m3_neon_cycle:
499
+ if (n < 8) {
500
+ nk_b64_vec_t a_vec, b_vec;
501
+ nk_partial_load_b8x8_serial_(a, &a_vec, n);
502
+ nk_partial_load_b8x8_serial_(b, &b_vec, n);
503
+ a_f16x8 = nk_e4m3x8_to_f16x8_neon_(a_vec.u8x8);
504
+ b_f16x8 = nk_e4m3x8_to_f16x8_neon_(b_vec.u8x8);
505
+ n = 0;
506
+ }
507
+ else {
508
+ a_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(a));
509
+ b_f16x8 = nk_e4m3x8_to_f16x8_neon_(vld1_u8(b));
510
+ a += 8, b += 8, n -= 8;
511
+ }
512
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
513
+ float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
514
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
515
+ float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
516
+ ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
517
+ ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
518
+ a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
519
+ a2_f32x4 = vfmaq_f32(a2_f32x4, a_high_f32x4, a_high_f32x4);
520
+ b2_f32x4 = vfmaq_f32(b2_f32x4, b_low_f32x4, b_low_f32x4);
521
+ b2_f32x4 = vfmaq_f32(b2_f32x4, b_high_f32x4, b_high_f32x4);
522
+ if (n) goto nk_angular_e4m3_neon_cycle;
523
+ nk_f32_t ab = vaddvq_f32(ab_f32x4);
524
+ nk_f32_t a2 = vaddvq_f32(a2_f32x4);
525
+ nk_f32_t b2 = vaddvq_f32(b2_f32x4);
526
+ *result = nk_angular_normalize_f32_neon_(ab, a2, b2);
527
+ }
528
+
529
+ NK_PUBLIC void nk_sqeuclidean_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
530
+ float16x8_t a_f16x8, b_f16x8;
531
+ float32x4_t sum_f32x4 = vdupq_n_f32(0);
532
+ nk_sqeuclidean_e5m2_neon_cycle:
533
+ if (n < 8) {
534
+ nk_b64_vec_t a_vec, b_vec;
535
+ nk_partial_load_b8x8_serial_(a, &a_vec, n);
536
+ nk_partial_load_b8x8_serial_(b, &b_vec, n);
537
+ a_f16x8 = nk_e5m2x8_to_f16x8_neon_(a_vec.u8x8);
538
+ b_f16x8 = nk_e5m2x8_to_f16x8_neon_(b_vec.u8x8);
539
+ n = 0;
540
+ }
541
+ else {
542
+ a_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(a));
543
+ b_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(b));
544
+ a += 8, b += 8, n -= 8;
545
+ }
546
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
547
+ float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
548
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
549
+ float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
550
+ float32x4_t diff_low_f32x4 = vsubq_f32(a_low_f32x4, b_low_f32x4);
551
+ float32x4_t diff_high_f32x4 = vsubq_f32(a_high_f32x4, b_high_f32x4);
552
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_low_f32x4, diff_low_f32x4);
553
+ sum_f32x4 = vfmaq_f32(sum_f32x4, diff_high_f32x4, diff_high_f32x4);
554
+ if (n) goto nk_sqeuclidean_e5m2_neon_cycle;
555
+ *result = vaddvq_f32(sum_f32x4);
556
+ }
557
+
558
+ NK_PUBLIC void nk_euclidean_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
559
+ nk_sqeuclidean_e5m2_neon(a, b, n, result);
560
+ *result = nk_f32_sqrt_neon(*result);
561
+ }
562
+
563
+ NK_PUBLIC void nk_angular_e5m2_neon(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
564
+ float16x8_t a_f16x8, b_f16x8;
565
+ float32x4_t ab_f32x4 = vdupq_n_f32(0);
566
+ float32x4_t a2_f32x4 = vdupq_n_f32(0);
567
+ float32x4_t b2_f32x4 = vdupq_n_f32(0);
568
+ nk_angular_e5m2_neon_cycle:
569
+ if (n < 8) {
570
+ nk_b64_vec_t a_vec, b_vec;
571
+ nk_partial_load_b8x8_serial_(a, &a_vec, n);
572
+ nk_partial_load_b8x8_serial_(b, &b_vec, n);
573
+ a_f16x8 = nk_e5m2x8_to_f16x8_neon_(a_vec.u8x8);
574
+ b_f16x8 = nk_e5m2x8_to_f16x8_neon_(b_vec.u8x8);
575
+ n = 0;
576
+ }
577
+ else {
578
+ a_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(a));
579
+ b_f16x8 = nk_e5m2x8_to_f16x8_neon_(vld1_u8(b));
580
+ a += 8, b += 8, n -= 8;
581
+ }
582
+ float32x4_t a_low_f32x4 = vcvt_f32_f16(vget_low_f16(a_f16x8));
583
+ float32x4_t a_high_f32x4 = vcvt_f32_f16(vget_high_f16(a_f16x8));
584
+ float32x4_t b_low_f32x4 = vcvt_f32_f16(vget_low_f16(b_f16x8));
585
+ float32x4_t b_high_f32x4 = vcvt_f32_f16(vget_high_f16(b_f16x8));
586
+ ab_f32x4 = vfmaq_f32(ab_f32x4, a_low_f32x4, b_low_f32x4);
587
+ ab_f32x4 = vfmaq_f32(ab_f32x4, a_high_f32x4, b_high_f32x4);
588
+ a2_f32x4 = vfmaq_f32(a2_f32x4, a_low_f32x4, a_low_f32x4);
589
+ a2_f32x4 = vfmaq_f32(a2_f32x4, a_high_f32x4, a_high_f32x4);
590
+ b2_f32x4 = vfmaq_f32(b2_f32x4, b_low_f32x4, b_low_f32x4);
591
+ b2_f32x4 = vfmaq_f32(b2_f32x4, b_high_f32x4, b_high_f32x4);
592
+ if (n) goto nk_angular_e5m2_neon_cycle;
593
+ nk_f32_t ab = vaddvq_f32(ab_f32x4);
594
+ nk_f32_t a2 = vaddvq_f32(a2_f32x4);
595
+ nk_f32_t b2 = vaddvq_f32(b2_f32x4);
596
+ *result = nk_angular_normalize_f32_neon_(ab, a2, b2);
597
+ }
598
+
599
+ /** @brief Angular from_dot: computes 1 − dot × rsqrt(query_sumsq × target_sumsq) for 4 pairs in f64. */
600
+ NK_INTERNAL void nk_angular_through_f64_from_dot_neon_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
601
+ nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
602
+ float64x2_t dots_ab_f64x2 = dots.f64x2s[0];
603
+ float64x2_t dots_cd_f64x2 = dots.f64x2s[1];
604
+ float64x2_t query_sumsq_f64x2 = vdupq_n_f64(query_sumsq);
605
+ float64x2_t target_sumsqs_ab_f64x2 = target_sumsqs.f64x2s[0];
606
+ float64x2_t target_sumsqs_cd_f64x2 = target_sumsqs.f64x2s[1];
607
+
608
+ // products = query_sumsq * target_sumsq
609
+ float64x2_t products_ab_f64x2 = vmulq_f64(query_sumsq_f64x2, target_sumsqs_ab_f64x2);
610
+ float64x2_t products_cd_f64x2 = vmulq_f64(query_sumsq_f64x2, target_sumsqs_cd_f64x2);
611
+
612
+ // rsqrt with Newton-Raphson (2 iterations for ~48-bit precision)
613
+ float64x2_t rsqrt_ab_f64x2 = vrsqrteq_f64(products_ab_f64x2);
614
+ float64x2_t rsqrt_cd_f64x2 = vrsqrteq_f64(products_cd_f64x2);
615
+ rsqrt_ab_f64x2 = vmulq_f64(rsqrt_ab_f64x2,
616
+ vrsqrtsq_f64(vmulq_f64(products_ab_f64x2, rsqrt_ab_f64x2), rsqrt_ab_f64x2));
617
+ rsqrt_cd_f64x2 = vmulq_f64(rsqrt_cd_f64x2,
618
+ vrsqrtsq_f64(vmulq_f64(products_cd_f64x2, rsqrt_cd_f64x2), rsqrt_cd_f64x2));
619
+ rsqrt_ab_f64x2 = vmulq_f64(rsqrt_ab_f64x2,
620
+ vrsqrtsq_f64(vmulq_f64(products_ab_f64x2, rsqrt_ab_f64x2), rsqrt_ab_f64x2));
621
+ rsqrt_cd_f64x2 = vmulq_f64(rsqrt_cd_f64x2,
622
+ vrsqrtsq_f64(vmulq_f64(products_cd_f64x2, rsqrt_cd_f64x2), rsqrt_cd_f64x2));
623
+
624
+ // angular = 1 − dot × rsqrt(product)
625
+ float64x2_t ones_f64x2 = vdupq_n_f64(1.0);
626
+ float64x2_t zeros_f64x2 = vdupq_n_f64(0.0);
627
+ float64x2_t result_ab_f64x2 = vsubq_f64(ones_f64x2, vmulq_f64(dots_ab_f64x2, rsqrt_ab_f64x2));
628
+ float64x2_t result_cd_f64x2 = vsubq_f64(ones_f64x2, vmulq_f64(dots_cd_f64x2, rsqrt_cd_f64x2));
629
+
630
+ // Clamp to [0, inf)
631
+ result_ab_f64x2 = vmaxq_f64(result_ab_f64x2, zeros_f64x2);
632
+ result_cd_f64x2 = vmaxq_f64(result_cd_f64x2, zeros_f64x2);
633
+
634
+ // Handle edge cases with vectorized selects
635
+ uint64x2_t products_zero_ab_u64x2 = vceqq_f64(products_ab_f64x2, zeros_f64x2);
636
+ uint64x2_t products_zero_cd_u64x2 = vceqq_f64(products_cd_f64x2, zeros_f64x2);
637
+ uint64x2_t dots_zero_ab_u64x2 = vceqq_f64(dots_ab_f64x2, zeros_f64x2);
638
+ uint64x2_t dots_zero_cd_u64x2 = vceqq_f64(dots_cd_f64x2, zeros_f64x2);
639
+
640
+ // Both zero → result = 0; products zero but dots nonzero → result = 1
641
+ uint64x2_t both_zero_ab_u64x2 = vandq_u64(products_zero_ab_u64x2, dots_zero_ab_u64x2);
642
+ uint64x2_t both_zero_cd_u64x2 = vandq_u64(products_zero_cd_u64x2, dots_zero_cd_u64x2);
643
+ result_ab_f64x2 = vbslq_f64(both_zero_ab_u64x2, zeros_f64x2, result_ab_f64x2);
644
+ result_cd_f64x2 = vbslq_f64(both_zero_cd_u64x2, zeros_f64x2, result_cd_f64x2);
645
+
646
+ uint64x2_t prod_zero_dot_nonzero_ab_u64x2 = vandq_u64(
647
+ products_zero_ab_u64x2, vreinterpretq_u64_u32(vmvnq_u32(vreinterpretq_u32_u64(dots_zero_ab_u64x2))));
648
+ uint64x2_t prod_zero_dot_nonzero_cd_u64x2 = vandq_u64(
649
+ products_zero_cd_u64x2, vreinterpretq_u64_u32(vmvnq_u32(vreinterpretq_u32_u64(dots_zero_cd_u64x2))));
650
+ result_ab_f64x2 = vbslq_f64(prod_zero_dot_nonzero_ab_u64x2, ones_f64x2, result_ab_f64x2);
651
+ result_cd_f64x2 = vbslq_f64(prod_zero_dot_nonzero_cd_u64x2, ones_f64x2, result_cd_f64x2);
652
+
653
+ results->f64x2s[0] = result_ab_f64x2;
654
+ results->f64x2s[1] = result_cd_f64x2;
655
+ }
656
+
657
+ /** @brief Euclidean from_dot: computes √(query_sumsq + target_sumsq − 2 × dot) for 4 pairs in f64. */
658
+ NK_INTERNAL void nk_euclidean_through_f64_from_dot_neon_(nk_b256_vec_t dots, nk_f64_t query_sumsq,
659
+ nk_b256_vec_t target_sumsqs, nk_b256_vec_t *results) {
660
+ float64x2_t dots_ab_f64x2 = dots.f64x2s[0];
661
+ float64x2_t dots_cd_f64x2 = dots.f64x2s[1];
662
+ float64x2_t query_sumsq_f64x2 = vdupq_n_f64(query_sumsq);
663
+ float64x2_t target_sumsqs_ab_f64x2 = target_sumsqs.f64x2s[0];
664
+ float64x2_t target_sumsqs_cd_f64x2 = target_sumsqs.f64x2s[1];
665
+
666
+ // dist_sq = query_sumsq + target_sumsq − 2 × dot
667
+ float64x2_t neg_two_f64x2 = vdupq_n_f64(-2.0);
668
+ float64x2_t sum_sq_ab_f64x2 = vaddq_f64(query_sumsq_f64x2, target_sumsqs_ab_f64x2);
669
+ float64x2_t sum_sq_cd_f64x2 = vaddq_f64(query_sumsq_f64x2, target_sumsqs_cd_f64x2);
670
+ float64x2_t dist_sq_ab_f64x2 = vfmaq_f64(sum_sq_ab_f64x2, neg_two_f64x2, dots_ab_f64x2);
671
+ float64x2_t dist_sq_cd_f64x2 = vfmaq_f64(sum_sq_cd_f64x2, neg_two_f64x2, dots_cd_f64x2);
672
+
673
+ // Clamp and sqrt in f64
674
+ float64x2_t zeros_f64x2 = vdupq_n_f64(0.0);
675
+ dist_sq_ab_f64x2 = vmaxq_f64(dist_sq_ab_f64x2, zeros_f64x2);
676
+ dist_sq_cd_f64x2 = vmaxq_f64(dist_sq_cd_f64x2, zeros_f64x2);
677
+ float64x2_t dist_ab_f64x2 = vsqrtq_f64(dist_sq_ab_f64x2);
678
+ float64x2_t dist_cd_f64x2 = vsqrtq_f64(dist_sq_cd_f64x2);
679
+
680
+ results->f64x2s[0] = dist_ab_f64x2;
681
+ results->f64x2s[1] = dist_cd_f64x2;
682
+ }
683
+
684
+ /** @brief Angular from_dot: computes 1 − dot × rsqrt(query_sumsq × target_sumsq) for 4 pairs in f32. */
685
+ NK_INTERNAL void nk_angular_through_f32_from_dot_neon_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
686
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
687
+ float32x4_t dots_f32x4 = dots.f32x4;
688
+ float32x4_t query_sumsq_f32x4 = vdupq_n_f32(query_sumsq);
689
+ float32x4_t products_f32x4 = vmulq_f32(query_sumsq_f32x4, target_sumsqs.f32x4);
690
+
691
+ // rsqrt with Newton-Raphson refinement (2 iterations)
692
+ float32x4_t rsqrt_f32x4 = vrsqrteq_f32(products_f32x4);
693
+ rsqrt_f32x4 = vmulq_f32(rsqrt_f32x4, vrsqrtsq_f32(vmulq_f32(products_f32x4, rsqrt_f32x4), rsqrt_f32x4));
694
+ rsqrt_f32x4 = vmulq_f32(rsqrt_f32x4, vrsqrtsq_f32(vmulq_f32(products_f32x4, rsqrt_f32x4), rsqrt_f32x4));
695
+
696
+ float32x4_t normalized_f32x4 = vmulq_f32(dots_f32x4, rsqrt_f32x4);
697
+ float32x4_t angular_f32x4 = vsubq_f32(vdupq_n_f32(1.0f), normalized_f32x4);
698
+ results->f32x4 = vmaxq_f32(angular_f32x4, vdupq_n_f32(0.0f));
699
+ }
700
+
701
+ /** @brief Euclidean from_dot: computes √(query_sumsq + target_sumsq − 2 × dot) for 4 pairs in f32. */
702
+ NK_INTERNAL void nk_euclidean_through_f32_from_dot_neon_(nk_b128_vec_t dots, nk_f32_t query_sumsq,
703
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
704
+ float32x4_t dots_f32x4 = dots.f32x4;
705
+ float32x4_t query_sumsq_f32x4 = vdupq_n_f32(query_sumsq);
706
+ float32x4_t sum_sq_f32x4 = vaddq_f32(query_sumsq_f32x4, target_sumsqs.f32x4);
707
+ // dist_sq = sum_sq − 2 × dot
708
+ float32x4_t dist_sq_f32x4 = vfmsq_f32(sum_sq_f32x4, vdupq_n_f32(2.0f), dots_f32x4);
709
+ // Clamp and sqrt
710
+ dist_sq_f32x4 = vmaxq_f32(dist_sq_f32x4, vdupq_n_f32(0.0f));
711
+ results->f32x4 = vsqrtq_f32(dist_sq_f32x4);
712
+ }
713
+
714
+ /** @brief Angular from_dot for i32 accumulators: cast to f32, rsqrt+NR, clamp. 4 pairs. */
715
+ NK_INTERNAL void nk_angular_through_i32_from_dot_neon_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
716
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
717
+ float32x4_t dots_f32x4 = vcvtq_f32_s32(dots.i32x4);
718
+ float32x4_t query_sumsq_f32x4 = vdupq_n_f32((nk_f32_t)query_sumsq);
719
+ float32x4_t products_f32x4 = vmulq_f32(query_sumsq_f32x4, vcvtq_f32_s32(target_sumsqs.i32x4));
720
+ float32x4_t rsqrt_f32x4 = nk_rsqrt_f32x4_neon_(products_f32x4);
721
+ float32x4_t normalized_f32x4 = vmulq_f32(dots_f32x4, rsqrt_f32x4);
722
+ float32x4_t angular_f32x4 = vsubq_f32(vdupq_n_f32(1.0f), normalized_f32x4);
723
+ results->f32x4 = vmaxq_f32(angular_f32x4, vdupq_n_f32(0.0f));
724
+ }
725
+
726
+ /** @brief Euclidean from_dot for i32 accumulators: cast to f32, then √(a² + b² − 2ab). 4 pairs. */
727
+ NK_INTERNAL void nk_euclidean_through_i32_from_dot_neon_(nk_b128_vec_t dots, nk_i32_t query_sumsq,
728
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
729
+ float32x4_t dots_f32x4 = vcvtq_f32_s32(dots.i32x4);
730
+ float32x4_t query_sumsq_f32x4 = vdupq_n_f32((nk_f32_t)query_sumsq);
731
+ float32x4_t sum_sq_f32x4 = vaddq_f32(query_sumsq_f32x4, vcvtq_f32_s32(target_sumsqs.i32x4));
732
+ float32x4_t dist_sq_f32x4 = vfmsq_f32(sum_sq_f32x4, vdupq_n_f32(2.0f), dots_f32x4);
733
+ dist_sq_f32x4 = vmaxq_f32(dist_sq_f32x4, vdupq_n_f32(0.0f));
734
+ results->f32x4 = vsqrtq_f32(dist_sq_f32x4);
735
+ }
736
+
737
+ /** @brief Angular from_dot for u32 accumulators: cast to f32, rsqrt+NR, clamp. 4 pairs. */
738
+ NK_INTERNAL void nk_angular_through_u32_from_dot_neon_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
739
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
740
+ float32x4_t dots_f32x4 = vcvtq_f32_u32(dots.u32x4);
741
+ float32x4_t query_sumsq_f32x4 = vdupq_n_f32((nk_f32_t)query_sumsq);
742
+ float32x4_t products_f32x4 = vmulq_f32(query_sumsq_f32x4, vcvtq_f32_u32(target_sumsqs.u32x4));
743
+ float32x4_t rsqrt_f32x4 = nk_rsqrt_f32x4_neon_(products_f32x4);
744
+ float32x4_t normalized_f32x4 = vmulq_f32(dots_f32x4, rsqrt_f32x4);
745
+ float32x4_t angular_f32x4 = vsubq_f32(vdupq_n_f32(1.0f), normalized_f32x4);
746
+ results->f32x4 = vmaxq_f32(angular_f32x4, vdupq_n_f32(0.0f));
747
+ }
748
+
749
+ /** @brief Euclidean from_dot for u32 accumulators: cast to f32, then √(a² + b² − 2ab). 4 pairs. */
750
+ NK_INTERNAL void nk_euclidean_through_u32_from_dot_neon_(nk_b128_vec_t dots, nk_u32_t query_sumsq,
751
+ nk_b128_vec_t target_sumsqs, nk_b128_vec_t *results) {
752
+ float32x4_t dots_f32x4 = vcvtq_f32_u32(dots.u32x4);
753
+ float32x4_t query_sumsq_f32x4 = vdupq_n_f32((nk_f32_t)query_sumsq);
754
+ float32x4_t sum_sq_f32x4 = vaddq_f32(query_sumsq_f32x4, vcvtq_f32_u32(target_sumsqs.u32x4));
755
+ float32x4_t dist_sq_f32x4 = vfmsq_f32(sum_sq_f32x4, vdupq_n_f32(2.0f), dots_f32x4);
756
+ dist_sq_f32x4 = vmaxq_f32(dist_sq_f32x4, vdupq_n_f32(0.0f));
757
+ results->f32x4 = vsqrtq_f32(dist_sq_f32x4);
758
+ }
759
+
760
+ #if defined(__clang__)
761
+ #pragma clang attribute pop
762
+ #elif defined(__GNUC__)
763
+ #pragma GCC pop_options
764
+ #endif
765
+
766
+ #if defined(__cplusplus)
767
+ } // extern "C"
768
+ #endif
769
+
770
+ #pragma endregion - Smaller Floats
771
+ #endif // NK_TARGET_NEON
772
+ #endif // NK_TARGET_ARM_
773
+ #endif // NK_SPATIAL_NEON_H