numkong 7.4.5 → 7.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. package/README.md +1 -0
  2. package/binding.gyp +99 -5
  3. package/c/dispatch_e5m2.c +23 -3
  4. package/c/dispatch_f16.c +23 -0
  5. package/c/numkong.c +0 -13
  6. package/include/numkong/attention/sme.h +34 -31
  7. package/include/numkong/capabilities.h +2 -15
  8. package/include/numkong/cast/README.md +3 -0
  9. package/include/numkong/cast/haswell.h +28 -64
  10. package/include/numkong/cast/neon.h +15 -0
  11. package/include/numkong/cast/serial.h +17 -0
  12. package/include/numkong/cast/skylake.h +67 -52
  13. package/include/numkong/cast.h +1 -0
  14. package/include/numkong/curved/smef64.h +82 -62
  15. package/include/numkong/dot/README.md +1 -0
  16. package/include/numkong/dot/haswell.h +92 -13
  17. package/include/numkong/dot/rvvbf16.h +1 -1
  18. package/include/numkong/dot/rvvhalf.h +1 -1
  19. package/include/numkong/dot/serial.h +15 -0
  20. package/include/numkong/dot/skylake.h +61 -14
  21. package/include/numkong/dot/sve.h +6 -5
  22. package/include/numkong/dot/svebfdot.h +2 -1
  23. package/include/numkong/dot/svehalf.h +6 -5
  24. package/include/numkong/dot/svesdot.h +3 -2
  25. package/include/numkong/dots/README.md +2 -0
  26. package/include/numkong/dots/graniteamx.h +1167 -0
  27. package/include/numkong/dots/haswell.h +28 -28
  28. package/include/numkong/dots/sapphireamx.h +1 -1
  29. package/include/numkong/dots/serial.h +33 -11
  30. package/include/numkong/dots/skylake.h +28 -23
  31. package/include/numkong/dots/sme.h +172 -140
  32. package/include/numkong/dots/smebi32.h +14 -11
  33. package/include/numkong/dots/smef64.h +31 -26
  34. package/include/numkong/dots.h +41 -3
  35. package/include/numkong/each/serial.h +39 -0
  36. package/include/numkong/geospatial/haswell.h +1 -1
  37. package/include/numkong/geospatial/neon.h +1 -1
  38. package/include/numkong/geospatial/serial.h +15 -4
  39. package/include/numkong/geospatial/skylake.h +1 -1
  40. package/include/numkong/maxsim/serial.h +15 -0
  41. package/include/numkong/maxsim/sme.h +34 -33
  42. package/include/numkong/mesh/README.md +50 -44
  43. package/include/numkong/mesh/genoa.h +462 -0
  44. package/include/numkong/mesh/haswell.h +806 -933
  45. package/include/numkong/mesh/neon.h +871 -943
  46. package/include/numkong/mesh/neonbfdot.h +382 -522
  47. package/include/numkong/mesh/neonfhm.h +676 -0
  48. package/include/numkong/mesh/rvv.h +404 -319
  49. package/include/numkong/mesh/serial.h +225 -161
  50. package/include/numkong/mesh/skylake.h +1029 -1585
  51. package/include/numkong/mesh/v128relaxed.h +403 -377
  52. package/include/numkong/mesh.h +38 -0
  53. package/include/numkong/reduce/neon.h +29 -0
  54. package/include/numkong/reduce/neonbfdot.h +2 -2
  55. package/include/numkong/reduce/neonfhm.h +4 -4
  56. package/include/numkong/reduce/serial.h +15 -1
  57. package/include/numkong/reduce/sve.h +52 -0
  58. package/include/numkong/reduce.h +4 -0
  59. package/include/numkong/set/sve.h +6 -5
  60. package/include/numkong/sets/smebi32.h +35 -30
  61. package/include/numkong/sparse/serial.h +17 -2
  62. package/include/numkong/sparse/sve2.h +3 -2
  63. package/include/numkong/spatial/genoa.h +0 -68
  64. package/include/numkong/spatial/haswell.h +98 -56
  65. package/include/numkong/spatial/serial.h +15 -0
  66. package/include/numkong/spatial/skylake.h +114 -54
  67. package/include/numkong/spatial/sve.h +7 -6
  68. package/include/numkong/spatial/svebfdot.h +7 -4
  69. package/include/numkong/spatial/svehalf.h +5 -4
  70. package/include/numkong/spatial/svesdot.h +9 -8
  71. package/include/numkong/spatial.h +0 -12
  72. package/include/numkong/spatials/graniteamx.h +301 -0
  73. package/include/numkong/spatials/serial.h +39 -0
  74. package/include/numkong/spatials/skylake.h +2 -2
  75. package/include/numkong/spatials/sme.h +391 -350
  76. package/include/numkong/spatials/smef64.h +79 -70
  77. package/include/numkong/spatials.h +54 -4
  78. package/include/numkong/tensor.hpp +107 -23
  79. package/include/numkong/types.h +59 -0
  80. package/javascript/dist/cjs/numkong.js +13 -0
  81. package/javascript/dist/esm/numkong.js +13 -0
  82. package/javascript/numkong.c +59 -14
  83. package/javascript/numkong.ts +13 -0
  84. package/package.json +7 -7
  85. package/probes/probe.js +2 -2
  86. package/wasm/numkong.wasm +0 -0
@@ -105,67 +105,73 @@ Each kernel runs for at least 20 seconds per configuration.
105
105
  Benchmark threads are pinned to specific cores; on machines with heterogeneous core types (e.g., Apple P/E cores), only the fastest cores are used.
106
106
  Workloads that significantly degrade CPU frequencies (Intel AMX, Apple SME) run in separate passes to avoid affecting throughput measurements of other kernels.
107
107
 
108
- ### Intel Sapphire Rapids
108
+ ### Intel Granite Rapids
109
+
110
+ Xeon 6776P, 2.3 GHz base, `cpu_scaling_enabled=false`.
111
+ Serial kernels compiled with `-fno-tree-vectorize`.
109
112
 
110
113
  #### Native
111
114
 
112
115
  | Kernel | 256 | 1024 | 4096 |
113
116
  | :------------------------ | -----------------------: | -----------------------: | -----------------------: |
114
117
  | __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
115
- | `nk_rmsd_f64_serial` | 354 mp/s, 1.4 ulp | 176 mp/s, 2.7 ulp | 159 mp/s, 5.0 ulp |
116
- | `nk_kabsch_f64_serial` | 71.1 mp/s, 1.4 ulp | 70.8 mp/s, 2.7 ulp | 80.3 mp/s, 5.2 ulp |
117
- | `nk_umeyama_f64_serial` | 70.1 mp/s, 1.0 ulp | 75.1 mp/s, 1.8 ulp | 79.1 mp/s, 3.9 ulp |
118
- | `nk_rmsd_f64_haswell` | 405 mp/s, 0.3 ulp | 260 mp/s, 0.4 ulp | 192 mp/s, 0.8 ulp |
119
- | `nk_kabsch_f64_haswell` | 82.1 mp/s, 0.9 ulp | 105 mp/s, 1.3 ulp | 133 mp/s, 2.3 ulp |
120
- | `nk_umeyama_f64_haswell` | 82.6 mp/s, 0.4 ulp | 119 mp/s, 0.8 ulp | 134 mp/s, 1.5 ulp |
121
- | `nk_rmsd_f64_skylake` | 540 mp/s, 0.3 ulp | 219 mp/s, 0.3 ulp | 213 mp/s, 0.5 ulp |
122
- | `nk_kabsch_f64_skylake` | 96.8 mp/s, 0.7 ulp | 115 mp/s, 0.9 ulp | 159 mp/s, 1.1 ulp |
123
- | `nk_umeyama_f64_skylake` | 101 mp/s, 0.2 ulp | 119 mp/s, 0.4 ulp | 157 mp/s, 0.8 ulp |
118
+ | `nk_rmsd_f64_serial` | 93.7 mp/s, 0.5 ulp | 87.4 mp/s, 0.5 ulp | 69.8 mp/s, 0.5 ulp |
119
+ | `nk_kabsch_f64_serial` | 11.8 mp/s, 0.8 ulp | 13.6 mp/s, 0.8 ulp | 12.8 mp/s, 0.8 ulp |
120
+ | `nk_umeyama_f64_serial` | 10.4 mp/s, 0.3 ulp | 11.7 mp/s, 0.3 ulp | 11.5 mp/s, 0.3 ulp |
121
+ | `nk_rmsd_f64_haswell` | 523 mp/s, 0.3 ulp | 564 mp/s, 0.4 ulp | 449 mp/s, 0.8 ulp |
122
+ | `nk_kabsch_f64_haswell` | 65.3 mp/s, 0.5 ulp | 203 mp/s, 0.9 ulp | 326 mp/s, 1.5 ulp |
123
+ | `nk_umeyama_f64_haswell` | 68.0 mp/s, 0.5 ulp | 200 mp/s, 0.8 ulp | 324 mp/s, 1.5 ulp |
124
+ | `nk_rmsd_f64_skylake` | 546 mp/s, 0.2 ulp | 587 mp/s, 0.3 ulp | 583 mp/s, 0.4 ulp |
125
+ | `nk_kabsch_f64_skylake` | 34.5 mp/s, 0.4 ulp | 107 mp/s, 0.5 ulp | 261 mp/s, 0.8 ulp |
126
+ | `nk_umeyama_f64_skylake` | 24.3 mp/s, 0.3 ulp | 82.7 mp/s, 0.5 ulp | 201 mp/s, 0.8 ulp |
124
127
  | __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
125
- | `nk_rmsd_f32_serial` | 480 mp/s, 1.4 ulp | 314 mp/s, 2.7 ulp | 270 mp/s, 5.4 ulp |
126
- | `nk_kabsch_f32_serial` | 83.2 mp/s, 1.5 ulp | 91.6 mp/s, 2.6 ulp | 110 mp/s, 5.3 ulp |
127
- | `nk_umeyama_f32_serial` | 80.4 mp/s, 1.0 ulp | 104 mp/s, 1.9 ulp | 106 mp/s, 3.7 ulp |
128
- | `nk_rmsd_f32_haswell` | 447 mp/s, 0.3 ulp | 484 mp/s, 0.3 ulp | 350 mp/s, 0.4 ulp |
129
- | `nk_kabsch_f32_haswell` | 101 mp/s, 0.7 ulp | 192 mp/s, 0.9 ulp | 213 mp/s, 1.3 ulp |
130
- | `nk_umeyama_f32_haswell` | 97.4 mp/s, 0.3 ulp | 155 mp/s, 0.4 ulp | 207 mp/s, 0.8 ulp |
131
- | `nk_rmsd_f32_skylake` | 1,000 mp/s, 0.7 ulp | 974 mp/s, 1.2 ulp | 786 mp/s, 2.4 ulp |
132
- | `nk_kabsch_f32_skylake` | 97.5 mp/s, 0.7 ulp | 232 mp/s, 0.7 ulp | 332 mp/s, 0.9 ulp |
133
- | `nk_umeyama_f32_skylake` | 92.5 mp/s, 0.2 ulp | 227 mp/s, 0.2 ulp | 325 mp/s, 0.3 ulp |
128
+ | `nk_rmsd_f32_serial` | 68.9 mp/s, 0.5 ulp | 70.7 mp/s, 0.5 ulp | 72.1 mp/s, 0.5 ulp |
129
+ | `nk_kabsch_f32_serial` | 11.2 mp/s, 0.8 ulp | 12.8 mp/s, 0.8 ulp | 14.0 mp/s, 0.9 ulp |
130
+ | `nk_umeyama_f32_serial` | 10.1 mp/s, 0.3 ulp | 11.2 mp/s, 0.3 ulp | 12.1 mp/s, 0.4 ulp |
131
+ | `nk_rmsd_f32_haswell` | 686 mp/s, 0.3 ulp | 848 mp/s, 0.5 ulp | 841 mp/s, 0.9 ulp |
132
+ | `nk_kabsch_f32_haswell` | 90.4 mp/s, 0.9 ulp | 250 mp/s, 1.3 ulp | 455 mp/s, 7.6 ulp |
133
+ | `nk_umeyama_f32_haswell` | 87.7 mp/s, 0.3 ulp | 250 mp/s, 0.4 ulp | 374 mp/s, 0.7 ulp |
134
+ | `nk_rmsd_f32_skylake` | 1,016 mp/s, 1.2 ulp | 1,112 mp/s, 1.2 ulp | 1,042 mp/s, 4.3 ulp |
135
+ | `nk_kabsch_f32_skylake` | 81.8 mp/s, 0.9 ulp | 241 mp/s, 4.1 ulp | 549 mp/s, 3.1 ulp |
136
+ | `nk_umeyama_f32_skylake` | 58.0 mp/s, 0.6 ulp | 168 mp/s, 2.9 ulp | 459 mp/s, 2.1 ulp |
134
137
  | __bf16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
135
- | `nk_rmsd_bf16_haswell` | 511 mp/s, 0.3 ulp | 481 mp/s, 3.5 ulp | 497 mp/s, 12.8 ulp |
136
- | `nk_kabsch_bf16_haswell` | 52.4 mp/s, 0.7 ulp | 65.3 mp/s, 0.9 ulp | 74.8 mp/s, 1.3 ulp |
137
- | `nk_umeyama_bf16_haswell` | 51.5 mp/s, 0.2 ulp | 69.2 mp/s, 0.4 ulp | 74.6 mp/s, 0.8 ulp |
138
- | `nk_rmsd_bf16_skylake` | 1,765 mp/s, 0.3 ulp | 1,945 mp/s, 0.5 ulp | 2,056 mp/s, 6.0 ulp |
139
- | `nk_kabsch_bf16_skylake` | 132 mp/s, 0.7 ulp | 370 mp/s, 0.8 ulp | 689 mp/s, 0.9 ulp |
140
- | `nk_umeyama_bf16_skylake` | 130 mp/s, 0.2 ulp | 366 mp/s, 0.3 ulp | 689 mp/s, 0.5 ulp |
138
+ | `nk_rmsd_bf16_haswell` | 284 mp/s, 0.3 ulp | 281 mp/s, 3.5 ulp | 273 mp/s, 12.8 ulp |
139
+ | `nk_kabsch_bf16_haswell` | 36.2 mp/s, 0.4 ulp | 106 mp/s, 7.6 ulp | 186 mp/s, 33.0 ulp |
140
+ | `nk_umeyama_bf16_haswell` | 34.5 mp/s, 0.3 ulp | 102 mp/s, 5.3 ulp | 186 mp/s, 23.1 ulp |
141
+ | `nk_rmsd_bf16_skylake` | 1,837 mp/s, 0.4 ulp | 2,357 mp/s, 5.4 ulp | 2,422 mp/s, 11.8 ulp |
142
+ | `nk_kabsch_bf16_skylake` | 34.1 mp/s, 0.3 ulp | 131 mp/s, 3.2 ulp | 487 mp/s, 20.4 ulp |
143
+ | `nk_umeyama_bf16_skylake` | 34.6 mp/s, 0.3 ulp | 130 mp/s, 2.2 ulp | 394 mp/s, 14.3 ulp |
144
+ | `nk_rmsd_bf16_genoa` | 1,743 mp/s, 0.3 ulp | 2,323 mp/s, 3.1 ulp | 2,066 mp/s, 20.2 ulp |
145
+ | `nk_kabsch_bf16_genoa` | 33.4 mp/s, 0.3 ulp | 133 mp/s, 3.2 ulp | 405 mp/s, 20.3 ulp |
146
+ | `nk_umeyama_bf16_genoa` | 33.2 mp/s, 0.3 ulp | 129 mp/s, 2.2 ulp | 439 mp/s, 14.3 ulp |
141
147
  | __f16__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
142
- | `nk_rmsd_f16_haswell` | 415 mp/s, 0.3 ulp | 497 mp/s, 0.7 ulp | 458 mp/s, 2.5 ulp |
143
- | `nk_kabsch_f16_haswell` | 151 mp/s, 0.7 ulp | 222 mp/s, 0.9 ulp | 221 mp/s, 1.4 ulp |
144
- | `nk_umeyama_f16_haswell` | 186 mp/s, 0.2 ulp | 232 mp/s, 0.5 ulp | 222 mp/s, 0.9 ulp |
145
- | `nk_rmsd_f16_skylake` | 1,813 mp/s, 0.3 ulp | 1,982 mp/s, 0.4 ulp | 2,049 mp/s, 1.8 ulp |
146
- | `nk_kabsch_f16_skylake` | 367 mp/s, 0.7 ulp | 695 mp/s, 0.7 ulp | 903 mp/s, 0.9 ulp |
147
- | `nk_umeyama_f16_skylake` | 341 mp/s, 0.2 ulp | 686 mp/s, 0.2 ulp | 882 mp/s, 0.4 ulp |
148
+ | `nk_rmsd_f16_haswell` | 273 mp/s, 0.2 ulp | 274 mp/s, 0.7 ulp | 291 mp/s, 2.5 ulp |
149
+ | `nk_kabsch_f16_haswell` | 34.4 mp/s, 0.5 ulp | 98.0 mp/s, 1.8 ulp | 197 mp/s, 8.2 ulp |
150
+ | `nk_umeyama_f16_haswell` | 35.5 mp/s, 0.4 ulp | 97.9 mp/s, 1.2 ulp | 196 mp/s, 5.7 ulp |
151
+ | `nk_rmsd_f16_skylake` | 1,834 mp/s, 0.3 ulp | 2,341 mp/s, 1.3 ulp | 2,418 mp/s, 3.9 ulp |
152
+ | `nk_kabsch_f16_skylake` | 34.0 mp/s, 0.7 ulp | 132 mp/s, 0.5 ulp | 480 mp/s, 4.7 ulp |
153
+ | `nk_umeyama_f16_skylake` | 33.8 mp/s, 0.5 ulp | 127 mp/s, 0.4 ulp | 481 mp/s, 3.3 ulp |
148
154
 
149
155
  #### WASM
150
156
 
151
- Measured with Wasmtime v42 (Cranelift backend).
157
+ Measured with Wasmtime v43 (Cranelift backend), WASI-SDK 24, `-msimd128 -mrelaxed-simd`.
152
158
 
153
159
  | Kernel | 256 | 1024 | 4096 |
154
160
  | :--------------------------- | -----------------------: | -----------------------: | -----------------------: |
155
161
  | __f64__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
156
- | `nk_rmsd_f64_serial` | 178 mp/s, 1.4 ulp | 158 mp/s, 2.6 ulp | ? mp/s, 5.3 ulp |
157
- | `nk_rmsd_f64_v128relaxed` | 273 mp/s, 0.4 ulp | 307 mp/s, 0.7 ulp | ? mp/s, 1.3 ulp |
158
- | `nk_kabsch_f64_serial` | 37.7 mp/s, 1.4 ulp | 51.7 mp/s, 2.5 ulp | ? mp/s, 5.2 ulp |
159
- | `nk_kabsch_f64_v128relaxed` | 31.7 mp/s, 1.2 ulp | 56.9 mp/s, 2.3 ulp | ? mp/s, 4.5 ulp |
160
- | `nk_umeyama_f64_serial` | 36.5 mp/s, 0.9 ulp | 49.6 mp/s, 1.9 ulp | ? mp/s, 3.6 ulp |
161
- | `nk_umeyama_f64_v128relaxed` | 32.6 mp/s, 0.8 ulp | 55.5 mp/s, 1.5 ulp | ? mp/s, 3.2 ulp |
162
+ | `nk_rmsd_f64_serial` | 89.9 mp/s, 0.5 ulp | 86.1 mp/s, 0.5 ulp | 73.4 mp/s, 0.5 ulp |
163
+ | `nk_rmsd_f64_v128relaxed` | 485 mp/s, 0.4 ulp | 552 mp/s, 0.7 ulp | 412 mp/s, 1.3 ulp |
164
+ | `nk_kabsch_f64_serial` | 12.1 mp/s, 0.8 ulp | 13.9 mp/s, 0.8 ulp | 14.0 mp/s, 0.9 ulp |
165
+ | `nk_kabsch_f64_v128relaxed` | 66.0 mp/s, 0.9 ulp | 188 mp/s, 1.7 ulp | 177 mp/s, 3.1 ulp |
166
+ | `nk_umeyama_f64_serial` | 10.8 mp/s, 0.3 ulp | 12.3 mp/s, 0.3 ulp | 12.2 mp/s, 0.4 ulp |
167
+ | `nk_umeyama_f64_v128relaxed` | 64.0 mp/s, 0.8 ulp | 187 mp/s, 1.6 ulp | 178 mp/s, 3.2 ulp |
162
168
  | __f32__ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ | ░░░░░░░░░░░░░░░░░░░░░░░░ |
163
- | `nk_rmsd_f32_serial` | 105 mp/s, 1.4 ulp | 122 mp/s, 2.7 ulp | ? mp/s, 5.2 ulp |
164
- | `nk_rmsd_f32_v128relaxed` | 213 mp/s, 0.3 ulp | 258 mp/s, 0.4 ulp | ? mp/s, 0.8 ulp |
165
- | `nk_kabsch_f32_serial` | 15.5 mp/s, 1.4 ulp | 32.8 mp/s, 2.6 ulp | ? mp/s, 5.1 ulp |
166
- | `nk_kabsch_f32_v128relaxed` | 13.5 mp/s, 0.9 ulp | 46.2 mp/s, 1.3 ulp | ? mp/s, 2.5 ulp |
167
- | `nk_umeyama_f32_serial` | 15.2 mp/s, 1.0 ulp | 37.4 mp/s, 1.8 ulp | ? mp/s, 3.7 ulp |
168
- | `nk_umeyama_f32_v128relaxed` | 18.3 mp/s, 0.4 ulp | 38.9 mp/s, 0.8 ulp | ? mp/s, 1.5 ulp |
169
+ | `nk_rmsd_f32_serial` | 80.6 mp/s, 0.5 ulp | 82.7 mp/s, 0.5 ulp | 70.3 mp/s, 0.5 ulp |
170
+ | `nk_rmsd_f32_v128relaxed` | 452 mp/s, 1.5 ulp | 416 mp/s, 1.3 ulp | 399 mp/s, 4.8 ulp |
171
+ | `nk_kabsch_f32_serial` | 11.4 mp/s, 0.8 ulp | 12.8 mp/s, 0.9 ulp | 12.7 mp/s, 0.8 ulp |
172
+ | `nk_kabsch_f32_v128relaxed` | 79.5 mp/s, 4.2 ulp | 132 mp/s, 3.9 ulp | 177 mp/s, 14.3 ulp |
173
+ | `nk_umeyama_f32_serial` | 10.1 mp/s, 0.3 ulp | 11.2 mp/s, 0.3 ulp | 11.2 mp/s, 0.3 ulp |
174
+ | `nk_umeyama_f32_v128relaxed` | 79.4 mp/s, 2.8 ulp | 138 mp/s, 2.8 ulp | 194 mp/s, 10.1 ulp |
169
175
 
170
176
 
171
177
  ### Apple M5
@@ -0,0 +1,462 @@
1
+ /**
2
+ * @brief SIMD-accelerated Point Cloud Alignment for Genoa (AVX-512-BF16).
3
+ * @file include/numkong/mesh/genoa.h
4
+ * @author Ash Vardanian
5
+ * @date December 28, 2025
6
+ *
7
+ * @sa include/numkong/mesh.h
8
+ *
9
+ * @section genoa_mesh_instructions Key AVX-512 BF16 Mesh Instructions
10
+ *
11
+ * Intrinsic Instruction Genoa Sapphire
12
+ * _mm512_dpbf16_ps VDPBF16PS (ZMM, ZMM, ZMM) 6cy @ p01 6cy @ p05
13
+ * _mm512_permutexvar_epi16 VPERMW (ZMM, ZMM, ZMM) 3cy @ p5 6cy @ p5
14
+ * _mm512_maskz_loadu_epi16 VMOVDQU16 (ZMM{k}, M) 9cy @ L1 9cy @ L1
15
+ *
16
+ * The bf16 mesh kernels use a 15-lane channel-grouped layout: 10 xyz triplets per ZMM (30 bf16
17
+ * values laid out as [x0..x9, y0..y9, z0..z9, _, _] after a single VPERMW). That maps cleanly
18
+ * onto VDPBF16PS, which pairs adjacent bf16 values per fp32 lane; 5 channel-consecutive pairs
19
+ * give a single H-cell per lane-range. Three product accumulators (a*b, a*rot1(b), a*rot2(b))
20
+ * cover the 9 cross-covariance cells, matching the Skylake structure.
21
+ */
22
+ #ifndef NK_MESH_GENOA_H
23
+ #define NK_MESH_GENOA_H
24
+
25
+ #if NK_TARGET_X8664_
26
+ #if NK_TARGET_GENOA
27
+
28
+ #include "numkong/types.h"
29
+ #include "numkong/mesh/serial.h"
30
+ #include "numkong/spatial/haswell.h" // `nk_f32_sqrt_haswell`
31
+
32
+ #if defined(__cplusplus)
33
+ extern "C" {
34
+ #endif
35
+
36
+ #if defined(__clang__)
37
+ #pragma clang attribute push( \
38
+ __attribute__((target("avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512bf16,f16c,fma,bmi,bmi2"))), \
39
+ apply_to = function)
40
+ #elif defined(__GNUC__)
41
+ #pragma GCC push_options
42
+ #pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512bf16", "f16c", "fma", "bmi", "bmi2")
43
+ #endif
44
+
45
+ NK_PUBLIC void nk_rmsd_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
46
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
47
+ if (rotation)
48
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
49
+ rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
50
+ if (scale) *scale = 1.0f;
51
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
52
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
53
+
54
+ // 32-lane bf16 chunks = 10 triplets + 2 padding bf16 per register.
55
+ // VDPBF16PS pairs adjacent bf16 per fp32 lane: lane[i] += a[2i]*b[2i] + a[2i+1]*b[2i+1].
56
+ // For RMSD we need Σ(a-b)², computed via Σ a² + Σ b² - 2 Σ a·b.
57
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
58
+ __m512 norm_squared_a_f32x16 = zeros_f32x16;
59
+ __m512 norm_squared_b_f32x16 = zeros_f32x16;
60
+ __m512 cross_product_f32x16 = zeros_f32x16;
61
+ nk_size_t index = 0;
62
+
63
+ __mmask32 const full_mask_bf16 = (__mmask32)0x3FFFFFFF; // 30 bf16 valid, 2 bf16 padding zeros
64
+
65
+ for (; index + 10 <= n; index += 10) {
66
+ __m512i a_bf16x32 = _mm512_maskz_loadu_epi16(full_mask_bf16, (__m512i const *)(a + index * 3));
67
+ __m512i b_bf16x32 = _mm512_maskz_loadu_epi16(full_mask_bf16, (__m512i const *)(b + index * 3));
68
+ norm_squared_a_f32x16 = _mm512_dpbf16_ps(norm_squared_a_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
69
+ nk_m512bh_from_m512i_(a_bf16x32));
70
+ norm_squared_b_f32x16 = _mm512_dpbf16_ps(norm_squared_b_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
71
+ nk_m512bh_from_m512i_(b_bf16x32));
72
+ cross_product_f32x16 = _mm512_dpbf16_ps(cross_product_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
73
+ nk_m512bh_from_m512i_(b_bf16x32));
74
+ }
75
+
76
+ if (index < n) {
77
+ __mmask32 tail_mask = (__mmask32)_bzhi_u32(0x3FFFFFFF, (nk_u32_t)((n - index) * 3));
78
+ __m512i a_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (__m512i const *)(a + index * 3));
79
+ __m512i b_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (__m512i const *)(b + index * 3));
80
+ norm_squared_a_f32x16 = _mm512_dpbf16_ps(norm_squared_a_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
81
+ nk_m512bh_from_m512i_(a_bf16x32));
82
+ norm_squared_b_f32x16 = _mm512_dpbf16_ps(norm_squared_b_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
83
+ nk_m512bh_from_m512i_(b_bf16x32));
84
+ cross_product_f32x16 = _mm512_dpbf16_ps(cross_product_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
85
+ nk_m512bh_from_m512i_(b_bf16x32));
86
+ }
87
+
88
+ nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
89
+ nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
90
+ nk_f32_t cross_product = _mm512_reduce_add_ps(cross_product_f32x16);
91
+ nk_f32_t sum_squared = norm_squared_a + norm_squared_b - 2.0f * cross_product;
92
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
93
+ *result = nk_f32_sqrt_haswell(sum_squared / (nk_f32_t)n);
94
+ }
95
+
96
+ // Channel-grouping permute: 10 xyz triplets + 2 padding bf16 → [x0..x9, y0..y9, z0..z9, _, _].
97
+ // After VPERMW lanes 0..4 carry the x-channel (2 bf16 per fp32 lane), 5..9 carry y, 10..14 carry z.
98
+ #define NK_MESH_GENOA_CHANNEL_GROUP_INDICES_ \
99
+ _mm512_set_epi16(31, 30, 29, 26, 23, 20, 17, 14, 11, 8, 5, 2, 28, 25, 22, 19, 16, 13, 10, 7, 4, 1, 27, 24, 21, 18, \
100
+ 15, 12, 9, 6, 3, 0)
101
+
102
+ // Rotation-1 applied during channel-grouping: each channel slot carries the *next* channel of b.
103
+ // x-slot gets b.y, y-slot gets b.z, z-slot gets b.x. Pairs covariance cells (xy, yz, zx).
104
+ #define NK_MESH_GENOA_ROTATION_1_INDICES_ \
105
+ _mm512_set_epi16(31, 30, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0, 29, 26, 23, 20, 17, 14, 11, 8, 5, 2, 28, 25, 22, 19, \
106
+ 16, 13, 10, 7, 4, 1)
107
+
108
+ // Rotation-2: x-slot gets b.z, y-slot gets b.x, z-slot gets b.y. Pairs covariance cells (xz, yx, zy).
109
+ #define NK_MESH_GENOA_ROTATION_2_INDICES_ \
110
+ _mm512_set_epi16(31, 30, 28, 25, 22, 19, 16, 13, 10, 7, 4, 1, 27, 24, 21, 18, 15, 12, 9, 6, 3, 0, 29, 26, 23, 20, \
111
+ 17, 14, 11, 8, 5, 2)
112
+
113
+ NK_PUBLIC void nk_kabsch_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
114
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
115
+ __m512i const idx_channel_group_i16x32 = NK_MESH_GENOA_CHANNEL_GROUP_INDICES_;
116
+ __m512i const idx_rotation_1_i16x32 = NK_MESH_GENOA_ROTATION_1_INDICES_;
117
+ __m512i const idx_rotation_2_i16x32 = NK_MESH_GENOA_ROTATION_2_INDICES_;
118
+ __m512i const ones_bf16x32 = _mm512_set1_epi16(0x3F80); // bf16 representation of 1.0
119
+
120
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
121
+ __m512 sum_a_f32x16 = zeros_f32x16, sum_b_f32x16 = zeros_f32x16;
122
+ __m512 norm_squared_a_f32x16 = zeros_f32x16, norm_squared_b_f32x16 = zeros_f32x16;
123
+ __m512 product_diagonal_f32x16 = zeros_f32x16;
124
+ __m512 product_rotation_1_f32x16 = zeros_f32x16;
125
+ __m512 product_rotation_2_f32x16 = zeros_f32x16;
126
+
127
+ __mmask32 const full_mask_bf16 = (__mmask32)0x3FFFFFFF;
128
+
129
+ nk_size_t index = 0;
130
+ for (; index + 10 <= n; index += 10) {
131
+ __m512i a_raw_bf16x32 = _mm512_maskz_loadu_epi16(full_mask_bf16, (__m512i const *)(a + index * 3));
132
+ __m512i b_raw_bf16x32 = _mm512_maskz_loadu_epi16(full_mask_bf16, (__m512i const *)(b + index * 3));
133
+ __m512i a_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, a_raw_bf16x32);
134
+ __m512i b_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, b_raw_bf16x32);
135
+ __m512i b_rotation_1_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_1_i16x32, b_raw_bf16x32);
136
+ __m512i b_rotation_2_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_2_i16x32, b_raw_bf16x32);
137
+
138
+ sum_a_f32x16 = _mm512_dpbf16_ps(sum_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
139
+ nk_m512bh_from_m512i_(ones_bf16x32));
140
+ sum_b_f32x16 = _mm512_dpbf16_ps(sum_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
141
+ nk_m512bh_from_m512i_(ones_bf16x32));
142
+ norm_squared_a_f32x16 = _mm512_dpbf16_ps(norm_squared_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
143
+ nk_m512bh_from_m512i_(a_grouped_bf16x32));
144
+ norm_squared_b_f32x16 = _mm512_dpbf16_ps(norm_squared_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
145
+ nk_m512bh_from_m512i_(b_grouped_bf16x32));
146
+ product_diagonal_f32x16 = _mm512_dpbf16_ps(product_diagonal_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
147
+ nk_m512bh_from_m512i_(b_grouped_bf16x32));
148
+ product_rotation_1_f32x16 = _mm512_dpbf16_ps(product_rotation_1_f32x16,
149
+ nk_m512bh_from_m512i_(a_grouped_bf16x32),
150
+ nk_m512bh_from_m512i_(b_rotation_1_bf16x32));
151
+ product_rotation_2_f32x16 = _mm512_dpbf16_ps(product_rotation_2_f32x16,
152
+ nk_m512bh_from_m512i_(a_grouped_bf16x32),
153
+ nk_m512bh_from_m512i_(b_rotation_2_bf16x32));
154
+ }
155
+
156
+ if (index < n) {
157
+ __mmask32 tail_mask = (__mmask32)_bzhi_u32(0x3FFFFFFF, (nk_u32_t)((n - index) * 3));
158
+ __m512i a_raw_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (__m512i const *)(a + index * 3));
159
+ __m512i b_raw_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (__m512i const *)(b + index * 3));
160
+ __m512i a_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, a_raw_bf16x32);
161
+ __m512i b_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, b_raw_bf16x32);
162
+ __m512i b_rotation_1_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_1_i16x32, b_raw_bf16x32);
163
+ __m512i b_rotation_2_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_2_i16x32, b_raw_bf16x32);
164
+
165
+ sum_a_f32x16 = _mm512_dpbf16_ps(sum_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
166
+ nk_m512bh_from_m512i_(ones_bf16x32));
167
+ sum_b_f32x16 = _mm512_dpbf16_ps(sum_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
168
+ nk_m512bh_from_m512i_(ones_bf16x32));
169
+ norm_squared_a_f32x16 = _mm512_dpbf16_ps(norm_squared_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
170
+ nk_m512bh_from_m512i_(a_grouped_bf16x32));
171
+ norm_squared_b_f32x16 = _mm512_dpbf16_ps(norm_squared_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
172
+ nk_m512bh_from_m512i_(b_grouped_bf16x32));
173
+ product_diagonal_f32x16 = _mm512_dpbf16_ps(product_diagonal_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
174
+ nk_m512bh_from_m512i_(b_grouped_bf16x32));
175
+ product_rotation_1_f32x16 = _mm512_dpbf16_ps(product_rotation_1_f32x16,
176
+ nk_m512bh_from_m512i_(a_grouped_bf16x32),
177
+ nk_m512bh_from_m512i_(b_rotation_1_bf16x32));
178
+ product_rotation_2_f32x16 = _mm512_dpbf16_ps(product_rotation_2_f32x16,
179
+ nk_m512bh_from_m512i_(a_grouped_bf16x32),
180
+ nk_m512bh_from_m512i_(b_rotation_2_bf16x32));
181
+ }
182
+
183
+ // Channel demux by lane range (x=0..4, y=5..9, z=10..14, lane 15 padding).
184
+ __mmask16 const mask_channel_x_f32 = 0x001F;
185
+ __mmask16 const mask_channel_y_f32 = 0x03E0;
186
+ __mmask16 const mask_channel_z_f32 = 0x7C00;
187
+
188
+ nk_f32_t sum_a_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_a_f32x16);
189
+ nk_f32_t sum_a_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_a_f32x16);
190
+ nk_f32_t sum_a_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_a_f32x16);
191
+ nk_f32_t sum_b_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_b_f32x16);
192
+ nk_f32_t sum_b_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_b_f32x16);
193
+ nk_f32_t sum_b_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_b_f32x16);
194
+ nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
195
+ nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
196
+
197
+ nk_f32_t covariance_x_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_diagonal_f32x16);
198
+ nk_f32_t covariance_y_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_diagonal_f32x16);
199
+ nk_f32_t covariance_z_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_diagonal_f32x16);
200
+ nk_f32_t covariance_x_y = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_1_f32x16);
201
+ nk_f32_t covariance_y_z = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_1_f32x16);
202
+ nk_f32_t covariance_z_x = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_1_f32x16);
203
+ nk_f32_t covariance_x_z = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_2_f32x16);
204
+ nk_f32_t covariance_y_x = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_2_f32x16);
205
+ nk_f32_t covariance_z_y = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_2_f32x16);
206
+
207
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
208
+ nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
209
+ nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
210
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
211
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
212
+
213
+ // Parallel-axis correction.
214
+ nk_f32_t cross_covariance[9];
215
+ cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
216
+ cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
217
+ cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
218
+ cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
219
+ cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
220
+ cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
221
+ cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
222
+ cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
223
+ cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
224
+
225
+ nk_f32_t centered_norm_squared_a = norm_squared_a -
226
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
227
+ centroid_a_z * centroid_a_z);
228
+ nk_f32_t centered_norm_squared_b = norm_squared_b -
229
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
230
+ centroid_b_z * centroid_b_z);
231
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
232
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
233
+
234
+ // Identity-dominant short-circuit.
235
+ nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
236
+ cross_covariance[4] * cross_covariance[4] +
237
+ cross_covariance[8] * cross_covariance[8];
238
+ nk_f32_t covariance_offdiagonal_norm_squared =
239
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
240
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
241
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
242
+ nk_f32_t optimal_rotation[9];
243
+ nk_f32_t trace_rotation_covariance;
244
+ if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
245
+ cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
246
+ optimal_rotation[0] = 1.0f, optimal_rotation[1] = 0.0f, optimal_rotation[2] = 0.0f;
247
+ optimal_rotation[3] = 0.0f, optimal_rotation[4] = 1.0f, optimal_rotation[5] = 0.0f;
248
+ optimal_rotation[6] = 0.0f, optimal_rotation[7] = 0.0f, optimal_rotation[8] = 1.0f;
249
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
250
+ }
251
+ else {
252
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
253
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
254
+ nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
255
+ if (nk_det3x3_f32_(optimal_rotation) < 0) {
256
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
257
+ nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
258
+ }
259
+ trace_rotation_covariance =
260
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
261
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
262
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
263
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
264
+ optimal_rotation[8] * cross_covariance[8];
265
+ }
266
+
267
+ if (rotation)
268
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
269
+ if (scale) *scale = 1.0f;
270
+
271
+ // Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
272
+ nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
273
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
274
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
275
+ }
276
+
277
+ NK_PUBLIC void nk_umeyama_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
278
+ nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
279
+ __m512i const idx_channel_group_i16x32 = NK_MESH_GENOA_CHANNEL_GROUP_INDICES_;
280
+ __m512i const idx_rotation_1_i16x32 = NK_MESH_GENOA_ROTATION_1_INDICES_;
281
+ __m512i const idx_rotation_2_i16x32 = NK_MESH_GENOA_ROTATION_2_INDICES_;
282
+ __m512i const ones_bf16x32 = _mm512_set1_epi16(0x3F80);
283
+
284
+ __m512 const zeros_f32x16 = _mm512_setzero_ps();
285
+ __m512 sum_a_f32x16 = zeros_f32x16, sum_b_f32x16 = zeros_f32x16;
286
+ __m512 norm_squared_a_f32x16 = zeros_f32x16, norm_squared_b_f32x16 = zeros_f32x16;
287
+ __m512 product_diagonal_f32x16 = zeros_f32x16;
288
+ __m512 product_rotation_1_f32x16 = zeros_f32x16;
289
+ __m512 product_rotation_2_f32x16 = zeros_f32x16;
290
+
291
+ __mmask32 const full_mask_bf16 = (__mmask32)0x3FFFFFFF;
292
+
293
+ nk_size_t index = 0;
294
+ for (; index + 10 <= n; index += 10) {
295
+ __m512i a_raw_bf16x32 = _mm512_maskz_loadu_epi16(full_mask_bf16, (__m512i const *)(a + index * 3));
296
+ __m512i b_raw_bf16x32 = _mm512_maskz_loadu_epi16(full_mask_bf16, (__m512i const *)(b + index * 3));
297
+ __m512i a_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, a_raw_bf16x32);
298
+ __m512i b_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, b_raw_bf16x32);
299
+ __m512i b_rotation_1_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_1_i16x32, b_raw_bf16x32);
300
+ __m512i b_rotation_2_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_2_i16x32, b_raw_bf16x32);
301
+
302
+ sum_a_f32x16 = _mm512_dpbf16_ps(sum_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
303
+ nk_m512bh_from_m512i_(ones_bf16x32));
304
+ sum_b_f32x16 = _mm512_dpbf16_ps(sum_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
305
+ nk_m512bh_from_m512i_(ones_bf16x32));
306
+ norm_squared_a_f32x16 = _mm512_dpbf16_ps(norm_squared_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
307
+ nk_m512bh_from_m512i_(a_grouped_bf16x32));
308
+ norm_squared_b_f32x16 = _mm512_dpbf16_ps(norm_squared_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
309
+ nk_m512bh_from_m512i_(b_grouped_bf16x32));
310
+ product_diagonal_f32x16 = _mm512_dpbf16_ps(product_diagonal_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
311
+ nk_m512bh_from_m512i_(b_grouped_bf16x32));
312
+ product_rotation_1_f32x16 = _mm512_dpbf16_ps(product_rotation_1_f32x16,
313
+ nk_m512bh_from_m512i_(a_grouped_bf16x32),
314
+ nk_m512bh_from_m512i_(b_rotation_1_bf16x32));
315
+ product_rotation_2_f32x16 = _mm512_dpbf16_ps(product_rotation_2_f32x16,
316
+ nk_m512bh_from_m512i_(a_grouped_bf16x32),
317
+ nk_m512bh_from_m512i_(b_rotation_2_bf16x32));
318
+ }
319
+
320
+ if (index < n) {
321
+ __mmask32 tail_mask = (__mmask32)_bzhi_u32(0x3FFFFFFF, (nk_u32_t)((n - index) * 3));
322
+ __m512i a_raw_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (__m512i const *)(a + index * 3));
323
+ __m512i b_raw_bf16x32 = _mm512_maskz_loadu_epi16(tail_mask, (__m512i const *)(b + index * 3));
324
+ __m512i a_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, a_raw_bf16x32);
325
+ __m512i b_grouped_bf16x32 = _mm512_permutexvar_epi16(idx_channel_group_i16x32, b_raw_bf16x32);
326
+ __m512i b_rotation_1_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_1_i16x32, b_raw_bf16x32);
327
+ __m512i b_rotation_2_bf16x32 = _mm512_permutexvar_epi16(idx_rotation_2_i16x32, b_raw_bf16x32);
328
+
329
+ sum_a_f32x16 = _mm512_dpbf16_ps(sum_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
330
+ nk_m512bh_from_m512i_(ones_bf16x32));
331
+ sum_b_f32x16 = _mm512_dpbf16_ps(sum_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
332
+ nk_m512bh_from_m512i_(ones_bf16x32));
333
+ norm_squared_a_f32x16 = _mm512_dpbf16_ps(norm_squared_a_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
334
+ nk_m512bh_from_m512i_(a_grouped_bf16x32));
335
+ norm_squared_b_f32x16 = _mm512_dpbf16_ps(norm_squared_b_f32x16, nk_m512bh_from_m512i_(b_grouped_bf16x32),
336
+ nk_m512bh_from_m512i_(b_grouped_bf16x32));
337
+ product_diagonal_f32x16 = _mm512_dpbf16_ps(product_diagonal_f32x16, nk_m512bh_from_m512i_(a_grouped_bf16x32),
338
+ nk_m512bh_from_m512i_(b_grouped_bf16x32));
339
+ product_rotation_1_f32x16 = _mm512_dpbf16_ps(product_rotation_1_f32x16,
340
+ nk_m512bh_from_m512i_(a_grouped_bf16x32),
341
+ nk_m512bh_from_m512i_(b_rotation_1_bf16x32));
342
+ product_rotation_2_f32x16 = _mm512_dpbf16_ps(product_rotation_2_f32x16,
343
+ nk_m512bh_from_m512i_(a_grouped_bf16x32),
344
+ nk_m512bh_from_m512i_(b_rotation_2_bf16x32));
345
+ }
346
+
347
+ __mmask16 const mask_channel_x_f32 = 0x001F;
348
+ __mmask16 const mask_channel_y_f32 = 0x03E0;
349
+ __mmask16 const mask_channel_z_f32 = 0x7C00;
350
+
351
+ nk_f32_t sum_a_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_a_f32x16);
352
+ nk_f32_t sum_a_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_a_f32x16);
353
+ nk_f32_t sum_a_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_a_f32x16);
354
+ nk_f32_t sum_b_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_b_f32x16);
355
+ nk_f32_t sum_b_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_b_f32x16);
356
+ nk_f32_t sum_b_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_b_f32x16);
357
+ nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
358
+ nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
359
+
360
+ nk_f32_t covariance_x_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_diagonal_f32x16);
361
+ nk_f32_t covariance_y_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_diagonal_f32x16);
362
+ nk_f32_t covariance_z_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_diagonal_f32x16);
363
+ nk_f32_t covariance_x_y = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_1_f32x16);
364
+ nk_f32_t covariance_y_z = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_1_f32x16);
365
+ nk_f32_t covariance_z_x = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_1_f32x16);
366
+ nk_f32_t covariance_x_z = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_2_f32x16);
367
+ nk_f32_t covariance_y_x = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_2_f32x16);
368
+ nk_f32_t covariance_z_y = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_2_f32x16);
369
+
370
+ nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
371
+ nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
372
+ nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
373
+ if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
374
+ if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
375
+
376
+ nk_f32_t cross_covariance[9];
377
+ cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
378
+ cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
379
+ cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
380
+ cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
381
+ cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
382
+ cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
383
+ cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
384
+ cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
385
+ cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
386
+
387
+ nk_f32_t centered_norm_squared_a = norm_squared_a -
388
+ (nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
389
+ centroid_a_z * centroid_a_z);
390
+ nk_f32_t centered_norm_squared_b = norm_squared_b -
391
+ (nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
392
+ centroid_b_z * centroid_b_z);
393
+ if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
394
+ if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
395
+
396
+ // Identity-dominant short-circuit.
397
+ nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
398
+ cross_covariance[4] * cross_covariance[4] +
399
+ cross_covariance[8] * cross_covariance[8];
400
+ nk_f32_t covariance_offdiagonal_norm_squared =
401
+ cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
402
+ cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
403
+ cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
404
+ nk_f32_t optimal_rotation[9];
405
+ nk_f32_t c;
406
+ nk_f32_t trace_rotation_covariance;
407
+ if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
408
+ cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
409
+ optimal_rotation[0] = 1.0f, optimal_rotation[1] = 0.0f, optimal_rotation[2] = 0.0f;
410
+ optimal_rotation[3] = 0.0f, optimal_rotation[4] = 1.0f, optimal_rotation[5] = 0.0f;
411
+ optimal_rotation[6] = 0.0f, optimal_rotation[7] = 0.0f, optimal_rotation[8] = 1.0f;
412
+ trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
413
+ c = centered_norm_squared_a > 0.0f ? trace_rotation_covariance / centered_norm_squared_a : 0.0f;
414
+ }
415
+ else {
416
+ nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
417
+ nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
418
+ nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
419
+
420
+ // Scale factor: c = trace(D · S) / ‖a-ā‖², with reflection sign via d3.
421
+ nk_f32_t det = nk_det3x3_f32_(optimal_rotation);
422
+ nk_f32_t d3 = det < 0.0f ? -1.0f : 1.0f;
423
+ nk_f32_t trace_ds = nk_sum_three_products_f32_(svd_diagonal[0], 1.0f, svd_diagonal[4], 1.0f, svd_diagonal[8],
424
+ d3);
425
+ c = centered_norm_squared_a > 0.0f ? trace_ds / centered_norm_squared_a : 0.0f;
426
+
427
+ if (det < 0.0f) {
428
+ svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
429
+ nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
430
+ }
431
+ trace_rotation_covariance =
432
+ optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
433
+ optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
434
+ optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
435
+ optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
436
+ optimal_rotation[8] * cross_covariance[8];
437
+ }
438
+
439
+ if (scale) *scale = c;
440
+ if (rotation)
441
+ for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
442
+
443
+ // Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
444
+ nk_f32_t sum_squared = c * c * centered_norm_squared_a + centered_norm_squared_b -
445
+ 2.0f * c * trace_rotation_covariance;
446
+ if (sum_squared < 0.0f) sum_squared = 0.0f;
447
+ *result = nk_f32_sqrt_haswell(sum_squared * inv_n);
448
+ }
449
+
450
+ #if defined(__clang__)
451
+ #pragma clang attribute pop
452
+ #elif defined(__GNUC__)
453
+ #pragma GCC pop_options
454
+ #endif
455
+
456
+ #if defined(__cplusplus)
457
+ } // extern "C"
458
+ #endif
459
+
460
+ #endif // NK_TARGET_GENOA
461
+ #endif // NK_TARGET_X8664_
462
+ #endif // NK_MESH_GENOA_H