numkong 7.5.0 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/binding.gyp +18 -0
- package/c/dispatch_e5m2.c +23 -3
- package/include/numkong/capabilities.h +1 -1
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +434 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +23 -8
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots.h +12 -0
- package/include/numkong/each/serial.h +18 -1
- package/include/numkong/geospatial/serial.h +14 -3
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +204 -162
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +128 -0
- package/include/numkong/spatials/serial.h +18 -1
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials.h +17 -0
- package/include/numkong/tensor.hpp +107 -23
- package/javascript/numkong.c +3 -2
- package/package.json +7 -7
- package/wasm/numkong.wasm +0 -0
|
@@ -88,10 +88,6 @@ NK_INTERNAL nk_f64_t nk_reduce_stable_f64x4_haswell_(__m256d values_f64x4) {
|
|
|
88
88
|
return sum + compensation;
|
|
89
89
|
}
|
|
90
90
|
|
|
91
|
-
NK_INTERNAL void nk_rotation_from_svd_f64_haswell_(nk_f64_t const *svd_u, nk_f64_t const *svd_v, nk_f64_t *rotation) {
|
|
92
|
-
nk_rotation_from_svd_f64_serial_(svd_u, svd_v, rotation);
|
|
93
|
-
}
|
|
94
|
-
|
|
95
91
|
NK_INTERNAL void nk_accumulate_square_f64x4_haswell_(__m256d *sum_f64x4, __m256d *compensation_f64x4,
|
|
96
92
|
__m256d values_f64x4) {
|
|
97
93
|
__m256d product_f64x4 = _mm256_mul_pd(values_f64x4, values_f64x4);
|
|
@@ -105,208 +101,6 @@ NK_INTERNAL void nk_accumulate_square_f64x4_haswell_(__m256d *sum_f64x4, __m256d
|
|
|
105
101
|
*compensation_f64x4 = _mm256_add_pd(*compensation_f64x4, _mm256_add_pd(sum_error_f64x4, product_error_f64x4));
|
|
106
102
|
}
|
|
107
103
|
|
|
108
|
-
/* Compute sum of squared distances after applying rotation (and optional scale).
|
|
109
|
-
* Used by kabsch (scale=1.0) and umeyama (scale=computed_scale).
|
|
110
|
-
* Returns sum_squared, caller computes sqrt(sum_squared / n).
|
|
111
|
-
*/
|
|
112
|
-
NK_INTERNAL nk_f64_t nk_transformed_ssd_f32_haswell_(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n,
|
|
113
|
-
nk_f64_t const *r, nk_f64_t scale, nk_f64_t centroid_a_x,
|
|
114
|
-
nk_f64_t centroid_a_y, nk_f64_t centroid_a_z,
|
|
115
|
-
nk_f64_t centroid_b_x, nk_f64_t centroid_b_y,
|
|
116
|
-
nk_f64_t centroid_b_z) {
|
|
117
|
-
__m256d scaled_rotation_x_x_f64x4 = _mm256_set1_pd(scale * r[0]);
|
|
118
|
-
__m256d scaled_rotation_x_y_f64x4 = _mm256_set1_pd(scale * r[1]);
|
|
119
|
-
__m256d scaled_rotation_x_z_f64x4 = _mm256_set1_pd(scale * r[2]);
|
|
120
|
-
__m256d scaled_rotation_y_x_f64x4 = _mm256_set1_pd(scale * r[3]);
|
|
121
|
-
__m256d scaled_rotation_y_y_f64x4 = _mm256_set1_pd(scale * r[4]);
|
|
122
|
-
__m256d scaled_rotation_y_z_f64x4 = _mm256_set1_pd(scale * r[5]);
|
|
123
|
-
__m256d scaled_rotation_z_x_f64x4 = _mm256_set1_pd(scale * r[6]);
|
|
124
|
-
__m256d scaled_rotation_z_y_f64x4 = _mm256_set1_pd(scale * r[7]);
|
|
125
|
-
__m256d scaled_rotation_z_z_f64x4 = _mm256_set1_pd(scale * r[8]);
|
|
126
|
-
__m256d centroid_a_x_f64x4 = _mm256_set1_pd(centroid_a_x), centroid_a_y_f64x4 = _mm256_set1_pd(centroid_a_y);
|
|
127
|
-
__m256d centroid_a_z_f64x4 = _mm256_set1_pd(centroid_a_z), centroid_b_x_f64x4 = _mm256_set1_pd(centroid_b_x);
|
|
128
|
-
__m256d centroid_b_y_f64x4 = _mm256_set1_pd(centroid_b_y), centroid_b_z_f64x4 = _mm256_set1_pd(centroid_b_z);
|
|
129
|
-
__m256d sum_squared_f64x4 = _mm256_setzero_pd();
|
|
130
|
-
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
131
|
-
nk_size_t index = 0;
|
|
132
|
-
|
|
133
|
-
for (; index + 8 <= n; index += 8) {
|
|
134
|
-
nk_deinterleave_f32x8_haswell_(a + index * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8),
|
|
135
|
-
nk_deinterleave_f32x8_haswell_(b + index * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
136
|
-
|
|
137
|
-
__m256d a_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_x_f32x8));
|
|
138
|
-
__m256d a_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_x_f32x8, 1));
|
|
139
|
-
__m256d a_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_y_f32x8));
|
|
140
|
-
__m256d a_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_y_f32x8, 1));
|
|
141
|
-
__m256d a_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(a_z_f32x8));
|
|
142
|
-
__m256d a_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(a_z_f32x8, 1));
|
|
143
|
-
__m256d b_x_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_x_f32x8));
|
|
144
|
-
__m256d b_x_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_x_f32x8, 1));
|
|
145
|
-
__m256d b_y_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_y_f32x8));
|
|
146
|
-
__m256d b_y_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_y_f32x8, 1));
|
|
147
|
-
__m256d b_z_low_f64x4 = _mm256_cvtps_pd(_mm256_castps256_ps128(b_z_f32x8));
|
|
148
|
-
__m256d b_z_high_f64x4 = _mm256_cvtps_pd(_mm256_extractf128_ps(b_z_f32x8, 1));
|
|
149
|
-
|
|
150
|
-
__m256d centered_a_x_low_f64x4 = _mm256_sub_pd(a_x_low_f64x4, centroid_a_x_f64x4);
|
|
151
|
-
__m256d centered_a_x_high_f64x4 = _mm256_sub_pd(a_x_high_f64x4, centroid_a_x_f64x4);
|
|
152
|
-
__m256d centered_a_y_low_f64x4 = _mm256_sub_pd(a_y_low_f64x4, centroid_a_y_f64x4);
|
|
153
|
-
__m256d centered_a_y_high_f64x4 = _mm256_sub_pd(a_y_high_f64x4, centroid_a_y_f64x4);
|
|
154
|
-
__m256d centered_a_z_low_f64x4 = _mm256_sub_pd(a_z_low_f64x4, centroid_a_z_f64x4);
|
|
155
|
-
__m256d centered_a_z_high_f64x4 = _mm256_sub_pd(a_z_high_f64x4, centroid_a_z_f64x4);
|
|
156
|
-
__m256d centered_b_x_low_f64x4 = _mm256_sub_pd(b_x_low_f64x4, centroid_b_x_f64x4);
|
|
157
|
-
__m256d centered_b_x_high_f64x4 = _mm256_sub_pd(b_x_high_f64x4, centroid_b_x_f64x4);
|
|
158
|
-
__m256d centered_b_y_low_f64x4 = _mm256_sub_pd(b_y_low_f64x4, centroid_b_y_f64x4);
|
|
159
|
-
__m256d centered_b_y_high_f64x4 = _mm256_sub_pd(b_y_high_f64x4, centroid_b_y_f64x4);
|
|
160
|
-
__m256d centered_b_z_low_f64x4 = _mm256_sub_pd(b_z_low_f64x4, centroid_b_z_f64x4);
|
|
161
|
-
__m256d centered_b_z_high_f64x4 = _mm256_sub_pd(b_z_high_f64x4, centroid_b_z_f64x4);
|
|
162
|
-
|
|
163
|
-
__m256d rotated_a_x_low_f64x4 = _mm256_fmadd_pd(
|
|
164
|
-
scaled_rotation_x_z_f64x4, centered_a_z_low_f64x4,
|
|
165
|
-
_mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_low_f64x4,
|
|
166
|
-
_mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_low_f64x4)));
|
|
167
|
-
__m256d rotated_a_x_high_f64x4 = _mm256_fmadd_pd(
|
|
168
|
-
scaled_rotation_x_z_f64x4, centered_a_z_high_f64x4,
|
|
169
|
-
_mm256_fmadd_pd(scaled_rotation_x_y_f64x4, centered_a_y_high_f64x4,
|
|
170
|
-
_mm256_mul_pd(scaled_rotation_x_x_f64x4, centered_a_x_high_f64x4)));
|
|
171
|
-
__m256d rotated_a_y_low_f64x4 = _mm256_fmadd_pd(
|
|
172
|
-
scaled_rotation_y_z_f64x4, centered_a_z_low_f64x4,
|
|
173
|
-
_mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_low_f64x4,
|
|
174
|
-
_mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_low_f64x4)));
|
|
175
|
-
__m256d rotated_a_y_high_f64x4 = _mm256_fmadd_pd(
|
|
176
|
-
scaled_rotation_y_z_f64x4, centered_a_z_high_f64x4,
|
|
177
|
-
_mm256_fmadd_pd(scaled_rotation_y_y_f64x4, centered_a_y_high_f64x4,
|
|
178
|
-
_mm256_mul_pd(scaled_rotation_y_x_f64x4, centered_a_x_high_f64x4)));
|
|
179
|
-
__m256d rotated_a_z_low_f64x4 = _mm256_fmadd_pd(
|
|
180
|
-
scaled_rotation_z_z_f64x4, centered_a_z_low_f64x4,
|
|
181
|
-
_mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_low_f64x4,
|
|
182
|
-
_mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_low_f64x4)));
|
|
183
|
-
__m256d rotated_a_z_high_f64x4 = _mm256_fmadd_pd(
|
|
184
|
-
scaled_rotation_z_z_f64x4, centered_a_z_high_f64x4,
|
|
185
|
-
_mm256_fmadd_pd(scaled_rotation_z_y_f64x4, centered_a_y_high_f64x4,
|
|
186
|
-
_mm256_mul_pd(scaled_rotation_z_x_f64x4, centered_a_x_high_f64x4)));
|
|
187
|
-
|
|
188
|
-
__m256d delta_x_low_f64x4 = _mm256_sub_pd(rotated_a_x_low_f64x4, centered_b_x_low_f64x4);
|
|
189
|
-
__m256d delta_x_high_f64x4 = _mm256_sub_pd(rotated_a_x_high_f64x4, centered_b_x_high_f64x4);
|
|
190
|
-
__m256d delta_y_low_f64x4 = _mm256_sub_pd(rotated_a_y_low_f64x4, centered_b_y_low_f64x4);
|
|
191
|
-
__m256d delta_y_high_f64x4 = _mm256_sub_pd(rotated_a_y_high_f64x4, centered_b_y_high_f64x4);
|
|
192
|
-
__m256d delta_z_low_f64x4 = _mm256_sub_pd(rotated_a_z_low_f64x4, centered_b_z_low_f64x4);
|
|
193
|
-
__m256d delta_z_high_f64x4 = _mm256_sub_pd(rotated_a_z_high_f64x4, centered_b_z_high_f64x4);
|
|
194
|
-
|
|
195
|
-
__m256d batch_sum_squared_f64x4 = _mm256_add_pd(_mm256_mul_pd(delta_x_low_f64x4, delta_x_low_f64x4),
|
|
196
|
-
_mm256_mul_pd(delta_x_high_f64x4, delta_x_high_f64x4));
|
|
197
|
-
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_low_f64x4, delta_y_low_f64x4, batch_sum_squared_f64x4);
|
|
198
|
-
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_y_high_f64x4, delta_y_high_f64x4, batch_sum_squared_f64x4);
|
|
199
|
-
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_low_f64x4, delta_z_low_f64x4, batch_sum_squared_f64x4);
|
|
200
|
-
batch_sum_squared_f64x4 = _mm256_fmadd_pd(delta_z_high_f64x4, delta_z_high_f64x4, batch_sum_squared_f64x4);
|
|
201
|
-
sum_squared_f64x4 = _mm256_add_pd(sum_squared_f64x4, batch_sum_squared_f64x4);
|
|
202
|
-
}
|
|
203
|
-
|
|
204
|
-
nk_f64_t sum_squared = nk_reduce_add_f64x4_haswell_(sum_squared_f64x4);
|
|
205
|
-
for (; index < n; ++index) {
|
|
206
|
-
nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x,
|
|
207
|
-
centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y,
|
|
208
|
-
centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
|
|
209
|
-
nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x,
|
|
210
|
-
centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y,
|
|
211
|
-
centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
|
|
212
|
-
nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z),
|
|
213
|
-
rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z),
|
|
214
|
-
rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
|
|
215
|
-
nk_f64_t delta_x = rotated_a_x - centered_b_x, delta_y = rotated_a_y - centered_b_y,
|
|
216
|
-
delta_z = rotated_a_z - centered_b_z;
|
|
217
|
-
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
218
|
-
}
|
|
219
|
-
|
|
220
|
-
return sum_squared;
|
|
221
|
-
}
|
|
222
|
-
|
|
223
|
-
/* Compute sum of squared distances for f64 after applying rotation (and optional scale).
|
|
224
|
-
* Rotation matrix, scale and data are all f64 for full precision.
|
|
225
|
-
*/
|
|
226
|
-
NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_haswell_(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n,
|
|
227
|
-
nk_f64_t const *r, nk_f64_t scale, nk_f64_t centroid_a_x,
|
|
228
|
-
nk_f64_t centroid_a_y, nk_f64_t centroid_a_z,
|
|
229
|
-
nk_f64_t centroid_b_x, nk_f64_t centroid_b_y,
|
|
230
|
-
nk_f64_t centroid_b_z) {
|
|
231
|
-
// Broadcast scaled rotation matrix elements
|
|
232
|
-
__m256d scaled_rotation_x_x_f64x4 = _mm256_set1_pd(scale * r[0]);
|
|
233
|
-
__m256d scaled_rotation_x_y_f64x4 = _mm256_set1_pd(scale * r[1]);
|
|
234
|
-
__m256d scaled_rotation_x_z_f64x4 = _mm256_set1_pd(scale * r[2]);
|
|
235
|
-
__m256d scaled_rotation_y_x_f64x4 = _mm256_set1_pd(scale * r[3]);
|
|
236
|
-
__m256d scaled_rotation_y_y_f64x4 = _mm256_set1_pd(scale * r[4]);
|
|
237
|
-
__m256d scaled_rotation_y_z_f64x4 = _mm256_set1_pd(scale * r[5]);
|
|
238
|
-
__m256d scaled_rotation_z_x_f64x4 = _mm256_set1_pd(scale * r[6]);
|
|
239
|
-
__m256d scaled_rotation_z_y_f64x4 = _mm256_set1_pd(scale * r[7]);
|
|
240
|
-
__m256d scaled_rotation_z_z_f64x4 = _mm256_set1_pd(scale * r[8]);
|
|
241
|
-
|
|
242
|
-
// Broadcast centroids
|
|
243
|
-
__m256d centroid_a_x_f64x4 = _mm256_set1_pd(centroid_a_x);
|
|
244
|
-
__m256d centroid_a_y_f64x4 = _mm256_set1_pd(centroid_a_y);
|
|
245
|
-
__m256d centroid_a_z_f64x4 = _mm256_set1_pd(centroid_a_z);
|
|
246
|
-
__m256d centroid_b_x_f64x4 = _mm256_set1_pd(centroid_b_x);
|
|
247
|
-
__m256d centroid_b_y_f64x4 = _mm256_set1_pd(centroid_b_y);
|
|
248
|
-
__m256d centroid_b_z_f64x4 = _mm256_set1_pd(centroid_b_z);
|
|
249
|
-
|
|
250
|
-
__m256d sum_squared_f64x4 = _mm256_setzero_pd();
|
|
251
|
-
__m256d sum_squared_compensation_f64x4 = _mm256_setzero_pd();
|
|
252
|
-
__m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
|
|
253
|
-
nk_size_t j = 0;
|
|
254
|
-
|
|
255
|
-
for (; j + 4 <= n; j += 4) {
|
|
256
|
-
nk_deinterleave_f64x4_haswell_(a + j * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
|
|
257
|
-
nk_deinterleave_f64x4_haswell_(b + j * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
|
|
258
|
-
|
|
259
|
-
// Center points
|
|
260
|
-
__m256d pa_x_f64x4 = _mm256_sub_pd(a_x_f64x4, centroid_a_x_f64x4);
|
|
261
|
-
__m256d pa_y_f64x4 = _mm256_sub_pd(a_y_f64x4, centroid_a_y_f64x4);
|
|
262
|
-
__m256d pa_z_f64x4 = _mm256_sub_pd(a_z_f64x4, centroid_a_z_f64x4);
|
|
263
|
-
__m256d pb_x_f64x4 = _mm256_sub_pd(b_x_f64x4, centroid_b_x_f64x4);
|
|
264
|
-
__m256d pb_y_f64x4 = _mm256_sub_pd(b_y_f64x4, centroid_b_y_f64x4);
|
|
265
|
-
__m256d pb_z_f64x4 = _mm256_sub_pd(b_z_f64x4, centroid_b_z_f64x4);
|
|
266
|
-
|
|
267
|
-
// Rotate and scale: ra = scale * R * pa
|
|
268
|
-
__m256d ra_x_f64x4 = _mm256_fmadd_pd(scaled_rotation_x_z_f64x4, pa_z_f64x4,
|
|
269
|
-
_mm256_fmadd_pd(scaled_rotation_x_y_f64x4, pa_y_f64x4,
|
|
270
|
-
_mm256_mul_pd(scaled_rotation_x_x_f64x4, pa_x_f64x4)));
|
|
271
|
-
__m256d ra_y_f64x4 = _mm256_fmadd_pd(scaled_rotation_y_z_f64x4, pa_z_f64x4,
|
|
272
|
-
_mm256_fmadd_pd(scaled_rotation_y_y_f64x4, pa_y_f64x4,
|
|
273
|
-
_mm256_mul_pd(scaled_rotation_y_x_f64x4, pa_x_f64x4)));
|
|
274
|
-
__m256d ra_z_f64x4 = _mm256_fmadd_pd(scaled_rotation_z_z_f64x4, pa_z_f64x4,
|
|
275
|
-
_mm256_fmadd_pd(scaled_rotation_z_y_f64x4, pa_y_f64x4,
|
|
276
|
-
_mm256_mul_pd(scaled_rotation_z_x_f64x4, pa_x_f64x4)));
|
|
277
|
-
|
|
278
|
-
// Delta and accumulate
|
|
279
|
-
__m256d delta_x_f64x4 = _mm256_sub_pd(ra_x_f64x4, pb_x_f64x4);
|
|
280
|
-
__m256d delta_y_f64x4 = _mm256_sub_pd(ra_y_f64x4, pb_y_f64x4);
|
|
281
|
-
__m256d delta_z_f64x4 = _mm256_sub_pd(ra_z_f64x4, pb_z_f64x4);
|
|
282
|
-
|
|
283
|
-
nk_accumulate_square_f64x4_haswell_(&sum_squared_f64x4, &sum_squared_compensation_f64x4, delta_x_f64x4);
|
|
284
|
-
nk_accumulate_square_f64x4_haswell_(&sum_squared_f64x4, &sum_squared_compensation_f64x4, delta_y_f64x4);
|
|
285
|
-
nk_accumulate_square_f64x4_haswell_(&sum_squared_f64x4, &sum_squared_compensation_f64x4, delta_z_f64x4);
|
|
286
|
-
}
|
|
287
|
-
|
|
288
|
-
nk_f64_t sum_squared = nk_dot_stable_sum_f64x4_haswell_(sum_squared_f64x4, sum_squared_compensation_f64x4);
|
|
289
|
-
nk_f64_t sum_squared_compensation = 0.0;
|
|
290
|
-
|
|
291
|
-
// Scalar tail
|
|
292
|
-
for (; j < n; ++j) {
|
|
293
|
-
nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x, pa_y = a[j * 3 + 1] - centroid_a_y,
|
|
294
|
-
pa_z = a[j * 3 + 2] - centroid_a_z;
|
|
295
|
-
nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x, pb_y = b[j * 3 + 1] - centroid_b_y,
|
|
296
|
-
pb_z = b[j * 3 + 2] - centroid_b_z;
|
|
297
|
-
nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
|
|
298
|
-
ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
|
|
299
|
-
ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
300
|
-
|
|
301
|
-
nk_f64_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
|
|
302
|
-
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_x);
|
|
303
|
-
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_y);
|
|
304
|
-
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_z);
|
|
305
|
-
}
|
|
306
|
-
|
|
307
|
-
return sum_squared + sum_squared_compensation;
|
|
308
|
-
}
|
|
309
|
-
|
|
310
104
|
NK_PUBLIC void nk_rmsd_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
311
105
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
|
|
312
106
|
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
@@ -441,6 +235,7 @@ NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_si
|
|
|
441
235
|
__m256d covariance_11_f64x4 = _mm256_setzero_pd(), covariance_12_f64x4 = _mm256_setzero_pd();
|
|
442
236
|
__m256d covariance_20_f64x4 = _mm256_setzero_pd(), covariance_21_f64x4 = _mm256_setzero_pd();
|
|
443
237
|
__m256d covariance_22_f64x4 = _mm256_setzero_pd();
|
|
238
|
+
__m256d norm_squared_a_f64x4 = _mm256_setzero_pd(), norm_squared_b_f64x4 = _mm256_setzero_pd();
|
|
444
239
|
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
445
240
|
nk_size_t index = 0;
|
|
446
241
|
|
|
@@ -494,6 +289,24 @@ NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_si
|
|
|
494
289
|
covariance_22_f64x4 = _mm256_add_pd(
|
|
495
290
|
covariance_22_f64x4,
|
|
496
291
|
_mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_z_high_f64x4)));
|
|
292
|
+
norm_squared_a_f64x4 = _mm256_add_pd(
|
|
293
|
+
norm_squared_a_f64x4,
|
|
294
|
+
_mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, a_x_low_f64x4), _mm256_mul_pd(a_x_high_f64x4, a_x_high_f64x4)));
|
|
295
|
+
norm_squared_a_f64x4 = _mm256_add_pd(
|
|
296
|
+
norm_squared_a_f64x4,
|
|
297
|
+
_mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, a_y_low_f64x4), _mm256_mul_pd(a_y_high_f64x4, a_y_high_f64x4)));
|
|
298
|
+
norm_squared_a_f64x4 = _mm256_add_pd(
|
|
299
|
+
norm_squared_a_f64x4,
|
|
300
|
+
_mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, a_z_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, a_z_high_f64x4)));
|
|
301
|
+
norm_squared_b_f64x4 = _mm256_add_pd(
|
|
302
|
+
norm_squared_b_f64x4,
|
|
303
|
+
_mm256_add_pd(_mm256_mul_pd(b_x_low_f64x4, b_x_low_f64x4), _mm256_mul_pd(b_x_high_f64x4, b_x_high_f64x4)));
|
|
304
|
+
norm_squared_b_f64x4 = _mm256_add_pd(
|
|
305
|
+
norm_squared_b_f64x4,
|
|
306
|
+
_mm256_add_pd(_mm256_mul_pd(b_y_low_f64x4, b_y_low_f64x4), _mm256_mul_pd(b_y_high_f64x4, b_y_high_f64x4)));
|
|
307
|
+
norm_squared_b_f64x4 = _mm256_add_pd(
|
|
308
|
+
norm_squared_b_f64x4,
|
|
309
|
+
_mm256_add_pd(_mm256_mul_pd(b_z_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(b_z_high_f64x4, b_z_high_f64x4)));
|
|
497
310
|
}
|
|
498
311
|
|
|
499
312
|
nk_f64_t sum_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
|
|
@@ -502,21 +315,25 @@ NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_si
|
|
|
502
315
|
nk_f64_t sum_b_x = nk_reduce_add_f64x4_haswell_(sum_b_x_f64x4);
|
|
503
316
|
nk_f64_t sum_b_y = nk_reduce_add_f64x4_haswell_(sum_b_y_f64x4);
|
|
504
317
|
nk_f64_t sum_b_z = nk_reduce_add_f64x4_haswell_(sum_b_z_f64x4);
|
|
505
|
-
nk_f64_t
|
|
318
|
+
nk_f64_t cross_covariance[9] = {
|
|
506
319
|
nk_reduce_add_f64x4_haswell_(covariance_00_f64x4), nk_reduce_add_f64x4_haswell_(covariance_01_f64x4),
|
|
507
320
|
nk_reduce_add_f64x4_haswell_(covariance_02_f64x4), nk_reduce_add_f64x4_haswell_(covariance_10_f64x4),
|
|
508
321
|
nk_reduce_add_f64x4_haswell_(covariance_11_f64x4), nk_reduce_add_f64x4_haswell_(covariance_12_f64x4),
|
|
509
322
|
nk_reduce_add_f64x4_haswell_(covariance_20_f64x4), nk_reduce_add_f64x4_haswell_(covariance_21_f64x4),
|
|
510
323
|
nk_reduce_add_f64x4_haswell_(covariance_22_f64x4)};
|
|
324
|
+
nk_f64_t norm_squared_a = nk_reduce_add_f64x4_haswell_(norm_squared_a_f64x4);
|
|
325
|
+
nk_f64_t norm_squared_b = nk_reduce_add_f64x4_haswell_(norm_squared_b_f64x4);
|
|
511
326
|
|
|
512
327
|
for (; index < n; ++index) {
|
|
513
328
|
nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
|
|
514
329
|
nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
|
|
515
330
|
sum_a_x += a_x, sum_a_y += a_y, sum_a_z += a_z;
|
|
516
331
|
sum_b_x += b_x, sum_b_y += b_y, sum_b_z += b_z;
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
332
|
+
cross_covariance[0] += a_x * b_x, cross_covariance[1] += a_x * b_y, cross_covariance[2] += a_x * b_z;
|
|
333
|
+
cross_covariance[3] += a_y * b_x, cross_covariance[4] += a_y * b_y, cross_covariance[5] += a_y * b_z;
|
|
334
|
+
cross_covariance[6] += a_z * b_x, cross_covariance[7] += a_z * b_y, cross_covariance[8] += a_z * b_z;
|
|
335
|
+
norm_squared_a += a_x * a_x + a_y * a_y + a_z * a_z;
|
|
336
|
+
norm_squared_b += b_x * b_x + b_y * b_y + b_z * b_z;
|
|
520
337
|
}
|
|
521
338
|
|
|
522
339
|
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
@@ -529,41 +346,81 @@ NK_PUBLIC void nk_kabsch_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_si
|
|
|
529
346
|
b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
|
|
530
347
|
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
531
348
|
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
349
|
+
cross_covariance[0] -= (nk_f64_t)n * centroid_a_x * centroid_b_x,
|
|
350
|
+
cross_covariance[1] -= (nk_f64_t)n * centroid_a_x * centroid_b_y,
|
|
351
|
+
cross_covariance[2] -= (nk_f64_t)n * centroid_a_x * centroid_b_z,
|
|
352
|
+
cross_covariance[3] -= (nk_f64_t)n * centroid_a_y * centroid_b_x,
|
|
353
|
+
cross_covariance[4] -= (nk_f64_t)n * centroid_a_y * centroid_b_y,
|
|
354
|
+
cross_covariance[5] -= (nk_f64_t)n * centroid_a_y * centroid_b_z,
|
|
355
|
+
cross_covariance[6] -= (nk_f64_t)n * centroid_a_z * centroid_b_x,
|
|
356
|
+
cross_covariance[7] -= (nk_f64_t)n * centroid_a_z * centroid_b_y,
|
|
357
|
+
cross_covariance[8] -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
|
|
358
|
+
|
|
359
|
+
nk_f64_t centered_norm_squared_a = norm_squared_a -
|
|
360
|
+
(nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
361
|
+
centroid_a_z * centroid_a_z);
|
|
362
|
+
nk_f64_t centered_norm_squared_b = norm_squared_b -
|
|
363
|
+
(nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
364
|
+
centroid_b_z * centroid_b_z);
|
|
365
|
+
if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
|
|
366
|
+
if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
|
|
367
|
+
|
|
368
|
+
// Identity-dominant short-circuit: R = I, trace(R * H) = H[0]+H[4]+H[8]. Skips SVD + two
|
|
369
|
+
// rotation_from_svd reconstructions when the inputs are already aligned.
|
|
370
|
+
nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
371
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
372
|
+
cross_covariance[8] * cross_covariance[8];
|
|
373
|
+
nk_f64_t covariance_offdiagonal_norm_squared =
|
|
374
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
375
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
376
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
377
|
+
nk_f64_t optimal_rotation[9];
|
|
378
|
+
nk_f64_t trace_rotation_covariance;
|
|
379
|
+
if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
|
|
380
|
+
cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
|
|
381
|
+
optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
|
|
382
|
+
optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
|
|
383
|
+
optimal_rotation[8] = 1;
|
|
384
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
385
|
+
}
|
|
386
|
+
else {
|
|
387
|
+
nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
388
|
+
nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
389
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
390
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
391
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
392
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
393
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
394
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
395
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
396
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
397
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
398
|
+
if (nk_det3x3_f64_(optimal_rotation) < 0) {
|
|
399
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
400
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
401
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
402
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
403
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
404
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
405
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
406
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
407
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
408
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
409
|
+
}
|
|
410
|
+
trace_rotation_covariance =
|
|
411
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
412
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
413
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
414
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
415
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
561
416
|
}
|
|
562
417
|
|
|
563
418
|
if (rotation)
|
|
564
|
-
for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)
|
|
565
|
-
|
|
566
|
-
|
|
419
|
+
for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)optimal_rotation[j];
|
|
420
|
+
|
|
421
|
+
// Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
|
|
422
|
+
nk_f64_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0 * trace_rotation_covariance;
|
|
423
|
+
if (sum_squared < 0.0) sum_squared = 0.0;
|
|
567
424
|
*result = nk_f64_sqrt_haswell(sum_squared / (nk_f64_t)n);
|
|
568
425
|
}
|
|
569
426
|
|
|
@@ -576,14 +433,15 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
576
433
|
__m256d sum_b_x_f64x4 = zeros_f64x4, sum_b_y_f64x4 = zeros_f64x4, sum_b_z_f64x4 = zeros_f64x4;
|
|
577
434
|
|
|
578
435
|
// Accumulators for covariance matrix (sum of outer products)
|
|
579
|
-
__m256d
|
|
580
|
-
__m256d
|
|
581
|
-
__m256d
|
|
436
|
+
__m256d covariance_xx_f64x4 = zeros_f64x4, covariance_xy_f64x4 = zeros_f64x4, covariance_xz_f64x4 = zeros_f64x4;
|
|
437
|
+
__m256d covariance_yx_f64x4 = zeros_f64x4, covariance_yy_f64x4 = zeros_f64x4, covariance_yz_f64x4 = zeros_f64x4;
|
|
438
|
+
__m256d covariance_zx_f64x4 = zeros_f64x4, covariance_zy_f64x4 = zeros_f64x4, covariance_zz_f64x4 = zeros_f64x4;
|
|
439
|
+
__m256d norm_squared_a_f64x4 = zeros_f64x4, norm_squared_b_f64x4 = zeros_f64x4;
|
|
582
440
|
|
|
583
441
|
nk_size_t i = 0;
|
|
584
442
|
__m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
|
|
585
443
|
|
|
586
|
-
// Fused single-pass
|
|
444
|
+
// Fused single-pass (centroids + covariance + norm-squared for folded SSD)
|
|
587
445
|
for (; i + 4 <= n; i += 4) {
|
|
588
446
|
nk_deinterleave_f64x4_haswell_(a + i * 3, &a_x_f64x4, &a_y_f64x4, &a_z_f64x4);
|
|
589
447
|
nk_deinterleave_f64x4_haswell_(b + i * 3, &b_x_f64x4, &b_y_f64x4, &b_z_f64x4);
|
|
@@ -595,15 +453,21 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
595
453
|
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
|
|
596
454
|
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
|
|
597
455
|
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
456
|
+
covariance_xx_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_x_f64x4, covariance_xx_f64x4);
|
|
457
|
+
covariance_xy_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_y_f64x4, covariance_xy_f64x4);
|
|
458
|
+
covariance_xz_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_z_f64x4, covariance_xz_f64x4);
|
|
459
|
+
covariance_yx_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_x_f64x4, covariance_yx_f64x4);
|
|
460
|
+
covariance_yy_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_y_f64x4, covariance_yy_f64x4);
|
|
461
|
+
covariance_yz_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_z_f64x4, covariance_yz_f64x4);
|
|
462
|
+
covariance_zx_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_x_f64x4, covariance_zx_f64x4);
|
|
463
|
+
covariance_zy_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_y_f64x4, covariance_zy_f64x4);
|
|
464
|
+
covariance_zz_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_z_f64x4, covariance_zz_f64x4);
|
|
465
|
+
norm_squared_a_f64x4 = _mm256_fmadd_pd(a_x_f64x4, a_x_f64x4, norm_squared_a_f64x4);
|
|
466
|
+
norm_squared_a_f64x4 = _mm256_fmadd_pd(a_y_f64x4, a_y_f64x4, norm_squared_a_f64x4);
|
|
467
|
+
norm_squared_a_f64x4 = _mm256_fmadd_pd(a_z_f64x4, a_z_f64x4, norm_squared_a_f64x4);
|
|
468
|
+
norm_squared_b_f64x4 = _mm256_fmadd_pd(b_x_f64x4, b_x_f64x4, norm_squared_b_f64x4);
|
|
469
|
+
norm_squared_b_f64x4 = _mm256_fmadd_pd(b_y_f64x4, b_y_f64x4, norm_squared_b_f64x4);
|
|
470
|
+
norm_squared_b_f64x4 = _mm256_fmadd_pd(b_z_f64x4, b_z_f64x4, norm_squared_b_f64x4);
|
|
607
471
|
}
|
|
608
472
|
|
|
609
473
|
// Reduce vector accumulators
|
|
@@ -614,15 +478,19 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
614
478
|
nk_f64_t sum_b_y = nk_reduce_stable_f64x4_haswell_(sum_b_y_f64x4), sum_b_y_compensation = 0.0;
|
|
615
479
|
nk_f64_t sum_b_z = nk_reduce_stable_f64x4_haswell_(sum_b_z_f64x4), sum_b_z_compensation = 0.0;
|
|
616
480
|
|
|
617
|
-
nk_f64_t covariance_x_x = nk_reduce_stable_f64x4_haswell_(
|
|
618
|
-
nk_f64_t covariance_x_y = nk_reduce_stable_f64x4_haswell_(
|
|
619
|
-
nk_f64_t covariance_x_z = nk_reduce_stable_f64x4_haswell_(
|
|
620
|
-
nk_f64_t covariance_y_x = nk_reduce_stable_f64x4_haswell_(
|
|
621
|
-
nk_f64_t covariance_y_y = nk_reduce_stable_f64x4_haswell_(
|
|
622
|
-
nk_f64_t covariance_y_z = nk_reduce_stable_f64x4_haswell_(
|
|
623
|
-
nk_f64_t covariance_z_x = nk_reduce_stable_f64x4_haswell_(
|
|
624
|
-
nk_f64_t covariance_z_y = nk_reduce_stable_f64x4_haswell_(
|
|
625
|
-
nk_f64_t covariance_z_z = nk_reduce_stable_f64x4_haswell_(
|
|
481
|
+
nk_f64_t covariance_x_x = nk_reduce_stable_f64x4_haswell_(covariance_xx_f64x4), covariance_x_x_compensation = 0.0;
|
|
482
|
+
nk_f64_t covariance_x_y = nk_reduce_stable_f64x4_haswell_(covariance_xy_f64x4), covariance_x_y_compensation = 0.0;
|
|
483
|
+
nk_f64_t covariance_x_z = nk_reduce_stable_f64x4_haswell_(covariance_xz_f64x4), covariance_x_z_compensation = 0.0;
|
|
484
|
+
nk_f64_t covariance_y_x = nk_reduce_stable_f64x4_haswell_(covariance_yx_f64x4), covariance_y_x_compensation = 0.0;
|
|
485
|
+
nk_f64_t covariance_y_y = nk_reduce_stable_f64x4_haswell_(covariance_yy_f64x4), covariance_y_y_compensation = 0.0;
|
|
486
|
+
nk_f64_t covariance_y_z = nk_reduce_stable_f64x4_haswell_(covariance_yz_f64x4), covariance_y_z_compensation = 0.0;
|
|
487
|
+
nk_f64_t covariance_z_x = nk_reduce_stable_f64x4_haswell_(covariance_zx_f64x4), covariance_z_x_compensation = 0.0;
|
|
488
|
+
nk_f64_t covariance_z_y = nk_reduce_stable_f64x4_haswell_(covariance_zy_f64x4), covariance_z_y_compensation = 0.0;
|
|
489
|
+
nk_f64_t covariance_z_z = nk_reduce_stable_f64x4_haswell_(covariance_zz_f64x4), covariance_z_z_compensation = 0.0;
|
|
490
|
+
nk_f64_t norm_squared_a_sum = nk_reduce_stable_f64x4_haswell_(norm_squared_a_f64x4),
|
|
491
|
+
norm_squared_a_compensation = 0.0;
|
|
492
|
+
nk_f64_t norm_squared_b_sum = nk_reduce_stable_f64x4_haswell_(norm_squared_b_f64x4),
|
|
493
|
+
norm_squared_b_compensation = 0.0;
|
|
626
494
|
|
|
627
495
|
// Scalar tail
|
|
628
496
|
for (; i < n; ++i) {
|
|
@@ -643,6 +511,12 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
643
511
|
nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx);
|
|
644
512
|
nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by);
|
|
645
513
|
nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
|
|
514
|
+
nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ax);
|
|
515
|
+
nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ay);
|
|
516
|
+
nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, az);
|
|
517
|
+
nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bx);
|
|
518
|
+
nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, by);
|
|
519
|
+
nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bz);
|
|
646
520
|
}
|
|
647
521
|
|
|
648
522
|
sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
|
|
@@ -653,6 +527,8 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
653
527
|
covariance_y_z += covariance_y_z_compensation;
|
|
654
528
|
covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
|
|
655
529
|
covariance_z_z += covariance_z_z_compensation;
|
|
530
|
+
norm_squared_a_sum += norm_squared_a_compensation;
|
|
531
|
+
norm_squared_b_sum += norm_squared_b_compensation;
|
|
656
532
|
|
|
657
533
|
// Compute centroids
|
|
658
534
|
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
@@ -677,29 +553,59 @@ NK_PUBLIC void nk_kabsch_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
677
553
|
covariance_z_y -= (nk_f64_t)n * centroid_a_z * centroid_b_y;
|
|
678
554
|
covariance_z_z -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
|
|
679
555
|
|
|
680
|
-
// Compute SVD and optimal rotation using f64 precision (svd_s is 9-element diagonal matrix)
|
|
681
556
|
nk_f64_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
682
557
|
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
683
|
-
nk_f64_t svd_u[9], svd_s[9], svd_v[9];
|
|
684
|
-
nk_svd3x3_f64_(cross_covariance, svd_u, svd_s, svd_v);
|
|
685
558
|
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
559
|
+
// Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
|
|
560
|
+
nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
|
|
561
|
+
(nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
562
|
+
centroid_a_z * centroid_a_z);
|
|
563
|
+
nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
|
|
564
|
+
(nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
565
|
+
centroid_b_z * centroid_b_z);
|
|
566
|
+
if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
|
|
567
|
+
if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
|
|
568
|
+
|
|
569
|
+
// Identity-dominant short-circuit: if H is essentially diagonal with positive diagonals, R = I.
|
|
570
|
+
nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
571
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
572
|
+
cross_covariance[8] * cross_covariance[8];
|
|
573
|
+
nk_f64_t covariance_offdiagonal_norm_squared =
|
|
574
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
575
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
576
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
577
|
+
nk_f64_t optimal_rotation[9];
|
|
578
|
+
nk_f64_t trace_rotation_covariance;
|
|
579
|
+
if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
|
|
580
|
+
cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
|
|
581
|
+
optimal_rotation[0] = 1.0, optimal_rotation[1] = 0.0, optimal_rotation[2] = 0.0;
|
|
582
|
+
optimal_rotation[3] = 0.0, optimal_rotation[4] = 1.0, optimal_rotation[5] = 0.0;
|
|
583
|
+
optimal_rotation[6] = 0.0, optimal_rotation[7] = 0.0, optimal_rotation[8] = 1.0;
|
|
584
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
585
|
+
}
|
|
586
|
+
else {
|
|
587
|
+
nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
588
|
+
nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
589
|
+
nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
|
|
590
|
+
if (nk_det3x3_f64_(optimal_rotation) < 0) {
|
|
591
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
592
|
+
nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
|
|
593
|
+
}
|
|
594
|
+
trace_rotation_covariance =
|
|
595
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
596
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
597
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
598
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
599
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
693
600
|
}
|
|
694
601
|
|
|
695
602
|
// Output rotation matrix and scale=1.0
|
|
696
603
|
if (rotation)
|
|
697
|
-
for (int j = 0; j < 9; ++j) rotation[j] =
|
|
604
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
698
605
|
if (scale) *scale = 1.0;
|
|
699
606
|
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
centroid_b_x, centroid_b_y, centroid_b_z);
|
|
607
|
+
nk_f64_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0 * trace_rotation_covariance;
|
|
608
|
+
if (sum_squared < 0.0) sum_squared = 0.0;
|
|
703
609
|
*result = nk_f64_sqrt_haswell(sum_squared * inv_n);
|
|
704
610
|
}
|
|
705
611
|
|
|
@@ -712,7 +618,8 @@ NK_PUBLIC void nk_umeyama_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_s
|
|
|
712
618
|
__m256d covariance_02_f64x4 = _mm256_setzero_pd(), covariance_10_f64x4 = _mm256_setzero_pd();
|
|
713
619
|
__m256d covariance_11_f64x4 = _mm256_setzero_pd(), covariance_12_f64x4 = _mm256_setzero_pd();
|
|
714
620
|
__m256d covariance_20_f64x4 = _mm256_setzero_pd(), covariance_21_f64x4 = _mm256_setzero_pd();
|
|
715
|
-
__m256d covariance_22_f64x4 = _mm256_setzero_pd()
|
|
621
|
+
__m256d covariance_22_f64x4 = _mm256_setzero_pd();
|
|
622
|
+
__m256d norm_squared_a_f64x4 = _mm256_setzero_pd(), norm_squared_b_f64x4 = _mm256_setzero_pd();
|
|
716
623
|
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
717
624
|
nk_size_t index = 0;
|
|
718
625
|
|
|
@@ -765,14 +672,22 @@ NK_PUBLIC void nk_umeyama_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_s
|
|
|
765
672
|
covariance_22_f64x4 = _mm256_add_pd(
|
|
766
673
|
covariance_22_f64x4,
|
|
767
674
|
_mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, b_z_low_f64x4), _mm256_mul_pd(a_z_high_f64x4, b_z_high_f64x4)));
|
|
768
|
-
|
|
769
|
-
|
|
675
|
+
norm_squared_a_f64x4 = _mm256_add_pd(
|
|
676
|
+
norm_squared_a_f64x4,
|
|
770
677
|
_mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_x_low_f64x4, a_x_low_f64x4),
|
|
771
678
|
_mm256_mul_pd(a_x_high_f64x4, a_x_high_f64x4)),
|
|
772
679
|
_mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(a_y_low_f64x4, a_y_low_f64x4),
|
|
773
680
|
_mm256_mul_pd(a_y_high_f64x4, a_y_high_f64x4)),
|
|
774
681
|
_mm256_add_pd(_mm256_mul_pd(a_z_low_f64x4, a_z_low_f64x4),
|
|
775
682
|
_mm256_mul_pd(a_z_high_f64x4, a_z_high_f64x4)))));
|
|
683
|
+
norm_squared_b_f64x4 = _mm256_add_pd(
|
|
684
|
+
norm_squared_b_f64x4,
|
|
685
|
+
_mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(b_x_low_f64x4, b_x_low_f64x4),
|
|
686
|
+
_mm256_mul_pd(b_x_high_f64x4, b_x_high_f64x4)),
|
|
687
|
+
_mm256_add_pd(_mm256_add_pd(_mm256_mul_pd(b_y_low_f64x4, b_y_low_f64x4),
|
|
688
|
+
_mm256_mul_pd(b_y_high_f64x4, b_y_high_f64x4)),
|
|
689
|
+
_mm256_add_pd(_mm256_mul_pd(b_z_low_f64x4, b_z_low_f64x4),
|
|
690
|
+
_mm256_mul_pd(b_z_high_f64x4, b_z_high_f64x4)))));
|
|
776
691
|
}
|
|
777
692
|
|
|
778
693
|
nk_f64_t sum_a_x = nk_reduce_add_f64x4_haswell_(sum_a_x_f64x4);
|
|
@@ -781,23 +696,25 @@ NK_PUBLIC void nk_umeyama_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_s
|
|
|
781
696
|
nk_f64_t sum_b_x = nk_reduce_add_f64x4_haswell_(sum_b_x_f64x4);
|
|
782
697
|
nk_f64_t sum_b_y = nk_reduce_add_f64x4_haswell_(sum_b_y_f64x4);
|
|
783
698
|
nk_f64_t sum_b_z = nk_reduce_add_f64x4_haswell_(sum_b_z_f64x4);
|
|
784
|
-
nk_f64_t
|
|
699
|
+
nk_f64_t cross_covariance[9] = {
|
|
785
700
|
nk_reduce_add_f64x4_haswell_(covariance_00_f64x4), nk_reduce_add_f64x4_haswell_(covariance_01_f64x4),
|
|
786
701
|
nk_reduce_add_f64x4_haswell_(covariance_02_f64x4), nk_reduce_add_f64x4_haswell_(covariance_10_f64x4),
|
|
787
702
|
nk_reduce_add_f64x4_haswell_(covariance_11_f64x4), nk_reduce_add_f64x4_haswell_(covariance_12_f64x4),
|
|
788
703
|
nk_reduce_add_f64x4_haswell_(covariance_20_f64x4), nk_reduce_add_f64x4_haswell_(covariance_21_f64x4),
|
|
789
704
|
nk_reduce_add_f64x4_haswell_(covariance_22_f64x4)};
|
|
790
|
-
nk_f64_t
|
|
705
|
+
nk_f64_t norm_squared_a_sum = nk_reduce_add_f64x4_haswell_(norm_squared_a_f64x4);
|
|
706
|
+
nk_f64_t norm_squared_b_sum = nk_reduce_add_f64x4_haswell_(norm_squared_b_f64x4);
|
|
791
707
|
|
|
792
708
|
for (; index < n; ++index) {
|
|
793
709
|
nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
|
|
794
710
|
nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
|
|
795
711
|
sum_a_x += a_x, sum_a_y += a_y, sum_a_z += a_z;
|
|
796
712
|
sum_b_x += b_x, sum_b_y += b_y, sum_b_z += b_z;
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
713
|
+
cross_covariance[0] += a_x * b_x, cross_covariance[1] += a_x * b_y, cross_covariance[2] += a_x * b_z;
|
|
714
|
+
cross_covariance[3] += a_y * b_x, cross_covariance[4] += a_y * b_y, cross_covariance[5] += a_y * b_z;
|
|
715
|
+
cross_covariance[6] += a_z * b_x, cross_covariance[7] += a_z * b_y, cross_covariance[8] += a_z * b_z;
|
|
716
|
+
norm_squared_a_sum += a_x * a_x + a_y * a_y + a_z * a_z;
|
|
717
|
+
norm_squared_b_sum += b_x * b_x + b_y * b_y + b_z * b_z;
|
|
801
718
|
}
|
|
802
719
|
|
|
803
720
|
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
@@ -810,49 +727,89 @@ NK_PUBLIC void nk_umeyama_f32_haswell(nk_f32_t const *a, nk_f32_t const *b, nk_s
|
|
|
810
727
|
b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
|
|
811
728
|
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
812
729
|
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
730
|
+
nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
|
|
731
|
+
(nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
732
|
+
centroid_a_z * centroid_a_z);
|
|
733
|
+
nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
|
|
734
|
+
(nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
735
|
+
centroid_b_z * centroid_b_z);
|
|
736
|
+
if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
|
|
737
|
+
if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
|
|
738
|
+
cross_covariance[0] -= (nk_f64_t)n * centroid_a_x * centroid_b_x,
|
|
739
|
+
cross_covariance[1] -= (nk_f64_t)n * centroid_a_x * centroid_b_y,
|
|
740
|
+
cross_covariance[2] -= (nk_f64_t)n * centroid_a_x * centroid_b_z,
|
|
741
|
+
cross_covariance[3] -= (nk_f64_t)n * centroid_a_y * centroid_b_x,
|
|
742
|
+
cross_covariance[4] -= (nk_f64_t)n * centroid_a_y * centroid_b_y,
|
|
743
|
+
cross_covariance[5] -= (nk_f64_t)n * centroid_a_y * centroid_b_z,
|
|
744
|
+
cross_covariance[6] -= (nk_f64_t)n * centroid_a_z * centroid_b_x,
|
|
745
|
+
cross_covariance[7] -= (nk_f64_t)n * centroid_a_z * centroid_b_y,
|
|
746
|
+
cross_covariance[8] -= (nk_f64_t)n * centroid_a_z * centroid_b_z;
|
|
747
|
+
|
|
748
|
+
// Identity-dominant short-circuit: if H is essentially diagonal with positive diagonals,
|
|
749
|
+
// R = I and trace(DS) reduces to trace(H) directly.
|
|
750
|
+
nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
751
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
752
|
+
cross_covariance[8] * cross_covariance[8];
|
|
753
|
+
nk_f64_t covariance_offdiagonal_norm_squared =
|
|
754
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
755
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
756
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
757
|
+
nk_f64_t optimal_rotation[9];
|
|
758
|
+
nk_f64_t applied_scale;
|
|
759
|
+
nk_f64_t trace_rotation_covariance;
|
|
760
|
+
if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
|
|
761
|
+
cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
|
|
762
|
+
optimal_rotation[0] = 1.0, optimal_rotation[1] = 0.0, optimal_rotation[2] = 0.0;
|
|
763
|
+
optimal_rotation[3] = 0.0, optimal_rotation[4] = 1.0, optimal_rotation[5] = 0.0;
|
|
764
|
+
optimal_rotation[6] = 0.0, optimal_rotation[7] = 0.0, optimal_rotation[8] = 1.0;
|
|
765
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
766
|
+
applied_scale = centered_norm_squared_a > 0.0 ? trace_rotation_covariance / centered_norm_squared_a : 0.0;
|
|
767
|
+
}
|
|
768
|
+
else {
|
|
769
|
+
nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
770
|
+
nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
771
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
772
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
773
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
774
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
775
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
776
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
777
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
778
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
779
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
780
|
+
|
|
781
|
+
nk_f64_t det = nk_det3x3_f64_(optimal_rotation), sign_correction = det < 0 ? -1.0 : 1.0;
|
|
782
|
+
if (det < 0) {
|
|
783
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
784
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
785
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
786
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
787
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
788
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
789
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
790
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
791
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
792
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
793
|
+
}
|
|
794
|
+
nk_f64_t trace_ds = svd_diagonal[0] + svd_diagonal[4] + sign_correction * svd_diagonal[8];
|
|
795
|
+
applied_scale = centered_norm_squared_a > 0.0 ? trace_ds / centered_norm_squared_a : 0.0;
|
|
796
|
+
trace_rotation_covariance =
|
|
797
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
798
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
799
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
800
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
801
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
846
802
|
}
|
|
847
803
|
|
|
848
|
-
nk_f64_t applied_scale = (svd_s[0] + svd_s[4] + sign_correction * svd_s[8]) / ((nk_f64_t)n * variance_a);
|
|
849
804
|
if (rotation)
|
|
850
|
-
for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)
|
|
805
|
+
for (int j = 0; j != 9; ++j) rotation[j] = (nk_f32_t)optimal_rotation[j];
|
|
851
806
|
if (scale) *scale = (nk_f32_t)applied_scale;
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
807
|
+
|
|
808
|
+
// Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
|
|
809
|
+
nk_f64_t sum_squared = applied_scale * applied_scale * centered_norm_squared_a + centered_norm_squared_b -
|
|
810
|
+
2.0 * applied_scale * trace_rotation_covariance;
|
|
811
|
+
if (sum_squared < 0.0) sum_squared = 0.0;
|
|
812
|
+
*result = nk_f64_sqrt_haswell(sum_squared / (nk_f64_t)n);
|
|
856
813
|
}
|
|
857
814
|
|
|
858
815
|
NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
@@ -862,10 +819,10 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
862
819
|
|
|
863
820
|
__m256d sum_a_x_f64x4 = zeros_f64x4, sum_a_y_f64x4 = zeros_f64x4, sum_a_z_f64x4 = zeros_f64x4;
|
|
864
821
|
__m256d sum_b_x_f64x4 = zeros_f64x4, sum_b_y_f64x4 = zeros_f64x4, sum_b_z_f64x4 = zeros_f64x4;
|
|
865
|
-
__m256d
|
|
866
|
-
__m256d
|
|
867
|
-
__m256d
|
|
868
|
-
__m256d
|
|
822
|
+
__m256d covariance_xx_f64x4 = zeros_f64x4, covariance_xy_f64x4 = zeros_f64x4, covariance_xz_f64x4 = zeros_f64x4;
|
|
823
|
+
__m256d covariance_yx_f64x4 = zeros_f64x4, covariance_yy_f64x4 = zeros_f64x4, covariance_yz_f64x4 = zeros_f64x4;
|
|
824
|
+
__m256d covariance_zx_f64x4 = zeros_f64x4, covariance_zy_f64x4 = zeros_f64x4, covariance_zz_f64x4 = zeros_f64x4;
|
|
825
|
+
__m256d norm_squared_a_f64x4 = zeros_f64x4, norm_squared_b_f64x4 = zeros_f64x4;
|
|
869
826
|
|
|
870
827
|
nk_size_t i = 0;
|
|
871
828
|
__m256d a_x_f64x4, a_y_f64x4, a_z_f64x4, b_x_f64x4, b_y_f64x4, b_z_f64x4;
|
|
@@ -881,18 +838,21 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
881
838
|
sum_b_y_f64x4 = _mm256_add_pd(sum_b_y_f64x4, b_y_f64x4);
|
|
882
839
|
sum_b_z_f64x4 = _mm256_add_pd(sum_b_z_f64x4, b_z_f64x4);
|
|
883
840
|
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
841
|
+
covariance_xx_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_x_f64x4, covariance_xx_f64x4),
|
|
842
|
+
covariance_xy_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_y_f64x4, covariance_xy_f64x4);
|
|
843
|
+
covariance_xz_f64x4 = _mm256_fmadd_pd(a_x_f64x4, b_z_f64x4, covariance_xz_f64x4);
|
|
844
|
+
covariance_yx_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_x_f64x4, covariance_yx_f64x4),
|
|
845
|
+
covariance_yy_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_y_f64x4, covariance_yy_f64x4);
|
|
846
|
+
covariance_yz_f64x4 = _mm256_fmadd_pd(a_y_f64x4, b_z_f64x4, covariance_yz_f64x4);
|
|
847
|
+
covariance_zx_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_x_f64x4, covariance_zx_f64x4),
|
|
848
|
+
covariance_zy_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_y_f64x4, covariance_zy_f64x4);
|
|
849
|
+
covariance_zz_f64x4 = _mm256_fmadd_pd(a_z_f64x4, b_z_f64x4, covariance_zz_f64x4);
|
|
850
|
+
norm_squared_a_f64x4 = _mm256_fmadd_pd(a_x_f64x4, a_x_f64x4, norm_squared_a_f64x4);
|
|
851
|
+
norm_squared_a_f64x4 = _mm256_fmadd_pd(a_y_f64x4, a_y_f64x4, norm_squared_a_f64x4);
|
|
852
|
+
norm_squared_a_f64x4 = _mm256_fmadd_pd(a_z_f64x4, a_z_f64x4, norm_squared_a_f64x4);
|
|
853
|
+
norm_squared_b_f64x4 = _mm256_fmadd_pd(b_x_f64x4, b_x_f64x4, norm_squared_b_f64x4);
|
|
854
|
+
norm_squared_b_f64x4 = _mm256_fmadd_pd(b_y_f64x4, b_y_f64x4, norm_squared_b_f64x4);
|
|
855
|
+
norm_squared_b_f64x4 = _mm256_fmadd_pd(b_z_f64x4, b_z_f64x4, norm_squared_b_f64x4);
|
|
896
856
|
}
|
|
897
857
|
|
|
898
858
|
// Reduce vector accumulators
|
|
@@ -902,16 +862,19 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
902
862
|
nk_f64_t sum_b_x = nk_reduce_stable_f64x4_haswell_(sum_b_x_f64x4), sum_b_x_compensation = 0.0;
|
|
903
863
|
nk_f64_t sum_b_y = nk_reduce_stable_f64x4_haswell_(sum_b_y_f64x4), sum_b_y_compensation = 0.0;
|
|
904
864
|
nk_f64_t sum_b_z = nk_reduce_stable_f64x4_haswell_(sum_b_z_f64x4), sum_b_z_compensation = 0.0;
|
|
905
|
-
nk_f64_t covariance_x_x = nk_reduce_stable_f64x4_haswell_(
|
|
906
|
-
nk_f64_t covariance_x_y = nk_reduce_stable_f64x4_haswell_(
|
|
907
|
-
nk_f64_t covariance_x_z = nk_reduce_stable_f64x4_haswell_(
|
|
908
|
-
nk_f64_t covariance_y_x = nk_reduce_stable_f64x4_haswell_(
|
|
909
|
-
nk_f64_t covariance_y_y = nk_reduce_stable_f64x4_haswell_(
|
|
910
|
-
nk_f64_t covariance_y_z = nk_reduce_stable_f64x4_haswell_(
|
|
911
|
-
nk_f64_t covariance_z_x = nk_reduce_stable_f64x4_haswell_(
|
|
912
|
-
nk_f64_t covariance_z_y = nk_reduce_stable_f64x4_haswell_(
|
|
913
|
-
nk_f64_t covariance_z_z = nk_reduce_stable_f64x4_haswell_(
|
|
914
|
-
nk_f64_t
|
|
865
|
+
nk_f64_t covariance_x_x = nk_reduce_stable_f64x4_haswell_(covariance_xx_f64x4), covariance_x_x_compensation = 0.0;
|
|
866
|
+
nk_f64_t covariance_x_y = nk_reduce_stable_f64x4_haswell_(covariance_xy_f64x4), covariance_x_y_compensation = 0.0;
|
|
867
|
+
nk_f64_t covariance_x_z = nk_reduce_stable_f64x4_haswell_(covariance_xz_f64x4), covariance_x_z_compensation = 0.0;
|
|
868
|
+
nk_f64_t covariance_y_x = nk_reduce_stable_f64x4_haswell_(covariance_yx_f64x4), covariance_y_x_compensation = 0.0;
|
|
869
|
+
nk_f64_t covariance_y_y = nk_reduce_stable_f64x4_haswell_(covariance_yy_f64x4), covariance_y_y_compensation = 0.0;
|
|
870
|
+
nk_f64_t covariance_y_z = nk_reduce_stable_f64x4_haswell_(covariance_yz_f64x4), covariance_y_z_compensation = 0.0;
|
|
871
|
+
nk_f64_t covariance_z_x = nk_reduce_stable_f64x4_haswell_(covariance_zx_f64x4), covariance_z_x_compensation = 0.0;
|
|
872
|
+
nk_f64_t covariance_z_y = nk_reduce_stable_f64x4_haswell_(covariance_zy_f64x4), covariance_z_y_compensation = 0.0;
|
|
873
|
+
nk_f64_t covariance_z_z = nk_reduce_stable_f64x4_haswell_(covariance_zz_f64x4), covariance_z_z_compensation = 0.0;
|
|
874
|
+
nk_f64_t norm_squared_a_sum = nk_reduce_stable_f64x4_haswell_(norm_squared_a_f64x4),
|
|
875
|
+
norm_squared_a_compensation = 0.0;
|
|
876
|
+
nk_f64_t norm_squared_b_sum = nk_reduce_stable_f64x4_haswell_(norm_squared_b_f64x4),
|
|
877
|
+
norm_squared_b_compensation = 0.0;
|
|
915
878
|
|
|
916
879
|
// Scalar tail loop for remaining points
|
|
917
880
|
for (; i < n; i++) {
|
|
@@ -932,9 +895,12 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
932
895
|
nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx);
|
|
933
896
|
nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by);
|
|
934
897
|
nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
|
|
935
|
-
nk_accumulate_square_f64_(&
|
|
936
|
-
nk_accumulate_square_f64_(&
|
|
937
|
-
nk_accumulate_square_f64_(&
|
|
898
|
+
nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ax);
|
|
899
|
+
nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ay);
|
|
900
|
+
nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, az);
|
|
901
|
+
nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bx);
|
|
902
|
+
nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, by);
|
|
903
|
+
nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bz);
|
|
938
904
|
}
|
|
939
905
|
|
|
940
906
|
sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
|
|
@@ -945,7 +911,8 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
945
911
|
covariance_y_z += covariance_y_z_compensation;
|
|
946
912
|
covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
|
|
947
913
|
covariance_z_z += covariance_z_z_compensation;
|
|
948
|
-
|
|
914
|
+
norm_squared_a_sum += norm_squared_a_compensation;
|
|
915
|
+
norm_squared_b_sum += norm_squared_b_compensation;
|
|
949
916
|
|
|
950
917
|
// Compute centroids
|
|
951
918
|
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
@@ -956,9 +923,15 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
956
923
|
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
957
924
|
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
958
925
|
|
|
959
|
-
//
|
|
960
|
-
nk_f64_t
|
|
961
|
-
|
|
926
|
+
// Centered norm-squared via parallel-axis identity; clamped at zero for numeric safety.
|
|
927
|
+
nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
|
|
928
|
+
(nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
929
|
+
centroid_a_z * centroid_a_z);
|
|
930
|
+
nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
|
|
931
|
+
(nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
932
|
+
centroid_b_z * centroid_b_z);
|
|
933
|
+
if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
|
|
934
|
+
if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
|
|
962
935
|
|
|
963
936
|
nk_f64_t cross_covariance[9];
|
|
964
937
|
cross_covariance[0] = covariance_x_x - sum_a_x * sum_b_x * inv_n;
|
|
@@ -971,34 +944,56 @@ NK_PUBLIC void nk_umeyama_f64_haswell(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
971
944
|
cross_covariance[7] = covariance_z_y - sum_a_z * sum_b_y * inv_n;
|
|
972
945
|
cross_covariance[8] = covariance_z_z - sum_a_z * sum_b_z * inv_n;
|
|
973
946
|
|
|
974
|
-
//
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
nk_f64_t
|
|
984
|
-
nk_f64_t
|
|
985
|
-
nk_f64_t
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
947
|
+
// Identity-dominant short-circuit: if H is essentially diagonal with positive diagonals,
|
|
948
|
+
// R = I and trace(DS) reduces to trace(H) directly.
|
|
949
|
+
nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
950
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
951
|
+
cross_covariance[8] * cross_covariance[8];
|
|
952
|
+
nk_f64_t covariance_offdiagonal_norm_squared =
|
|
953
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
954
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
955
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
956
|
+
nk_f64_t optimal_rotation[9];
|
|
957
|
+
nk_f64_t c;
|
|
958
|
+
nk_f64_t trace_rotation_covariance;
|
|
959
|
+
if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
|
|
960
|
+
cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
|
|
961
|
+
optimal_rotation[0] = 1.0, optimal_rotation[1] = 0.0, optimal_rotation[2] = 0.0;
|
|
962
|
+
optimal_rotation[3] = 0.0, optimal_rotation[4] = 1.0, optimal_rotation[5] = 0.0;
|
|
963
|
+
optimal_rotation[6] = 0.0, optimal_rotation[7] = 0.0, optimal_rotation[8] = 1.0;
|
|
964
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
965
|
+
c = centered_norm_squared_a > 0.0 ? trace_rotation_covariance / centered_norm_squared_a : 0.0;
|
|
966
|
+
}
|
|
967
|
+
else {
|
|
968
|
+
nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
969
|
+
nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
970
|
+
nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
|
|
971
|
+
|
|
972
|
+
nk_f64_t det = nk_det3x3_f64_(optimal_rotation);
|
|
973
|
+
nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
|
|
974
|
+
nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_diagonal[0], 1.0, svd_diagonal[4], 1.0, svd_diagonal[8], d3);
|
|
975
|
+
c = centered_norm_squared_a > 0.0 ? trace_ds / centered_norm_squared_a : 0.0;
|
|
976
|
+
|
|
977
|
+
if (det < 0) {
|
|
978
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
979
|
+
nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
|
|
980
|
+
}
|
|
981
|
+
trace_rotation_covariance =
|
|
982
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
983
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
984
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
985
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
986
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
993
987
|
}
|
|
994
988
|
|
|
995
|
-
|
|
989
|
+
if (scale) *scale = c;
|
|
996
990
|
if (rotation)
|
|
997
|
-
for (int j = 0; j < 9; ++j) rotation[j] =
|
|
991
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
998
992
|
|
|
999
|
-
//
|
|
1000
|
-
nk_f64_t sum_squared =
|
|
1001
|
-
|
|
993
|
+
// Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
|
|
994
|
+
nk_f64_t sum_squared = c * c * centered_norm_squared_a + centered_norm_squared_b -
|
|
995
|
+
2.0 * c * trace_rotation_covariance;
|
|
996
|
+
if (sum_squared < 0.0) sum_squared = 0.0;
|
|
1002
997
|
*result = nk_f64_sqrt_haswell(sum_squared * inv_n);
|
|
1003
998
|
}
|
|
1004
999
|
|
|
@@ -1046,237 +1041,34 @@ NK_INTERNAL void nk_deinterleave_bf16x8_to_f32x8_haswell_(nk_bf16_t const *ptr,
|
|
|
1046
1041
|
*z_out = nk_bf16x8_to_f32x8_haswell_(z_vec.xmms[0]);
|
|
1047
1042
|
}
|
|
1048
1043
|
|
|
1049
|
-
/* Compute sum of squared distances for f16 data after applying rotation (and optional scale).
|
|
1050
|
-
* Loads f16 data, converts to f32 during processing.
|
|
1051
|
-
* Note: rotation matrix r is f32 (from SVD), scale and computation done in f32.
|
|
1052
|
-
*/
|
|
1053
|
-
NK_INTERNAL nk_f32_t nk_transformed_ssd_f16_haswell_(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
|
|
1054
|
-
nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
|
|
1055
|
-
nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
|
|
1056
|
-
nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
|
|
1057
|
-
nk_f32_t centroid_b_z) {
|
|
1058
|
-
// Broadcast scaled rotation matrix elements
|
|
1059
|
-
__m256 scaled_rotation_x_x_f32x8 = _mm256_set1_ps(scale * r[0]);
|
|
1060
|
-
__m256 scaled_rotation_x_y_f32x8 = _mm256_set1_ps(scale * r[1]);
|
|
1061
|
-
__m256 scaled_rotation_x_z_f32x8 = _mm256_set1_ps(scale * r[2]);
|
|
1062
|
-
__m256 scaled_rotation_y_x_f32x8 = _mm256_set1_ps(scale * r[3]);
|
|
1063
|
-
__m256 scaled_rotation_y_y_f32x8 = _mm256_set1_ps(scale * r[4]);
|
|
1064
|
-
__m256 scaled_rotation_y_z_f32x8 = _mm256_set1_ps(scale * r[5]);
|
|
1065
|
-
__m256 scaled_rotation_z_x_f32x8 = _mm256_set1_ps(scale * r[6]);
|
|
1066
|
-
__m256 scaled_rotation_z_y_f32x8 = _mm256_set1_ps(scale * r[7]);
|
|
1067
|
-
__m256 scaled_rotation_z_z_f32x8 = _mm256_set1_ps(scale * r[8]);
|
|
1068
|
-
|
|
1069
|
-
// Broadcast centroids
|
|
1070
|
-
__m256 centroid_a_x_f32x8 = _mm256_set1_ps(centroid_a_x);
|
|
1071
|
-
__m256 centroid_a_y_f32x8 = _mm256_set1_ps(centroid_a_y);
|
|
1072
|
-
__m256 centroid_a_z_f32x8 = _mm256_set1_ps(centroid_a_z);
|
|
1073
|
-
__m256 centroid_b_x_f32x8 = _mm256_set1_ps(centroid_b_x);
|
|
1074
|
-
__m256 centroid_b_y_f32x8 = _mm256_set1_ps(centroid_b_y);
|
|
1075
|
-
__m256 centroid_b_z_f32x8 = _mm256_set1_ps(centroid_b_z);
|
|
1076
|
-
|
|
1077
|
-
__m256 sum_squared_f32x8 = _mm256_setzero_ps();
|
|
1078
|
-
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
1079
|
-
nk_size_t j = 0;
|
|
1080
|
-
|
|
1081
|
-
for (; j + 8 <= n; j += 8) {
|
|
1082
|
-
nk_deinterleave_f16x8_to_f32x8_haswell_(a + j * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
|
|
1083
|
-
nk_deinterleave_f16x8_to_f32x8_haswell_(b + j * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
1084
|
-
|
|
1085
|
-
// Center points
|
|
1086
|
-
__m256 pa_x_f32x8 = _mm256_sub_ps(a_x_f32x8, centroid_a_x_f32x8);
|
|
1087
|
-
__m256 pa_y_f32x8 = _mm256_sub_ps(a_y_f32x8, centroid_a_y_f32x8);
|
|
1088
|
-
__m256 pa_z_f32x8 = _mm256_sub_ps(a_z_f32x8, centroid_a_z_f32x8);
|
|
1089
|
-
__m256 pb_x_f32x8 = _mm256_sub_ps(b_x_f32x8, centroid_b_x_f32x8);
|
|
1090
|
-
__m256 pb_y_f32x8 = _mm256_sub_ps(b_y_f32x8, centroid_b_y_f32x8);
|
|
1091
|
-
__m256 pb_z_f32x8 = _mm256_sub_ps(b_z_f32x8, centroid_b_z_f32x8);
|
|
1092
|
-
|
|
1093
|
-
// Rotate and scale: ra = scale * R * pa
|
|
1094
|
-
__m256 ra_x_f32x8 = _mm256_fmadd_ps(scaled_rotation_x_z_f32x8, pa_z_f32x8,
|
|
1095
|
-
_mm256_fmadd_ps(scaled_rotation_x_y_f32x8, pa_y_f32x8,
|
|
1096
|
-
_mm256_mul_ps(scaled_rotation_x_x_f32x8, pa_x_f32x8)));
|
|
1097
|
-
__m256 ra_y_f32x8 = _mm256_fmadd_ps(scaled_rotation_y_z_f32x8, pa_z_f32x8,
|
|
1098
|
-
_mm256_fmadd_ps(scaled_rotation_y_y_f32x8, pa_y_f32x8,
|
|
1099
|
-
_mm256_mul_ps(scaled_rotation_y_x_f32x8, pa_x_f32x8)));
|
|
1100
|
-
__m256 ra_z_f32x8 = _mm256_fmadd_ps(scaled_rotation_z_z_f32x8, pa_z_f32x8,
|
|
1101
|
-
_mm256_fmadd_ps(scaled_rotation_z_y_f32x8, pa_y_f32x8,
|
|
1102
|
-
_mm256_mul_ps(scaled_rotation_z_x_f32x8, pa_x_f32x8)));
|
|
1103
|
-
|
|
1104
|
-
// Delta and accumulate
|
|
1105
|
-
__m256 delta_x_f32x8 = _mm256_sub_ps(ra_x_f32x8, pb_x_f32x8);
|
|
1106
|
-
__m256 delta_y_f32x8 = _mm256_sub_ps(ra_y_f32x8, pb_y_f32x8);
|
|
1107
|
-
__m256 delta_z_f32x8 = _mm256_sub_ps(ra_z_f32x8, pb_z_f32x8);
|
|
1108
|
-
|
|
1109
|
-
sum_squared_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_f32x8);
|
|
1110
|
-
sum_squared_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_f32x8);
|
|
1111
|
-
sum_squared_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_f32x8);
|
|
1112
|
-
}
|
|
1113
|
-
|
|
1114
|
-
nk_f32_t sum_squared = nk_reduce_add_f32x8_haswell_(sum_squared_f32x8);
|
|
1115
|
-
|
|
1116
|
-
// Scalar tail
|
|
1117
|
-
for (; j < n; ++j) {
|
|
1118
|
-
nk_f32_t a_x_f32, a_y_f32, a_z_f32, b_x_f32, b_y_f32, b_z_f32;
|
|
1119
|
-
nk_f16_to_f32_haswell(&a[j * 3 + 0], &a_x_f32);
|
|
1120
|
-
nk_f16_to_f32_haswell(&a[j * 3 + 1], &a_y_f32);
|
|
1121
|
-
nk_f16_to_f32_haswell(&a[j * 3 + 2], &a_z_f32);
|
|
1122
|
-
nk_f16_to_f32_haswell(&b[j * 3 + 0], &b_x_f32);
|
|
1123
|
-
nk_f16_to_f32_haswell(&b[j * 3 + 1], &b_y_f32);
|
|
1124
|
-
nk_f16_to_f32_haswell(&b[j * 3 + 2], &b_z_f32);
|
|
1125
|
-
|
|
1126
|
-
nk_f32_t pa_x = a_x_f32 - centroid_a_x, pa_y = a_y_f32 - centroid_a_y, pa_z = a_z_f32 - centroid_a_z;
|
|
1127
|
-
nk_f32_t pb_x = b_x_f32 - centroid_b_x, pb_y = b_y_f32 - centroid_b_y, pb_z = b_z_f32 - centroid_b_z;
|
|
1128
|
-
nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
|
|
1129
|
-
ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
|
|
1130
|
-
ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
1131
|
-
|
|
1132
|
-
nk_f32_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
|
|
1133
|
-
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
1134
|
-
}
|
|
1135
|
-
|
|
1136
|
-
return sum_squared;
|
|
1137
|
-
}
|
|
1138
|
-
|
|
1139
|
-
/* Compute sum of squared distances for bf16 data after applying rotation (and optional scale).
|
|
1140
|
-
* Loads bf16 data, converts to f32 during processing.
|
|
1141
|
-
* Note: rotation matrix r is f32 (from SVD), scale and computation done in f32.
|
|
1142
|
-
*/
|
|
1143
|
-
NK_INTERNAL nk_f32_t nk_transformed_ssd_bf16_haswell_(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n,
|
|
1144
|
-
nk_f32_t const *r, nk_f32_t scale, nk_f32_t centroid_a_x,
|
|
1145
|
-
nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
|
|
1146
|
-
nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
|
|
1147
|
-
nk_f32_t centroid_b_z) {
|
|
1148
|
-
// Broadcast scaled rotation matrix elements
|
|
1149
|
-
__m256 scaled_rotation_x_x_f32x8 = _mm256_set1_ps(scale * r[0]);
|
|
1150
|
-
__m256 scaled_rotation_x_y_f32x8 = _mm256_set1_ps(scale * r[1]);
|
|
1151
|
-
__m256 scaled_rotation_x_z_f32x8 = _mm256_set1_ps(scale * r[2]);
|
|
1152
|
-
__m256 scaled_rotation_y_x_f32x8 = _mm256_set1_ps(scale * r[3]);
|
|
1153
|
-
__m256 scaled_rotation_y_y_f32x8 = _mm256_set1_ps(scale * r[4]);
|
|
1154
|
-
__m256 scaled_rotation_y_z_f32x8 = _mm256_set1_ps(scale * r[5]);
|
|
1155
|
-
__m256 scaled_rotation_z_x_f32x8 = _mm256_set1_ps(scale * r[6]);
|
|
1156
|
-
__m256 scaled_rotation_z_y_f32x8 = _mm256_set1_ps(scale * r[7]);
|
|
1157
|
-
__m256 scaled_rotation_z_z_f32x8 = _mm256_set1_ps(scale * r[8]);
|
|
1158
|
-
|
|
1159
|
-
// Broadcast centroids
|
|
1160
|
-
__m256 centroid_a_x_f32x8 = _mm256_set1_ps(centroid_a_x);
|
|
1161
|
-
__m256 centroid_a_y_f32x8 = _mm256_set1_ps(centroid_a_y);
|
|
1162
|
-
__m256 centroid_a_z_f32x8 = _mm256_set1_ps(centroid_a_z);
|
|
1163
|
-
__m256 centroid_b_x_f32x8 = _mm256_set1_ps(centroid_b_x);
|
|
1164
|
-
__m256 centroid_b_y_f32x8 = _mm256_set1_ps(centroid_b_y);
|
|
1165
|
-
__m256 centroid_b_z_f32x8 = _mm256_set1_ps(centroid_b_z);
|
|
1166
|
-
|
|
1167
|
-
__m256 sum_squared_f32x8 = _mm256_setzero_ps();
|
|
1168
|
-
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
1169
|
-
nk_size_t j = 0;
|
|
1170
|
-
|
|
1171
|
-
for (; j + 8 <= n; j += 8) {
|
|
1172
|
-
nk_deinterleave_bf16x8_to_f32x8_haswell_(a + j * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
|
|
1173
|
-
nk_deinterleave_bf16x8_to_f32x8_haswell_(b + j * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
1174
|
-
|
|
1175
|
-
// Center points
|
|
1176
|
-
__m256 pa_x_f32x8 = _mm256_sub_ps(a_x_f32x8, centroid_a_x_f32x8);
|
|
1177
|
-
__m256 pa_y_f32x8 = _mm256_sub_ps(a_y_f32x8, centroid_a_y_f32x8);
|
|
1178
|
-
__m256 pa_z_f32x8 = _mm256_sub_ps(a_z_f32x8, centroid_a_z_f32x8);
|
|
1179
|
-
__m256 pb_x_f32x8 = _mm256_sub_ps(b_x_f32x8, centroid_b_x_f32x8);
|
|
1180
|
-
__m256 pb_y_f32x8 = _mm256_sub_ps(b_y_f32x8, centroid_b_y_f32x8);
|
|
1181
|
-
__m256 pb_z_f32x8 = _mm256_sub_ps(b_z_f32x8, centroid_b_z_f32x8);
|
|
1182
|
-
|
|
1183
|
-
// Rotate and scale: ra = scale * R * pa
|
|
1184
|
-
__m256 ra_x_f32x8 = _mm256_fmadd_ps(scaled_rotation_x_z_f32x8, pa_z_f32x8,
|
|
1185
|
-
_mm256_fmadd_ps(scaled_rotation_x_y_f32x8, pa_y_f32x8,
|
|
1186
|
-
_mm256_mul_ps(scaled_rotation_x_x_f32x8, pa_x_f32x8)));
|
|
1187
|
-
__m256 ra_y_f32x8 = _mm256_fmadd_ps(scaled_rotation_y_z_f32x8, pa_z_f32x8,
|
|
1188
|
-
_mm256_fmadd_ps(scaled_rotation_y_y_f32x8, pa_y_f32x8,
|
|
1189
|
-
_mm256_mul_ps(scaled_rotation_y_x_f32x8, pa_x_f32x8)));
|
|
1190
|
-
__m256 ra_z_f32x8 = _mm256_fmadd_ps(scaled_rotation_z_z_f32x8, pa_z_f32x8,
|
|
1191
|
-
_mm256_fmadd_ps(scaled_rotation_z_y_f32x8, pa_y_f32x8,
|
|
1192
|
-
_mm256_mul_ps(scaled_rotation_z_x_f32x8, pa_x_f32x8)));
|
|
1193
|
-
|
|
1194
|
-
// Delta and accumulate
|
|
1195
|
-
__m256 delta_x_f32x8 = _mm256_sub_ps(ra_x_f32x8, pb_x_f32x8);
|
|
1196
|
-
__m256 delta_y_f32x8 = _mm256_sub_ps(ra_y_f32x8, pb_y_f32x8);
|
|
1197
|
-
__m256 delta_z_f32x8 = _mm256_sub_ps(ra_z_f32x8, pb_z_f32x8);
|
|
1198
|
-
|
|
1199
|
-
sum_squared_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_f32x8);
|
|
1200
|
-
sum_squared_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_f32x8);
|
|
1201
|
-
sum_squared_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_f32x8);
|
|
1202
|
-
}
|
|
1203
|
-
|
|
1204
|
-
nk_f32_t sum_squared = nk_reduce_add_f32x8_haswell_(sum_squared_f32x8);
|
|
1205
|
-
|
|
1206
|
-
// Scalar tail
|
|
1207
|
-
for (; j < n; ++j) {
|
|
1208
|
-
nk_f32_t a_x_f32, a_y_f32, a_z_f32, b_x_f32, b_y_f32, b_z_f32;
|
|
1209
|
-
nk_bf16_to_f32_serial(&a[j * 3 + 0], &a_x_f32);
|
|
1210
|
-
nk_bf16_to_f32_serial(&a[j * 3 + 1], &a_y_f32);
|
|
1211
|
-
nk_bf16_to_f32_serial(&a[j * 3 + 2], &a_z_f32);
|
|
1212
|
-
nk_bf16_to_f32_serial(&b[j * 3 + 0], &b_x_f32);
|
|
1213
|
-
nk_bf16_to_f32_serial(&b[j * 3 + 1], &b_y_f32);
|
|
1214
|
-
nk_bf16_to_f32_serial(&b[j * 3 + 2], &b_z_f32);
|
|
1215
|
-
|
|
1216
|
-
nk_f32_t pa_x = a_x_f32 - centroid_a_x, pa_y = a_y_f32 - centroid_a_y, pa_z = a_z_f32 - centroid_a_z;
|
|
1217
|
-
nk_f32_t pb_x = b_x_f32 - centroid_b_x, pb_y = b_y_f32 - centroid_b_y, pb_z = b_z_f32 - centroid_b_z;
|
|
1218
|
-
nk_f32_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
|
|
1219
|
-
ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
|
|
1220
|
-
ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
1221
|
-
|
|
1222
|
-
nk_f32_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
|
|
1223
|
-
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
1224
|
-
}
|
|
1225
|
-
|
|
1226
|
-
return sum_squared;
|
|
1227
|
-
}
|
|
1228
|
-
|
|
1229
1044
|
NK_PUBLIC void nk_rmsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1230
1045
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1231
|
-
// RMSD uses identity rotation and scale=1.0
|
|
1232
1046
|
if (rotation)
|
|
1233
1047
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
1234
1048
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
1235
1049
|
if (scale) *scale = 1.0f;
|
|
1050
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
1051
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
1236
1052
|
|
|
1237
1053
|
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
1238
|
-
|
|
1239
|
-
// Accumulators for centroids and squared differences (all in f32)
|
|
1240
|
-
__m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
|
|
1241
|
-
__m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
|
|
1242
1054
|
__m256 sum_squared_x_f32x8 = zeros_f32x8, sum_squared_y_f32x8 = zeros_f32x8, sum_squared_z_f32x8 = zeros_f32x8;
|
|
1243
|
-
|
|
1244
1055
|
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
1245
1056
|
nk_size_t i = 0;
|
|
1246
1057
|
|
|
1247
|
-
// Main loop processing 8 points at a time
|
|
1248
1058
|
for (; i + 8 <= n; i += 8) {
|
|
1249
1059
|
nk_deinterleave_f16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
|
|
1250
1060
|
nk_deinterleave_f16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
1251
|
-
|
|
1252
|
-
sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
|
|
1253
|
-
sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
|
|
1254
|
-
sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
|
|
1255
|
-
sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
|
|
1256
|
-
sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
|
|
1257
|
-
sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
|
|
1258
|
-
|
|
1259
1061
|
__m256 delta_x_f32x8 = _mm256_sub_ps(a_x_f32x8, b_x_f32x8);
|
|
1260
1062
|
__m256 delta_y_f32x8 = _mm256_sub_ps(a_y_f32x8, b_y_f32x8);
|
|
1261
1063
|
__m256 delta_z_f32x8 = _mm256_sub_ps(a_z_f32x8, b_z_f32x8);
|
|
1262
|
-
|
|
1263
1064
|
sum_squared_x_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_x_f32x8);
|
|
1264
1065
|
sum_squared_y_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_y_f32x8);
|
|
1265
1066
|
sum_squared_z_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_z_f32x8);
|
|
1266
1067
|
}
|
|
1267
1068
|
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
nk_f32_t total_az = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
|
|
1272
|
-
nk_f32_t total_bx = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
|
|
1273
|
-
nk_f32_t total_by = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
|
|
1274
|
-
nk_f32_t total_bz = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
|
|
1275
|
-
nk_f32_t total_sq_x = nk_reduce_add_f32x8_haswell_(sum_squared_x_f32x8);
|
|
1276
|
-
nk_f32_t total_sq_y = nk_reduce_add_f32x8_haswell_(sum_squared_y_f32x8);
|
|
1277
|
-
nk_f32_t total_sq_z = nk_reduce_add_f32x8_haswell_(sum_squared_z_f32x8);
|
|
1278
|
-
|
|
1279
|
-
// Scalar tail
|
|
1069
|
+
nk_f32_t sum_squared = nk_reduce_add_f32x8_haswell_(sum_squared_x_f32x8) +
|
|
1070
|
+
nk_reduce_add_f32x8_haswell_(sum_squared_y_f32x8) +
|
|
1071
|
+
nk_reduce_add_f32x8_haswell_(sum_squared_z_f32x8);
|
|
1280
1072
|
for (; i < n; ++i) {
|
|
1281
1073
|
nk_f32_t ax, ay, az, bx, by, bz;
|
|
1282
1074
|
nk_f16_to_f32_haswell(&a[i * 3 + 0], &ax);
|
|
@@ -1285,91 +1077,41 @@ NK_PUBLIC void nk_rmsd_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size
|
|
|
1285
1077
|
nk_f16_to_f32_haswell(&b[i * 3 + 0], &bx);
|
|
1286
1078
|
nk_f16_to_f32_haswell(&b[i * 3 + 1], &by);
|
|
1287
1079
|
nk_f16_to_f32_haswell(&b[i * 3 + 2], &bz);
|
|
1288
|
-
total_ax += ax;
|
|
1289
|
-
total_ay += ay;
|
|
1290
|
-
total_az += az;
|
|
1291
|
-
total_bx += bx;
|
|
1292
|
-
total_by += by;
|
|
1293
|
-
total_bz += bz;
|
|
1294
1080
|
nk_f32_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
|
|
1295
|
-
|
|
1296
|
-
total_sq_y += delta_y * delta_y;
|
|
1297
|
-
total_sq_z += delta_z * delta_z;
|
|
1081
|
+
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
1298
1082
|
}
|
|
1299
1083
|
|
|
1300
|
-
|
|
1301
|
-
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1302
|
-
nk_f32_t centroid_a_x = total_ax * inv_n;
|
|
1303
|
-
nk_f32_t centroid_a_y = total_ay * inv_n;
|
|
1304
|
-
nk_f32_t centroid_a_z = total_az * inv_n;
|
|
1305
|
-
nk_f32_t centroid_b_x = total_bx * inv_n;
|
|
1306
|
-
nk_f32_t centroid_b_y = total_by * inv_n;
|
|
1307
|
-
nk_f32_t centroid_b_z = total_bz * inv_n;
|
|
1308
|
-
|
|
1309
|
-
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1310
|
-
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1311
|
-
|
|
1312
|
-
// Compute RMSD
|
|
1313
|
-
nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
1314
|
-
nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
1315
|
-
nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
1316
|
-
nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
|
|
1317
|
-
nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
1318
|
-
|
|
1319
|
-
*result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
1084
|
+
*result = nk_f32_sqrt_haswell(sum_squared / (nk_f32_t)n);
|
|
1320
1085
|
}
|
|
1321
1086
|
|
|
1322
1087
|
NK_PUBLIC void nk_rmsd_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1323
1088
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1324
|
-
// RMSD uses identity rotation and scale=1.0
|
|
1325
1089
|
if (rotation)
|
|
1326
1090
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
1327
1091
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
1328
1092
|
if (scale) *scale = 1.0f;
|
|
1093
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
1094
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
1329
1095
|
|
|
1330
1096
|
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
1331
|
-
|
|
1332
|
-
// Accumulators for centroids and squared differences (all in f32)
|
|
1333
|
-
__m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
|
|
1334
|
-
__m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
|
|
1335
1097
|
__m256 sum_squared_x_f32x8 = zeros_f32x8, sum_squared_y_f32x8 = zeros_f32x8, sum_squared_z_f32x8 = zeros_f32x8;
|
|
1336
|
-
|
|
1337
1098
|
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
1338
1099
|
nk_size_t i = 0;
|
|
1339
1100
|
|
|
1340
|
-
// Main loop processing 8 points at a time
|
|
1341
1101
|
for (; i + 8 <= n; i += 8) {
|
|
1342
1102
|
nk_deinterleave_bf16x8_to_f32x8_haswell_(a + i * 3, &a_x_f32x8, &a_y_f32x8, &a_z_f32x8);
|
|
1343
1103
|
nk_deinterleave_bf16x8_to_f32x8_haswell_(b + i * 3, &b_x_f32x8, &b_y_f32x8, &b_z_f32x8);
|
|
1344
|
-
|
|
1345
|
-
sum_a_x_f32x8 = _mm256_add_ps(sum_a_x_f32x8, a_x_f32x8);
|
|
1346
|
-
sum_a_y_f32x8 = _mm256_add_ps(sum_a_y_f32x8, a_y_f32x8);
|
|
1347
|
-
sum_a_z_f32x8 = _mm256_add_ps(sum_a_z_f32x8, a_z_f32x8);
|
|
1348
|
-
sum_b_x_f32x8 = _mm256_add_ps(sum_b_x_f32x8, b_x_f32x8);
|
|
1349
|
-
sum_b_y_f32x8 = _mm256_add_ps(sum_b_y_f32x8, b_y_f32x8);
|
|
1350
|
-
sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
|
|
1351
|
-
|
|
1352
1104
|
__m256 delta_x_f32x8 = _mm256_sub_ps(a_x_f32x8, b_x_f32x8);
|
|
1353
1105
|
__m256 delta_y_f32x8 = _mm256_sub_ps(a_y_f32x8, b_y_f32x8);
|
|
1354
1106
|
__m256 delta_z_f32x8 = _mm256_sub_ps(a_z_f32x8, b_z_f32x8);
|
|
1355
|
-
|
|
1356
1107
|
sum_squared_x_f32x8 = _mm256_fmadd_ps(delta_x_f32x8, delta_x_f32x8, sum_squared_x_f32x8);
|
|
1357
1108
|
sum_squared_y_f32x8 = _mm256_fmadd_ps(delta_y_f32x8, delta_y_f32x8, sum_squared_y_f32x8);
|
|
1358
1109
|
sum_squared_z_f32x8 = _mm256_fmadd_ps(delta_z_f32x8, delta_z_f32x8, sum_squared_z_f32x8);
|
|
1359
1110
|
}
|
|
1360
1111
|
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
nk_f32_t total_az = nk_reduce_add_f32x8_haswell_(sum_a_z_f32x8);
|
|
1365
|
-
nk_f32_t total_bx = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
|
|
1366
|
-
nk_f32_t total_by = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
|
|
1367
|
-
nk_f32_t total_bz = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
|
|
1368
|
-
nk_f32_t total_sq_x = nk_reduce_add_f32x8_haswell_(sum_squared_x_f32x8);
|
|
1369
|
-
nk_f32_t total_sq_y = nk_reduce_add_f32x8_haswell_(sum_squared_y_f32x8);
|
|
1370
|
-
nk_f32_t total_sq_z = nk_reduce_add_f32x8_haswell_(sum_squared_z_f32x8);
|
|
1371
|
-
|
|
1372
|
-
// Scalar tail
|
|
1112
|
+
nk_f32_t sum_squared = nk_reduce_add_f32x8_haswell_(sum_squared_x_f32x8) +
|
|
1113
|
+
nk_reduce_add_f32x8_haswell_(sum_squared_y_f32x8) +
|
|
1114
|
+
nk_reduce_add_f32x8_haswell_(sum_squared_z_f32x8);
|
|
1373
1115
|
for (; i < n; ++i) {
|
|
1374
1116
|
nk_f32_t ax, ay, az, bx, by, bz;
|
|
1375
1117
|
nk_bf16_to_f32_serial(&a[i * 3 + 0], &ax);
|
|
@@ -1378,43 +1120,16 @@ NK_PUBLIC void nk_rmsd_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_s
|
|
|
1378
1120
|
nk_bf16_to_f32_serial(&b[i * 3 + 0], &bx);
|
|
1379
1121
|
nk_bf16_to_f32_serial(&b[i * 3 + 1], &by);
|
|
1380
1122
|
nk_bf16_to_f32_serial(&b[i * 3 + 2], &bz);
|
|
1381
|
-
total_ax += ax;
|
|
1382
|
-
total_ay += ay;
|
|
1383
|
-
total_az += az;
|
|
1384
|
-
total_bx += bx;
|
|
1385
|
-
total_by += by;
|
|
1386
|
-
total_bz += bz;
|
|
1387
1123
|
nk_f32_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
|
|
1388
|
-
|
|
1389
|
-
total_sq_y += delta_y * delta_y;
|
|
1390
|
-
total_sq_z += delta_z * delta_z;
|
|
1124
|
+
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
1391
1125
|
}
|
|
1392
1126
|
|
|
1393
|
-
|
|
1394
|
-
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1395
|
-
nk_f32_t centroid_a_x = total_ax * inv_n;
|
|
1396
|
-
nk_f32_t centroid_a_y = total_ay * inv_n;
|
|
1397
|
-
nk_f32_t centroid_a_z = total_az * inv_n;
|
|
1398
|
-
nk_f32_t centroid_b_x = total_bx * inv_n;
|
|
1399
|
-
nk_f32_t centroid_b_y = total_by * inv_n;
|
|
1400
|
-
nk_f32_t centroid_b_z = total_bz * inv_n;
|
|
1401
|
-
|
|
1402
|
-
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1403
|
-
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1404
|
-
|
|
1405
|
-
// Compute RMSD
|
|
1406
|
-
nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
1407
|
-
nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
1408
|
-
nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
1409
|
-
nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
|
|
1410
|
-
nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
1411
|
-
|
|
1412
|
-
*result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
1127
|
+
*result = nk_f32_sqrt_haswell(sum_squared / (nk_f32_t)n);
|
|
1413
1128
|
}
|
|
1414
1129
|
|
|
1415
1130
|
NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1416
1131
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1417
|
-
// Fused single-pass: load f16, convert to f32, compute centroids and
|
|
1132
|
+
// Fused single-pass: load f16, convert to f32, compute centroids, covariance, and ‖a‖²/‖b‖²
|
|
1418
1133
|
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
1419
1134
|
|
|
1420
1135
|
// Accumulators for centroids (f32)
|
|
@@ -1422,9 +1137,10 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
|
|
|
1422
1137
|
__m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
|
|
1423
1138
|
|
|
1424
1139
|
// Accumulators for covariance matrix (sum of outer products)
|
|
1425
|
-
__m256
|
|
1426
|
-
__m256
|
|
1427
|
-
__m256
|
|
1140
|
+
__m256 covariance_xx_f32x8 = zeros_f32x8, covariance_xy_f32x8 = zeros_f32x8, covariance_xz_f32x8 = zeros_f32x8;
|
|
1141
|
+
__m256 covariance_yx_f32x8 = zeros_f32x8, covariance_yy_f32x8 = zeros_f32x8, covariance_yz_f32x8 = zeros_f32x8;
|
|
1142
|
+
__m256 covariance_zx_f32x8 = zeros_f32x8, covariance_zy_f32x8 = zeros_f32x8, covariance_zz_f32x8 = zeros_f32x8;
|
|
1143
|
+
__m256 norm_squared_a_f32x8 = zeros_f32x8, norm_squared_b_f32x8 = zeros_f32x8;
|
|
1428
1144
|
|
|
1429
1145
|
nk_size_t i = 0;
|
|
1430
1146
|
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
@@ -1442,15 +1158,23 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
|
|
|
1442
1158
|
sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
|
|
1443
1159
|
|
|
1444
1160
|
// Accumulate outer products
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1161
|
+
covariance_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, covariance_xx_f32x8);
|
|
1162
|
+
covariance_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, covariance_xy_f32x8);
|
|
1163
|
+
covariance_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, covariance_xz_f32x8);
|
|
1164
|
+
covariance_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, covariance_yx_f32x8);
|
|
1165
|
+
covariance_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, covariance_yy_f32x8);
|
|
1166
|
+
covariance_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, covariance_yz_f32x8);
|
|
1167
|
+
covariance_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, covariance_zx_f32x8);
|
|
1168
|
+
covariance_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, covariance_zy_f32x8);
|
|
1169
|
+
covariance_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, covariance_zz_f32x8);
|
|
1170
|
+
|
|
1171
|
+
// Accumulate ‖a‖² and ‖b‖² for folded SSD
|
|
1172
|
+
norm_squared_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, norm_squared_a_f32x8);
|
|
1173
|
+
norm_squared_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, norm_squared_a_f32x8);
|
|
1174
|
+
norm_squared_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, norm_squared_a_f32x8);
|
|
1175
|
+
norm_squared_b_f32x8 = _mm256_fmadd_ps(b_x_f32x8, b_x_f32x8, norm_squared_b_f32x8);
|
|
1176
|
+
norm_squared_b_f32x8 = _mm256_fmadd_ps(b_y_f32x8, b_y_f32x8, norm_squared_b_f32x8);
|
|
1177
|
+
norm_squared_b_f32x8 = _mm256_fmadd_ps(b_z_f32x8, b_z_f32x8, norm_squared_b_f32x8);
|
|
1454
1178
|
}
|
|
1455
1179
|
|
|
1456
1180
|
// Reduce vector accumulators
|
|
@@ -1461,15 +1185,17 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
|
|
|
1461
1185
|
nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
|
|
1462
1186
|
nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
|
|
1463
1187
|
|
|
1464
|
-
nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(
|
|
1465
|
-
nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(
|
|
1466
|
-
nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(
|
|
1467
|
-
nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(
|
|
1468
|
-
nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(
|
|
1469
|
-
nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(
|
|
1470
|
-
nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(
|
|
1471
|
-
nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(
|
|
1472
|
-
nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(
|
|
1188
|
+
nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(covariance_xx_f32x8);
|
|
1189
|
+
nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(covariance_xy_f32x8);
|
|
1190
|
+
nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(covariance_xz_f32x8);
|
|
1191
|
+
nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(covariance_yx_f32x8);
|
|
1192
|
+
nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(covariance_yy_f32x8);
|
|
1193
|
+
nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(covariance_yz_f32x8);
|
|
1194
|
+
nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(covariance_zx_f32x8);
|
|
1195
|
+
nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(covariance_zy_f32x8);
|
|
1196
|
+
nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(covariance_zz_f32x8);
|
|
1197
|
+
nk_f32_t norm_squared_a_sum = nk_reduce_add_f32x8_haswell_(norm_squared_a_f32x8);
|
|
1198
|
+
nk_f32_t norm_squared_b_sum = nk_reduce_add_f32x8_haswell_(norm_squared_b_f32x8);
|
|
1473
1199
|
|
|
1474
1200
|
// Scalar tail
|
|
1475
1201
|
for (; i < n; ++i) {
|
|
@@ -1485,6 +1211,8 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
|
|
|
1485
1211
|
covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
|
|
1486
1212
|
covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
|
|
1487
1213
|
covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
|
|
1214
|
+
norm_squared_a_sum += ax * ax + ay * ay + az * az;
|
|
1215
|
+
norm_squared_b_sum += bx * bx + by * by + bz * bz;
|
|
1488
1216
|
}
|
|
1489
1217
|
|
|
1490
1218
|
// Compute centroids
|
|
@@ -1510,52 +1238,84 @@ NK_PUBLIC void nk_kabsch_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_si
|
|
|
1510
1238
|
covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
1511
1239
|
covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
1512
1240
|
|
|
1513
|
-
// Compute SVD and optimal rotation
|
|
1514
1241
|
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
1515
1242
|
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1531
|
-
|
|
1532
|
-
|
|
1533
|
-
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1243
|
+
|
|
1244
|
+
// Centered ‖a-ā‖², ‖b-b̄‖² via the parallel-axis identity
|
|
1245
|
+
nk_f32_t centered_norm_squared_a = norm_squared_a_sum -
|
|
1246
|
+
(nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
1247
|
+
centroid_a_z * centroid_a_z);
|
|
1248
|
+
nk_f32_t centered_norm_squared_b = norm_squared_b_sum -
|
|
1249
|
+
(nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
1250
|
+
centroid_b_z * centroid_b_z);
|
|
1251
|
+
if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
|
|
1252
|
+
if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
|
|
1253
|
+
|
|
1254
|
+
// Identity-dominant short-circuit: R = I, trace(R · H) = H[0]+H[4]+H[8]. Skips SVD + two
|
|
1255
|
+
// rotation reconstructions when the inputs are already aligned.
|
|
1256
|
+
nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
1257
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
1258
|
+
cross_covariance[8] * cross_covariance[8];
|
|
1259
|
+
nk_f32_t covariance_offdiagonal_norm_squared =
|
|
1260
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
1261
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
1262
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
1263
|
+
nk_f32_t optimal_rotation[9];
|
|
1264
|
+
nk_f32_t trace_rotation_covariance;
|
|
1265
|
+
if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
|
|
1266
|
+
cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
|
|
1267
|
+
optimal_rotation[0] = 1.0f, optimal_rotation[1] = 0.0f, optimal_rotation[2] = 0.0f;
|
|
1268
|
+
optimal_rotation[3] = 0.0f, optimal_rotation[4] = 1.0f, optimal_rotation[5] = 0.0f;
|
|
1269
|
+
optimal_rotation[6] = 0.0f, optimal_rotation[7] = 0.0f, optimal_rotation[8] = 1.0f;
|
|
1270
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
1271
|
+
}
|
|
1272
|
+
else {
|
|
1273
|
+
nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
1274
|
+
nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
1275
|
+
// R = V * Uᵀ
|
|
1276
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
1277
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
1278
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
1279
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
1280
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
1281
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
1282
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
1283
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
1284
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
1285
|
+
if (nk_det3x3_f32_(optimal_rotation) < 0) {
|
|
1286
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
1287
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
1288
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
1289
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
1290
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
1291
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
1292
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
1293
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
1294
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
1295
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
1296
|
+
}
|
|
1297
|
+
trace_rotation_covariance =
|
|
1298
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
1299
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
1300
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
1301
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
1302
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
1543
1303
|
}
|
|
1544
1304
|
|
|
1545
1305
|
// Output rotation matrix and scale=1.0
|
|
1546
1306
|
if (rotation)
|
|
1547
|
-
for (int j = 0; j < 9; ++j) rotation[j] =
|
|
1307
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
1548
1308
|
if (scale) *scale = 1.0f;
|
|
1549
1309
|
|
|
1550
|
-
//
|
|
1551
|
-
nk_f32_t sum_squared =
|
|
1552
|
-
|
|
1310
|
+
// Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
|
|
1311
|
+
nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
|
|
1312
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
1553
1313
|
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
1554
1314
|
}
|
|
1555
1315
|
|
|
1556
1316
|
NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1557
1317
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1558
|
-
// Fused single-pass: load bf16, convert to f32, compute centroids and
|
|
1318
|
+
// Fused single-pass: load bf16, convert to f32, compute centroids, covariance, and ‖a‖²/‖b‖²
|
|
1559
1319
|
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
1560
1320
|
|
|
1561
1321
|
// Accumulators for centroids (f32)
|
|
@@ -1563,9 +1323,10 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
|
|
|
1563
1323
|
__m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
|
|
1564
1324
|
|
|
1565
1325
|
// Accumulators for covariance matrix (sum of outer products)
|
|
1566
|
-
__m256
|
|
1567
|
-
__m256
|
|
1568
|
-
__m256
|
|
1326
|
+
__m256 covariance_xx_f32x8 = zeros_f32x8, covariance_xy_f32x8 = zeros_f32x8, covariance_xz_f32x8 = zeros_f32x8;
|
|
1327
|
+
__m256 covariance_yx_f32x8 = zeros_f32x8, covariance_yy_f32x8 = zeros_f32x8, covariance_yz_f32x8 = zeros_f32x8;
|
|
1328
|
+
__m256 covariance_zx_f32x8 = zeros_f32x8, covariance_zy_f32x8 = zeros_f32x8, covariance_zz_f32x8 = zeros_f32x8;
|
|
1329
|
+
__m256 norm_squared_a_f32x8 = zeros_f32x8, norm_squared_b_f32x8 = zeros_f32x8;
|
|
1569
1330
|
|
|
1570
1331
|
nk_size_t i = 0;
|
|
1571
1332
|
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
@@ -1583,15 +1344,23 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
|
|
|
1583
1344
|
sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
|
|
1584
1345
|
|
|
1585
1346
|
// Accumulate outer products
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
|
|
1347
|
+
covariance_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, covariance_xx_f32x8);
|
|
1348
|
+
covariance_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, covariance_xy_f32x8);
|
|
1349
|
+
covariance_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, covariance_xz_f32x8);
|
|
1350
|
+
covariance_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, covariance_yx_f32x8);
|
|
1351
|
+
covariance_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, covariance_yy_f32x8);
|
|
1352
|
+
covariance_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, covariance_yz_f32x8);
|
|
1353
|
+
covariance_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, covariance_zx_f32x8);
|
|
1354
|
+
covariance_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, covariance_zy_f32x8);
|
|
1355
|
+
covariance_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, covariance_zz_f32x8);
|
|
1356
|
+
|
|
1357
|
+
// Accumulate ‖a‖² and ‖b‖² for folded SSD
|
|
1358
|
+
norm_squared_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, norm_squared_a_f32x8);
|
|
1359
|
+
norm_squared_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, norm_squared_a_f32x8);
|
|
1360
|
+
norm_squared_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, norm_squared_a_f32x8);
|
|
1361
|
+
norm_squared_b_f32x8 = _mm256_fmadd_ps(b_x_f32x8, b_x_f32x8, norm_squared_b_f32x8);
|
|
1362
|
+
norm_squared_b_f32x8 = _mm256_fmadd_ps(b_y_f32x8, b_y_f32x8, norm_squared_b_f32x8);
|
|
1363
|
+
norm_squared_b_f32x8 = _mm256_fmadd_ps(b_z_f32x8, b_z_f32x8, norm_squared_b_f32x8);
|
|
1595
1364
|
}
|
|
1596
1365
|
|
|
1597
1366
|
// Reduce vector accumulators
|
|
@@ -1602,15 +1371,17 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
|
|
|
1602
1371
|
nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
|
|
1603
1372
|
nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
|
|
1604
1373
|
|
|
1605
|
-
nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(
|
|
1606
|
-
nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(
|
|
1607
|
-
nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(
|
|
1608
|
-
nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(
|
|
1609
|
-
nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(
|
|
1610
|
-
nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(
|
|
1611
|
-
nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(
|
|
1612
|
-
nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(
|
|
1613
|
-
nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(
|
|
1374
|
+
nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(covariance_xx_f32x8);
|
|
1375
|
+
nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(covariance_xy_f32x8);
|
|
1376
|
+
nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(covariance_xz_f32x8);
|
|
1377
|
+
nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(covariance_yx_f32x8);
|
|
1378
|
+
nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(covariance_yy_f32x8);
|
|
1379
|
+
nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(covariance_yz_f32x8);
|
|
1380
|
+
nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(covariance_zx_f32x8);
|
|
1381
|
+
nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(covariance_zy_f32x8);
|
|
1382
|
+
nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(covariance_zz_f32x8);
|
|
1383
|
+
nk_f32_t norm_squared_a_sum = nk_reduce_add_f32x8_haswell_(norm_squared_a_f32x8);
|
|
1384
|
+
nk_f32_t norm_squared_b_sum = nk_reduce_add_f32x8_haswell_(norm_squared_b_f32x8);
|
|
1614
1385
|
|
|
1615
1386
|
// Scalar tail
|
|
1616
1387
|
for (; i < n; ++i) {
|
|
@@ -1626,6 +1397,8 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
|
|
|
1626
1397
|
covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
|
|
1627
1398
|
covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
|
|
1628
1399
|
covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
|
|
1400
|
+
norm_squared_a_sum += ax * ax + ay * ay + az * az;
|
|
1401
|
+
norm_squared_b_sum += bx * bx + by * by + bz * bz;
|
|
1629
1402
|
}
|
|
1630
1403
|
|
|
1631
1404
|
// Compute centroids
|
|
@@ -1651,60 +1424,92 @@ NK_PUBLIC void nk_kabsch_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk
|
|
|
1651
1424
|
covariance_z_y -= (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
1652
1425
|
covariance_z_z -= (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
1653
1426
|
|
|
1654
|
-
// Compute SVD and optimal rotation
|
|
1655
1427
|
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
1656
1428
|
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
1657
|
-
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
|
-
|
|
1663
|
-
|
|
1664
|
-
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
|
|
1670
|
-
|
|
1671
|
-
|
|
1672
|
-
|
|
1673
|
-
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
|
|
1677
|
-
|
|
1678
|
-
|
|
1679
|
-
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
|
|
1683
|
-
|
|
1429
|
+
|
|
1430
|
+
// Centered ‖a-ā‖², ‖b-b̄‖² via the parallel-axis identity
|
|
1431
|
+
nk_f32_t centered_norm_squared_a = norm_squared_a_sum -
|
|
1432
|
+
(nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
1433
|
+
centroid_a_z * centroid_a_z);
|
|
1434
|
+
nk_f32_t centered_norm_squared_b = norm_squared_b_sum -
|
|
1435
|
+
(nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
1436
|
+
centroid_b_z * centroid_b_z);
|
|
1437
|
+
if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
|
|
1438
|
+
if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
|
|
1439
|
+
|
|
1440
|
+
// Identity-dominant short-circuit: R = I, trace(R · H) = H[0]+H[4]+H[8]. Skips SVD + two
|
|
1441
|
+
// rotation reconstructions when the inputs are already aligned.
|
|
1442
|
+
nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
1443
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
1444
|
+
cross_covariance[8] * cross_covariance[8];
|
|
1445
|
+
nk_f32_t covariance_offdiagonal_norm_squared =
|
|
1446
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
1447
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
1448
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
1449
|
+
nk_f32_t optimal_rotation[9];
|
|
1450
|
+
nk_f32_t trace_rotation_covariance;
|
|
1451
|
+
if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
|
|
1452
|
+
cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
|
|
1453
|
+
optimal_rotation[0] = 1.0f, optimal_rotation[1] = 0.0f, optimal_rotation[2] = 0.0f;
|
|
1454
|
+
optimal_rotation[3] = 0.0f, optimal_rotation[4] = 1.0f, optimal_rotation[5] = 0.0f;
|
|
1455
|
+
optimal_rotation[6] = 0.0f, optimal_rotation[7] = 0.0f, optimal_rotation[8] = 1.0f;
|
|
1456
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
1457
|
+
}
|
|
1458
|
+
else {
|
|
1459
|
+
nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
1460
|
+
nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
1461
|
+
// R = V * Uᵀ
|
|
1462
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
1463
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
1464
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
1465
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
1466
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
1467
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
1468
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
1469
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
1470
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
1471
|
+
if (nk_det3x3_f32_(optimal_rotation) < 0) {
|
|
1472
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
1473
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
1474
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
1475
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
1476
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
1477
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
1478
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
1479
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
1480
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
1481
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
1482
|
+
}
|
|
1483
|
+
trace_rotation_covariance =
|
|
1484
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
1485
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
1486
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
1487
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
1488
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
1684
1489
|
}
|
|
1685
1490
|
|
|
1686
1491
|
// Output rotation matrix and scale=1.0
|
|
1687
1492
|
if (rotation)
|
|
1688
|
-
for (int j = 0; j < 9; ++j) rotation[j] =
|
|
1493
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
1689
1494
|
if (scale) *scale = 1.0f;
|
|
1690
1495
|
|
|
1691
|
-
//
|
|
1692
|
-
nk_f32_t sum_squared =
|
|
1693
|
-
|
|
1496
|
+
// Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
|
|
1497
|
+
nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
|
|
1498
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
1694
1499
|
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
1695
1500
|
}
|
|
1696
1501
|
|
|
1697
1502
|
NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1698
1503
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1699
|
-
// Fused single-pass: load f16, convert to f32, compute centroids, covariance, and
|
|
1504
|
+
// Fused single-pass: load f16, convert to f32, compute centroids, covariance, and ‖a‖²/‖b‖²
|
|
1700
1505
|
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
1701
1506
|
|
|
1702
1507
|
__m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
|
|
1703
1508
|
__m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
|
|
1704
|
-
__m256
|
|
1705
|
-
__m256
|
|
1706
|
-
__m256
|
|
1707
|
-
__m256
|
|
1509
|
+
__m256 covariance_xx_f32x8 = zeros_f32x8, covariance_xy_f32x8 = zeros_f32x8, covariance_xz_f32x8 = zeros_f32x8;
|
|
1510
|
+
__m256 covariance_yx_f32x8 = zeros_f32x8, covariance_yy_f32x8 = zeros_f32x8, covariance_yz_f32x8 = zeros_f32x8;
|
|
1511
|
+
__m256 covariance_zx_f32x8 = zeros_f32x8, covariance_zy_f32x8 = zeros_f32x8, covariance_zz_f32x8 = zeros_f32x8;
|
|
1512
|
+
__m256 norm_squared_a_f32x8 = zeros_f32x8, norm_squared_b_f32x8 = zeros_f32x8;
|
|
1708
1513
|
|
|
1709
1514
|
nk_size_t i = 0;
|
|
1710
1515
|
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
@@ -1722,20 +1527,23 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
|
|
|
1722
1527
|
sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
|
|
1723
1528
|
|
|
1724
1529
|
// Accumulate outer products
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
|
|
1731
|
-
|
|
1732
|
-
|
|
1733
|
-
|
|
1734
|
-
|
|
1735
|
-
// Accumulate
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
|
|
1530
|
+
covariance_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, covariance_xx_f32x8);
|
|
1531
|
+
covariance_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, covariance_xy_f32x8);
|
|
1532
|
+
covariance_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, covariance_xz_f32x8);
|
|
1533
|
+
covariance_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, covariance_yx_f32x8);
|
|
1534
|
+
covariance_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, covariance_yy_f32x8);
|
|
1535
|
+
covariance_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, covariance_yz_f32x8);
|
|
1536
|
+
covariance_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, covariance_zx_f32x8);
|
|
1537
|
+
covariance_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, covariance_zy_f32x8);
|
|
1538
|
+
covariance_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, covariance_zz_f32x8);
|
|
1539
|
+
|
|
1540
|
+
// Accumulate ‖a‖² and ‖b‖² for folded SSD
|
|
1541
|
+
norm_squared_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, norm_squared_a_f32x8);
|
|
1542
|
+
norm_squared_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, norm_squared_a_f32x8);
|
|
1543
|
+
norm_squared_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, norm_squared_a_f32x8);
|
|
1544
|
+
norm_squared_b_f32x8 = _mm256_fmadd_ps(b_x_f32x8, b_x_f32x8, norm_squared_b_f32x8);
|
|
1545
|
+
norm_squared_b_f32x8 = _mm256_fmadd_ps(b_y_f32x8, b_y_f32x8, norm_squared_b_f32x8);
|
|
1546
|
+
norm_squared_b_f32x8 = _mm256_fmadd_ps(b_z_f32x8, b_z_f32x8, norm_squared_b_f32x8);
|
|
1739
1547
|
}
|
|
1740
1548
|
|
|
1741
1549
|
// Reduce vector accumulators
|
|
@@ -1745,16 +1553,17 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
|
|
|
1745
1553
|
nk_f32_t sum_b_x = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
|
|
1746
1554
|
nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
|
|
1747
1555
|
nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
|
|
1748
|
-
nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(
|
|
1749
|
-
nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(
|
|
1750
|
-
nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(
|
|
1751
|
-
nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(
|
|
1752
|
-
nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(
|
|
1753
|
-
nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(
|
|
1754
|
-
nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(
|
|
1755
|
-
nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(
|
|
1756
|
-
nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(
|
|
1757
|
-
nk_f32_t
|
|
1556
|
+
nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(covariance_xx_f32x8);
|
|
1557
|
+
nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(covariance_xy_f32x8);
|
|
1558
|
+
nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(covariance_xz_f32x8);
|
|
1559
|
+
nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(covariance_yx_f32x8);
|
|
1560
|
+
nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(covariance_yy_f32x8);
|
|
1561
|
+
nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(covariance_yz_f32x8);
|
|
1562
|
+
nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(covariance_zx_f32x8);
|
|
1563
|
+
nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(covariance_zy_f32x8);
|
|
1564
|
+
nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(covariance_zz_f32x8);
|
|
1565
|
+
nk_f32_t norm_squared_a_sum = nk_reduce_add_f32x8_haswell_(norm_squared_a_f32x8);
|
|
1566
|
+
nk_f32_t norm_squared_b_sum = nk_reduce_add_f32x8_haswell_(norm_squared_b_f32x8);
|
|
1758
1567
|
|
|
1759
1568
|
// Scalar tail
|
|
1760
1569
|
for (; i < n; ++i) {
|
|
@@ -1770,7 +1579,8 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
|
|
|
1770
1579
|
covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
|
|
1771
1580
|
covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
|
|
1772
1581
|
covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
|
|
1773
|
-
|
|
1582
|
+
norm_squared_a_sum += ax * ax + ay * ay + az * az;
|
|
1583
|
+
norm_squared_b_sum += bx * bx + by * by + bz * bz;
|
|
1774
1584
|
}
|
|
1775
1585
|
|
|
1776
1586
|
// Compute centroids
|
|
@@ -1781,10 +1591,6 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
|
|
|
1781
1591
|
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1782
1592
|
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1783
1593
|
|
|
1784
|
-
// Compute centered covariance and variance
|
|
1785
|
-
nk_f32_t variance_a = variance_a_sum * inv_n -
|
|
1786
|
-
(centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
|
|
1787
|
-
|
|
1788
1594
|
// Apply centering correction to covariance matrix
|
|
1789
1595
|
covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
1790
1596
|
covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
@@ -1799,64 +1605,97 @@ NK_PUBLIC void nk_umeyama_f16_haswell(nk_f16_t const *a, nk_f16_t const *b, nk_s
|
|
|
1799
1605
|
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
1800
1606
|
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
1801
1607
|
|
|
1802
|
-
//
|
|
1803
|
-
nk_f32_t
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
|
|
1812
|
-
|
|
1813
|
-
|
|
1814
|
-
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
|
|
1820
|
-
|
|
1821
|
-
nk_f32_t
|
|
1822
|
-
nk_f32_t
|
|
1823
|
-
|
|
1824
|
-
|
|
1825
|
-
|
|
1826
|
-
|
|
1827
|
-
|
|
1828
|
-
|
|
1829
|
-
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
|
|
1833
|
-
|
|
1834
|
-
|
|
1835
|
-
|
|
1836
|
-
|
|
1608
|
+
// Centered ‖a-ā‖², ‖b-b̄‖² via the parallel-axis identity
|
|
1609
|
+
nk_f32_t centered_norm_squared_a = norm_squared_a_sum -
|
|
1610
|
+
(nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
1611
|
+
centroid_a_z * centroid_a_z);
|
|
1612
|
+
nk_f32_t centered_norm_squared_b = norm_squared_b_sum -
|
|
1613
|
+
(nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
1614
|
+
centroid_b_z * centroid_b_z);
|
|
1615
|
+
if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
|
|
1616
|
+
if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
|
|
1617
|
+
|
|
1618
|
+
// Identity-dominant short-circuit: if H is essentially diagonal with positive diagonals,
|
|
1619
|
+
// R = I and trace(DS) = trace(H).
|
|
1620
|
+
nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
1621
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
1622
|
+
cross_covariance[8] * cross_covariance[8];
|
|
1623
|
+
nk_f32_t covariance_offdiagonal_norm_squared =
|
|
1624
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
1625
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
1626
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
1627
|
+
nk_f32_t optimal_rotation[9];
|
|
1628
|
+
nk_f32_t applied_scale;
|
|
1629
|
+
nk_f32_t trace_rotation_covariance;
|
|
1630
|
+
if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
|
|
1631
|
+
cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
|
|
1632
|
+
optimal_rotation[0] = 1.0f, optimal_rotation[1] = 0.0f, optimal_rotation[2] = 0.0f;
|
|
1633
|
+
optimal_rotation[3] = 0.0f, optimal_rotation[4] = 1.0f, optimal_rotation[5] = 0.0f;
|
|
1634
|
+
optimal_rotation[6] = 0.0f, optimal_rotation[7] = 0.0f, optimal_rotation[8] = 1.0f;
|
|
1635
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
1636
|
+
applied_scale = centered_norm_squared_a > 0.0f ? trace_rotation_covariance / centered_norm_squared_a : 0.0f;
|
|
1637
|
+
}
|
|
1638
|
+
else {
|
|
1639
|
+
nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
1640
|
+
nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
1641
|
+
// R = V * Uᵀ
|
|
1642
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
1643
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
1644
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
1645
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
1646
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
1647
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
1648
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
1649
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
1650
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
1651
|
+
|
|
1652
|
+
nk_f32_t det = nk_det3x3_f32_(optimal_rotation);
|
|
1653
|
+
nk_f32_t sign_correction = det < 0 ? -1.0f : 1.0f;
|
|
1654
|
+
if (det < 0) {
|
|
1655
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
1656
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
1657
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
1658
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
1659
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
1660
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
1661
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
1662
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
1663
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
1664
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
1665
|
+
}
|
|
1666
|
+
nk_f32_t trace_ds = svd_diagonal[0] + svd_diagonal[4] + sign_correction * svd_diagonal[8];
|
|
1667
|
+
applied_scale = centered_norm_squared_a > 0.0f ? trace_ds / centered_norm_squared_a : 0.0f;
|
|
1668
|
+
trace_rotation_covariance =
|
|
1669
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
1670
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
1671
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
1672
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
1673
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
1837
1674
|
}
|
|
1838
1675
|
|
|
1839
|
-
// Output rotation matrix
|
|
1676
|
+
// Output rotation matrix and scale
|
|
1840
1677
|
if (rotation)
|
|
1841
|
-
for (int j = 0; j < 9; ++j) rotation[j] =
|
|
1678
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
1679
|
+
if (scale) *scale = applied_scale;
|
|
1842
1680
|
|
|
1843
|
-
//
|
|
1844
|
-
nk_f32_t sum_squared =
|
|
1845
|
-
|
|
1681
|
+
// Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
|
|
1682
|
+
nk_f32_t sum_squared = applied_scale * applied_scale * centered_norm_squared_a + centered_norm_squared_b -
|
|
1683
|
+
2.0f * applied_scale * trace_rotation_covariance;
|
|
1684
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
1846
1685
|
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
1847
1686
|
}
|
|
1848
1687
|
|
|
1849
1688
|
NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1850
1689
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1851
|
-
// Fused single-pass: load bf16, convert to f32, compute centroids, covariance, and
|
|
1690
|
+
// Fused single-pass: load bf16, convert to f32, compute centroids, covariance, and ‖a‖²/‖b‖²
|
|
1852
1691
|
__m256 const zeros_f32x8 = _mm256_setzero_ps();
|
|
1853
1692
|
|
|
1854
1693
|
__m256 sum_a_x_f32x8 = zeros_f32x8, sum_a_y_f32x8 = zeros_f32x8, sum_a_z_f32x8 = zeros_f32x8;
|
|
1855
1694
|
__m256 sum_b_x_f32x8 = zeros_f32x8, sum_b_y_f32x8 = zeros_f32x8, sum_b_z_f32x8 = zeros_f32x8;
|
|
1856
|
-
__m256
|
|
1857
|
-
__m256
|
|
1858
|
-
__m256
|
|
1859
|
-
__m256
|
|
1695
|
+
__m256 covariance_xx_f32x8 = zeros_f32x8, covariance_xy_f32x8 = zeros_f32x8, covariance_xz_f32x8 = zeros_f32x8;
|
|
1696
|
+
__m256 covariance_yx_f32x8 = zeros_f32x8, covariance_yy_f32x8 = zeros_f32x8, covariance_yz_f32x8 = zeros_f32x8;
|
|
1697
|
+
__m256 covariance_zx_f32x8 = zeros_f32x8, covariance_zy_f32x8 = zeros_f32x8, covariance_zz_f32x8 = zeros_f32x8;
|
|
1698
|
+
__m256 norm_squared_a_f32x8 = zeros_f32x8, norm_squared_b_f32x8 = zeros_f32x8;
|
|
1860
1699
|
|
|
1861
1700
|
nk_size_t i = 0;
|
|
1862
1701
|
__m256 a_x_f32x8, a_y_f32x8, a_z_f32x8, b_x_f32x8, b_y_f32x8, b_z_f32x8;
|
|
@@ -1874,20 +1713,23 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
1874
1713
|
sum_b_z_f32x8 = _mm256_add_ps(sum_b_z_f32x8, b_z_f32x8);
|
|
1875
1714
|
|
|
1876
1715
|
// Accumulate outer products
|
|
1877
|
-
|
|
1878
|
-
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
|
|
1886
|
-
|
|
1887
|
-
// Accumulate
|
|
1888
|
-
|
|
1889
|
-
|
|
1890
|
-
|
|
1716
|
+
covariance_xx_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_x_f32x8, covariance_xx_f32x8);
|
|
1717
|
+
covariance_xy_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_y_f32x8, covariance_xy_f32x8);
|
|
1718
|
+
covariance_xz_f32x8 = _mm256_fmadd_ps(a_x_f32x8, b_z_f32x8, covariance_xz_f32x8);
|
|
1719
|
+
covariance_yx_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_x_f32x8, covariance_yx_f32x8);
|
|
1720
|
+
covariance_yy_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_y_f32x8, covariance_yy_f32x8);
|
|
1721
|
+
covariance_yz_f32x8 = _mm256_fmadd_ps(a_y_f32x8, b_z_f32x8, covariance_yz_f32x8);
|
|
1722
|
+
covariance_zx_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_x_f32x8, covariance_zx_f32x8);
|
|
1723
|
+
covariance_zy_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_y_f32x8, covariance_zy_f32x8);
|
|
1724
|
+
covariance_zz_f32x8 = _mm256_fmadd_ps(a_z_f32x8, b_z_f32x8, covariance_zz_f32x8);
|
|
1725
|
+
|
|
1726
|
+
// Accumulate ‖a‖² and ‖b‖² for folded SSD
|
|
1727
|
+
norm_squared_a_f32x8 = _mm256_fmadd_ps(a_x_f32x8, a_x_f32x8, norm_squared_a_f32x8);
|
|
1728
|
+
norm_squared_a_f32x8 = _mm256_fmadd_ps(a_y_f32x8, a_y_f32x8, norm_squared_a_f32x8);
|
|
1729
|
+
norm_squared_a_f32x8 = _mm256_fmadd_ps(a_z_f32x8, a_z_f32x8, norm_squared_a_f32x8);
|
|
1730
|
+
norm_squared_b_f32x8 = _mm256_fmadd_ps(b_x_f32x8, b_x_f32x8, norm_squared_b_f32x8);
|
|
1731
|
+
norm_squared_b_f32x8 = _mm256_fmadd_ps(b_y_f32x8, b_y_f32x8, norm_squared_b_f32x8);
|
|
1732
|
+
norm_squared_b_f32x8 = _mm256_fmadd_ps(b_z_f32x8, b_z_f32x8, norm_squared_b_f32x8);
|
|
1891
1733
|
}
|
|
1892
1734
|
|
|
1893
1735
|
// Reduce vector accumulators
|
|
@@ -1897,16 +1739,17 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
1897
1739
|
nk_f32_t sum_b_x = nk_reduce_add_f32x8_haswell_(sum_b_x_f32x8);
|
|
1898
1740
|
nk_f32_t sum_b_y = nk_reduce_add_f32x8_haswell_(sum_b_y_f32x8);
|
|
1899
1741
|
nk_f32_t sum_b_z = nk_reduce_add_f32x8_haswell_(sum_b_z_f32x8);
|
|
1900
|
-
nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(
|
|
1901
|
-
nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(
|
|
1902
|
-
nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(
|
|
1903
|
-
nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(
|
|
1904
|
-
nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(
|
|
1905
|
-
nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(
|
|
1906
|
-
nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(
|
|
1907
|
-
nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(
|
|
1908
|
-
nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(
|
|
1909
|
-
nk_f32_t
|
|
1742
|
+
nk_f32_t covariance_x_x = nk_reduce_add_f32x8_haswell_(covariance_xx_f32x8);
|
|
1743
|
+
nk_f32_t covariance_x_y = nk_reduce_add_f32x8_haswell_(covariance_xy_f32x8);
|
|
1744
|
+
nk_f32_t covariance_x_z = nk_reduce_add_f32x8_haswell_(covariance_xz_f32x8);
|
|
1745
|
+
nk_f32_t covariance_y_x = nk_reduce_add_f32x8_haswell_(covariance_yx_f32x8);
|
|
1746
|
+
nk_f32_t covariance_y_y = nk_reduce_add_f32x8_haswell_(covariance_yy_f32x8);
|
|
1747
|
+
nk_f32_t covariance_y_z = nk_reduce_add_f32x8_haswell_(covariance_yz_f32x8);
|
|
1748
|
+
nk_f32_t covariance_z_x = nk_reduce_add_f32x8_haswell_(covariance_zx_f32x8);
|
|
1749
|
+
nk_f32_t covariance_z_y = nk_reduce_add_f32x8_haswell_(covariance_zy_f32x8);
|
|
1750
|
+
nk_f32_t covariance_z_z = nk_reduce_add_f32x8_haswell_(covariance_zz_f32x8);
|
|
1751
|
+
nk_f32_t norm_squared_a_sum = nk_reduce_add_f32x8_haswell_(norm_squared_a_f32x8);
|
|
1752
|
+
nk_f32_t norm_squared_b_sum = nk_reduce_add_f32x8_haswell_(norm_squared_b_f32x8);
|
|
1910
1753
|
|
|
1911
1754
|
// Scalar tail
|
|
1912
1755
|
for (; i < n; ++i) {
|
|
@@ -1922,7 +1765,8 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
1922
1765
|
covariance_x_x += ax * bx, covariance_x_y += ax * by, covariance_x_z += ax * bz;
|
|
1923
1766
|
covariance_y_x += ay * bx, covariance_y_y += ay * by, covariance_y_z += ay * bz;
|
|
1924
1767
|
covariance_z_x += az * bx, covariance_z_y += az * by, covariance_z_z += az * bz;
|
|
1925
|
-
|
|
1768
|
+
norm_squared_a_sum += ax * ax + ay * ay + az * az;
|
|
1769
|
+
norm_squared_b_sum += bx * bx + by * by + bz * bz;
|
|
1926
1770
|
}
|
|
1927
1771
|
|
|
1928
1772
|
// Compute centroids
|
|
@@ -1933,10 +1777,6 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
1933
1777
|
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1934
1778
|
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1935
1779
|
|
|
1936
|
-
// Compute centered covariance and variance
|
|
1937
|
-
nk_f32_t variance_a = variance_a_sum * inv_n -
|
|
1938
|
-
(centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y + centroid_a_z * centroid_a_z);
|
|
1939
|
-
|
|
1940
1780
|
// Apply centering correction to covariance matrix
|
|
1941
1781
|
covariance_x_x -= (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
1942
1782
|
covariance_x_y -= (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
@@ -1951,50 +1791,83 @@ NK_PUBLIC void nk_umeyama_bf16_haswell(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
1951
1791
|
nk_f32_t cross_covariance[9] = {covariance_x_x, covariance_x_y, covariance_x_z, covariance_y_x, covariance_y_y,
|
|
1952
1792
|
covariance_y_z, covariance_z_x, covariance_z_y, covariance_z_z};
|
|
1953
1793
|
|
|
1954
|
-
//
|
|
1955
|
-
nk_f32_t
|
|
1956
|
-
|
|
1957
|
-
|
|
1958
|
-
|
|
1959
|
-
|
|
1960
|
-
|
|
1961
|
-
|
|
1962
|
-
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
1972
|
-
|
|
1973
|
-
nk_f32_t
|
|
1974
|
-
nk_f32_t
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
|
|
1794
|
+
// Centered ‖a-ā‖², ‖b-b̄‖² via the parallel-axis identity
|
|
1795
|
+
nk_f32_t centered_norm_squared_a = norm_squared_a_sum -
|
|
1796
|
+
(nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
1797
|
+
centroid_a_z * centroid_a_z);
|
|
1798
|
+
nk_f32_t centered_norm_squared_b = norm_squared_b_sum -
|
|
1799
|
+
(nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
1800
|
+
centroid_b_z * centroid_b_z);
|
|
1801
|
+
if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
|
|
1802
|
+
if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
|
|
1803
|
+
|
|
1804
|
+
// Identity-dominant short-circuit: if H is essentially diagonal with positive diagonals,
|
|
1805
|
+
// R = I and trace(DS) = trace(H).
|
|
1806
|
+
nk_f32_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
1807
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
1808
|
+
cross_covariance[8] * cross_covariance[8];
|
|
1809
|
+
nk_f32_t covariance_offdiagonal_norm_squared =
|
|
1810
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
1811
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
1812
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
1813
|
+
nk_f32_t optimal_rotation[9];
|
|
1814
|
+
nk_f32_t applied_scale;
|
|
1815
|
+
nk_f32_t trace_rotation_covariance;
|
|
1816
|
+
if (covariance_offdiagonal_norm_squared < 1e-12f * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0f &&
|
|
1817
|
+
cross_covariance[4] > 0.0f && cross_covariance[8] > 0.0f) {
|
|
1818
|
+
optimal_rotation[0] = 1.0f, optimal_rotation[1] = 0.0f, optimal_rotation[2] = 0.0f;
|
|
1819
|
+
optimal_rotation[3] = 0.0f, optimal_rotation[4] = 1.0f, optimal_rotation[5] = 0.0f;
|
|
1820
|
+
optimal_rotation[6] = 0.0f, optimal_rotation[7] = 0.0f, optimal_rotation[8] = 1.0f;
|
|
1821
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
1822
|
+
applied_scale = centered_norm_squared_a > 0.0f ? trace_rotation_covariance / centered_norm_squared_a : 0.0f;
|
|
1823
|
+
}
|
|
1824
|
+
else {
|
|
1825
|
+
nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
1826
|
+
nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
1827
|
+
// R = V * Uᵀ
|
|
1828
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
1829
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
1830
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
1831
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
1832
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
1833
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
1834
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
1835
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
1836
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
1837
|
+
|
|
1838
|
+
nk_f32_t det = nk_det3x3_f32_(optimal_rotation);
|
|
1839
|
+
nk_f32_t sign_correction = det < 0 ? -1.0f : 1.0f;
|
|
1840
|
+
if (det < 0) {
|
|
1841
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
1842
|
+
optimal_rotation[0] = svd_right[0] * svd_left[0] + svd_right[1] * svd_left[1] + svd_right[2] * svd_left[2];
|
|
1843
|
+
optimal_rotation[1] = svd_right[0] * svd_left[3] + svd_right[1] * svd_left[4] + svd_right[2] * svd_left[5];
|
|
1844
|
+
optimal_rotation[2] = svd_right[0] * svd_left[6] + svd_right[1] * svd_left[7] + svd_right[2] * svd_left[8];
|
|
1845
|
+
optimal_rotation[3] = svd_right[3] * svd_left[0] + svd_right[4] * svd_left[1] + svd_right[5] * svd_left[2];
|
|
1846
|
+
optimal_rotation[4] = svd_right[3] * svd_left[3] + svd_right[4] * svd_left[4] + svd_right[5] * svd_left[5];
|
|
1847
|
+
optimal_rotation[5] = svd_right[3] * svd_left[6] + svd_right[4] * svd_left[7] + svd_right[5] * svd_left[8];
|
|
1848
|
+
optimal_rotation[6] = svd_right[6] * svd_left[0] + svd_right[7] * svd_left[1] + svd_right[8] * svd_left[2];
|
|
1849
|
+
optimal_rotation[7] = svd_right[6] * svd_left[3] + svd_right[7] * svd_left[4] + svd_right[8] * svd_left[5];
|
|
1850
|
+
optimal_rotation[8] = svd_right[6] * svd_left[6] + svd_right[7] * svd_left[7] + svd_right[8] * svd_left[8];
|
|
1851
|
+
}
|
|
1852
|
+
nk_f32_t trace_ds = svd_diagonal[0] + svd_diagonal[4] + sign_correction * svd_diagonal[8];
|
|
1853
|
+
applied_scale = centered_norm_squared_a > 0.0f ? trace_ds / centered_norm_squared_a : 0.0f;
|
|
1854
|
+
trace_rotation_covariance =
|
|
1855
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
1856
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
1857
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
1858
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
1859
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
1989
1860
|
}
|
|
1990
1861
|
|
|
1991
|
-
// Output rotation matrix
|
|
1862
|
+
// Output rotation matrix and scale
|
|
1992
1863
|
if (rotation)
|
|
1993
|
-
for (int j = 0; j < 9; ++j) rotation[j] =
|
|
1864
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
1865
|
+
if (scale) *scale = applied_scale;
|
|
1994
1866
|
|
|
1995
|
-
//
|
|
1996
|
-
nk_f32_t sum_squared =
|
|
1997
|
-
|
|
1867
|
+
// Folded SSD with scale: c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
|
|
1868
|
+
nk_f32_t sum_squared = applied_scale * applied_scale * centered_norm_squared_a + centered_norm_squared_b -
|
|
1869
|
+
2.0f * applied_scale * trace_rotation_covariance;
|
|
1870
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
1998
1871
|
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
1999
1872
|
}
|
|
2000
1873
|
|