numkong 7.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (294) hide show
  1. package/LICENSE +201 -0
  2. package/README.md +495 -0
  3. package/binding.gyp +540 -0
  4. package/c/dispatch.h +512 -0
  5. package/c/dispatch_bf16.c +389 -0
  6. package/c/dispatch_bf16c.c +52 -0
  7. package/c/dispatch_e2m3.c +263 -0
  8. package/c/dispatch_e3m2.c +243 -0
  9. package/c/dispatch_e4m3.c +276 -0
  10. package/c/dispatch_e5m2.c +272 -0
  11. package/c/dispatch_f16.c +376 -0
  12. package/c/dispatch_f16c.c +58 -0
  13. package/c/dispatch_f32.c +378 -0
  14. package/c/dispatch_f32c.c +99 -0
  15. package/c/dispatch_f64.c +296 -0
  16. package/c/dispatch_f64c.c +98 -0
  17. package/c/dispatch_i16.c +96 -0
  18. package/c/dispatch_i32.c +89 -0
  19. package/c/dispatch_i4.c +150 -0
  20. package/c/dispatch_i64.c +86 -0
  21. package/c/dispatch_i8.c +289 -0
  22. package/c/dispatch_other.c +330 -0
  23. package/c/dispatch_u1.c +148 -0
  24. package/c/dispatch_u16.c +124 -0
  25. package/c/dispatch_u32.c +118 -0
  26. package/c/dispatch_u4.c +150 -0
  27. package/c/dispatch_u64.c +102 -0
  28. package/c/dispatch_u8.c +303 -0
  29. package/c/numkong.c +950 -0
  30. package/include/README.md +573 -0
  31. package/include/module.modulemap +129 -0
  32. package/include/numkong/attention/sapphireamx.h +1361 -0
  33. package/include/numkong/attention/sme.h +2066 -0
  34. package/include/numkong/attention.h +49 -0
  35. package/include/numkong/capabilities.h +748 -0
  36. package/include/numkong/cast/README.md +262 -0
  37. package/include/numkong/cast/haswell.h +975 -0
  38. package/include/numkong/cast/icelake.h +470 -0
  39. package/include/numkong/cast/neon.h +1192 -0
  40. package/include/numkong/cast/rvv.h +1021 -0
  41. package/include/numkong/cast/sapphire.h +262 -0
  42. package/include/numkong/cast/serial.h +2262 -0
  43. package/include/numkong/cast/skylake.h +856 -0
  44. package/include/numkong/cast/v128relaxed.h +180 -0
  45. package/include/numkong/cast.h +230 -0
  46. package/include/numkong/curved/README.md +223 -0
  47. package/include/numkong/curved/genoa.h +182 -0
  48. package/include/numkong/curved/haswell.h +276 -0
  49. package/include/numkong/curved/neon.h +205 -0
  50. package/include/numkong/curved/neonbfdot.h +212 -0
  51. package/include/numkong/curved/neonhalf.h +212 -0
  52. package/include/numkong/curved/rvv.h +305 -0
  53. package/include/numkong/curved/serial.h +207 -0
  54. package/include/numkong/curved/skylake.h +457 -0
  55. package/include/numkong/curved/smef64.h +506 -0
  56. package/include/numkong/curved.h +517 -0
  57. package/include/numkong/curved.hpp +144 -0
  58. package/include/numkong/dot/README.md +425 -0
  59. package/include/numkong/dot/alder.h +563 -0
  60. package/include/numkong/dot/genoa.h +315 -0
  61. package/include/numkong/dot/haswell.h +1688 -0
  62. package/include/numkong/dot/icelake.h +883 -0
  63. package/include/numkong/dot/neon.h +818 -0
  64. package/include/numkong/dot/neonbfdot.h +244 -0
  65. package/include/numkong/dot/neonfhm.h +360 -0
  66. package/include/numkong/dot/neonhalf.h +198 -0
  67. package/include/numkong/dot/neonsdot.h +508 -0
  68. package/include/numkong/dot/rvv.h +714 -0
  69. package/include/numkong/dot/rvvbb.h +72 -0
  70. package/include/numkong/dot/rvvbf16.h +123 -0
  71. package/include/numkong/dot/rvvhalf.h +129 -0
  72. package/include/numkong/dot/sapphire.h +141 -0
  73. package/include/numkong/dot/serial.h +838 -0
  74. package/include/numkong/dot/sierra.h +405 -0
  75. package/include/numkong/dot/skylake.h +1084 -0
  76. package/include/numkong/dot/sve.h +379 -0
  77. package/include/numkong/dot/svebfdot.h +74 -0
  78. package/include/numkong/dot/svehalf.h +123 -0
  79. package/include/numkong/dot/v128relaxed.h +1258 -0
  80. package/include/numkong/dot.h +1070 -0
  81. package/include/numkong/dot.hpp +94 -0
  82. package/include/numkong/dots/README.md +496 -0
  83. package/include/numkong/dots/alder.h +114 -0
  84. package/include/numkong/dots/genoa.h +94 -0
  85. package/include/numkong/dots/haswell.h +295 -0
  86. package/include/numkong/dots/icelake.h +171 -0
  87. package/include/numkong/dots/neon.h +120 -0
  88. package/include/numkong/dots/neonbfdot.h +58 -0
  89. package/include/numkong/dots/neonfhm.h +94 -0
  90. package/include/numkong/dots/neonhalf.h +57 -0
  91. package/include/numkong/dots/neonsdot.h +108 -0
  92. package/include/numkong/dots/rvv.h +2486 -0
  93. package/include/numkong/dots/sapphireamx.h +3973 -0
  94. package/include/numkong/dots/serial.h +2844 -0
  95. package/include/numkong/dots/sierra.h +97 -0
  96. package/include/numkong/dots/skylake.h +196 -0
  97. package/include/numkong/dots/sme.h +5372 -0
  98. package/include/numkong/dots/smebi32.h +461 -0
  99. package/include/numkong/dots/smef64.h +1318 -0
  100. package/include/numkong/dots/smehalf.h +47 -0
  101. package/include/numkong/dots/v128relaxed.h +294 -0
  102. package/include/numkong/dots.h +2804 -0
  103. package/include/numkong/dots.hpp +639 -0
  104. package/include/numkong/each/README.md +469 -0
  105. package/include/numkong/each/haswell.h +1658 -0
  106. package/include/numkong/each/icelake.h +272 -0
  107. package/include/numkong/each/neon.h +1104 -0
  108. package/include/numkong/each/neonbfdot.h +212 -0
  109. package/include/numkong/each/neonhalf.h +410 -0
  110. package/include/numkong/each/rvv.h +1121 -0
  111. package/include/numkong/each/sapphire.h +477 -0
  112. package/include/numkong/each/serial.h +260 -0
  113. package/include/numkong/each/skylake.h +1562 -0
  114. package/include/numkong/each.h +2146 -0
  115. package/include/numkong/each.hpp +434 -0
  116. package/include/numkong/geospatial/README.md +147 -0
  117. package/include/numkong/geospatial/haswell.h +593 -0
  118. package/include/numkong/geospatial/neon.h +571 -0
  119. package/include/numkong/geospatial/rvv.h +701 -0
  120. package/include/numkong/geospatial/serial.h +309 -0
  121. package/include/numkong/geospatial/skylake.h +577 -0
  122. package/include/numkong/geospatial/v128relaxed.h +613 -0
  123. package/include/numkong/geospatial.h +453 -0
  124. package/include/numkong/geospatial.hpp +235 -0
  125. package/include/numkong/matrix.hpp +336 -0
  126. package/include/numkong/maxsim/README.md +187 -0
  127. package/include/numkong/maxsim/alder.h +511 -0
  128. package/include/numkong/maxsim/genoa.h +115 -0
  129. package/include/numkong/maxsim/haswell.h +553 -0
  130. package/include/numkong/maxsim/icelake.h +480 -0
  131. package/include/numkong/maxsim/neonsdot.h +394 -0
  132. package/include/numkong/maxsim/sapphireamx.h +877 -0
  133. package/include/numkong/maxsim/serial.h +490 -0
  134. package/include/numkong/maxsim/sme.h +929 -0
  135. package/include/numkong/maxsim/v128relaxed.h +280 -0
  136. package/include/numkong/maxsim.h +571 -0
  137. package/include/numkong/maxsim.hpp +133 -0
  138. package/include/numkong/mesh/README.md +227 -0
  139. package/include/numkong/mesh/haswell.h +2235 -0
  140. package/include/numkong/mesh/neon.h +1329 -0
  141. package/include/numkong/mesh/neonbfdot.h +842 -0
  142. package/include/numkong/mesh/neonhalf.h +616 -0
  143. package/include/numkong/mesh/rvv.h +916 -0
  144. package/include/numkong/mesh/serial.h +742 -0
  145. package/include/numkong/mesh/skylake.h +1135 -0
  146. package/include/numkong/mesh/v128relaxed.h +1052 -0
  147. package/include/numkong/mesh.h +652 -0
  148. package/include/numkong/mesh.hpp +762 -0
  149. package/include/numkong/numkong.h +78 -0
  150. package/include/numkong/numkong.hpp +57 -0
  151. package/include/numkong/probability/README.md +173 -0
  152. package/include/numkong/probability/haswell.h +267 -0
  153. package/include/numkong/probability/neon.h +225 -0
  154. package/include/numkong/probability/rvv.h +409 -0
  155. package/include/numkong/probability/serial.h +169 -0
  156. package/include/numkong/probability/skylake.h +324 -0
  157. package/include/numkong/probability.h +383 -0
  158. package/include/numkong/probability.hpp +120 -0
  159. package/include/numkong/random.h +50 -0
  160. package/include/numkong/random.hpp +285 -0
  161. package/include/numkong/reduce/README.md +547 -0
  162. package/include/numkong/reduce/alder.h +632 -0
  163. package/include/numkong/reduce/genoa.h +201 -0
  164. package/include/numkong/reduce/haswell.h +3783 -0
  165. package/include/numkong/reduce/icelake.h +549 -0
  166. package/include/numkong/reduce/neon.h +3841 -0
  167. package/include/numkong/reduce/neonbfdot.h +353 -0
  168. package/include/numkong/reduce/neonfhm.h +665 -0
  169. package/include/numkong/reduce/neonhalf.h +157 -0
  170. package/include/numkong/reduce/neonsdot.h +357 -0
  171. package/include/numkong/reduce/rvv.h +3407 -0
  172. package/include/numkong/reduce/serial.h +757 -0
  173. package/include/numkong/reduce/sierra.h +338 -0
  174. package/include/numkong/reduce/skylake.h +3792 -0
  175. package/include/numkong/reduce/v128relaxed.h +2302 -0
  176. package/include/numkong/reduce.h +1597 -0
  177. package/include/numkong/reduce.hpp +633 -0
  178. package/include/numkong/scalar/README.md +89 -0
  179. package/include/numkong/scalar/haswell.h +113 -0
  180. package/include/numkong/scalar/neon.h +122 -0
  181. package/include/numkong/scalar/neonhalf.h +70 -0
  182. package/include/numkong/scalar/rvv.h +211 -0
  183. package/include/numkong/scalar/sapphire.h +63 -0
  184. package/include/numkong/scalar/serial.h +332 -0
  185. package/include/numkong/scalar/v128relaxed.h +56 -0
  186. package/include/numkong/scalar.h +683 -0
  187. package/include/numkong/set/README.md +179 -0
  188. package/include/numkong/set/haswell.h +334 -0
  189. package/include/numkong/set/icelake.h +485 -0
  190. package/include/numkong/set/neon.h +364 -0
  191. package/include/numkong/set/rvv.h +226 -0
  192. package/include/numkong/set/rvvbb.h +117 -0
  193. package/include/numkong/set/serial.h +174 -0
  194. package/include/numkong/set/sve.h +185 -0
  195. package/include/numkong/set/v128relaxed.h +240 -0
  196. package/include/numkong/set.h +457 -0
  197. package/include/numkong/set.hpp +114 -0
  198. package/include/numkong/sets/README.md +149 -0
  199. package/include/numkong/sets/haswell.h +63 -0
  200. package/include/numkong/sets/icelake.h +66 -0
  201. package/include/numkong/sets/neon.h +61 -0
  202. package/include/numkong/sets/serial.h +43 -0
  203. package/include/numkong/sets/smebi32.h +1099 -0
  204. package/include/numkong/sets/v128relaxed.h +58 -0
  205. package/include/numkong/sets.h +339 -0
  206. package/include/numkong/sparse/README.md +156 -0
  207. package/include/numkong/sparse/icelake.h +463 -0
  208. package/include/numkong/sparse/neon.h +288 -0
  209. package/include/numkong/sparse/serial.h +117 -0
  210. package/include/numkong/sparse/sve2.h +507 -0
  211. package/include/numkong/sparse/turin.h +322 -0
  212. package/include/numkong/sparse.h +363 -0
  213. package/include/numkong/sparse.hpp +113 -0
  214. package/include/numkong/spatial/README.md +435 -0
  215. package/include/numkong/spatial/alder.h +607 -0
  216. package/include/numkong/spatial/genoa.h +290 -0
  217. package/include/numkong/spatial/haswell.h +960 -0
  218. package/include/numkong/spatial/icelake.h +586 -0
  219. package/include/numkong/spatial/neon.h +773 -0
  220. package/include/numkong/spatial/neonbfdot.h +165 -0
  221. package/include/numkong/spatial/neonhalf.h +118 -0
  222. package/include/numkong/spatial/neonsdot.h +261 -0
  223. package/include/numkong/spatial/rvv.h +984 -0
  224. package/include/numkong/spatial/rvvbf16.h +123 -0
  225. package/include/numkong/spatial/rvvhalf.h +117 -0
  226. package/include/numkong/spatial/sapphire.h +343 -0
  227. package/include/numkong/spatial/serial.h +346 -0
  228. package/include/numkong/spatial/sierra.h +323 -0
  229. package/include/numkong/spatial/skylake.h +606 -0
  230. package/include/numkong/spatial/sve.h +224 -0
  231. package/include/numkong/spatial/svebfdot.h +122 -0
  232. package/include/numkong/spatial/svehalf.h +109 -0
  233. package/include/numkong/spatial/v128relaxed.h +717 -0
  234. package/include/numkong/spatial.h +1425 -0
  235. package/include/numkong/spatial.hpp +183 -0
  236. package/include/numkong/spatials/README.md +580 -0
  237. package/include/numkong/spatials/alder.h +94 -0
  238. package/include/numkong/spatials/genoa.h +94 -0
  239. package/include/numkong/spatials/haswell.h +219 -0
  240. package/include/numkong/spatials/icelake.h +113 -0
  241. package/include/numkong/spatials/neon.h +109 -0
  242. package/include/numkong/spatials/neonbfdot.h +60 -0
  243. package/include/numkong/spatials/neonfhm.h +92 -0
  244. package/include/numkong/spatials/neonhalf.h +58 -0
  245. package/include/numkong/spatials/neonsdot.h +109 -0
  246. package/include/numkong/spatials/rvv.h +1960 -0
  247. package/include/numkong/spatials/sapphireamx.h +1149 -0
  248. package/include/numkong/spatials/serial.h +226 -0
  249. package/include/numkong/spatials/sierra.h +96 -0
  250. package/include/numkong/spatials/skylake.h +184 -0
  251. package/include/numkong/spatials/sme.h +1901 -0
  252. package/include/numkong/spatials/smef64.h +465 -0
  253. package/include/numkong/spatials/v128relaxed.h +240 -0
  254. package/include/numkong/spatials.h +3021 -0
  255. package/include/numkong/spatials.hpp +508 -0
  256. package/include/numkong/tensor.hpp +1592 -0
  257. package/include/numkong/trigonometry/README.md +184 -0
  258. package/include/numkong/trigonometry/haswell.h +652 -0
  259. package/include/numkong/trigonometry/neon.h +639 -0
  260. package/include/numkong/trigonometry/rvv.h +699 -0
  261. package/include/numkong/trigonometry/serial.h +703 -0
  262. package/include/numkong/trigonometry/skylake.h +721 -0
  263. package/include/numkong/trigonometry/v128relaxed.h +666 -0
  264. package/include/numkong/trigonometry.h +467 -0
  265. package/include/numkong/trigonometry.hpp +166 -0
  266. package/include/numkong/types.h +1384 -0
  267. package/include/numkong/types.hpp +5603 -0
  268. package/include/numkong/vector.hpp +698 -0
  269. package/javascript/README.md +246 -0
  270. package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
  271. package/javascript/dist/cjs/numkong-wasm.js +617 -0
  272. package/javascript/dist/cjs/numkong.d.ts +343 -0
  273. package/javascript/dist/cjs/numkong.js +523 -0
  274. package/javascript/dist/cjs/package.json +3 -0
  275. package/javascript/dist/cjs/types.d.ts +284 -0
  276. package/javascript/dist/cjs/types.js +653 -0
  277. package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
  278. package/javascript/dist/esm/numkong-wasm.js +595 -0
  279. package/javascript/dist/esm/numkong.d.ts +343 -0
  280. package/javascript/dist/esm/numkong.js +452 -0
  281. package/javascript/dist/esm/package.json +3 -0
  282. package/javascript/dist/esm/types.d.ts +284 -0
  283. package/javascript/dist/esm/types.js +630 -0
  284. package/javascript/dist-package-cjs.json +3 -0
  285. package/javascript/dist-package-esm.json +3 -0
  286. package/javascript/node-gyp-build.d.ts +1 -0
  287. package/javascript/numkong-wasm.ts +756 -0
  288. package/javascript/numkong.c +689 -0
  289. package/javascript/numkong.ts +575 -0
  290. package/javascript/tsconfig-base.json +39 -0
  291. package/javascript/tsconfig-cjs.json +8 -0
  292. package/javascript/tsconfig-esm.json +8 -0
  293. package/javascript/types.ts +674 -0
  294. package/package.json +87 -0
@@ -0,0 +1,616 @@
1
+ /**
2
+ * @brief SIMD-accelerated Point Cloud Alignment for NEON FP16.
3
+ * @file include/numkong/mesh/neonhalf.h
4
+ * @author Ash Vardanian
5
+ * @date December 27, 2025
6
+ *
7
+ * @sa include/numkong/mesh.h
8
+ *
9
+ * @section mesh_neonhalf_instructions ARM NEON FP16 Instructions (ARMv8.2-FP16)
10
+ *
11
+ * Intrinsic Instruction Latency Throughput
12
+ * A76 M4+/V1+/Oryon
13
+ * vld3_f16 LD3 (V.4H x 3) 6cy 1/cy 2/cy
14
+ * vcvt_f32_f16 FCVTL (V.4S, V.4H) 3cy 2/cy 4/cy
15
+ * vfmaq_f32 FMLA (V.4S, V.4S, V.4S) 4cy 2/cy 4/cy
16
+ * vaddq_f32 FADD (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
17
+ * vsubq_f32 FSUB (V.4S, V.4S, V.4S) 2cy 2/cy 4/cy
18
+ * vmulq_f32 FMUL (V.4S, V.4S, V.4S) 3cy 2/cy 4/cy
19
+ * vdupq_n_f32 DUP (V.4S, scalar) 2cy 2/cy 4/cy
20
+ * vaddvq_f32 FADDP+FADDP (V.4S) 4cy 1/cy 2/cy
21
+ *
22
+ * Mesh alignment algorithms (RMSD, Kabsch, Umeyama) for 3D point cloud registration using F16 input
23
+ * with F32 intermediate precision. VLD3 provides efficient stride-3 deinterleaving for XYZ triplets,
24
+ * then FCVTL widens to F32 for rotation matrix and centroid computations.
25
+ *
26
+ * These algorithms compute optimal rigid body (Kabsch) or similarity (Umeyama) transformations
27
+ * between point sets, commonly used in structural biology (protein alignment) and computer vision.
28
+ * F16 storage halves memory for large point clouds while F32 arithmetic ensures numerical stability.
29
+ */
30
+ #ifndef NK_MESH_NEONHALF_H
31
+ #define NK_MESH_NEONHALF_H
32
+
33
+ #if NK_TARGET_ARM_
34
+ #if NK_TARGET_NEONHALF
35
+
36
+ #include "numkong/types.h"
37
+ #include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
38
+
39
+ #if defined(__cplusplus)
40
+ extern "C" {
41
+ #endif
42
+
43
+ #if defined(__clang__)
44
+ #pragma clang attribute push(__attribute__((target("arch=armv8.2-a+simd+fp16"))), apply_to = function)
45
+ #elif defined(__GNUC__)
46
+ #pragma GCC push_options
47
+ #pragma GCC target("arch=armv8.2-a+simd+fp16")
48
+ #endif
49
+
50
+ NK_INTERNAL void nk_deinterleave_f16x4_to_f32x4_neonhalf_(nk_f16_t const *ptr, float32x4_t *x_out, float32x4_t *y_out,
51
+ float32x4_t *z_out) {
52
+ // Deinterleave 12 f16 values (4 xyz triplets) into separate x, y, z vectors.
53
+ // Uses NEON vld3_f16 for efficient stride-3 deinterleaving, then converts to f32.
54
+ //
55
+ // Input: 12 contiguous f16 values [x0,y0,z0, x1,y1,z1, x2,y2,z2, x3,y3,z3]
56
+ // Output: x[4], y[4], z[4] vectors in f32
57
+ float16x4x3_t xyz = vld3_f16((nk_f16_for_arm_simd_t const *)ptr);
58
+ *x_out = vcvt_f32_f16(xyz.val[0]);
59
+ *y_out = vcvt_f32_f16(xyz.val[1]);
60
+ *z_out = vcvt_f32_f16(xyz.val[2]);
61
+ }
62
+
63
+ NK_INTERNAL void nk_partial_deinterleave_f16_to_f32x4_neonhalf_(nk_f16_t const *ptr, nk_size_t n_points,
64
+ float32x4_t *x_out, float32x4_t *y_out,
65
+ float32x4_t *z_out) {
66
+ nk_u16_t buf[12] = {0};
67
+ nk_u16_t const *src = (nk_u16_t const *)ptr;
68
+ for (nk_size_t k = 0; k < n_points * 3; ++k) buf[k] = src[k];
69
+ nk_deinterleave_f16x4_to_f32x4_neonhalf_((nk_f16_t const *)buf, x_out, y_out, z_out);
70
+ }
71
+
72
+ NK_INTERNAL nk_f32_t nk_transformed_ssd_f16_neonhalf_(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
73
+ nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
74
+ nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
75
+ nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
76
+ nk_f32_t centroid_b_z) {
77
+ // Compute sum of squared differences after rigid transformation.
78
+ // Used by Kabsch algorithm for RMSD computation after rotation is applied.
79
+ float32x4_t const centroid_a_x_f32x4 = vdupq_n_f32(centroid_a_x);
80
+ float32x4_t const centroid_a_y_f32x4 = vdupq_n_f32(centroid_a_y);
81
+ float32x4_t const centroid_a_z_f32x4 = vdupq_n_f32(centroid_a_z);
82
+ float32x4_t const centroid_b_x_f32x4 = vdupq_n_f32(centroid_b_x);
83
+ float32x4_t const centroid_b_y_f32x4 = vdupq_n_f32(centroid_b_y);
84
+ float32x4_t const centroid_b_z_f32x4 = vdupq_n_f32(centroid_b_z);
85
+ float32x4_t const scale_f32x4 = vdupq_n_f32(scale);
86
+
87
+ // Load rotation matrix elements
88
+ float32x4_t const r00_f32x4 = vdupq_n_f32(r[0]), r01_f32x4 = vdupq_n_f32(r[1]), r02_f32x4 = vdupq_n_f32(r[2]);
89
+ float32x4_t const r10_f32x4 = vdupq_n_f32(r[3]), r11_f32x4 = vdupq_n_f32(r[4]), r12_f32x4 = vdupq_n_f32(r[5]);
90
+ float32x4_t const r20_f32x4 = vdupq_n_f32(r[6]), r21_f32x4 = vdupq_n_f32(r[7]), r22_f32x4 = vdupq_n_f32(r[8]);
91
+
92
+ float32x4_t sum_squared_f32x4 = vdupq_n_f32(0);
93
+ float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
94
+
95
+ nk_size_t j = 0;
96
+ for (; j + 4 <= n; j += 4) {
97
+ nk_deinterleave_f16x4_to_f32x4_neonhalf_(a + j * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
98
+ nk_deinterleave_f16x4_to_f32x4_neonhalf_(b + j * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
99
+
100
+ // Center points
101
+ float32x4_t pa_x_f32x4 = vsubq_f32(a_x_f32x4, centroid_a_x_f32x4);
102
+ float32x4_t pa_y_f32x4 = vsubq_f32(a_y_f32x4, centroid_a_y_f32x4);
103
+ float32x4_t pa_z_f32x4 = vsubq_f32(a_z_f32x4, centroid_a_z_f32x4);
104
+ float32x4_t pb_x_f32x4 = vsubq_f32(b_x_f32x4, centroid_b_x_f32x4);
105
+ float32x4_t pb_y_f32x4 = vsubq_f32(b_y_f32x4, centroid_b_y_f32x4);
106
+ float32x4_t pb_z_f32x4 = vsubq_f32(b_z_f32x4, centroid_b_z_f32x4);
107
+
108
+ // Apply rotation: R * pa (with optional scaling)
109
+ float32x4_t ra_x_f32x4 = vmulq_f32(
110
+ scale_f32x4,
111
+ vfmaq_f32(vfmaq_f32(vmulq_f32(r00_f32x4, pa_x_f32x4), r01_f32x4, pa_y_f32x4), r02_f32x4, pa_z_f32x4));
112
+ float32x4_t ra_y_f32x4 = vmulq_f32(
113
+ scale_f32x4,
114
+ vfmaq_f32(vfmaq_f32(vmulq_f32(r10_f32x4, pa_x_f32x4), r11_f32x4, pa_y_f32x4), r12_f32x4, pa_z_f32x4));
115
+ float32x4_t ra_z_f32x4 = vmulq_f32(
116
+ scale_f32x4,
117
+ vfmaq_f32(vfmaq_f32(vmulq_f32(r20_f32x4, pa_x_f32x4), r21_f32x4, pa_y_f32x4), r22_f32x4, pa_z_f32x4));
118
+
119
+ // Compute squared differences
120
+ float32x4_t delta_x_f32x4 = vsubq_f32(ra_x_f32x4, pb_x_f32x4);
121
+ float32x4_t delta_y_f32x4 = vsubq_f32(ra_y_f32x4, pb_y_f32x4);
122
+ float32x4_t delta_z_f32x4 = vsubq_f32(ra_z_f32x4, pb_z_f32x4);
123
+
124
+ sum_squared_f32x4 = vfmaq_f32(sum_squared_f32x4, delta_x_f32x4, delta_x_f32x4);
125
+ sum_squared_f32x4 = vfmaq_f32(sum_squared_f32x4, delta_y_f32x4, delta_y_f32x4);
126
+ sum_squared_f32x4 = vfmaq_f32(sum_squared_f32x4, delta_z_f32x4, delta_z_f32x4);
127
+ }
128
+
129
+ // Reduce to scalar
130
+ nk_f32_t sum_squared = vaddvq_f32(sum_squared_f32x4);
131
+
132
+ if (j < n) {
133
+ float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
134
+ nk_partial_deinterleave_f16_to_f32x4_neonhalf_(a + j * 3, n - j, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
135
+ nk_partial_deinterleave_f16_to_f32x4_neonhalf_(b + j * 3, n - j, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
136
+
137
+ uint32x4_t lane_u32x4 = {0, 1, 2, 3};
138
+ uint32x4_t valid_u32x4 = vcltq_u32(lane_u32x4, vdupq_n_u32((uint32_t)(n - j)));
139
+ float32x4_t zero_f32x4 = vdupq_n_f32(0);
140
+ a_x_f32x4 = vbslq_f32(valid_u32x4, a_x_f32x4, zero_f32x4);
141
+ a_y_f32x4 = vbslq_f32(valid_u32x4, a_y_f32x4, zero_f32x4);
142
+ a_z_f32x4 = vbslq_f32(valid_u32x4, a_z_f32x4, zero_f32x4);
143
+ b_x_f32x4 = vbslq_f32(valid_u32x4, b_x_f32x4, zero_f32x4);
144
+ b_y_f32x4 = vbslq_f32(valid_u32x4, b_y_f32x4, zero_f32x4);
145
+ b_z_f32x4 = vbslq_f32(valid_u32x4, b_z_f32x4, zero_f32x4);
146
+
147
+ float32x4_t pa_x_f32x4 = vsubq_f32(a_x_f32x4, centroid_a_x_f32x4);
148
+ float32x4_t pa_y_f32x4 = vsubq_f32(a_y_f32x4, centroid_a_y_f32x4);
149
+ float32x4_t pa_z_f32x4 = vsubq_f32(a_z_f32x4, centroid_a_z_f32x4);
150
+ float32x4_t pb_x_f32x4 = vsubq_f32(b_x_f32x4, centroid_b_x_f32x4);
151
+ float32x4_t pb_y_f32x4 = vsubq_f32(b_y_f32x4, centroid_b_y_f32x4);
152
+ float32x4_t pb_z_f32x4 = vsubq_f32(b_z_f32x4, centroid_b_z_f32x4);
153
+
154
+ float32x4_t ra_x_f32x4 = vmulq_f32(
155
+ scale_f32x4,
156
+ vfmaq_f32(vfmaq_f32(vmulq_f32(r00_f32x4, pa_x_f32x4), r01_f32x4, pa_y_f32x4), r02_f32x4, pa_z_f32x4));
157
+ float32x4_t ra_y_f32x4 = vmulq_f32(
158
+ scale_f32x4,
159
+ vfmaq_f32(vfmaq_f32(vmulq_f32(r10_f32x4, pa_x_f32x4), r11_f32x4, pa_y_f32x4), r12_f32x4, pa_z_f32x4));
160
+ float32x4_t ra_z_f32x4 = vmulq_f32(
161
+ scale_f32x4,
162
+ vfmaq_f32(vfmaq_f32(vmulq_f32(r20_f32x4, pa_x_f32x4), r21_f32x4, pa_y_f32x4), r22_f32x4, pa_z_f32x4));
163
+
164
+ float32x4_t delta_x_f32x4 = vsubq_f32(ra_x_f32x4, pb_x_f32x4);
165
+ float32x4_t delta_y_f32x4 = vsubq_f32(ra_y_f32x4, pb_y_f32x4);
166
+ float32x4_t delta_z_f32x4 = vsubq_f32(ra_z_f32x4, pb_z_f32x4);
167
+
168
+ float32x4_t tail_sum_f32x4 = vmulq_f32(delta_x_f32x4, delta_x_f32x4);
169
+ tail_sum_f32x4 = vfmaq_f32(tail_sum_f32x4, delta_y_f32x4, delta_y_f32x4);
170
+ tail_sum_f32x4 = vfmaq_f32(tail_sum_f32x4, delta_z_f32x4, delta_z_f32x4);
171
+ sum_squared += vaddvq_f32(tail_sum_f32x4);
172
+ }
173
+
174
+ return sum_squared;
175
+ }
176
+
177
+ /**
178
+ * @brief RMSD (Root Mean Square Deviation) computation using NEON FP16 with widening to FP32.
179
+ * Computes the RMS of distances between corresponding points after centroid alignment.
180
+ */
181
+ NK_PUBLIC void nk_rmsd_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
182
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
183
+ // RMSD uses identity rotation and scale=1.0
184
+ if (rotation) {
185
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0;
186
+ rotation[3] = 0, rotation[4] = 1, rotation[5] = 0;
187
+ rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
188
+ }
189
+ if (scale) *scale = 1.0f;
190
+
191
+ float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
192
+
193
+ // Accumulators for centroids and squared differences (all in f32)
194
+ float32x4_t sum_a_x_f32x4 = zeros_f32x4, sum_a_y_f32x4 = zeros_f32x4, sum_a_z_f32x4 = zeros_f32x4;
195
+ float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
196
+ float32x4_t sum_squared_x_f32x4 = zeros_f32x4, sum_squared_y_f32x4 = zeros_f32x4, sum_squared_z_f32x4 = zeros_f32x4;
197
+
198
+ float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
199
+ nk_size_t i = 0;
200
+
201
+ // Main loop processing 4 points at a time
202
+ for (; i + 4 <= n; i += 4) {
203
+ nk_deinterleave_f16x4_to_f32x4_neonhalf_(a + i * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
204
+ nk_deinterleave_f16x4_to_f32x4_neonhalf_(b + i * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
205
+
206
+ sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
207
+ sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
208
+ sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
209
+ sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
210
+ sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
211
+ sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
212
+
213
+ float32x4_t delta_x_f32x4 = vsubq_f32(a_x_f32x4, b_x_f32x4);
214
+ float32x4_t delta_y_f32x4 = vsubq_f32(a_y_f32x4, b_y_f32x4);
215
+ float32x4_t delta_z_f32x4 = vsubq_f32(a_z_f32x4, b_z_f32x4);
216
+
217
+ sum_squared_x_f32x4 = vfmaq_f32(sum_squared_x_f32x4, delta_x_f32x4, delta_x_f32x4);
218
+ sum_squared_y_f32x4 = vfmaq_f32(sum_squared_y_f32x4, delta_y_f32x4, delta_y_f32x4);
219
+ sum_squared_z_f32x4 = vfmaq_f32(sum_squared_z_f32x4, delta_z_f32x4, delta_z_f32x4);
220
+ }
221
+
222
+ if (i < n) {
223
+ nk_partial_deinterleave_f16_to_f32x4_neonhalf_(a + i * 3, n - i, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
224
+ nk_partial_deinterleave_f16_to_f32x4_neonhalf_(b + i * 3, n - i, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
225
+
226
+ sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
227
+ sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
228
+ sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
229
+ sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
230
+ sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
231
+ sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
232
+
233
+ float32x4_t delta_x_f32x4 = vsubq_f32(a_x_f32x4, b_x_f32x4);
234
+ float32x4_t delta_y_f32x4 = vsubq_f32(a_y_f32x4, b_y_f32x4);
235
+ float32x4_t delta_z_f32x4 = vsubq_f32(a_z_f32x4, b_z_f32x4);
236
+
237
+ sum_squared_x_f32x4 = vfmaq_f32(sum_squared_x_f32x4, delta_x_f32x4, delta_x_f32x4);
238
+ sum_squared_y_f32x4 = vfmaq_f32(sum_squared_y_f32x4, delta_y_f32x4, delta_y_f32x4);
239
+ sum_squared_z_f32x4 = vfmaq_f32(sum_squared_z_f32x4, delta_z_f32x4, delta_z_f32x4);
240
+ }
241
+
242
+ // Reduce vectors to scalars
243
+ nk_f32_t total_ax = vaddvq_f32(sum_a_x_f32x4);
244
+ nk_f32_t total_ay = vaddvq_f32(sum_a_y_f32x4);
245
+ nk_f32_t total_az = vaddvq_f32(sum_a_z_f32x4);
246
+ nk_f32_t total_bx = vaddvq_f32(sum_b_x_f32x4);
247
+ nk_f32_t total_by = vaddvq_f32(sum_b_y_f32x4);
248
+ nk_f32_t total_bz = vaddvq_f32(sum_b_z_f32x4);
249
+ nk_f32_t total_sq_x = vaddvq_f32(sum_squared_x_f32x4);
250
+ nk_f32_t total_sq_y = vaddvq_f32(sum_squared_y_f32x4);
251
+ nk_f32_t total_sq_z = vaddvq_f32(sum_squared_z_f32x4);
252
+
253
+ // Compute centroids
254
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
255
+ nk_f32_t centroid_a_x = total_ax * inv_n;
256
+ nk_f32_t centroid_a_y = total_ay * inv_n;
257
+ nk_f32_t centroid_a_z = total_az * inv_n;
258
+ nk_f32_t centroid_b_x = total_bx * inv_n;
259
+ nk_f32_t centroid_b_y = total_by * inv_n;
260
+ nk_f32_t centroid_b_z = total_bz * inv_n;
261
+
262
+ if (a_centroid) {
263
+ a_centroid[0] = centroid_a_x;
264
+ a_centroid[1] = centroid_a_y;
265
+ a_centroid[2] = centroid_a_z;
266
+ }
267
+ if (b_centroid) {
268
+ b_centroid[0] = centroid_b_x;
269
+ b_centroid[1] = centroid_b_y;
270
+ b_centroid[2] = centroid_b_z;
271
+ }
272
+
273
+ // Compute RMSD
274
+ nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
275
+ nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
276
+ nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
277
+ nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
278
+ nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
279
+
280
+ *result = nk_f32_sqrt_neon(sum_squared * inv_n - mean_diff_sq);
281
+ }
282
+
283
+ /**
284
+ * @brief Kabsch algorithm for optimal rigid body superposition using NEON FP16 with widening to FP32.
285
+ * Finds the rotation matrix R that minimizes RMSD between two point sets.
286
+ */
287
+ NK_PUBLIC void nk_kabsch_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
288
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
289
+ // Fused single-pass: load f16, convert to f32, compute centroids and covariance
290
+ float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
291
+
292
+ // Accumulators for centroids (f32)
293
+ float32x4_t sum_a_x_f32x4 = zeros_f32x4, sum_a_y_f32x4 = zeros_f32x4, sum_a_z_f32x4 = zeros_f32x4;
294
+ float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
295
+
296
+ // Accumulators for covariance matrix (sum of outer products)
297
+ float32x4_t cov_xx_f32x4 = zeros_f32x4, cov_xy_f32x4 = zeros_f32x4, cov_xz_f32x4 = zeros_f32x4;
298
+ float32x4_t cov_yx_f32x4 = zeros_f32x4, cov_yy_f32x4 = zeros_f32x4, cov_yz_f32x4 = zeros_f32x4;
299
+ float32x4_t cov_zx_f32x4 = zeros_f32x4, cov_zy_f32x4 = zeros_f32x4, cov_zz_f32x4 = zeros_f32x4;
300
+
301
+ nk_size_t i = 0;
302
+ float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
303
+
304
+ for (; i + 4 <= n; i += 4) {
305
+ nk_deinterleave_f16x4_to_f32x4_neonhalf_(a + i * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
306
+ nk_deinterleave_f16x4_to_f32x4_neonhalf_(b + i * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
307
+
308
+ // Accumulate centroids
309
+ sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
310
+ sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
311
+ sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
312
+ sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
313
+ sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
314
+ sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
315
+
316
+ // Accumulate outer products
317
+ cov_xx_f32x4 = vfmaq_f32(cov_xx_f32x4, a_x_f32x4, b_x_f32x4);
318
+ cov_xy_f32x4 = vfmaq_f32(cov_xy_f32x4, a_x_f32x4, b_y_f32x4);
319
+ cov_xz_f32x4 = vfmaq_f32(cov_xz_f32x4, a_x_f32x4, b_z_f32x4);
320
+ cov_yx_f32x4 = vfmaq_f32(cov_yx_f32x4, a_y_f32x4, b_x_f32x4);
321
+ cov_yy_f32x4 = vfmaq_f32(cov_yy_f32x4, a_y_f32x4, b_y_f32x4);
322
+ cov_yz_f32x4 = vfmaq_f32(cov_yz_f32x4, a_y_f32x4, b_z_f32x4);
323
+ cov_zx_f32x4 = vfmaq_f32(cov_zx_f32x4, a_z_f32x4, b_x_f32x4);
324
+ cov_zy_f32x4 = vfmaq_f32(cov_zy_f32x4, a_z_f32x4, b_y_f32x4);
325
+ cov_zz_f32x4 = vfmaq_f32(cov_zz_f32x4, a_z_f32x4, b_z_f32x4);
326
+ }
327
+
328
+ if (i < n) {
329
+ nk_partial_deinterleave_f16_to_f32x4_neonhalf_(a + i * 3, n - i, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
330
+ nk_partial_deinterleave_f16_to_f32x4_neonhalf_(b + i * 3, n - i, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
331
+
332
+ sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
333
+ sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
334
+ sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
335
+ sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
336
+ sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
337
+ sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
338
+
339
+ cov_xx_f32x4 = vfmaq_f32(cov_xx_f32x4, a_x_f32x4, b_x_f32x4);
340
+ cov_xy_f32x4 = vfmaq_f32(cov_xy_f32x4, a_x_f32x4, b_y_f32x4);
341
+ cov_xz_f32x4 = vfmaq_f32(cov_xz_f32x4, a_x_f32x4, b_z_f32x4);
342
+ cov_yx_f32x4 = vfmaq_f32(cov_yx_f32x4, a_y_f32x4, b_x_f32x4);
343
+ cov_yy_f32x4 = vfmaq_f32(cov_yy_f32x4, a_y_f32x4, b_y_f32x4);
344
+ cov_yz_f32x4 = vfmaq_f32(cov_yz_f32x4, a_y_f32x4, b_z_f32x4);
345
+ cov_zx_f32x4 = vfmaq_f32(cov_zx_f32x4, a_z_f32x4, b_x_f32x4);
346
+ cov_zy_f32x4 = vfmaq_f32(cov_zy_f32x4, a_z_f32x4, b_y_f32x4);
347
+ cov_zz_f32x4 = vfmaq_f32(cov_zz_f32x4, a_z_f32x4, b_z_f32x4);
348
+ }
349
+
350
+ // Reduce vector accumulators
351
+ nk_f32_t sum_a_x = vaddvq_f32(sum_a_x_f32x4);
352
+ nk_f32_t sum_a_y = vaddvq_f32(sum_a_y_f32x4);
353
+ nk_f32_t sum_a_z = vaddvq_f32(sum_a_z_f32x4);
354
+ nk_f32_t sum_b_x = vaddvq_f32(sum_b_x_f32x4);
355
+ nk_f32_t sum_b_y = vaddvq_f32(sum_b_y_f32x4);
356
+ nk_f32_t sum_b_z = vaddvq_f32(sum_b_z_f32x4);
357
+
358
+ nk_f32_t covariance_x_x = vaddvq_f32(cov_xx_f32x4);
359
+ nk_f32_t covariance_x_y = vaddvq_f32(cov_xy_f32x4);
360
+ nk_f32_t covariance_x_z = vaddvq_f32(cov_xz_f32x4);
361
+ nk_f32_t covariance_y_x = vaddvq_f32(cov_yx_f32x4);
362
+ nk_f32_t covariance_y_y = vaddvq_f32(cov_yy_f32x4);
363
+ nk_f32_t covariance_y_z = vaddvq_f32(cov_yz_f32x4);
364
+ nk_f32_t covariance_z_x = vaddvq_f32(cov_zx_f32x4);
365
+ nk_f32_t covariance_z_y = vaddvq_f32(cov_zy_f32x4);
366
+ nk_f32_t covariance_z_z = vaddvq_f32(cov_zz_f32x4);
367
+
368
+ // Compute centroids
369
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
370
+ nk_f32_t centroid_a_x = sum_a_x * inv_n;
371
+ nk_f32_t centroid_a_y = sum_a_y * inv_n;
372
+ nk_f32_t centroid_a_z = sum_a_z * inv_n;
373
+ nk_f32_t centroid_b_x = sum_b_x * inv_n;
374
+ nk_f32_t centroid_b_y = sum_b_y * inv_n;
375
+ nk_f32_t centroid_b_z = sum_b_z * inv_n;
376
+
377
+ if (a_centroid) {
378
+ a_centroid[0] = centroid_a_x;
379
+ a_centroid[1] = centroid_a_y;
380
+ a_centroid[2] = centroid_a_z;
381
+ }
382
+ if (b_centroid) {
383
+ b_centroid[0] = centroid_b_x;
384
+ b_centroid[1] = centroid_b_y;
385
+ b_centroid[2] = centroid_b_z;
386
+ }
387
+
388
+ // Compute centered covariance: H = (A - centroid_A)ᵀ * (B - centroid_B)
389
+ // H = sum(a * bᵀ) - n * centroid_a * centroid_bᵀ
390
+ nk_f32_t h[9];
391
+ h[0] = covariance_x_x - n * centroid_a_x * centroid_b_x;
392
+ h[1] = covariance_x_y - n * centroid_a_x * centroid_b_y;
393
+ h[2] = covariance_x_z - n * centroid_a_x * centroid_b_z;
394
+ h[3] = covariance_y_x - n * centroid_a_y * centroid_b_x;
395
+ h[4] = covariance_y_y - n * centroid_a_y * centroid_b_y;
396
+ h[5] = covariance_y_z - n * centroid_a_y * centroid_b_z;
397
+ h[6] = covariance_z_x - n * centroid_a_z * centroid_b_x;
398
+ h[7] = covariance_z_y - n * centroid_a_z * centroid_b_y;
399
+ h[8] = covariance_z_z - n * centroid_a_z * centroid_b_z;
400
+
401
+ // SVD of H = U * S * Vᵀ
402
+ nk_f32_t svd_u[9], svd_s[9], svd_v[9];
403
+ nk_svd3x3_f32_(h, svd_u, svd_s, svd_v);
404
+
405
+ // R = V * Uᵀ
406
+ nk_f32_t r[9];
407
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
408
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
409
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
410
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
411
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
412
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
413
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
414
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
415
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
416
+
417
+ // Handle reflection: if det(R) < 0, negate third column of V and recompute
418
+ nk_f32_t det_r = nk_det3x3_f32_(r);
419
+ if (det_r < 0) {
420
+ svd_v[2] = -svd_v[2];
421
+ svd_v[5] = -svd_v[5];
422
+ svd_v[8] = -svd_v[8];
423
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
424
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
425
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
426
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
427
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
428
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
429
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
430
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
431
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
432
+ }
433
+
434
+ if (rotation) {
435
+ for (int j = 0; j < 9; ++j) rotation[j] = r[j];
436
+ }
437
+ if (scale) *scale = 1.0f;
438
+
439
+ // Compute RMSD after rotation
440
+ nk_f32_t sum_squared = nk_transformed_ssd_f16_neonhalf_(a, b, n, r, 1.0f, centroid_a_x, centroid_a_y, centroid_a_z,
441
+ centroid_b_x, centroid_b_y, centroid_b_z);
442
+ *result = nk_f32_sqrt_neon(sum_squared * inv_n);
443
+ }
444
+
445
+ NK_PUBLIC void nk_umeyama_f16_neonhalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
446
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
447
+ // Fused single-pass: load f16, convert to f32, compute centroids, covariance, and variance
448
+ float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
449
+
450
+ float32x4_t sum_a_x_f32x4 = zeros_f32x4, sum_a_y_f32x4 = zeros_f32x4, sum_a_z_f32x4 = zeros_f32x4;
451
+ float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
452
+ float32x4_t cov_xx_f32x4 = zeros_f32x4, cov_xy_f32x4 = zeros_f32x4, cov_xz_f32x4 = zeros_f32x4;
453
+ float32x4_t cov_yx_f32x4 = zeros_f32x4, cov_yy_f32x4 = zeros_f32x4, cov_yz_f32x4 = zeros_f32x4;
454
+ float32x4_t cov_zx_f32x4 = zeros_f32x4, cov_zy_f32x4 = zeros_f32x4, cov_zz_f32x4 = zeros_f32x4;
455
+ float32x4_t variance_a_f32x4 = zeros_f32x4;
456
+
457
+ nk_size_t i = 0;
458
+ float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
459
+
460
+ for (; i + 4 <= n; i += 4) {
461
+ nk_deinterleave_f16x4_to_f32x4_neonhalf_(a + i * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
462
+ nk_deinterleave_f16x4_to_f32x4_neonhalf_(b + i * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
463
+
464
+ // Accumulate centroids
465
+ sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
466
+ sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
467
+ sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
468
+ sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
469
+ sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
470
+ sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
471
+
472
+ // Accumulate outer products
473
+ cov_xx_f32x4 = vfmaq_f32(cov_xx_f32x4, a_x_f32x4, b_x_f32x4);
474
+ cov_xy_f32x4 = vfmaq_f32(cov_xy_f32x4, a_x_f32x4, b_y_f32x4);
475
+ cov_xz_f32x4 = vfmaq_f32(cov_xz_f32x4, a_x_f32x4, b_z_f32x4);
476
+ cov_yx_f32x4 = vfmaq_f32(cov_yx_f32x4, a_y_f32x4, b_x_f32x4);
477
+ cov_yy_f32x4 = vfmaq_f32(cov_yy_f32x4, a_y_f32x4, b_y_f32x4);
478
+ cov_yz_f32x4 = vfmaq_f32(cov_yz_f32x4, a_y_f32x4, b_z_f32x4);
479
+ cov_zx_f32x4 = vfmaq_f32(cov_zx_f32x4, a_z_f32x4, b_x_f32x4);
480
+ cov_zy_f32x4 = vfmaq_f32(cov_zy_f32x4, a_z_f32x4, b_y_f32x4);
481
+ cov_zz_f32x4 = vfmaq_f32(cov_zz_f32x4, a_z_f32x4, b_z_f32x4);
482
+
483
+ // Accumulate variance of A
484
+ variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_x_f32x4, a_x_f32x4);
485
+ variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_y_f32x4, a_y_f32x4);
486
+ variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_z_f32x4, a_z_f32x4);
487
+ }
488
+
489
+ if (i < n) {
490
+ nk_partial_deinterleave_f16_to_f32x4_neonhalf_(a + i * 3, n - i, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
491
+ nk_partial_deinterleave_f16_to_f32x4_neonhalf_(b + i * 3, n - i, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
492
+
493
+ sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
494
+ sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
495
+ sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
496
+ sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
497
+ sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
498
+ sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
499
+
500
+ cov_xx_f32x4 = vfmaq_f32(cov_xx_f32x4, a_x_f32x4, b_x_f32x4);
501
+ cov_xy_f32x4 = vfmaq_f32(cov_xy_f32x4, a_x_f32x4, b_y_f32x4);
502
+ cov_xz_f32x4 = vfmaq_f32(cov_xz_f32x4, a_x_f32x4, b_z_f32x4);
503
+ cov_yx_f32x4 = vfmaq_f32(cov_yx_f32x4, a_y_f32x4, b_x_f32x4);
504
+ cov_yy_f32x4 = vfmaq_f32(cov_yy_f32x4, a_y_f32x4, b_y_f32x4);
505
+ cov_yz_f32x4 = vfmaq_f32(cov_yz_f32x4, a_y_f32x4, b_z_f32x4);
506
+ cov_zx_f32x4 = vfmaq_f32(cov_zx_f32x4, a_z_f32x4, b_x_f32x4);
507
+ cov_zy_f32x4 = vfmaq_f32(cov_zy_f32x4, a_z_f32x4, b_y_f32x4);
508
+ cov_zz_f32x4 = vfmaq_f32(cov_zz_f32x4, a_z_f32x4, b_z_f32x4);
509
+
510
+ variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_x_f32x4, a_x_f32x4);
511
+ variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_y_f32x4, a_y_f32x4);
512
+ variance_a_f32x4 = vfmaq_f32(variance_a_f32x4, a_z_f32x4, a_z_f32x4);
513
+ }
514
+
515
+ // Reduce vector accumulators
516
+ nk_f32_t sum_a_x = vaddvq_f32(sum_a_x_f32x4);
517
+ nk_f32_t sum_a_y = vaddvq_f32(sum_a_y_f32x4);
518
+ nk_f32_t sum_a_z = vaddvq_f32(sum_a_z_f32x4);
519
+ nk_f32_t sum_b_x = vaddvq_f32(sum_b_x_f32x4);
520
+ nk_f32_t sum_b_y = vaddvq_f32(sum_b_y_f32x4);
521
+ nk_f32_t sum_b_z = vaddvq_f32(sum_b_z_f32x4);
522
+ nk_f32_t covariance_x_x = vaddvq_f32(cov_xx_f32x4);
523
+ nk_f32_t covariance_x_y = vaddvq_f32(cov_xy_f32x4);
524
+ nk_f32_t covariance_x_z = vaddvq_f32(cov_xz_f32x4);
525
+ nk_f32_t covariance_y_x = vaddvq_f32(cov_yx_f32x4);
526
+ nk_f32_t covariance_y_y = vaddvq_f32(cov_yy_f32x4);
527
+ nk_f32_t covariance_y_z = vaddvq_f32(cov_yz_f32x4);
528
+ nk_f32_t covariance_z_x = vaddvq_f32(cov_zx_f32x4);
529
+ nk_f32_t covariance_z_y = vaddvq_f32(cov_zy_f32x4);
530
+ nk_f32_t covariance_z_z = vaddvq_f32(cov_zz_f32x4);
531
+ nk_f32_t variance_a_sum = vaddvq_f32(variance_a_f32x4);
532
+
533
+ // Compute centroids
534
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
535
+ nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
536
+ nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
537
+
538
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
539
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
540
+
541
+ // Compute centered covariance and variance
542
+ nk_f32_t variance_a = variance_a_sum * inv_n -
543
+ (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
544
+
545
+ nk_f32_t h[9];
546
+ h[0] = covariance_x_x - n * centroid_a_x * centroid_b_x;
547
+ h[1] = covariance_x_y - n * centroid_a_x * centroid_b_y;
548
+ h[2] = covariance_x_z - n * centroid_a_x * centroid_b_z;
549
+ h[3] = covariance_y_x - n * centroid_a_y * centroid_b_x;
550
+ h[4] = covariance_y_y - n * centroid_a_y * centroid_b_y;
551
+ h[5] = covariance_y_z - n * centroid_a_y * centroid_b_z;
552
+ h[6] = covariance_z_x - n * centroid_a_z * centroid_b_x;
553
+ h[7] = covariance_z_y - n * centroid_a_z * centroid_b_y;
554
+ h[8] = covariance_z_z - n * centroid_a_z * centroid_b_z;
555
+
556
+ // SVD of H = U * S * Vᵀ
557
+ nk_f32_t svd_u[9], svd_s[9], svd_v[9];
558
+ nk_svd3x3_f32_(h, svd_u, svd_s, svd_v);
559
+
560
+ // R = V * Uᵀ
561
+ nk_f32_t r[9];
562
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
563
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
564
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
565
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
566
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
567
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
568
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
569
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
570
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
571
+
572
+ // Handle reflection and compute scale: c = trace(D × S) / variance(a)
573
+ nk_f32_t det_r = nk_det3x3_f32_(r);
574
+ nk_f32_t sign_det = det_r < 0 ? -1.0f : 1.0f;
575
+ nk_f32_t trace_scaled_s = svd_s[0] + svd_s[4] + sign_det * svd_s[8];
576
+ nk_f32_t scale_factor = trace_scaled_s / ((nk_f32_t)n * variance_a);
577
+ if (scale) *scale = scale_factor;
578
+
579
+ if (det_r < 0) {
580
+ svd_v[2] = -svd_v[2];
581
+ svd_v[5] = -svd_v[5];
582
+ svd_v[8] = -svd_v[8];
583
+ r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
584
+ r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
585
+ r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
586
+ r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
587
+ r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
588
+ r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
589
+ r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
590
+ r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
591
+ r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
592
+ }
593
+
594
+ if (rotation) {
595
+ for (int j = 0; j < 9; ++j) rotation[j] = r[j];
596
+ }
597
+
598
+ // Compute RMSD after similarity transform
599
+ nk_f32_t sum_squared = nk_transformed_ssd_f16_neonhalf_(a, b, n, r, scale_factor, centroid_a_x, centroid_a_y,
600
+ centroid_a_z, centroid_b_x, centroid_b_y, centroid_b_z);
601
+ *result = nk_f32_sqrt_neon(sum_squared * inv_n);
602
+ }
603
+
604
+ #if defined(__clang__)
605
+ #pragma clang attribute pop
606
+ #elif defined(__GNUC__)
607
+ #pragma GCC pop_options
608
+ #endif
609
+
610
+ #if defined(__cplusplus)
611
+ } // extern "C"
612
+ #endif
613
+
614
+ #endif // NK_TARGET_NEONHALF
615
+ #endif // NK_TARGET_ARM_
616
+ #endif // NK_MESH_NEONHALF_H