numkong 7.5.0 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/binding.gyp +18 -0
- package/c/dispatch_e5m2.c +23 -3
- package/include/numkong/capabilities.h +1 -1
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +434 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +23 -8
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots.h +12 -0
- package/include/numkong/each/serial.h +18 -1
- package/include/numkong/geospatial/serial.h +14 -3
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +204 -162
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +128 -0
- package/include/numkong/spatials/serial.h +18 -1
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials.h +17 -0
- package/include/numkong/tensor.hpp +107 -23
- package/javascript/numkong.c +3 -2
- package/package.json +7 -7
- package/wasm/numkong.wasm +0 -0
|
@@ -14,9 +14,12 @@
|
|
|
14
14
|
* _mm512_permutex2var_ps VPERMT2PS (ZMM, ZMM, ZMM) 3cy @ p5 4cy @ p12
|
|
15
15
|
* _mm512_extractf32x8_ps VEXTRACTF32X8 (YMM, ZMM, I8) 3cy @ p5 1cy @ p0123
|
|
16
16
|
*
|
|
17
|
-
*
|
|
18
|
-
*
|
|
19
|
-
*
|
|
17
|
+
* Most `*_f32` mesh kernels use a 15-lane stride-3 chunk layout: 5 xyz triplets per ZMM (lane 15
|
|
18
|
+
* masked to zero) so the xyz phase is identical across all chunks and no per-chunk deinterleave is
|
|
19
|
+
* needed. The 9 cross-covariance cells come from three accumulators a*b, a*rot1(b), a*rot2(b)
|
|
20
|
+
* demuxed per channel post-loop, where rot1/rot2 are cheap within-triplet permutexvar rotations.
|
|
21
|
+
* `*_f64`, `*_f16`, `*_bf16` kernels still use VPERMT2PS deinterleave (helpers retained below).
|
|
22
|
+
* Dual FMA accumulators on Skylake-X hide the 4cy latency for centroid and covariance computation.
|
|
20
23
|
*/
|
|
21
24
|
#ifndef NK_MESH_SKYLAKE_H
|
|
22
25
|
#define NK_MESH_SKYLAKE_H
|
|
@@ -231,10 +234,6 @@ NK_INTERNAL nk_f64_t nk_reduce_stable_f64x8_skylake_(__m512d values_f64x8) {
|
|
|
231
234
|
return sum + compensation;
|
|
232
235
|
}
|
|
233
236
|
|
|
234
|
-
NK_INTERNAL void nk_rotation_from_svd_f64_skylake_(nk_f64_t const *svd_u, nk_f64_t const *svd_v, nk_f64_t *rotation) {
|
|
235
|
-
nk_rotation_from_svd_f64_serial_(svd_u, svd_v, rotation);
|
|
236
|
-
}
|
|
237
|
-
|
|
238
237
|
NK_INTERNAL void nk_accumulate_square_f64x8_skylake_(__m512d *sum_f64x8, __m512d *compensation_f64x8,
|
|
239
238
|
__m512d values_f64x8) {
|
|
240
239
|
__m512d product_f64x8 = _mm512_mul_pd(values_f64x8, values_f64x8);
|
|
@@ -248,394 +247,181 @@ NK_INTERNAL void nk_accumulate_square_f64x8_skylake_(__m512d *sum_f64x8, __m512d
|
|
|
248
247
|
*compensation_f64x8 = _mm512_add_pd(*compensation_f64x8, _mm512_add_pd(sum_error_f64x8, product_error_f64x8));
|
|
249
248
|
}
|
|
250
249
|
|
|
251
|
-
/*
|
|
252
|
-
*
|
|
253
|
-
*
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
__m512d scaled_rotation_y_z_f64x8 = _mm512_set1_pd(scale * r[5]);
|
|
266
|
-
__m512d scaled_rotation_z_x_f64x8 = _mm512_set1_pd(scale * r[6]);
|
|
267
|
-
__m512d scaled_rotation_z_y_f64x8 = _mm512_set1_pd(scale * r[7]);
|
|
268
|
-
__m512d scaled_rotation_z_z_f64x8 = _mm512_set1_pd(scale * r[8]);
|
|
269
|
-
__m512d centroid_a_x_f64x8 = _mm512_set1_pd(centroid_a_x), centroid_a_y_f64x8 = _mm512_set1_pd(centroid_a_y);
|
|
270
|
-
__m512d centroid_a_z_f64x8 = _mm512_set1_pd(centroid_a_z), centroid_b_x_f64x8 = _mm512_set1_pd(centroid_b_x);
|
|
271
|
-
__m512d centroid_b_y_f64x8 = _mm512_set1_pd(centroid_b_y), centroid_b_z_f64x8 = _mm512_set1_pd(centroid_b_z);
|
|
272
|
-
__m512d sum_squared_f64x8 = _mm512_setzero_pd();
|
|
273
|
-
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
274
|
-
nk_size_t index = 0;
|
|
275
|
-
|
|
276
|
-
for (; index + 16 <= n; index += 16) {
|
|
277
|
-
nk_deinterleave_f32x16_skylake_(a + index * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16),
|
|
278
|
-
nk_deinterleave_f32x16_skylake_(b + index * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
279
|
-
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
280
|
-
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
281
|
-
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
282
|
-
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
283
|
-
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
284
|
-
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
285
|
-
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
286
|
-
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
287
|
-
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
288
|
-
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
289
|
-
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
290
|
-
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
291
|
-
|
|
292
|
-
__m512d centered_a_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, centroid_a_x_f64x8);
|
|
293
|
-
__m512d centered_a_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, centroid_a_x_f64x8);
|
|
294
|
-
__m512d centered_a_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, centroid_a_y_f64x8);
|
|
295
|
-
__m512d centered_a_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, centroid_a_y_f64x8);
|
|
296
|
-
__m512d centered_a_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, centroid_a_z_f64x8);
|
|
297
|
-
__m512d centered_a_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, centroid_a_z_f64x8);
|
|
298
|
-
__m512d centered_b_x_low_f64x8 = _mm512_sub_pd(b_x_low_f64x8, centroid_b_x_f64x8);
|
|
299
|
-
__m512d centered_b_x_high_f64x8 = _mm512_sub_pd(b_x_high_f64x8, centroid_b_x_f64x8);
|
|
300
|
-
__m512d centered_b_y_low_f64x8 = _mm512_sub_pd(b_y_low_f64x8, centroid_b_y_f64x8);
|
|
301
|
-
__m512d centered_b_y_high_f64x8 = _mm512_sub_pd(b_y_high_f64x8, centroid_b_y_f64x8);
|
|
302
|
-
__m512d centered_b_z_low_f64x8 = _mm512_sub_pd(b_z_low_f64x8, centroid_b_z_f64x8);
|
|
303
|
-
__m512d centered_b_z_high_f64x8 = _mm512_sub_pd(b_z_high_f64x8, centroid_b_z_f64x8);
|
|
304
|
-
|
|
305
|
-
__m512d rotated_a_x_low_f64x8 = _mm512_fmadd_pd(
|
|
306
|
-
scaled_rotation_x_z_f64x8, centered_a_z_low_f64x8,
|
|
307
|
-
_mm512_fmadd_pd(scaled_rotation_x_y_f64x8, centered_a_y_low_f64x8,
|
|
308
|
-
_mm512_mul_pd(scaled_rotation_x_x_f64x8, centered_a_x_low_f64x8)));
|
|
309
|
-
__m512d rotated_a_x_high_f64x8 = _mm512_fmadd_pd(
|
|
310
|
-
scaled_rotation_x_z_f64x8, centered_a_z_high_f64x8,
|
|
311
|
-
_mm512_fmadd_pd(scaled_rotation_x_y_f64x8, centered_a_y_high_f64x8,
|
|
312
|
-
_mm512_mul_pd(scaled_rotation_x_x_f64x8, centered_a_x_high_f64x8)));
|
|
313
|
-
__m512d rotated_a_y_low_f64x8 = _mm512_fmadd_pd(
|
|
314
|
-
scaled_rotation_y_z_f64x8, centered_a_z_low_f64x8,
|
|
315
|
-
_mm512_fmadd_pd(scaled_rotation_y_y_f64x8, centered_a_y_low_f64x8,
|
|
316
|
-
_mm512_mul_pd(scaled_rotation_y_x_f64x8, centered_a_x_low_f64x8)));
|
|
317
|
-
__m512d rotated_a_y_high_f64x8 = _mm512_fmadd_pd(
|
|
318
|
-
scaled_rotation_y_z_f64x8, centered_a_z_high_f64x8,
|
|
319
|
-
_mm512_fmadd_pd(scaled_rotation_y_y_f64x8, centered_a_y_high_f64x8,
|
|
320
|
-
_mm512_mul_pd(scaled_rotation_y_x_f64x8, centered_a_x_high_f64x8)));
|
|
321
|
-
__m512d rotated_a_z_low_f64x8 = _mm512_fmadd_pd(
|
|
322
|
-
scaled_rotation_z_z_f64x8, centered_a_z_low_f64x8,
|
|
323
|
-
_mm512_fmadd_pd(scaled_rotation_z_y_f64x8, centered_a_y_low_f64x8,
|
|
324
|
-
_mm512_mul_pd(scaled_rotation_z_x_f64x8, centered_a_x_low_f64x8)));
|
|
325
|
-
__m512d rotated_a_z_high_f64x8 = _mm512_fmadd_pd(
|
|
326
|
-
scaled_rotation_z_z_f64x8, centered_a_z_high_f64x8,
|
|
327
|
-
_mm512_fmadd_pd(scaled_rotation_z_y_f64x8, centered_a_y_high_f64x8,
|
|
328
|
-
_mm512_mul_pd(scaled_rotation_z_x_f64x8, centered_a_x_high_f64x8)));
|
|
329
|
-
|
|
330
|
-
__m512d delta_x_low_f64x8 = _mm512_sub_pd(rotated_a_x_low_f64x8, centered_b_x_low_f64x8);
|
|
331
|
-
__m512d delta_x_high_f64x8 = _mm512_sub_pd(rotated_a_x_high_f64x8, centered_b_x_high_f64x8);
|
|
332
|
-
__m512d delta_y_low_f64x8 = _mm512_sub_pd(rotated_a_y_low_f64x8, centered_b_y_low_f64x8);
|
|
333
|
-
__m512d delta_y_high_f64x8 = _mm512_sub_pd(rotated_a_y_high_f64x8, centered_b_y_high_f64x8);
|
|
334
|
-
__m512d delta_z_low_f64x8 = _mm512_sub_pd(rotated_a_z_low_f64x8, centered_b_z_low_f64x8);
|
|
335
|
-
__m512d delta_z_high_f64x8 = _mm512_sub_pd(rotated_a_z_high_f64x8, centered_b_z_high_f64x8);
|
|
336
|
-
|
|
337
|
-
__m512d batch_sum_squared_f64x8 = _mm512_add_pd(_mm512_mul_pd(delta_x_low_f64x8, delta_x_low_f64x8),
|
|
338
|
-
_mm512_mul_pd(delta_x_high_f64x8, delta_x_high_f64x8));
|
|
339
|
-
batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, batch_sum_squared_f64x8);
|
|
340
|
-
batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, batch_sum_squared_f64x8);
|
|
341
|
-
batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, batch_sum_squared_f64x8);
|
|
342
|
-
batch_sum_squared_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, batch_sum_squared_f64x8);
|
|
343
|
-
sum_squared_f64x8 = _mm512_add_pd(sum_squared_f64x8, batch_sum_squared_f64x8);
|
|
344
|
-
}
|
|
345
|
-
|
|
346
|
-
nk_f64_t sum_squared = _mm512_reduce_add_pd(sum_squared_f64x8);
|
|
347
|
-
for (; index < n; ++index) {
|
|
348
|
-
nk_f64_t centered_a_x = (nk_f64_t)a[index * 3 + 0] - centroid_a_x,
|
|
349
|
-
centered_a_y = (nk_f64_t)a[index * 3 + 1] - centroid_a_y,
|
|
350
|
-
centered_a_z = (nk_f64_t)a[index * 3 + 2] - centroid_a_z;
|
|
351
|
-
nk_f64_t centered_b_x = (nk_f64_t)b[index * 3 + 0] - centroid_b_x,
|
|
352
|
-
centered_b_y = (nk_f64_t)b[index * 3 + 1] - centroid_b_y,
|
|
353
|
-
centered_b_z = (nk_f64_t)b[index * 3 + 2] - centroid_b_z;
|
|
354
|
-
nk_f64_t rotated_a_x = scale * (r[0] * centered_a_x + r[1] * centered_a_y + r[2] * centered_a_z),
|
|
355
|
-
rotated_a_y = scale * (r[3] * centered_a_x + r[4] * centered_a_y + r[5] * centered_a_z),
|
|
356
|
-
rotated_a_z = scale * (r[6] * centered_a_x + r[7] * centered_a_y + r[8] * centered_a_z);
|
|
357
|
-
nk_f64_t delta_x = rotated_a_x - centered_b_x, delta_y = rotated_a_y - centered_b_y,
|
|
358
|
-
delta_z = rotated_a_z - centered_b_z;
|
|
359
|
-
sum_squared += delta_x * delta_x + delta_y * delta_y + delta_z * delta_z;
|
|
360
|
-
}
|
|
361
|
-
|
|
362
|
-
return sum_squared;
|
|
363
|
-
}
|
|
364
|
-
|
|
365
|
-
/* Compute sum of squared distances for f64 after applying rotation (and optional scale).
|
|
366
|
-
* Rotation matrix, scale and data are all f64 for full precision.
|
|
367
|
-
*/
|
|
368
|
-
NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_skylake_(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n,
|
|
369
|
-
nk_f64_t const *r, nk_f64_t scale, nk_f64_t centroid_a_x,
|
|
370
|
-
nk_f64_t centroid_a_y, nk_f64_t centroid_a_z,
|
|
371
|
-
nk_f64_t centroid_b_x, nk_f64_t centroid_b_y,
|
|
372
|
-
nk_f64_t centroid_b_z) {
|
|
373
|
-
// Broadcast scaled rotation matrix elements
|
|
374
|
-
__m512d scaled_rotation_x_x_f64x8 = _mm512_set1_pd(scale * r[0]);
|
|
375
|
-
__m512d scaled_rotation_x_y_f64x8 = _mm512_set1_pd(scale * r[1]);
|
|
376
|
-
__m512d scaled_rotation_x_z_f64x8 = _mm512_set1_pd(scale * r[2]);
|
|
377
|
-
__m512d scaled_rotation_y_x_f64x8 = _mm512_set1_pd(scale * r[3]);
|
|
378
|
-
__m512d scaled_rotation_y_y_f64x8 = _mm512_set1_pd(scale * r[4]);
|
|
379
|
-
__m512d scaled_rotation_y_z_f64x8 = _mm512_set1_pd(scale * r[5]);
|
|
380
|
-
__m512d scaled_rotation_z_x_f64x8 = _mm512_set1_pd(scale * r[6]);
|
|
381
|
-
__m512d scaled_rotation_z_y_f64x8 = _mm512_set1_pd(scale * r[7]);
|
|
382
|
-
__m512d scaled_rotation_z_z_f64x8 = _mm512_set1_pd(scale * r[8]);
|
|
383
|
-
|
|
384
|
-
// Broadcast centroids
|
|
385
|
-
__m512d centroid_a_x_f64x8 = _mm512_set1_pd(centroid_a_x);
|
|
386
|
-
__m512d centroid_a_y_f64x8 = _mm512_set1_pd(centroid_a_y);
|
|
387
|
-
__m512d centroid_a_z_f64x8 = _mm512_set1_pd(centroid_a_z);
|
|
388
|
-
__m512d centroid_b_x_f64x8 = _mm512_set1_pd(centroid_b_x);
|
|
389
|
-
__m512d centroid_b_y_f64x8 = _mm512_set1_pd(centroid_b_y);
|
|
390
|
-
__m512d centroid_b_z_f64x8 = _mm512_set1_pd(centroid_b_z);
|
|
391
|
-
|
|
392
|
-
__m512d sum_squared_f64x8 = _mm512_setzero_pd();
|
|
393
|
-
__m512d sum_squared_compensation_f64x8 = _mm512_setzero_pd();
|
|
394
|
-
__m512d a_x_f64x8, a_y_f64x8, a_z_f64x8, b_x_f64x8, b_y_f64x8, b_z_f64x8;
|
|
395
|
-
nk_size_t j = 0;
|
|
396
|
-
|
|
397
|
-
for (; j + 8 <= n; j += 8) {
|
|
398
|
-
nk_deinterleave_f64x8_skylake_(a + j * 3, &a_x_f64x8, &a_y_f64x8, &a_z_f64x8);
|
|
399
|
-
nk_deinterleave_f64x8_skylake_(b + j * 3, &b_x_f64x8, &b_y_f64x8, &b_z_f64x8);
|
|
400
|
-
|
|
401
|
-
// Center points
|
|
402
|
-
__m512d pa_x_f64x8 = _mm512_sub_pd(a_x_f64x8, centroid_a_x_f64x8);
|
|
403
|
-
__m512d pa_y_f64x8 = _mm512_sub_pd(a_y_f64x8, centroid_a_y_f64x8);
|
|
404
|
-
__m512d pa_z_f64x8 = _mm512_sub_pd(a_z_f64x8, centroid_a_z_f64x8);
|
|
405
|
-
__m512d pb_x_f64x8 = _mm512_sub_pd(b_x_f64x8, centroid_b_x_f64x8);
|
|
406
|
-
__m512d pb_y_f64x8 = _mm512_sub_pd(b_y_f64x8, centroid_b_y_f64x8);
|
|
407
|
-
__m512d pb_z_f64x8 = _mm512_sub_pd(b_z_f64x8, centroid_b_z_f64x8);
|
|
408
|
-
|
|
409
|
-
// Rotate and scale: ra = scale * R * pa
|
|
410
|
-
__m512d ra_x_f64x8 = _mm512_fmadd_pd(scaled_rotation_x_z_f64x8, pa_z_f64x8,
|
|
411
|
-
_mm512_fmadd_pd(scaled_rotation_x_y_f64x8, pa_y_f64x8,
|
|
412
|
-
_mm512_mul_pd(scaled_rotation_x_x_f64x8, pa_x_f64x8)));
|
|
413
|
-
__m512d ra_y_f64x8 = _mm512_fmadd_pd(scaled_rotation_y_z_f64x8, pa_z_f64x8,
|
|
414
|
-
_mm512_fmadd_pd(scaled_rotation_y_y_f64x8, pa_y_f64x8,
|
|
415
|
-
_mm512_mul_pd(scaled_rotation_y_x_f64x8, pa_x_f64x8)));
|
|
416
|
-
__m512d ra_z_f64x8 = _mm512_fmadd_pd(scaled_rotation_z_z_f64x8, pa_z_f64x8,
|
|
417
|
-
_mm512_fmadd_pd(scaled_rotation_z_y_f64x8, pa_y_f64x8,
|
|
418
|
-
_mm512_mul_pd(scaled_rotation_z_x_f64x8, pa_x_f64x8)));
|
|
419
|
-
|
|
420
|
-
// Delta and accumulate
|
|
421
|
-
__m512d delta_x_f64x8 = _mm512_sub_pd(ra_x_f64x8, pb_x_f64x8);
|
|
422
|
-
__m512d delta_y_f64x8 = _mm512_sub_pd(ra_y_f64x8, pb_y_f64x8);
|
|
423
|
-
__m512d delta_z_f64x8 = _mm512_sub_pd(ra_z_f64x8, pb_z_f64x8);
|
|
424
|
-
|
|
425
|
-
nk_accumulate_square_f64x8_skylake_(&sum_squared_f64x8, &sum_squared_compensation_f64x8, delta_x_f64x8);
|
|
426
|
-
nk_accumulate_square_f64x8_skylake_(&sum_squared_f64x8, &sum_squared_compensation_f64x8, delta_y_f64x8);
|
|
427
|
-
nk_accumulate_square_f64x8_skylake_(&sum_squared_f64x8, &sum_squared_compensation_f64x8, delta_z_f64x8);
|
|
428
|
-
}
|
|
429
|
-
|
|
430
|
-
nk_f64_t sum_squared = nk_dot_stable_sum_f64x8_skylake_(sum_squared_f64x8, sum_squared_compensation_f64x8);
|
|
431
|
-
nk_f64_t sum_squared_compensation = 0.0;
|
|
432
|
-
|
|
433
|
-
// Scalar tail
|
|
434
|
-
for (; j < n; ++j) {
|
|
435
|
-
nk_f64_t pa_x = a[j * 3 + 0] - centroid_a_x, pa_y = a[j * 3 + 1] - centroid_a_y,
|
|
436
|
-
pa_z = a[j * 3 + 2] - centroid_a_z;
|
|
437
|
-
nk_f64_t pb_x = b[j * 3 + 0] - centroid_b_x, pb_y = b[j * 3 + 1] - centroid_b_y,
|
|
438
|
-
pb_z = b[j * 3 + 2] - centroid_b_z;
|
|
439
|
-
|
|
440
|
-
nk_f64_t ra_x = scale * (r[0] * pa_x + r[1] * pa_y + r[2] * pa_z),
|
|
441
|
-
ra_y = scale * (r[3] * pa_x + r[4] * pa_y + r[5] * pa_z),
|
|
442
|
-
ra_z = scale * (r[6] * pa_x + r[7] * pa_y + r[8] * pa_z);
|
|
443
|
-
|
|
444
|
-
nk_f64_t delta_x = ra_x - pb_x, delta_y = ra_y - pb_y, delta_z = ra_z - pb_z;
|
|
445
|
-
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_x);
|
|
446
|
-
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_y);
|
|
447
|
-
nk_accumulate_square_f64_(&sum_squared, &sum_squared_compensation, delta_z);
|
|
448
|
-
}
|
|
449
|
-
|
|
450
|
-
return sum_squared + sum_squared_compensation;
|
|
451
|
-
}
|
|
452
|
-
|
|
453
|
-
/* Compute sum of squared distances for f16 data after applying rotation (and optional scale).
|
|
454
|
-
* Loads f16, converts to f32 for computation. Rotation matrix, scale, and centroids are f32.
|
|
250
|
+
/* Single-pass streaming statistics over an f32 xyz point-cloud pair.
|
|
251
|
+
* Processes 5 xyz triplets per chunk (15 fp32 lanes, lane 15 masked to zero) so the stride-3
|
|
252
|
+
* phase is identical across all chunks and no deinterleave is needed. All accumulators are f64.
|
|
253
|
+
* Outputs via pointers:
|
|
254
|
+
* sum_a_out[3] / sum_b_out[3] - per-channel Sum(a), Sum(b)
|
|
255
|
+
* raw_covarianceariance_out[9] - row-major uncentered Sum(a_j * b_k)
|
|
256
|
+
* norm_squared_a_out / norm_squared_b_out - Sum(||a||^2), Sum(||b||^2) across all three channels
|
|
257
|
+
*
|
|
258
|
+
* The 9 H-cells come from three product accumulators prod_{diag,rot1,rot2} demuxed post-loop
|
|
259
|
+
* by a-channel. Rotations of b happen in fp64 via permutex2var_pd on the already-widened
|
|
260
|
+
* halves: widening the rotated fp32 vector would add two extra cvtps_pd per chunk, which we
|
|
261
|
+
* skip. Post-loop, each (accumulator, channel) pair is gathered into a single 8-lane vector
|
|
262
|
+
* via one maskz-permutex2var_pd and reduced once — 17 horizontal reductions total (the
|
|
263
|
+
* theoretical minimum for 17 scalar outputs) instead of 32 masked ones.
|
|
455
264
|
*/
|
|
456
|
-
NK_INTERNAL
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
480
|
-
nk_size_t j = 0;
|
|
481
|
-
|
|
482
|
-
for (; j + 16 <= n; j += 16) {
|
|
483
|
-
nk_deinterleave_f16x16_to_f32x16_skylake_(a + j * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
484
|
-
nk_deinterleave_f16x16_to_f32x16_skylake_(b + j * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
485
|
-
|
|
486
|
-
__m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
|
|
487
|
-
__m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
|
|
488
|
-
__m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
|
|
489
|
-
__m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
|
|
490
|
-
__m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
|
|
491
|
-
__m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
|
|
492
|
-
|
|
493
|
-
__m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
|
|
494
|
-
_mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
|
|
495
|
-
_mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
|
|
496
|
-
__m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
|
|
497
|
-
_mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
|
|
498
|
-
_mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
|
|
499
|
-
__m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
|
|
500
|
-
_mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
|
|
501
|
-
_mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
|
|
502
|
-
|
|
503
|
-
__m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
|
|
504
|
-
__m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
|
|
505
|
-
__m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
|
|
506
|
-
|
|
507
|
-
sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
|
|
508
|
-
sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
|
|
509
|
-
sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
|
|
510
|
-
}
|
|
511
|
-
|
|
512
|
-
// Tail: deinterleave remaining points into zero-initialized vectors
|
|
513
|
-
if (j < n) {
|
|
514
|
-
nk_size_t tail = n - j;
|
|
515
|
-
nk_deinterleave_f16_tail_to_f32x16_skylake_(a + j * 3, tail, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
516
|
-
nk_deinterleave_f16_tail_to_f32x16_skylake_(b + j * 3, tail, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
517
|
-
|
|
518
|
-
__m512 pa_x_f32x16 = _mm512_sub_ps(a_x_f32x16, centroid_a_x_f32x16);
|
|
519
|
-
__m512 pa_y_f32x16 = _mm512_sub_ps(a_y_f32x16, centroid_a_y_f32x16);
|
|
520
|
-
__m512 pa_z_f32x16 = _mm512_sub_ps(a_z_f32x16, centroid_a_z_f32x16);
|
|
521
|
-
__m512 pb_x_f32x16 = _mm512_sub_ps(b_x_f32x16, centroid_b_x_f32x16);
|
|
522
|
-
__m512 pb_y_f32x16 = _mm512_sub_ps(b_y_f32x16, centroid_b_y_f32x16);
|
|
523
|
-
__m512 pb_z_f32x16 = _mm512_sub_ps(b_z_f32x16, centroid_b_z_f32x16);
|
|
524
|
-
|
|
525
|
-
__m512 ra_x_f32x16 = _mm512_fmadd_ps(scaled_rotation_x_z_f32x16, pa_z_f32x16,
|
|
526
|
-
_mm512_fmadd_ps(scaled_rotation_x_y_f32x16, pa_y_f32x16,
|
|
527
|
-
_mm512_mul_ps(scaled_rotation_x_x_f32x16, pa_x_f32x16)));
|
|
528
|
-
__m512 ra_y_f32x16 = _mm512_fmadd_ps(scaled_rotation_y_z_f32x16, pa_z_f32x16,
|
|
529
|
-
_mm512_fmadd_ps(scaled_rotation_y_y_f32x16, pa_y_f32x16,
|
|
530
|
-
_mm512_mul_ps(scaled_rotation_y_x_f32x16, pa_x_f32x16)));
|
|
531
|
-
__m512 ra_z_f32x16 = _mm512_fmadd_ps(scaled_rotation_z_z_f32x16, pa_z_f32x16,
|
|
532
|
-
_mm512_fmadd_ps(scaled_rotation_z_y_f32x16, pa_y_f32x16,
|
|
533
|
-
_mm512_mul_ps(scaled_rotation_z_x_f32x16, pa_x_f32x16)));
|
|
534
|
-
|
|
535
|
-
__m512 delta_x_f32x16 = _mm512_sub_ps(ra_x_f32x16, pb_x_f32x16);
|
|
536
|
-
__m512 delta_y_f32x16 = _mm512_sub_ps(ra_y_f32x16, pb_y_f32x16);
|
|
537
|
-
__m512 delta_z_f32x16 = _mm512_sub_ps(ra_z_f32x16, pb_z_f32x16);
|
|
538
|
-
|
|
539
|
-
sum_squared_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_f32x16);
|
|
540
|
-
sum_squared_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_f32x16);
|
|
541
|
-
sum_squared_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_f32x16);
|
|
542
|
-
}
|
|
265
|
+
NK_INTERNAL void nk_mesh_streaming_stats_f32_skylake_( //
|
|
266
|
+
nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f64_t *sum_a_out, nk_f64_t *sum_b_out,
|
|
267
|
+
nk_f64_t *raw_covarianceariance_out, nk_f64_t *norm_squared_a_out, nk_f64_t *norm_squared_b_out) {
|
|
268
|
+
|
|
269
|
+
// Within-triplet rotation indices for fp64 permutex2var across (b_low, b_high) as the two
|
|
270
|
+
// sources. Indices 0..7 pull from b_low, 8..15 pull from b_high. Derived from the fp32
|
|
271
|
+
// rotation pattern {1,2,0,4,5,3,7,8,6,10,11,9,13,14,12,15} (rot1) and {2,0,1,5,3,4,8,6,
|
|
272
|
+
// 7,11,9,10,14,12,13,15} (rot2) split at fp32 lane 8.
|
|
273
|
+
__m512i const idx_rotation_1_low_i64x8 = _mm512_setr_epi64(1, 2, 0, 4, 5, 3, 7, 8);
|
|
274
|
+
__m512i const idx_rotation_1_high_i64x8 = _mm512_setr_epi64(6, 10, 11, 9, 13, 14, 12, 15);
|
|
275
|
+
__m512i const idx_rotation_2_low_i64x8 = _mm512_setr_epi64(2, 0, 1, 5, 3, 4, 8, 6);
|
|
276
|
+
__m512i const idx_rotation_2_high_i64x8 = _mm512_setr_epi64(7, 11, 9, 10, 14, 12, 13, 15);
|
|
277
|
+
|
|
278
|
+
// Per-channel gather indices packing the 5 contributing fp64 lanes (across both halves)
|
|
279
|
+
// into lanes 0..4 of the output, with lanes 5..7 zeroed by maskz so the subsequent
|
|
280
|
+
// _mm512_reduce_add_pd is exact without needing a mask-reduce variant.
|
|
281
|
+
// x -> low {0,3,6} + high {1,4} = idx [0, 3, 6, 9, 12, _, _, _]
|
|
282
|
+
// y -> low {1,4,7} + high {2,5} = idx [1, 4, 7, 10, 13, _, _, _]
|
|
283
|
+
// z -> low {2,5} + high {0,3,6} = idx [2, 5, 8, 11, 14, _, _, _]
|
|
284
|
+
__m512i const idx_channel_x_i64x8 = _mm512_setr_epi64(0, 3, 6, 9, 12, 0, 0, 0);
|
|
285
|
+
__m512i const idx_channel_y_i64x8 = _mm512_setr_epi64(1, 4, 7, 10, 13, 0, 0, 0);
|
|
286
|
+
__m512i const idx_channel_z_i64x8 = _mm512_setr_epi64(2, 5, 8, 11, 14, 0, 0, 0);
|
|
287
|
+
__mmask8 const channel_lanes_mask = 0x1F;
|
|
543
288
|
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
nk_f32_t centroid_a_y, nk_f32_t centroid_a_z,
|
|
553
|
-
nk_f32_t centroid_b_x, nk_f32_t centroid_b_y,
|
|
554
|
-
nk_f32_t centroid_b_z) {
|
|
555
|
-
__m512 scaled_rotation_x_x_f32x16 = _mm512_set1_ps(scale * r[0]);
|
|
556
|
-
__m512 scaled_rotation_x_y_f32x16 = _mm512_set1_ps(scale * r[1]);
|
|
557
|
-
__m512 scaled_rotation_x_z_f32x16 = _mm512_set1_ps(scale * r[2]);
|
|
558
|
-
__m512 scaled_rotation_y_x_f32x16 = _mm512_set1_ps(scale * r[3]);
|
|
559
|
-
__m512 scaled_rotation_y_y_f32x16 = _mm512_set1_ps(scale * r[4]);
|
|
560
|
-
__m512 scaled_rotation_y_z_f32x16 = _mm512_set1_ps(scale * r[5]);
|
|
561
|
-
__m512 scaled_rotation_z_x_f32x16 = _mm512_set1_ps(scale * r[6]);
|
|
562
|
-
__m512 scaled_rotation_z_y_f32x16 = _mm512_set1_ps(scale * r[7]);
|
|
563
|
-
__m512 scaled_rotation_z_z_f32x16 = _mm512_set1_ps(scale * r[8]);
|
|
564
|
-
|
|
565
|
-
__m512 centroid_a_x_f32x16 = _mm512_set1_ps(centroid_a_x);
|
|
566
|
-
__m512 centroid_a_y_f32x16 = _mm512_set1_ps(centroid_a_y);
|
|
567
|
-
__m512 centroid_a_z_f32x16 = _mm512_set1_ps(centroid_a_z);
|
|
568
|
-
__m512 centroid_b_x_f32x16 = _mm512_set1_ps(centroid_b_x);
|
|
569
|
-
__m512 centroid_b_y_f32x16 = _mm512_set1_ps(centroid_b_y);
|
|
570
|
-
__m512 centroid_b_z_f32x16 = _mm512_set1_ps(centroid_b_z);
|
|
289
|
+
__m512d const zeros_f64x8 = _mm512_setzero_pd();
|
|
290
|
+
__m512d sum_a_low_f64x8 = zeros_f64x8, sum_a_high_f64x8 = zeros_f64x8;
|
|
291
|
+
__m512d sum_b_low_f64x8 = zeros_f64x8, sum_b_high_f64x8 = zeros_f64x8;
|
|
292
|
+
__m512d norm_squared_a_low_f64x8 = zeros_f64x8, norm_squared_a_high_f64x8 = zeros_f64x8;
|
|
293
|
+
__m512d norm_squared_b_low_f64x8 = zeros_f64x8, norm_squared_b_high_f64x8 = zeros_f64x8;
|
|
294
|
+
__m512d product_diagonal_low_f64x8 = zeros_f64x8, product_diagonal_high_f64x8 = zeros_f64x8;
|
|
295
|
+
__m512d product_rotation_1_low_f64x8 = zeros_f64x8, product_rotation_1_high_f64x8 = zeros_f64x8;
|
|
296
|
+
__m512d product_rotation_2_low_f64x8 = zeros_f64x8, product_rotation_2_high_f64x8 = zeros_f64x8;
|
|
571
297
|
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
298
|
+
nk_size_t index = 0;
|
|
299
|
+
// Main loop: 5 points (15 fp32) per chunk, lane 15 zeroed by mask 0x7FFF.
|
|
300
|
+
for (; index + 5 <= n; index += 5) {
|
|
301
|
+
__m512 a_f32x16 = _mm512_maskz_loadu_ps(0x7FFF, a + index * 3);
|
|
302
|
+
__m512 b_f32x16 = _mm512_maskz_loadu_ps(0x7FFF, b + index * 3);
|
|
303
|
+
|
|
304
|
+
__m512d a_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_f32x16));
|
|
305
|
+
__m512d a_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_f32x16, 1));
|
|
306
|
+
__m512d b_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_f32x16));
|
|
307
|
+
__m512d b_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_f32x16, 1));
|
|
308
|
+
|
|
309
|
+
__m512d b_rot1_low_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_1_low_i64x8, b_high_f64x8);
|
|
310
|
+
__m512d b_rot1_high_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_1_high_i64x8, b_high_f64x8);
|
|
311
|
+
__m512d b_rot2_low_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_2_low_i64x8, b_high_f64x8);
|
|
312
|
+
__m512d b_rot2_high_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_2_high_i64x8, b_high_f64x8);
|
|
313
|
+
|
|
314
|
+
sum_a_low_f64x8 = _mm512_add_pd(sum_a_low_f64x8, a_low_f64x8);
|
|
315
|
+
sum_a_high_f64x8 = _mm512_add_pd(sum_a_high_f64x8, a_high_f64x8);
|
|
316
|
+
sum_b_low_f64x8 = _mm512_add_pd(sum_b_low_f64x8, b_low_f64x8);
|
|
317
|
+
sum_b_high_f64x8 = _mm512_add_pd(sum_b_high_f64x8, b_high_f64x8);
|
|
318
|
+
|
|
319
|
+
norm_squared_a_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, a_low_f64x8, norm_squared_a_low_f64x8);
|
|
320
|
+
norm_squared_a_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, a_high_f64x8, norm_squared_a_high_f64x8);
|
|
321
|
+
norm_squared_b_low_f64x8 = _mm512_fmadd_pd(b_low_f64x8, b_low_f64x8, norm_squared_b_low_f64x8);
|
|
322
|
+
norm_squared_b_high_f64x8 = _mm512_fmadd_pd(b_high_f64x8, b_high_f64x8, norm_squared_b_high_f64x8);
|
|
323
|
+
|
|
324
|
+
product_diagonal_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, b_low_f64x8, product_diagonal_low_f64x8);
|
|
325
|
+
product_diagonal_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, b_high_f64x8, product_diagonal_high_f64x8);
|
|
326
|
+
product_rotation_1_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, b_rot1_low_f64x8, product_rotation_1_low_f64x8);
|
|
327
|
+
product_rotation_1_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, b_rot1_high_f64x8, product_rotation_1_high_f64x8);
|
|
328
|
+
product_rotation_2_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, b_rot2_low_f64x8, product_rotation_2_low_f64x8);
|
|
329
|
+
product_rotation_2_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, b_rot2_high_f64x8, product_rotation_2_high_f64x8);
|
|
604
330
|
}
|
|
605
331
|
|
|
606
|
-
// Tail:
|
|
607
|
-
if (
|
|
608
|
-
nk_size_t
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
332
|
+
// Tail: 1..4 points (3..12 fp32) via narrower mask; identical body.
|
|
333
|
+
if (index < n) {
|
|
334
|
+
nk_size_t tail_floats = (n - index) * 3;
|
|
335
|
+
__mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, tail_floats);
|
|
336
|
+
__m512 a_f32x16 = _mm512_maskz_loadu_ps(tail_mask, a + index * 3);
|
|
337
|
+
__m512 b_f32x16 = _mm512_maskz_loadu_ps(tail_mask, b + index * 3);
|
|
338
|
+
|
|
339
|
+
__m512d a_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_f32x16));
|
|
340
|
+
__m512d a_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_f32x16, 1));
|
|
341
|
+
__m512d b_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_f32x16));
|
|
342
|
+
__m512d b_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_f32x16, 1));
|
|
343
|
+
|
|
344
|
+
__m512d b_rot1_low_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_1_low_i64x8, b_high_f64x8);
|
|
345
|
+
__m512d b_rot1_high_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_1_high_i64x8, b_high_f64x8);
|
|
346
|
+
__m512d b_rot2_low_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_2_low_i64x8, b_high_f64x8);
|
|
347
|
+
__m512d b_rot2_high_f64x8 = _mm512_permutex2var_pd(b_low_f64x8, idx_rotation_2_high_i64x8, b_high_f64x8);
|
|
348
|
+
|
|
349
|
+
sum_a_low_f64x8 = _mm512_add_pd(sum_a_low_f64x8, a_low_f64x8);
|
|
350
|
+
sum_a_high_f64x8 = _mm512_add_pd(sum_a_high_f64x8, a_high_f64x8);
|
|
351
|
+
sum_b_low_f64x8 = _mm512_add_pd(sum_b_low_f64x8, b_low_f64x8);
|
|
352
|
+
sum_b_high_f64x8 = _mm512_add_pd(sum_b_high_f64x8, b_high_f64x8);
|
|
353
|
+
|
|
354
|
+
norm_squared_a_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, a_low_f64x8, norm_squared_a_low_f64x8);
|
|
355
|
+
norm_squared_a_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, a_high_f64x8, norm_squared_a_high_f64x8);
|
|
356
|
+
norm_squared_b_low_f64x8 = _mm512_fmadd_pd(b_low_f64x8, b_low_f64x8, norm_squared_b_low_f64x8);
|
|
357
|
+
norm_squared_b_high_f64x8 = _mm512_fmadd_pd(b_high_f64x8, b_high_f64x8, norm_squared_b_high_f64x8);
|
|
358
|
+
|
|
359
|
+
product_diagonal_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, b_low_f64x8, product_diagonal_low_f64x8);
|
|
360
|
+
product_diagonal_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, b_high_f64x8, product_diagonal_high_f64x8);
|
|
361
|
+
product_rotation_1_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, b_rot1_low_f64x8, product_rotation_1_low_f64x8);
|
|
362
|
+
product_rotation_1_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, b_rot1_high_f64x8, product_rotation_1_high_f64x8);
|
|
363
|
+
product_rotation_2_low_f64x8 = _mm512_fmadd_pd(a_low_f64x8, b_rot2_low_f64x8, product_rotation_2_low_f64x8);
|
|
364
|
+
product_rotation_2_high_f64x8 = _mm512_fmadd_pd(a_high_f64x8, b_rot2_high_f64x8, product_rotation_2_high_f64x8);
|
|
636
365
|
}
|
|
637
366
|
|
|
638
|
-
|
|
367
|
+
// Post-loop: gather each (accumulator, a-channel) pair into a single 8-lane vector via one
|
|
368
|
+
// maskz-permutex2var_pd across (low, high) halves, then one _mm512_reduce_add_pd per scalar
|
|
369
|
+
// output. 17 reductions total (6 sums + 9 H cells + 2 norms) = the scalar-output floor.
|
|
370
|
+
|
|
371
|
+
__m512d sum_a_x_f64x8 = _mm512_maskz_permutex2var_pd(channel_lanes_mask, sum_a_low_f64x8, idx_channel_x_i64x8,
|
|
372
|
+
sum_a_high_f64x8);
|
|
373
|
+
__m512d sum_a_y_f64x8 = _mm512_maskz_permutex2var_pd(channel_lanes_mask, sum_a_low_f64x8, idx_channel_y_i64x8,
|
|
374
|
+
sum_a_high_f64x8);
|
|
375
|
+
__m512d sum_a_z_f64x8 = _mm512_maskz_permutex2var_pd(channel_lanes_mask, sum_a_low_f64x8, idx_channel_z_i64x8,
|
|
376
|
+
sum_a_high_f64x8);
|
|
377
|
+
sum_a_out[0] = _mm512_reduce_add_pd(sum_a_x_f64x8);
|
|
378
|
+
sum_a_out[1] = _mm512_reduce_add_pd(sum_a_y_f64x8);
|
|
379
|
+
sum_a_out[2] = _mm512_reduce_add_pd(sum_a_z_f64x8);
|
|
380
|
+
|
|
381
|
+
__m512d sum_b_x_f64x8 = _mm512_maskz_permutex2var_pd(channel_lanes_mask, sum_b_low_f64x8, idx_channel_x_i64x8,
|
|
382
|
+
sum_b_high_f64x8);
|
|
383
|
+
__m512d sum_b_y_f64x8 = _mm512_maskz_permutex2var_pd(channel_lanes_mask, sum_b_low_f64x8, idx_channel_y_i64x8,
|
|
384
|
+
sum_b_high_f64x8);
|
|
385
|
+
__m512d sum_b_z_f64x8 = _mm512_maskz_permutex2var_pd(channel_lanes_mask, sum_b_low_f64x8, idx_channel_z_i64x8,
|
|
386
|
+
sum_b_high_f64x8);
|
|
387
|
+
sum_b_out[0] = _mm512_reduce_add_pd(sum_b_x_f64x8);
|
|
388
|
+
sum_b_out[1] = _mm512_reduce_add_pd(sum_b_y_f64x8);
|
|
389
|
+
sum_b_out[2] = _mm512_reduce_add_pd(sum_b_z_f64x8);
|
|
390
|
+
|
|
391
|
+
// H cells: a-channel picks which demux mask applies; prod-vector picks which b-channel the
|
|
392
|
+
// product pairs a with (diag -> same, rot1 -> +1, rot2 -> +2 mod 3).
|
|
393
|
+
__m512d product_diagonal_x_f64x8 = _mm512_maskz_permutex2var_pd( //
|
|
394
|
+
channel_lanes_mask, product_diagonal_low_f64x8, idx_channel_x_i64x8, product_diagonal_high_f64x8);
|
|
395
|
+
__m512d product_diagonal_y_f64x8 = _mm512_maskz_permutex2var_pd( //
|
|
396
|
+
channel_lanes_mask, product_diagonal_low_f64x8, idx_channel_y_i64x8, product_diagonal_high_f64x8);
|
|
397
|
+
__m512d product_diagonal_z_f64x8 = _mm512_maskz_permutex2var_pd( //
|
|
398
|
+
channel_lanes_mask, product_diagonal_low_f64x8, idx_channel_z_i64x8, product_diagonal_high_f64x8);
|
|
399
|
+
__m512d product_rotation_1_x_f64x8 = _mm512_maskz_permutex2var_pd( //
|
|
400
|
+
channel_lanes_mask, product_rotation_1_low_f64x8, idx_channel_x_i64x8, product_rotation_1_high_f64x8);
|
|
401
|
+
__m512d product_rotation_1_y_f64x8 = _mm512_maskz_permutex2var_pd( //
|
|
402
|
+
channel_lanes_mask, product_rotation_1_low_f64x8, idx_channel_y_i64x8, product_rotation_1_high_f64x8);
|
|
403
|
+
__m512d product_rotation_1_z_f64x8 = _mm512_maskz_permutex2var_pd( //
|
|
404
|
+
channel_lanes_mask, product_rotation_1_low_f64x8, idx_channel_z_i64x8, product_rotation_1_high_f64x8);
|
|
405
|
+
__m512d product_rotation_2_x_f64x8 = _mm512_maskz_permutex2var_pd( //
|
|
406
|
+
channel_lanes_mask, product_rotation_2_low_f64x8, idx_channel_x_i64x8, product_rotation_2_high_f64x8);
|
|
407
|
+
__m512d product_rotation_2_y_f64x8 = _mm512_maskz_permutex2var_pd( //
|
|
408
|
+
channel_lanes_mask, product_rotation_2_low_f64x8, idx_channel_y_i64x8, product_rotation_2_high_f64x8);
|
|
409
|
+
__m512d product_rotation_2_z_f64x8 = _mm512_maskz_permutex2var_pd( //
|
|
410
|
+
channel_lanes_mask, product_rotation_2_low_f64x8, idx_channel_z_i64x8, product_rotation_2_high_f64x8);
|
|
411
|
+
|
|
412
|
+
raw_covarianceariance_out[0] = _mm512_reduce_add_pd(product_diagonal_x_f64x8); // H[x,x]
|
|
413
|
+
raw_covarianceariance_out[1] = _mm512_reduce_add_pd(product_rotation_1_x_f64x8); // H[x,y]
|
|
414
|
+
raw_covarianceariance_out[2] = _mm512_reduce_add_pd(product_rotation_2_x_f64x8); // H[x,z]
|
|
415
|
+
raw_covarianceariance_out[3] = _mm512_reduce_add_pd(product_rotation_2_y_f64x8); // H[y,x]
|
|
416
|
+
raw_covarianceariance_out[4] = _mm512_reduce_add_pd(product_diagonal_y_f64x8); // H[y,y]
|
|
417
|
+
raw_covarianceariance_out[5] = _mm512_reduce_add_pd(product_rotation_1_y_f64x8); // H[y,z]
|
|
418
|
+
raw_covarianceariance_out[6] = _mm512_reduce_add_pd(product_rotation_1_z_f64x8); // H[z,x]
|
|
419
|
+
raw_covarianceariance_out[7] = _mm512_reduce_add_pd(product_rotation_2_z_f64x8); // H[z,y]
|
|
420
|
+
raw_covarianceariance_out[8] = _mm512_reduce_add_pd(product_diagonal_z_f64x8); // H[z,z]
|
|
421
|
+
|
|
422
|
+
// Norms collapse all three channels, no demux.
|
|
423
|
+
*norm_squared_a_out = _mm512_reduce_add_pd(_mm512_add_pd(norm_squared_a_low_f64x8, norm_squared_a_high_f64x8));
|
|
424
|
+
*norm_squared_b_out = _mm512_reduce_add_pd(_mm512_add_pd(norm_squared_b_low_f64x8, norm_squared_b_high_f64x8));
|
|
639
425
|
}
|
|
640
426
|
|
|
641
427
|
NK_PUBLIC void nk_rmsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
@@ -647,277 +433,55 @@ NK_PUBLIC void nk_rmsd_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size
|
|
|
647
433
|
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
648
434
|
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
649
435
|
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
// Main loop with 2x unrolling (32 points per iteration)
|
|
656
|
-
for (; i + 32 <= n; i += 32) {
|
|
657
|
-
// Iteration 0: points i..i+15
|
|
658
|
-
nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
659
|
-
nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
660
|
-
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
661
|
-
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
662
|
-
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
663
|
-
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
664
|
-
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
665
|
-
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
666
|
-
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
667
|
-
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
668
|
-
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
669
|
-
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
670
|
-
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
671
|
-
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
672
|
-
|
|
673
|
-
__m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
|
|
674
|
-
__m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
|
|
675
|
-
__m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
|
|
676
|
-
__m512d delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
|
|
677
|
-
__m512d delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
|
|
678
|
-
__m512d delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
|
|
679
|
-
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
|
|
680
|
-
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
|
|
681
|
-
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
|
|
682
|
-
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
|
|
683
|
-
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
|
|
684
|
-
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
|
|
685
|
-
|
|
686
|
-
// Iteration 1: points i+16..i+31
|
|
687
|
-
nk_deinterleave_f32x16_skylake_(a + (i + 16) * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
688
|
-
nk_deinterleave_f32x16_skylake_(b + (i + 16) * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
689
|
-
a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
690
|
-
a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
691
|
-
a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
692
|
-
a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
693
|
-
a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
694
|
-
a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
695
|
-
b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
696
|
-
b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
697
|
-
b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
698
|
-
b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
699
|
-
b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
700
|
-
b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
701
|
-
|
|
702
|
-
delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
|
|
703
|
-
delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
|
|
704
|
-
delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
|
|
705
|
-
delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
|
|
706
|
-
delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
|
|
707
|
-
delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
|
|
708
|
-
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
|
|
709
|
-
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
|
|
710
|
-
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
|
|
711
|
-
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
|
|
712
|
-
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
|
|
713
|
-
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
|
|
714
|
-
}
|
|
436
|
+
// 15-lane stride-3 chunks: 5 points (15 fp32) per iteration, lane 15 zeroed by mask.
|
|
437
|
+
// Identity rotation + zero centroid (per commit 1a83ab4f) make this a single (a-b)^2 sum.
|
|
438
|
+
__m512d sum_squared_low_f64x8 = _mm512_setzero_pd();
|
|
439
|
+
__m512d sum_squared_high_f64x8 = _mm512_setzero_pd();
|
|
440
|
+
nk_size_t index = 0;
|
|
715
441
|
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
__m512d
|
|
721
|
-
__m512d
|
|
722
|
-
__m512d
|
|
723
|
-
__m512d
|
|
724
|
-
__m512d
|
|
725
|
-
__m512d
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
729
|
-
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
730
|
-
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
731
|
-
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
732
|
-
|
|
733
|
-
__m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
|
|
734
|
-
__m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
|
|
735
|
-
__m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
|
|
736
|
-
__m512d delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
|
|
737
|
-
__m512d delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
|
|
738
|
-
__m512d delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
|
|
739
|
-
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
|
|
740
|
-
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
|
|
741
|
-
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
|
|
742
|
-
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
|
|
743
|
-
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
|
|
744
|
-
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
|
|
442
|
+
for (; index + 5 <= n; index += 5) {
|
|
443
|
+
__m512 a_f32x16 = _mm512_maskz_loadu_ps(0x7FFF, a + index * 3);
|
|
444
|
+
__m512 b_f32x16 = _mm512_maskz_loadu_ps(0x7FFF, b + index * 3);
|
|
445
|
+
// Widen before subtracting: fp32 subtraction catastrophically cancels when a ~ b.
|
|
446
|
+
__m512d a_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_f32x16));
|
|
447
|
+
__m512d a_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_f32x16, 1));
|
|
448
|
+
__m512d b_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_f32x16));
|
|
449
|
+
__m512d b_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_f32x16, 1));
|
|
450
|
+
__m512d delta_low_f64x8 = _mm512_sub_pd(a_low_f64x8, b_low_f64x8);
|
|
451
|
+
__m512d delta_high_f64x8 = _mm512_sub_pd(a_high_f64x8, b_high_f64x8);
|
|
452
|
+
sum_squared_low_f64x8 = _mm512_fmadd_pd(delta_low_f64x8, delta_low_f64x8, sum_squared_low_f64x8);
|
|
453
|
+
sum_squared_high_f64x8 = _mm512_fmadd_pd(delta_high_f64x8, delta_high_f64x8, sum_squared_high_f64x8);
|
|
745
454
|
}
|
|
746
455
|
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
b_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 1, 4);
|
|
761
|
-
b_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 2, 4);
|
|
762
|
-
|
|
763
|
-
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
764
|
-
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
765
|
-
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
766
|
-
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
767
|
-
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
768
|
-
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
769
|
-
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
770
|
-
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
771
|
-
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
772
|
-
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
773
|
-
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
774
|
-
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
775
|
-
|
|
776
|
-
__m512d delta_x_low_f64x8 = _mm512_sub_pd(a_x_low_f64x8, b_x_low_f64x8);
|
|
777
|
-
__m512d delta_x_high_f64x8 = _mm512_sub_pd(a_x_high_f64x8, b_x_high_f64x8);
|
|
778
|
-
__m512d delta_y_low_f64x8 = _mm512_sub_pd(a_y_low_f64x8, b_y_low_f64x8);
|
|
779
|
-
__m512d delta_y_high_f64x8 = _mm512_sub_pd(a_y_high_f64x8, b_y_high_f64x8);
|
|
780
|
-
__m512d delta_z_low_f64x8 = _mm512_sub_pd(a_z_low_f64x8, b_z_low_f64x8);
|
|
781
|
-
__m512d delta_z_high_f64x8 = _mm512_sub_pd(a_z_high_f64x8, b_z_high_f64x8);
|
|
782
|
-
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_low_f64x8, delta_x_low_f64x8, sum_squared_x_f64x8);
|
|
783
|
-
sum_squared_x_f64x8 = _mm512_fmadd_pd(delta_x_high_f64x8, delta_x_high_f64x8, sum_squared_x_f64x8);
|
|
784
|
-
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_low_f64x8, delta_y_low_f64x8, sum_squared_y_f64x8);
|
|
785
|
-
sum_squared_y_f64x8 = _mm512_fmadd_pd(delta_y_high_f64x8, delta_y_high_f64x8, sum_squared_y_f64x8);
|
|
786
|
-
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_low_f64x8, delta_z_low_f64x8, sum_squared_z_f64x8);
|
|
787
|
-
sum_squared_z_f64x8 = _mm512_fmadd_pd(delta_z_high_f64x8, delta_z_high_f64x8, sum_squared_z_f64x8);
|
|
456
|
+
if (index < n) {
|
|
457
|
+
nk_size_t tail_floats = (n - index) * 3;
|
|
458
|
+
__mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, tail_floats);
|
|
459
|
+
__m512 a_f32x16 = _mm512_maskz_loadu_ps(tail_mask, a + index * 3);
|
|
460
|
+
__m512 b_f32x16 = _mm512_maskz_loadu_ps(tail_mask, b + index * 3);
|
|
461
|
+
__m512d a_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_f32x16));
|
|
462
|
+
__m512d a_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_f32x16, 1));
|
|
463
|
+
__m512d b_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_f32x16));
|
|
464
|
+
__m512d b_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_f32x16, 1));
|
|
465
|
+
__m512d delta_low_f64x8 = _mm512_sub_pd(a_low_f64x8, b_low_f64x8);
|
|
466
|
+
__m512d delta_high_f64x8 = _mm512_sub_pd(a_high_f64x8, b_high_f64x8);
|
|
467
|
+
sum_squared_low_f64x8 = _mm512_fmadd_pd(delta_low_f64x8, delta_low_f64x8, sum_squared_low_f64x8);
|
|
468
|
+
sum_squared_high_f64x8 = _mm512_fmadd_pd(delta_high_f64x8, delta_high_f64x8, sum_squared_high_f64x8);
|
|
788
469
|
}
|
|
789
470
|
|
|
790
|
-
nk_f64_t
|
|
791
|
-
|
|
792
|
-
nk_f64_t total_sq_z = _mm512_reduce_add_pd(sum_squared_z_f64x8);
|
|
793
|
-
*result = nk_f64_sqrt_haswell((total_sq_x + total_sq_y + total_sq_z) / (nk_f64_t)n);
|
|
471
|
+
nk_f64_t sum_squared = _mm512_reduce_add_pd(_mm512_add_pd(sum_squared_low_f64x8, sum_squared_high_f64x8));
|
|
472
|
+
*result = nk_f64_sqrt_haswell(sum_squared / (nk_f64_t)n);
|
|
794
473
|
}
|
|
795
474
|
|
|
796
475
|
NK_PUBLIC void nk_kabsch_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
797
476
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
|
|
798
|
-
//
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
__m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
|
|
802
|
-
__m512d cov_xx_f64x8 = zeros_f64x8, cov_xy_f64x8 = zeros_f64x8, cov_xz_f64x8 = zeros_f64x8;
|
|
803
|
-
__m512d cov_yx_f64x8 = zeros_f64x8, cov_yy_f64x8 = zeros_f64x8, cov_yz_f64x8 = zeros_f64x8;
|
|
804
|
-
__m512d cov_zx_f64x8 = zeros_f64x8, cov_zy_f64x8 = zeros_f64x8, cov_zz_f64x8 = zeros_f64x8;
|
|
805
|
-
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
806
|
-
nk_size_t i = 0;
|
|
807
|
-
|
|
808
|
-
for (; i + 16 <= n; i += 16) {
|
|
809
|
-
nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
810
|
-
nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
811
|
-
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
812
|
-
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
813
|
-
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
814
|
-
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
815
|
-
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
816
|
-
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
817
|
-
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
818
|
-
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
819
|
-
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
820
|
-
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
821
|
-
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
822
|
-
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
823
|
-
|
|
824
|
-
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
825
|
-
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
826
|
-
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
827
|
-
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
828
|
-
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
829
|
-
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
830
|
-
|
|
831
|
-
cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
|
|
832
|
-
_mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
|
|
833
|
-
cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
|
|
834
|
-
_mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
|
|
835
|
-
cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
|
|
836
|
-
_mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
|
|
837
|
-
cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
|
|
838
|
-
_mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
|
|
839
|
-
cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
|
|
840
|
-
_mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
|
|
841
|
-
cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
|
|
842
|
-
_mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
|
|
843
|
-
cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
|
|
844
|
-
_mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
|
|
845
|
-
cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
|
|
846
|
-
_mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
|
|
847
|
-
cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
|
|
848
|
-
_mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
|
|
849
|
-
}
|
|
850
|
-
|
|
851
|
-
// Tail: use masked gather for remaining < 16 points
|
|
852
|
-
if (i < n) {
|
|
853
|
-
nk_size_t tail = n - i;
|
|
854
|
-
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, tail);
|
|
855
|
-
__m512i const gather_idx_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45);
|
|
856
|
-
__m512 zeros_f32x16 = _mm512_setzero_ps();
|
|
857
|
-
nk_f32_t const *a_tail = a + i * 3;
|
|
858
|
-
nk_f32_t const *b_tail = b + i * 3;
|
|
859
|
-
|
|
860
|
-
a_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 0, 4);
|
|
861
|
-
a_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 1, 4);
|
|
862
|
-
a_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 2, 4);
|
|
863
|
-
b_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 0, 4);
|
|
864
|
-
b_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 1, 4);
|
|
865
|
-
b_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 2, 4);
|
|
866
|
-
|
|
867
|
-
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
868
|
-
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
869
|
-
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
870
|
-
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
871
|
-
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
872
|
-
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
873
|
-
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
874
|
-
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
875
|
-
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
876
|
-
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
877
|
-
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
878
|
-
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
879
|
-
|
|
880
|
-
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
881
|
-
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
882
|
-
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
883
|
-
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
884
|
-
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
885
|
-
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
886
|
-
|
|
887
|
-
cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
|
|
888
|
-
_mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
|
|
889
|
-
cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
|
|
890
|
-
_mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
|
|
891
|
-
cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
|
|
892
|
-
_mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
|
|
893
|
-
cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
|
|
894
|
-
_mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
|
|
895
|
-
cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
|
|
896
|
-
_mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
|
|
897
|
-
cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
|
|
898
|
-
_mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
|
|
899
|
-
cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
|
|
900
|
-
_mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
|
|
901
|
-
cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
|
|
902
|
-
_mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
|
|
903
|
-
cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
|
|
904
|
-
_mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
|
|
905
|
-
}
|
|
477
|
+
// Single pass over (a, b) via streaming-stats helper — no deinterleave, no second SSD pass.
|
|
478
|
+
nk_f64_t sum_a[3], sum_b[3], raw_covariance[9], norm_squared_a, norm_squared_b;
|
|
479
|
+
nk_mesh_streaming_stats_f32_skylake_(a, b, n, sum_a, sum_b, raw_covariance, &norm_squared_a, &norm_squared_b);
|
|
906
480
|
|
|
907
|
-
nk_f64_t
|
|
908
|
-
|
|
909
|
-
nk_f64_t
|
|
910
|
-
|
|
911
|
-
nk_f64_t covariance_x_x = _mm512_reduce_add_pd(cov_xx_f64x8), covariance_x_y = _mm512_reduce_add_pd(cov_xy_f64x8),
|
|
912
|
-
covariance_x_z = _mm512_reduce_add_pd(cov_xz_f64x8);
|
|
913
|
-
nk_f64_t covariance_y_x = _mm512_reduce_add_pd(cov_yx_f64x8), covariance_y_y = _mm512_reduce_add_pd(cov_yy_f64x8),
|
|
914
|
-
covariance_y_z = _mm512_reduce_add_pd(cov_yz_f64x8);
|
|
915
|
-
nk_f64_t covariance_z_x = _mm512_reduce_add_pd(cov_zx_f64x8), covariance_z_y = _mm512_reduce_add_pd(cov_zy_f64x8),
|
|
916
|
-
covariance_z_z = _mm512_reduce_add_pd(cov_zz_f64x8);
|
|
917
|
-
|
|
918
|
-
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
919
|
-
nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
920
|
-
nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
481
|
+
nk_f64_t n_f64 = (nk_f64_t)n;
|
|
482
|
+
nk_f64_t inv_n = 1.0 / n_f64;
|
|
483
|
+
nk_f64_t centroid_a_x = sum_a[0] * inv_n, centroid_a_y = sum_a[1] * inv_n, centroid_a_z = sum_a[2] * inv_n;
|
|
484
|
+
nk_f64_t centroid_b_x = sum_b[0] * inv_n, centroid_b_y = sum_b[1] * inv_n, centroid_b_z = sum_b[2] * inv_n;
|
|
921
485
|
if (a_centroid)
|
|
922
486
|
a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
|
|
923
487
|
a_centroid[2] = (nk_f32_t)centroid_a_z;
|
|
@@ -926,32 +490,69 @@ NK_PUBLIC void nk_kabsch_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_si
|
|
|
926
490
|
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
927
491
|
if (scale) *scale = 1.0f;
|
|
928
492
|
|
|
929
|
-
|
|
493
|
+
// Parallel-axis correction: H_centered[j,k] = Sum(a_j * b_k) - n * centroid_a[j] * centroid_b[k].
|
|
930
494
|
nk_f64_t cross_covariance[9];
|
|
931
|
-
cross_covariance[0] =
|
|
932
|
-
cross_covariance[1] =
|
|
933
|
-
cross_covariance[2] =
|
|
934
|
-
cross_covariance[3] =
|
|
935
|
-
cross_covariance[4] =
|
|
936
|
-
cross_covariance[5] =
|
|
937
|
-
cross_covariance[6] =
|
|
938
|
-
cross_covariance[7] =
|
|
939
|
-
cross_covariance[8] =
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
495
|
+
cross_covariance[0] = raw_covariance[0] - n_f64 * centroid_a_x * centroid_b_x;
|
|
496
|
+
cross_covariance[1] = raw_covariance[1] - n_f64 * centroid_a_x * centroid_b_y;
|
|
497
|
+
cross_covariance[2] = raw_covariance[2] - n_f64 * centroid_a_x * centroid_b_z;
|
|
498
|
+
cross_covariance[3] = raw_covariance[3] - n_f64 * centroid_a_y * centroid_b_x;
|
|
499
|
+
cross_covariance[4] = raw_covariance[4] - n_f64 * centroid_a_y * centroid_b_y;
|
|
500
|
+
cross_covariance[5] = raw_covariance[5] - n_f64 * centroid_a_y * centroid_b_z;
|
|
501
|
+
cross_covariance[6] = raw_covariance[6] - n_f64 * centroid_a_z * centroid_b_x;
|
|
502
|
+
cross_covariance[7] = raw_covariance[7] - n_f64 * centroid_a_z * centroid_b_y;
|
|
503
|
+
cross_covariance[8] = raw_covariance[8] - n_f64 * centroid_a_z * centroid_b_z;
|
|
504
|
+
|
|
505
|
+
// Identity-dominant short-circuit: skip SVD + rotation_from_svd when H is near-diagonal
|
|
506
|
+
// positive-definite. `r` is set to identity and trace(R * H) collapses to H[0]+H[4]+H[8].
|
|
507
|
+
// Saves ~500 cycles on aligned/pre-registered inputs; zero cost when inputs are random
|
|
508
|
+
// (branch is well-predicted in practice).
|
|
509
|
+
nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
510
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
511
|
+
cross_covariance[8] * cross_covariance[8];
|
|
512
|
+
nk_f64_t covariance_offdiagonal_norm_squared =
|
|
513
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
514
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
515
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
516
|
+
nk_f64_t optimal_rotation[9];
|
|
517
|
+
nk_f64_t trace_rotation_covariance;
|
|
518
|
+
if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
|
|
519
|
+
cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
|
|
520
|
+
optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
|
|
521
|
+
optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
|
|
522
|
+
optimal_rotation[8] = 1;
|
|
523
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
524
|
+
}
|
|
525
|
+
else {
|
|
526
|
+
nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
527
|
+
nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
528
|
+
nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
|
|
529
|
+
if (nk_det3x3_f64_(optimal_rotation) < 0) {
|
|
530
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
531
|
+
nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
|
|
532
|
+
}
|
|
533
|
+
trace_rotation_covariance =
|
|
534
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
535
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
536
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
537
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
538
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
948
539
|
}
|
|
949
540
|
if (rotation)
|
|
950
|
-
for (int j = 0; j < 9; ++j) rotation[j] = (nk_f32_t)
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
541
|
+
for (int j = 0; j < 9; ++j) rotation[j] = (nk_f32_t)optimal_rotation[j];
|
|
542
|
+
|
|
543
|
+
// Folded SSD via trace identity: SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
|
|
544
|
+
nk_f64_t centered_norm_squared_a = norm_squared_a -
|
|
545
|
+
n_f64 * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
546
|
+
centroid_a_z * centroid_a_z);
|
|
547
|
+
nk_f64_t centered_norm_squared_b = norm_squared_b -
|
|
548
|
+
n_f64 * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
549
|
+
centroid_b_z * centroid_b_z);
|
|
550
|
+
if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
|
|
551
|
+
if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
|
|
552
|
+
|
|
553
|
+
nk_f64_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0 * trace_rotation_covariance;
|
|
554
|
+
if (sum_squared < 0.0) sum_squared = 0.0;
|
|
555
|
+
*result = nk_f64_sqrt_haswell(sum_squared / n_f64);
|
|
955
556
|
}
|
|
956
557
|
|
|
957
558
|
NK_PUBLIC void nk_rmsd_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
@@ -1064,14 +665,15 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
1064
665
|
__m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
|
|
1065
666
|
|
|
1066
667
|
// Accumulators for covariance matrix (sum of outer products)
|
|
1067
|
-
__m512d
|
|
1068
|
-
__m512d
|
|
1069
|
-
__m512d
|
|
668
|
+
__m512d covariance_xx_f64x8 = zeros_f64x8, covariance_xy_f64x8 = zeros_f64x8, covariance_xz_f64x8 = zeros_f64x8;
|
|
669
|
+
__m512d covariance_yx_f64x8 = zeros_f64x8, covariance_yy_f64x8 = zeros_f64x8, covariance_yz_f64x8 = zeros_f64x8;
|
|
670
|
+
__m512d covariance_zx_f64x8 = zeros_f64x8, covariance_zy_f64x8 = zeros_f64x8, covariance_zz_f64x8 = zeros_f64x8;
|
|
671
|
+
__m512d norm_squared_a_f64x8 = zeros_f64x8, norm_squared_b_f64x8 = zeros_f64x8;
|
|
1070
672
|
|
|
1071
673
|
nk_size_t i = 0;
|
|
1072
674
|
__m512d a_x_f64x8, a_y_f64x8, a_z_f64x8, b_x_f64x8, b_y_f64x8, b_z_f64x8;
|
|
1073
675
|
|
|
1074
|
-
// Fused single-pass: accumulate sums
|
|
676
|
+
// Fused single-pass: accumulate sums, outer products, and norms^2 together
|
|
1075
677
|
for (; i + 8 <= n; i += 8) {
|
|
1076
678
|
nk_deinterleave_f64x8_skylake_(a + i * 3, &a_x_f64x8, &a_y_f64x8, &a_z_f64x8);
|
|
1077
679
|
nk_deinterleave_f64x8_skylake_(b + i * 3, &b_x_f64x8, &b_y_f64x8, &b_z_f64x8);
|
|
@@ -1085,15 +687,21 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
1085
687
|
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, b_z_f64x8);
|
|
1086
688
|
|
|
1087
689
|
// Accumulate outer products (raw, not centered)
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
690
|
+
covariance_xx_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_x_f64x8, covariance_xx_f64x8),
|
|
691
|
+
covariance_xy_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_y_f64x8, covariance_xy_f64x8),
|
|
692
|
+
covariance_xz_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_z_f64x8, covariance_xz_f64x8);
|
|
693
|
+
covariance_yx_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_x_f64x8, covariance_yx_f64x8),
|
|
694
|
+
covariance_yy_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_y_f64x8, covariance_yy_f64x8),
|
|
695
|
+
covariance_yz_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_z_f64x8, covariance_yz_f64x8);
|
|
696
|
+
covariance_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, covariance_zx_f64x8),
|
|
697
|
+
covariance_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, covariance_zy_f64x8),
|
|
698
|
+
covariance_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, covariance_zz_f64x8);
|
|
699
|
+
norm_squared_a_f64x8 = _mm512_fmadd_pd(a_x_f64x8, a_x_f64x8, norm_squared_a_f64x8);
|
|
700
|
+
norm_squared_a_f64x8 = _mm512_fmadd_pd(a_y_f64x8, a_y_f64x8, norm_squared_a_f64x8);
|
|
701
|
+
norm_squared_a_f64x8 = _mm512_fmadd_pd(a_z_f64x8, a_z_f64x8, norm_squared_a_f64x8);
|
|
702
|
+
norm_squared_b_f64x8 = _mm512_fmadd_pd(b_x_f64x8, b_x_f64x8, norm_squared_b_f64x8);
|
|
703
|
+
norm_squared_b_f64x8 = _mm512_fmadd_pd(b_y_f64x8, b_y_f64x8, norm_squared_b_f64x8);
|
|
704
|
+
norm_squared_b_f64x8 = _mm512_fmadd_pd(b_z_f64x8, b_z_f64x8, norm_squared_b_f64x8);
|
|
1097
705
|
}
|
|
1098
706
|
|
|
1099
707
|
// Tail: masked gather for remaining points
|
|
@@ -1117,15 +725,21 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
1117
725
|
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, b_y_f64x8),
|
|
1118
726
|
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, b_z_f64x8);
|
|
1119
727
|
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
728
|
+
covariance_xx_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_x_f64x8, covariance_xx_f64x8),
|
|
729
|
+
covariance_xy_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_y_f64x8, covariance_xy_f64x8),
|
|
730
|
+
covariance_xz_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_z_f64x8, covariance_xz_f64x8);
|
|
731
|
+
covariance_yx_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_x_f64x8, covariance_yx_f64x8),
|
|
732
|
+
covariance_yy_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_y_f64x8, covariance_yy_f64x8),
|
|
733
|
+
covariance_yz_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_z_f64x8, covariance_yz_f64x8);
|
|
734
|
+
covariance_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, covariance_zx_f64x8),
|
|
735
|
+
covariance_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, covariance_zy_f64x8),
|
|
736
|
+
covariance_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, covariance_zz_f64x8);
|
|
737
|
+
norm_squared_a_f64x8 = _mm512_fmadd_pd(a_x_f64x8, a_x_f64x8, norm_squared_a_f64x8);
|
|
738
|
+
norm_squared_a_f64x8 = _mm512_fmadd_pd(a_y_f64x8, a_y_f64x8, norm_squared_a_f64x8);
|
|
739
|
+
norm_squared_a_f64x8 = _mm512_fmadd_pd(a_z_f64x8, a_z_f64x8, norm_squared_a_f64x8);
|
|
740
|
+
norm_squared_b_f64x8 = _mm512_fmadd_pd(b_x_f64x8, b_x_f64x8, norm_squared_b_f64x8);
|
|
741
|
+
norm_squared_b_f64x8 = _mm512_fmadd_pd(b_y_f64x8, b_y_f64x8, norm_squared_b_f64x8);
|
|
742
|
+
norm_squared_b_f64x8 = _mm512_fmadd_pd(b_z_f64x8, b_z_f64x8, norm_squared_b_f64x8);
|
|
1129
743
|
i = n;
|
|
1130
744
|
}
|
|
1131
745
|
|
|
@@ -1137,15 +751,19 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
1137
751
|
nk_f64_t sum_b_x = nk_reduce_stable_f64x8_skylake_(sum_b_x_f64x8), sum_b_x_compensation = 0.0;
|
|
1138
752
|
nk_f64_t sum_b_y = nk_reduce_stable_f64x8_skylake_(sum_b_y_f64x8), sum_b_y_compensation = 0.0;
|
|
1139
753
|
nk_f64_t sum_b_z = nk_reduce_stable_f64x8_skylake_(sum_b_z_f64x8), sum_b_z_compensation = 0.0;
|
|
1140
|
-
nk_f64_t covariance_x_x = nk_reduce_stable_f64x8_skylake_(
|
|
1141
|
-
nk_f64_t covariance_x_y = nk_reduce_stable_f64x8_skylake_(
|
|
1142
|
-
nk_f64_t covariance_x_z = nk_reduce_stable_f64x8_skylake_(
|
|
1143
|
-
nk_f64_t covariance_y_x = nk_reduce_stable_f64x8_skylake_(
|
|
1144
|
-
nk_f64_t covariance_y_y = nk_reduce_stable_f64x8_skylake_(
|
|
1145
|
-
nk_f64_t covariance_y_z = nk_reduce_stable_f64x8_skylake_(
|
|
1146
|
-
nk_f64_t covariance_z_x = nk_reduce_stable_f64x8_skylake_(
|
|
1147
|
-
nk_f64_t covariance_z_y = nk_reduce_stable_f64x8_skylake_(
|
|
1148
|
-
nk_f64_t covariance_z_z = nk_reduce_stable_f64x8_skylake_(
|
|
754
|
+
nk_f64_t covariance_x_x = nk_reduce_stable_f64x8_skylake_(covariance_xx_f64x8), covariance_x_x_compensation = 0.0;
|
|
755
|
+
nk_f64_t covariance_x_y = nk_reduce_stable_f64x8_skylake_(covariance_xy_f64x8), covariance_x_y_compensation = 0.0;
|
|
756
|
+
nk_f64_t covariance_x_z = nk_reduce_stable_f64x8_skylake_(covariance_xz_f64x8), covariance_x_z_compensation = 0.0;
|
|
757
|
+
nk_f64_t covariance_y_x = nk_reduce_stable_f64x8_skylake_(covariance_yx_f64x8), covariance_y_x_compensation = 0.0;
|
|
758
|
+
nk_f64_t covariance_y_y = nk_reduce_stable_f64x8_skylake_(covariance_yy_f64x8), covariance_y_y_compensation = 0.0;
|
|
759
|
+
nk_f64_t covariance_y_z = nk_reduce_stable_f64x8_skylake_(covariance_yz_f64x8), covariance_y_z_compensation = 0.0;
|
|
760
|
+
nk_f64_t covariance_z_x = nk_reduce_stable_f64x8_skylake_(covariance_zx_f64x8), covariance_z_x_compensation = 0.0;
|
|
761
|
+
nk_f64_t covariance_z_y = nk_reduce_stable_f64x8_skylake_(covariance_zy_f64x8), covariance_z_y_compensation = 0.0;
|
|
762
|
+
nk_f64_t covariance_z_z = nk_reduce_stable_f64x8_skylake_(covariance_zz_f64x8), covariance_z_z_compensation = 0.0;
|
|
763
|
+
nk_f64_t norm_squared_a_sum = nk_reduce_stable_f64x8_skylake_(norm_squared_a_f64x8),
|
|
764
|
+
norm_squared_a_compensation = 0.0;
|
|
765
|
+
nk_f64_t norm_squared_b_sum = nk_reduce_stable_f64x8_skylake_(norm_squared_b_f64x8),
|
|
766
|
+
norm_squared_b_compensation = 0.0;
|
|
1149
767
|
|
|
1150
768
|
for (; i < n; ++i) {
|
|
1151
769
|
nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
|
|
@@ -1165,6 +783,12 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
1165
783
|
nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx);
|
|
1166
784
|
nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by);
|
|
1167
785
|
nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
|
|
786
|
+
nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ax);
|
|
787
|
+
nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ay);
|
|
788
|
+
nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, az);
|
|
789
|
+
nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bx);
|
|
790
|
+
nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, by);
|
|
791
|
+
nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bz);
|
|
1168
792
|
}
|
|
1169
793
|
|
|
1170
794
|
sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
|
|
@@ -1175,6 +799,8 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
1175
799
|
covariance_y_z += covariance_y_z_compensation;
|
|
1176
800
|
covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
|
|
1177
801
|
covariance_z_z += covariance_z_z_compensation;
|
|
802
|
+
norm_squared_a_sum += norm_squared_a_compensation;
|
|
803
|
+
norm_squared_b_sum += norm_squared_b_compensation;
|
|
1178
804
|
|
|
1179
805
|
nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
1180
806
|
nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
@@ -1194,177 +820,70 @@ NK_PUBLIC void nk_kabsch_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_si
|
|
|
1194
820
|
cross_covariance[7] = covariance_z_y - sum_a_z * sum_b_y * inv_n;
|
|
1195
821
|
cross_covariance[8] = covariance_z_z - sum_a_z * sum_b_z * inv_n;
|
|
1196
822
|
|
|
1197
|
-
//
|
|
1198
|
-
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
823
|
+
// Identity-dominant short-circuit: if H_centered is near-diagonal positive-definite,
|
|
824
|
+
// R = I and trace(R * H) = H[0] + H[4] + H[8]. Saves ~500 cycles on aligned inputs.
|
|
825
|
+
nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
826
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
827
|
+
cross_covariance[8] * cross_covariance[8];
|
|
828
|
+
nk_f64_t covariance_offdiagonal_norm_squared =
|
|
829
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
830
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
831
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
832
|
+
nk_f64_t optimal_rotation[9];
|
|
833
|
+
nk_f64_t trace_rotation_covariance;
|
|
834
|
+
if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
|
|
835
|
+
cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
|
|
836
|
+
optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
|
|
837
|
+
optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
|
|
838
|
+
optimal_rotation[8] = 1;
|
|
839
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
840
|
+
}
|
|
841
|
+
else {
|
|
842
|
+
nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
843
|
+
nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
844
|
+
nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
|
|
845
|
+
if (nk_det3x3_f64_(optimal_rotation) < 0) {
|
|
846
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
847
|
+
nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
|
|
848
|
+
}
|
|
849
|
+
trace_rotation_covariance =
|
|
850
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
851
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
852
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
853
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
854
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
1208
855
|
}
|
|
1209
856
|
|
|
1210
857
|
// Output rotation matrix and scale=1.0.
|
|
1211
858
|
if (rotation)
|
|
1212
|
-
for (int j = 0; j < 9; ++j) rotation[j] = (nk_f64_t)
|
|
859
|
+
for (int j = 0; j < 9; ++j) rotation[j] = (nk_f64_t)optimal_rotation[j];
|
|
1213
860
|
if (scale) *scale = 1.0;
|
|
1214
861
|
|
|
1215
|
-
//
|
|
1216
|
-
|
|
1217
|
-
|
|
862
|
+
// Folded SSD via trace identity - no second pass over the buffers:
|
|
863
|
+
// SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered).
|
|
864
|
+
nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
|
|
865
|
+
(nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
866
|
+
centroid_a_z * centroid_a_z);
|
|
867
|
+
nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
|
|
868
|
+
(nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
869
|
+
centroid_b_z * centroid_b_z);
|
|
870
|
+
if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
|
|
871
|
+
if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
|
|
872
|
+
nk_f64_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0 * trace_rotation_covariance;
|
|
873
|
+
if (sum_squared < 0.0) sum_squared = 0.0;
|
|
1218
874
|
*result = nk_f64_sqrt_haswell(sum_squared * inv_n);
|
|
1219
875
|
}
|
|
1220
876
|
|
|
1221
877
|
NK_PUBLIC void nk_umeyama_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1222
878
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
|
|
1223
|
-
//
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
__m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
|
|
1227
|
-
__m512d cov_xx_f64x8 = zeros_f64x8, cov_xy_f64x8 = zeros_f64x8, cov_xz_f64x8 = zeros_f64x8;
|
|
1228
|
-
__m512d cov_yx_f64x8 = zeros_f64x8, cov_yy_f64x8 = zeros_f64x8, cov_yz_f64x8 = zeros_f64x8;
|
|
1229
|
-
__m512d cov_zx_f64x8 = zeros_f64x8, cov_zy_f64x8 = zeros_f64x8, cov_zz_f64x8 = zeros_f64x8;
|
|
1230
|
-
__m512d variance_a_f64x8 = zeros_f64x8;
|
|
1231
|
-
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
1232
|
-
nk_size_t i = 0;
|
|
1233
|
-
|
|
1234
|
-
for (; i + 16 <= n; i += 16) {
|
|
1235
|
-
nk_deinterleave_f32x16_skylake_(a + i * 3, &a_x_f32x16, &a_y_f32x16, &a_z_f32x16);
|
|
1236
|
-
nk_deinterleave_f32x16_skylake_(b + i * 3, &b_x_f32x16, &b_y_f32x16, &b_z_f32x16);
|
|
1237
|
-
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
1238
|
-
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
1239
|
-
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
1240
|
-
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
1241
|
-
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
1242
|
-
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
1243
|
-
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
1244
|
-
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
1245
|
-
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
1246
|
-
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
1247
|
-
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
1248
|
-
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
1249
|
-
|
|
1250
|
-
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
1251
|
-
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
1252
|
-
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
1253
|
-
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
1254
|
-
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
1255
|
-
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
1256
|
-
|
|
1257
|
-
cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
|
|
1258
|
-
_mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
|
|
1259
|
-
cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
|
|
1260
|
-
_mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
|
|
1261
|
-
cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
|
|
1262
|
-
_mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
|
|
1263
|
-
cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
|
|
1264
|
-
_mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
|
|
1265
|
-
cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
|
|
1266
|
-
_mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
|
|
1267
|
-
cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
|
|
1268
|
-
_mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
|
|
1269
|
-
cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
|
|
1270
|
-
_mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
|
|
1271
|
-
cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
|
|
1272
|
-
_mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
|
|
1273
|
-
cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
|
|
1274
|
-
_mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
|
|
1275
|
-
|
|
1276
|
-
variance_a_f64x8 = _mm512_add_pd(
|
|
1277
|
-
variance_a_f64x8,
|
|
1278
|
-
_mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, a_x_low_f64x8), _mm512_mul_pd(a_x_high_f64x8, a_x_high_f64x8)));
|
|
1279
|
-
variance_a_f64x8 = _mm512_add_pd(
|
|
1280
|
-
variance_a_f64x8,
|
|
1281
|
-
_mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, a_y_low_f64x8), _mm512_mul_pd(a_y_high_f64x8, a_y_high_f64x8)));
|
|
1282
|
-
variance_a_f64x8 = _mm512_add_pd(
|
|
1283
|
-
variance_a_f64x8,
|
|
1284
|
-
_mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, a_z_low_f64x8), _mm512_mul_pd(a_z_high_f64x8, a_z_high_f64x8)));
|
|
1285
|
-
}
|
|
1286
|
-
|
|
1287
|
-
// Tail: use masked gather for remaining < 16 points
|
|
1288
|
-
if (i < n) {
|
|
1289
|
-
nk_size_t tail = n - i;
|
|
1290
|
-
__mmask16 mask = (__mmask16)_bzhi_u32(0xFFFF, tail);
|
|
1291
|
-
__m512i const gather_idx_i32x16 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 33, 36, 39, 42, 45);
|
|
1292
|
-
__m512 zeros_f32x16 = _mm512_setzero_ps();
|
|
1293
|
-
nk_f32_t const *a_tail = a + i * 3;
|
|
1294
|
-
nk_f32_t const *b_tail = b + i * 3;
|
|
1295
|
-
|
|
1296
|
-
a_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 0, 4);
|
|
1297
|
-
a_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 1, 4);
|
|
1298
|
-
a_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, a_tail + 2, 4);
|
|
1299
|
-
b_x_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 0, 4);
|
|
1300
|
-
b_y_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 1, 4);
|
|
1301
|
-
b_z_f32x16 = _mm512_mask_i32gather_ps(zeros_f32x16, mask, gather_idx_i32x16, b_tail + 2, 4);
|
|
1302
|
-
|
|
1303
|
-
__m512d a_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_x_f32x16));
|
|
1304
|
-
__m512d a_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_x_f32x16, 1));
|
|
1305
|
-
__m512d a_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_y_f32x16));
|
|
1306
|
-
__m512d a_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_y_f32x16, 1));
|
|
1307
|
-
__m512d a_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(a_z_f32x16));
|
|
1308
|
-
__m512d a_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(a_z_f32x16, 1));
|
|
1309
|
-
__m512d b_x_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_x_f32x16));
|
|
1310
|
-
__m512d b_x_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_x_f32x16, 1));
|
|
1311
|
-
__m512d b_y_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_y_f32x16));
|
|
1312
|
-
__m512d b_y_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_y_f32x16, 1));
|
|
1313
|
-
__m512d b_z_low_f64x8 = _mm512_cvtps_pd(_mm512_castps512_ps256(b_z_f32x16));
|
|
1314
|
-
__m512d b_z_high_f64x8 = _mm512_cvtps_pd(_mm512_extractf32x8_ps(b_z_f32x16, 1));
|
|
1315
|
-
|
|
1316
|
-
sum_a_x_f64x8 = _mm512_add_pd(sum_a_x_f64x8, _mm512_add_pd(a_x_low_f64x8, a_x_high_f64x8));
|
|
1317
|
-
sum_a_y_f64x8 = _mm512_add_pd(sum_a_y_f64x8, _mm512_add_pd(a_y_low_f64x8, a_y_high_f64x8));
|
|
1318
|
-
sum_a_z_f64x8 = _mm512_add_pd(sum_a_z_f64x8, _mm512_add_pd(a_z_low_f64x8, a_z_high_f64x8));
|
|
1319
|
-
sum_b_x_f64x8 = _mm512_add_pd(sum_b_x_f64x8, _mm512_add_pd(b_x_low_f64x8, b_x_high_f64x8));
|
|
1320
|
-
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, _mm512_add_pd(b_y_low_f64x8, b_y_high_f64x8));
|
|
1321
|
-
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, _mm512_add_pd(b_z_low_f64x8, b_z_high_f64x8));
|
|
1322
|
-
|
|
1323
|
-
cov_xx_f64x8 = _mm512_add_pd(cov_xx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_x_low_f64x8),
|
|
1324
|
-
_mm512_mul_pd(a_x_high_f64x8, b_x_high_f64x8)));
|
|
1325
|
-
cov_xy_f64x8 = _mm512_add_pd(cov_xy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_y_low_f64x8),
|
|
1326
|
-
_mm512_mul_pd(a_x_high_f64x8, b_y_high_f64x8)));
|
|
1327
|
-
cov_xz_f64x8 = _mm512_add_pd(cov_xz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, b_z_low_f64x8),
|
|
1328
|
-
_mm512_mul_pd(a_x_high_f64x8, b_z_high_f64x8)));
|
|
1329
|
-
cov_yx_f64x8 = _mm512_add_pd(cov_yx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_x_low_f64x8),
|
|
1330
|
-
_mm512_mul_pd(a_y_high_f64x8, b_x_high_f64x8)));
|
|
1331
|
-
cov_yy_f64x8 = _mm512_add_pd(cov_yy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_y_low_f64x8),
|
|
1332
|
-
_mm512_mul_pd(a_y_high_f64x8, b_y_high_f64x8)));
|
|
1333
|
-
cov_yz_f64x8 = _mm512_add_pd(cov_yz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, b_z_low_f64x8),
|
|
1334
|
-
_mm512_mul_pd(a_y_high_f64x8, b_z_high_f64x8)));
|
|
1335
|
-
cov_zx_f64x8 = _mm512_add_pd(cov_zx_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_x_low_f64x8),
|
|
1336
|
-
_mm512_mul_pd(a_z_high_f64x8, b_x_high_f64x8)));
|
|
1337
|
-
cov_zy_f64x8 = _mm512_add_pd(cov_zy_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_y_low_f64x8),
|
|
1338
|
-
_mm512_mul_pd(a_z_high_f64x8, b_y_high_f64x8)));
|
|
1339
|
-
cov_zz_f64x8 = _mm512_add_pd(cov_zz_f64x8, _mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, b_z_low_f64x8),
|
|
1340
|
-
_mm512_mul_pd(a_z_high_f64x8, b_z_high_f64x8)));
|
|
1341
|
-
|
|
1342
|
-
variance_a_f64x8 = _mm512_add_pd(
|
|
1343
|
-
variance_a_f64x8,
|
|
1344
|
-
_mm512_add_pd(_mm512_mul_pd(a_x_low_f64x8, a_x_low_f64x8), _mm512_mul_pd(a_x_high_f64x8, a_x_high_f64x8)));
|
|
1345
|
-
variance_a_f64x8 = _mm512_add_pd(
|
|
1346
|
-
variance_a_f64x8,
|
|
1347
|
-
_mm512_add_pd(_mm512_mul_pd(a_y_low_f64x8, a_y_low_f64x8), _mm512_mul_pd(a_y_high_f64x8, a_y_high_f64x8)));
|
|
1348
|
-
variance_a_f64x8 = _mm512_add_pd(
|
|
1349
|
-
variance_a_f64x8,
|
|
1350
|
-
_mm512_add_pd(_mm512_mul_pd(a_z_low_f64x8, a_z_low_f64x8), _mm512_mul_pd(a_z_high_f64x8, a_z_high_f64x8)));
|
|
1351
|
-
}
|
|
879
|
+
// Single pass over (a, b) via streaming-stats helper — no deinterleave, no second SSD pass.
|
|
880
|
+
nk_f64_t sum_a[3], sum_b[3], raw_covariance[9], norm_squared_a, norm_squared_b;
|
|
881
|
+
nk_mesh_streaming_stats_f32_skylake_(a, b, n, sum_a, sum_b, raw_covariance, &norm_squared_a, &norm_squared_b);
|
|
1352
882
|
|
|
1353
|
-
nk_f64_t
|
|
1354
|
-
|
|
1355
|
-
nk_f64_t
|
|
1356
|
-
|
|
1357
|
-
nk_f64_t covariance_x_x = _mm512_reduce_add_pd(cov_xx_f64x8), covariance_x_y = _mm512_reduce_add_pd(cov_xy_f64x8),
|
|
1358
|
-
covariance_x_z = _mm512_reduce_add_pd(cov_xz_f64x8);
|
|
1359
|
-
nk_f64_t covariance_y_x = _mm512_reduce_add_pd(cov_yx_f64x8), covariance_y_y = _mm512_reduce_add_pd(cov_yy_f64x8),
|
|
1360
|
-
covariance_y_z = _mm512_reduce_add_pd(cov_yz_f64x8);
|
|
1361
|
-
nk_f64_t covariance_z_x = _mm512_reduce_add_pd(cov_zx_f64x8), covariance_z_y = _mm512_reduce_add_pd(cov_zy_f64x8),
|
|
1362
|
-
covariance_z_z = _mm512_reduce_add_pd(cov_zz_f64x8);
|
|
1363
|
-
nk_f64_t variance_a_sum = _mm512_reduce_add_pd(variance_a_f64x8);
|
|
1364
|
-
|
|
1365
|
-
nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
|
|
1366
|
-
nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
1367
|
-
nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
883
|
+
nk_f64_t n_f64 = (nk_f64_t)n;
|
|
884
|
+
nk_f64_t inv_n = 1.0 / n_f64;
|
|
885
|
+
nk_f64_t centroid_a_x = sum_a[0] * inv_n, centroid_a_y = sum_a[1] * inv_n, centroid_a_z = sum_a[2] * inv_n;
|
|
886
|
+
nk_f64_t centroid_b_x = sum_b[0] * inv_n, centroid_b_y = sum_b[1] * inv_n, centroid_b_z = sum_b[2] * inv_n;
|
|
1368
887
|
if (a_centroid)
|
|
1369
888
|
a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
|
|
1370
889
|
a_centroid[2] = (nk_f32_t)centroid_a_z;
|
|
@@ -1372,49 +891,81 @@ NK_PUBLIC void nk_umeyama_f32_skylake(nk_f32_t const *a, nk_f32_t const *b, nk_s
|
|
|
1372
891
|
b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
|
|
1373
892
|
b_centroid[2] = (nk_f32_t)centroid_b_z;
|
|
1374
893
|
|
|
1375
|
-
//
|
|
1376
|
-
nk_f64_t
|
|
1377
|
-
|
|
894
|
+
// Centered norms and centered covariance via parallel-axis identity.
|
|
895
|
+
nk_f64_t centered_norm_squared_a = norm_squared_a -
|
|
896
|
+
n_f64 * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
897
|
+
centroid_a_z * centroid_a_z);
|
|
898
|
+
nk_f64_t centered_norm_squared_b = norm_squared_b -
|
|
899
|
+
n_f64 * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
900
|
+
centroid_b_z * centroid_b_z);
|
|
901
|
+
if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
|
|
902
|
+
if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
|
|
903
|
+
nk_f64_t variance_a = centered_norm_squared_a * inv_n;
|
|
1378
904
|
|
|
1379
|
-
// Compute centered covariance matrix: Hᵢⱼ = Σ(aᵢ×bⱼ) - Σaᵢ × Σbⱼ / n
|
|
1380
|
-
nk_f64_t n_f64 = (nk_f64_t)n;
|
|
1381
905
|
nk_f64_t cross_covariance[9];
|
|
1382
|
-
cross_covariance[0] =
|
|
1383
|
-
cross_covariance[1] =
|
|
1384
|
-
cross_covariance[2] =
|
|
1385
|
-
cross_covariance[3] =
|
|
1386
|
-
cross_covariance[4] =
|
|
1387
|
-
cross_covariance[5] =
|
|
1388
|
-
cross_covariance[6] =
|
|
1389
|
-
cross_covariance[7] =
|
|
1390
|
-
cross_covariance[8] =
|
|
1391
|
-
|
|
1392
|
-
//
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
nk_f64_t
|
|
1403
|
-
nk_f64_t applied_scale
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1409
|
-
|
|
906
|
+
cross_covariance[0] = raw_covariance[0] - n_f64 * centroid_a_x * centroid_b_x;
|
|
907
|
+
cross_covariance[1] = raw_covariance[1] - n_f64 * centroid_a_x * centroid_b_y;
|
|
908
|
+
cross_covariance[2] = raw_covariance[2] - n_f64 * centroid_a_x * centroid_b_z;
|
|
909
|
+
cross_covariance[3] = raw_covariance[3] - n_f64 * centroid_a_y * centroid_b_x;
|
|
910
|
+
cross_covariance[4] = raw_covariance[4] - n_f64 * centroid_a_y * centroid_b_y;
|
|
911
|
+
cross_covariance[5] = raw_covariance[5] - n_f64 * centroid_a_y * centroid_b_z;
|
|
912
|
+
cross_covariance[6] = raw_covariance[6] - n_f64 * centroid_a_z * centroid_b_x;
|
|
913
|
+
cross_covariance[7] = raw_covariance[7] - n_f64 * centroid_a_z * centroid_b_y;
|
|
914
|
+
cross_covariance[8] = raw_covariance[8] - n_f64 * centroid_a_z * centroid_b_z;
|
|
915
|
+
|
|
916
|
+
// Identity-dominant short-circuit: when H_centered is near-diagonal positive-definite, R = I
|
|
917
|
+
// and trace(R * H) collapses to H[0]+H[4]+H[8]. Also d3 = +1, so trace_ds = sum of diagonal,
|
|
918
|
+
// and applied_scale = trace_ds / (n * variance_a). Skips SVD + two rotation_from_svd calls.
|
|
919
|
+
nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
920
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
921
|
+
cross_covariance[8] * cross_covariance[8];
|
|
922
|
+
nk_f64_t covariance_offdiagonal_norm_squared =
|
|
923
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
924
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
925
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
926
|
+
nk_f64_t optimal_rotation[9];
|
|
927
|
+
nk_f64_t applied_scale;
|
|
928
|
+
nk_f64_t trace_rotation_covariance;
|
|
929
|
+
if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
|
|
930
|
+
cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
|
|
931
|
+
optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
|
|
932
|
+
optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
|
|
933
|
+
optimal_rotation[8] = 1;
|
|
934
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
935
|
+
applied_scale = trace_rotation_covariance / (n_f64 * variance_a);
|
|
1410
936
|
}
|
|
1411
|
-
|
|
937
|
+
else {
|
|
938
|
+
nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
939
|
+
nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
940
|
+
nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
|
|
941
|
+
|
|
942
|
+
// Scale factor: c = trace(D · S) / (n * variance_a), with reflection sign via d3.
|
|
943
|
+
nk_f64_t det = nk_det3x3_f64_(optimal_rotation);
|
|
944
|
+
nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
|
|
945
|
+
nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_diagonal[0], 1.0, svd_diagonal[4], 1.0, svd_diagonal[8], d3);
|
|
946
|
+
applied_scale = trace_ds / (n_f64 * variance_a);
|
|
947
|
+
|
|
948
|
+
if (det < 0) {
|
|
949
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
950
|
+
nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
|
|
951
|
+
}
|
|
952
|
+
trace_rotation_covariance =
|
|
953
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
954
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
955
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
956
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
957
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
958
|
+
}
|
|
959
|
+
if (scale) *scale = (nk_f32_t)applied_scale;
|
|
1412
960
|
if (rotation)
|
|
1413
|
-
for (int j = 0; j < 9; ++j) rotation[j] = (nk_f32_t)
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
961
|
+
for (int j = 0; j < 9; ++j) rotation[j] = (nk_f32_t)optimal_rotation[j];
|
|
962
|
+
|
|
963
|
+
// Folded SSD with scale: sum(|| s*R*(a-abar) - (b-bbar) ||^2)
|
|
964
|
+
// = s²·‖a-ā‖² + ‖b-b̄‖² − 2s·trace(R · H_centered).
|
|
965
|
+
nk_f64_t sum_squared = applied_scale * applied_scale * centered_norm_squared_a + centered_norm_squared_b -
|
|
966
|
+
2.0 * applied_scale * trace_rotation_covariance;
|
|
967
|
+
if (sum_squared < 0.0) sum_squared = 0.0;
|
|
968
|
+
*result = nk_f64_sqrt_haswell(sum_squared / n_f64);
|
|
1418
969
|
}
|
|
1419
970
|
|
|
1420
971
|
NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
|
|
@@ -1425,10 +976,10 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1425
976
|
|
|
1426
977
|
__m512d sum_a_x_f64x8 = zeros_f64x8, sum_a_y_f64x8 = zeros_f64x8, sum_a_z_f64x8 = zeros_f64x8;
|
|
1427
978
|
__m512d sum_b_x_f64x8 = zeros_f64x8, sum_b_y_f64x8 = zeros_f64x8, sum_b_z_f64x8 = zeros_f64x8;
|
|
1428
|
-
__m512d
|
|
1429
|
-
__m512d
|
|
1430
|
-
__m512d
|
|
1431
|
-
__m512d
|
|
979
|
+
__m512d covariance_xx_f64x8 = zeros_f64x8, covariance_xy_f64x8 = zeros_f64x8, covariance_xz_f64x8 = zeros_f64x8;
|
|
980
|
+
__m512d covariance_yx_f64x8 = zeros_f64x8, covariance_yy_f64x8 = zeros_f64x8, covariance_yz_f64x8 = zeros_f64x8;
|
|
981
|
+
__m512d covariance_zx_f64x8 = zeros_f64x8, covariance_zy_f64x8 = zeros_f64x8, covariance_zz_f64x8 = zeros_f64x8;
|
|
982
|
+
__m512d norm_squared_a_f64x8 = zeros_f64x8, norm_squared_b_f64x8 = zeros_f64x8;
|
|
1432
983
|
|
|
1433
984
|
nk_size_t i = 0;
|
|
1434
985
|
__m512d a_x_f64x8, a_y_f64x8, a_z_f64x8, b_x_f64x8, b_y_f64x8, b_z_f64x8;
|
|
@@ -1444,18 +995,21 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1444
995
|
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, b_y_f64x8);
|
|
1445
996
|
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, b_z_f64x8);
|
|
1446
997
|
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
998
|
+
covariance_xx_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_x_f64x8, covariance_xx_f64x8),
|
|
999
|
+
covariance_xy_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_y_f64x8, covariance_xy_f64x8);
|
|
1000
|
+
covariance_xz_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_z_f64x8, covariance_xz_f64x8);
|
|
1001
|
+
covariance_yx_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_x_f64x8, covariance_yx_f64x8),
|
|
1002
|
+
covariance_yy_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_y_f64x8, covariance_yy_f64x8);
|
|
1003
|
+
covariance_yz_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_z_f64x8, covariance_yz_f64x8);
|
|
1004
|
+
covariance_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, covariance_zx_f64x8),
|
|
1005
|
+
covariance_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, covariance_zy_f64x8);
|
|
1006
|
+
covariance_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, covariance_zz_f64x8);
|
|
1007
|
+
norm_squared_a_f64x8 = _mm512_fmadd_pd(a_x_f64x8, a_x_f64x8, norm_squared_a_f64x8);
|
|
1008
|
+
norm_squared_a_f64x8 = _mm512_fmadd_pd(a_y_f64x8, a_y_f64x8, norm_squared_a_f64x8);
|
|
1009
|
+
norm_squared_a_f64x8 = _mm512_fmadd_pd(a_z_f64x8, a_z_f64x8, norm_squared_a_f64x8);
|
|
1010
|
+
norm_squared_b_f64x8 = _mm512_fmadd_pd(b_x_f64x8, b_x_f64x8, norm_squared_b_f64x8);
|
|
1011
|
+
norm_squared_b_f64x8 = _mm512_fmadd_pd(b_y_f64x8, b_y_f64x8, norm_squared_b_f64x8);
|
|
1012
|
+
norm_squared_b_f64x8 = _mm512_fmadd_pd(b_z_f64x8, b_z_f64x8, norm_squared_b_f64x8);
|
|
1459
1013
|
}
|
|
1460
1014
|
|
|
1461
1015
|
if (i < n) {
|
|
@@ -1478,18 +1032,21 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1478
1032
|
sum_b_y_f64x8 = _mm512_add_pd(sum_b_y_f64x8, b_y_f64x8);
|
|
1479
1033
|
sum_b_z_f64x8 = _mm512_add_pd(sum_b_z_f64x8, b_z_f64x8);
|
|
1480
1034
|
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1035
|
+
covariance_xx_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_x_f64x8, covariance_xx_f64x8),
|
|
1036
|
+
covariance_xy_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_y_f64x8, covariance_xy_f64x8);
|
|
1037
|
+
covariance_xz_f64x8 = _mm512_fmadd_pd(a_x_f64x8, b_z_f64x8, covariance_xz_f64x8);
|
|
1038
|
+
covariance_yx_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_x_f64x8, covariance_yx_f64x8),
|
|
1039
|
+
covariance_yy_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_y_f64x8, covariance_yy_f64x8);
|
|
1040
|
+
covariance_yz_f64x8 = _mm512_fmadd_pd(a_y_f64x8, b_z_f64x8, covariance_yz_f64x8);
|
|
1041
|
+
covariance_zx_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_x_f64x8, covariance_zx_f64x8),
|
|
1042
|
+
covariance_zy_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_y_f64x8, covariance_zy_f64x8);
|
|
1043
|
+
covariance_zz_f64x8 = _mm512_fmadd_pd(a_z_f64x8, b_z_f64x8, covariance_zz_f64x8);
|
|
1044
|
+
norm_squared_a_f64x8 = _mm512_fmadd_pd(a_x_f64x8, a_x_f64x8, norm_squared_a_f64x8);
|
|
1045
|
+
norm_squared_a_f64x8 = _mm512_fmadd_pd(a_y_f64x8, a_y_f64x8, norm_squared_a_f64x8);
|
|
1046
|
+
norm_squared_a_f64x8 = _mm512_fmadd_pd(a_z_f64x8, a_z_f64x8, norm_squared_a_f64x8);
|
|
1047
|
+
norm_squared_b_f64x8 = _mm512_fmadd_pd(b_x_f64x8, b_x_f64x8, norm_squared_b_f64x8);
|
|
1048
|
+
norm_squared_b_f64x8 = _mm512_fmadd_pd(b_y_f64x8, b_y_f64x8, norm_squared_b_f64x8);
|
|
1049
|
+
norm_squared_b_f64x8 = _mm512_fmadd_pd(b_z_f64x8, b_z_f64x8, norm_squared_b_f64x8);
|
|
1493
1050
|
i = n;
|
|
1494
1051
|
}
|
|
1495
1052
|
|
|
@@ -1501,16 +1058,19 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1501
1058
|
nk_f64_t sum_b_x = nk_reduce_stable_f64x8_skylake_(sum_b_x_f64x8), sum_b_x_compensation = 0.0;
|
|
1502
1059
|
nk_f64_t sum_b_y = nk_reduce_stable_f64x8_skylake_(sum_b_y_f64x8), sum_b_y_compensation = 0.0;
|
|
1503
1060
|
nk_f64_t sum_b_z = nk_reduce_stable_f64x8_skylake_(sum_b_z_f64x8), sum_b_z_compensation = 0.0;
|
|
1504
|
-
nk_f64_t covariance_x_x = nk_reduce_stable_f64x8_skylake_(
|
|
1505
|
-
nk_f64_t covariance_x_y = nk_reduce_stable_f64x8_skylake_(
|
|
1506
|
-
nk_f64_t covariance_x_z = nk_reduce_stable_f64x8_skylake_(
|
|
1507
|
-
nk_f64_t covariance_y_x = nk_reduce_stable_f64x8_skylake_(
|
|
1508
|
-
nk_f64_t covariance_y_y = nk_reduce_stable_f64x8_skylake_(
|
|
1509
|
-
nk_f64_t covariance_y_z = nk_reduce_stable_f64x8_skylake_(
|
|
1510
|
-
nk_f64_t covariance_z_x = nk_reduce_stable_f64x8_skylake_(
|
|
1511
|
-
nk_f64_t covariance_z_y = nk_reduce_stable_f64x8_skylake_(
|
|
1512
|
-
nk_f64_t covariance_z_z = nk_reduce_stable_f64x8_skylake_(
|
|
1513
|
-
nk_f64_t
|
|
1061
|
+
nk_f64_t covariance_x_x = nk_reduce_stable_f64x8_skylake_(covariance_xx_f64x8), covariance_x_x_compensation = 0.0;
|
|
1062
|
+
nk_f64_t covariance_x_y = nk_reduce_stable_f64x8_skylake_(covariance_xy_f64x8), covariance_x_y_compensation = 0.0;
|
|
1063
|
+
nk_f64_t covariance_x_z = nk_reduce_stable_f64x8_skylake_(covariance_xz_f64x8), covariance_x_z_compensation = 0.0;
|
|
1064
|
+
nk_f64_t covariance_y_x = nk_reduce_stable_f64x8_skylake_(covariance_yx_f64x8), covariance_y_x_compensation = 0.0;
|
|
1065
|
+
nk_f64_t covariance_y_y = nk_reduce_stable_f64x8_skylake_(covariance_yy_f64x8), covariance_y_y_compensation = 0.0;
|
|
1066
|
+
nk_f64_t covariance_y_z = nk_reduce_stable_f64x8_skylake_(covariance_yz_f64x8), covariance_y_z_compensation = 0.0;
|
|
1067
|
+
nk_f64_t covariance_z_x = nk_reduce_stable_f64x8_skylake_(covariance_zx_f64x8), covariance_z_x_compensation = 0.0;
|
|
1068
|
+
nk_f64_t covariance_z_y = nk_reduce_stable_f64x8_skylake_(covariance_zy_f64x8), covariance_z_y_compensation = 0.0;
|
|
1069
|
+
nk_f64_t covariance_z_z = nk_reduce_stable_f64x8_skylake_(covariance_zz_f64x8), covariance_z_z_compensation = 0.0;
|
|
1070
|
+
nk_f64_t norm_squared_a_sum = nk_reduce_stable_f64x8_skylake_(norm_squared_a_f64x8),
|
|
1071
|
+
norm_squared_a_compensation = 0.0;
|
|
1072
|
+
nk_f64_t norm_squared_b_sum = nk_reduce_stable_f64x8_skylake_(norm_squared_b_f64x8),
|
|
1073
|
+
norm_squared_b_compensation = 0.0;
|
|
1514
1074
|
|
|
1515
1075
|
for (; i < n; ++i) {
|
|
1516
1076
|
nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
|
|
@@ -1530,9 +1090,12 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1530
1090
|
nk_accumulate_product_f64_(&covariance_z_x, &covariance_z_x_compensation, az, bx);
|
|
1531
1091
|
nk_accumulate_product_f64_(&covariance_z_y, &covariance_z_y_compensation, az, by);
|
|
1532
1092
|
nk_accumulate_product_f64_(&covariance_z_z, &covariance_z_z_compensation, az, bz);
|
|
1533
|
-
nk_accumulate_square_f64_(&
|
|
1534
|
-
nk_accumulate_square_f64_(&
|
|
1535
|
-
nk_accumulate_square_f64_(&
|
|
1093
|
+
nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ax);
|
|
1094
|
+
nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, ay);
|
|
1095
|
+
nk_accumulate_square_f64_(&norm_squared_a_sum, &norm_squared_a_compensation, az);
|
|
1096
|
+
nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bx);
|
|
1097
|
+
nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, by);
|
|
1098
|
+
nk_accumulate_square_f64_(&norm_squared_b_sum, &norm_squared_b_compensation, bz);
|
|
1536
1099
|
}
|
|
1537
1100
|
|
|
1538
1101
|
sum_a_x += sum_a_x_compensation, sum_a_y += sum_a_y_compensation, sum_a_z += sum_a_z_compensation;
|
|
@@ -1543,7 +1106,8 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1543
1106
|
covariance_y_z += covariance_y_z_compensation;
|
|
1544
1107
|
covariance_z_x += covariance_z_x_compensation, covariance_z_y += covariance_z_y_compensation,
|
|
1545
1108
|
covariance_z_z += covariance_z_z_compensation;
|
|
1546
|
-
|
|
1109
|
+
norm_squared_a_sum += norm_squared_a_compensation;
|
|
1110
|
+
norm_squared_b_sum += norm_squared_b_compensation;
|
|
1547
1111
|
|
|
1548
1112
|
nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
1549
1113
|
nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
@@ -1551,9 +1115,15 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1551
1115
|
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1552
1116
|
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1553
1117
|
|
|
1554
|
-
//
|
|
1555
|
-
nk_f64_t
|
|
1556
|
-
|
|
1118
|
+
// Centered norm squared via parallel-axis identity (clamped for numerical safety).
|
|
1119
|
+
nk_f64_t centered_norm_squared_a = norm_squared_a_sum -
|
|
1120
|
+
(nk_f64_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
1121
|
+
centroid_a_z * centroid_a_z);
|
|
1122
|
+
nk_f64_t centered_norm_squared_b = norm_squared_b_sum -
|
|
1123
|
+
(nk_f64_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
1124
|
+
centroid_b_z * centroid_b_z);
|
|
1125
|
+
if (centered_norm_squared_a < 0.0) centered_norm_squared_a = 0.0;
|
|
1126
|
+
if (centered_norm_squared_b < 0.0) centered_norm_squared_b = 0.0;
|
|
1557
1127
|
|
|
1558
1128
|
// Compute centered covariance matrix: Hᵢⱼ = Σ(aᵢ×bⱼ) - Σaᵢ × Σbⱼ / n.
|
|
1559
1129
|
nk_f64_t cross_covariance[9];
|
|
@@ -1568,32 +1138,58 @@ NK_PUBLIC void nk_umeyama_f64_skylake(nk_f64_t const *a, nk_f64_t const *b, nk_s
|
|
|
1568
1138
|
cross_covariance[8] = covariance_z_z - sum_a_z * sum_b_z * inv_n;
|
|
1569
1139
|
|
|
1570
1140
|
// SVD using f64 for full precision
|
|
1571
|
-
|
|
1572
|
-
|
|
1573
|
-
|
|
1574
|
-
nk_f64_t
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
nk_f64_t
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
|
|
1141
|
+
// Identity-dominant short-circuit: when H_centered is near-diagonal positive-definite,
|
|
1142
|
+
// R = I, trace(R * H) = H[0]+H[4]+H[8] (also == trace_ds with d3=+1), and the scale
|
|
1143
|
+
// derivation collapses. Skips SVD + two rotation_from_svd calls.
|
|
1144
|
+
nk_f64_t covariance_diagonal_norm_squared = cross_covariance[0] * cross_covariance[0] +
|
|
1145
|
+
cross_covariance[4] * cross_covariance[4] +
|
|
1146
|
+
cross_covariance[8] * cross_covariance[8];
|
|
1147
|
+
nk_f64_t covariance_offdiagonal_norm_squared =
|
|
1148
|
+
cross_covariance[1] * cross_covariance[1] + cross_covariance[2] * cross_covariance[2] +
|
|
1149
|
+
cross_covariance[3] * cross_covariance[3] + cross_covariance[5] * cross_covariance[5] +
|
|
1150
|
+
cross_covariance[6] * cross_covariance[6] + cross_covariance[7] * cross_covariance[7];
|
|
1151
|
+
nk_f64_t optimal_rotation[9];
|
|
1152
|
+
nk_f64_t c;
|
|
1153
|
+
nk_f64_t trace_rotation_covariance;
|
|
1154
|
+
if (covariance_offdiagonal_norm_squared < 1e-20 * covariance_diagonal_norm_squared && cross_covariance[0] > 0.0 &&
|
|
1155
|
+
cross_covariance[4] > 0.0 && cross_covariance[8] > 0.0) {
|
|
1156
|
+
optimal_rotation[0] = 1, optimal_rotation[1] = 0, optimal_rotation[2] = 0, optimal_rotation[3] = 0,
|
|
1157
|
+
optimal_rotation[4] = 1, optimal_rotation[5] = 0, optimal_rotation[6] = 0, optimal_rotation[7] = 0,
|
|
1158
|
+
optimal_rotation[8] = 1;
|
|
1159
|
+
trace_rotation_covariance = cross_covariance[0] + cross_covariance[4] + cross_covariance[8];
|
|
1160
|
+
c = centered_norm_squared_a > 0.0 ? trace_rotation_covariance / centered_norm_squared_a : 0.0;
|
|
1588
1161
|
}
|
|
1589
|
-
|
|
1590
|
-
|
|
1162
|
+
else {
|
|
1163
|
+
nk_f64_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
1164
|
+
nk_svd3x3_f64_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
1165
|
+
nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
|
|
1166
|
+
|
|
1167
|
+
// Scale factor: c = trace(D · S) / (n * variance(a)), with reflection sign via d3.
|
|
1168
|
+
nk_f64_t det = nk_det3x3_f64_(optimal_rotation);
|
|
1169
|
+
nk_f64_t d3 = det < 0 ? -1.0 : 1.0;
|
|
1170
|
+
nk_f64_t trace_ds = nk_sum_three_products_f64_(svd_diagonal[0], 1.0, svd_diagonal[4], 1.0, svd_diagonal[8], d3);
|
|
1171
|
+
c = centered_norm_squared_a > 0.0 ? trace_ds / centered_norm_squared_a : 0.0;
|
|
1172
|
+
|
|
1173
|
+
if (det < 0) {
|
|
1174
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
1175
|
+
nk_rotation_from_svd_f64_serial_(svd_left, svd_right, optimal_rotation);
|
|
1176
|
+
}
|
|
1177
|
+
trace_rotation_covariance =
|
|
1178
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
1179
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
1180
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
1181
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
1182
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
1183
|
+
}
|
|
1184
|
+
if (scale) *scale = c;
|
|
1591
1185
|
if (rotation)
|
|
1592
|
-
for (int j = 0; j < 9; ++j) rotation[j] = (nk_f64_t)
|
|
1186
|
+
for (int j = 0; j < 9; ++j) rotation[j] = (nk_f64_t)optimal_rotation[j];
|
|
1593
1187
|
|
|
1594
|
-
//
|
|
1595
|
-
|
|
1596
|
-
|
|
1188
|
+
// Folded SSD with scale: Sum(|| c*R*(a-abar) - (b-bbar) ||^2)
|
|
1189
|
+
// = c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
|
|
1190
|
+
nk_f64_t sum_squared = c * c * centered_norm_squared_a + centered_norm_squared_b -
|
|
1191
|
+
2.0 * c * trace_rotation_covariance;
|
|
1192
|
+
if (sum_squared < 0.0) sum_squared = 0.0;
|
|
1597
1193
|
*result = nk_f64_sqrt_haswell(sum_squared * inv_n);
|
|
1598
1194
|
}
|
|
1599
1195
|
|
|
@@ -1603,85 +1199,34 @@ NK_PUBLIC void nk_rmsd_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size
|
|
|
1603
1199
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
1604
1200
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
1605
1201
|
if (scale) *scale = 1.0f;
|
|
1202
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
1203
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
1606
1204
|
|
|
1607
|
-
|
|
1608
|
-
__m512
|
|
1609
|
-
|
|
1610
|
-
__m512 sum_squared_x_f32x16 = zeros_f32x16, sum_squared_y_f32x16 = zeros_f32x16;
|
|
1611
|
-
__m512 sum_squared_z_f32x16 = zeros_f32x16;
|
|
1612
|
-
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
1613
|
-
nk_size_t i = 0;
|
|
1205
|
+
// 15-lane stride-3 layout: mask at the f16 level so the last chunk stays in-bounds.
|
|
1206
|
+
__m512 sum_squared_f32x16 = _mm512_setzero_ps();
|
|
1207
|
+
nk_size_t index = 0;
|
|
1614
1208
|
|
|
1615
|
-
for (;
|
|
1616
|
-
|
|
1617
|
-
|
|
1618
|
-
|
|
1619
|
-
|
|
1620
|
-
|
|
1621
|
-
|
|
1622
|
-
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
1623
|
-
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
1624
|
-
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
1625
|
-
|
|
1626
|
-
__m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
|
|
1627
|
-
__m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
|
|
1628
|
-
__m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
|
|
1629
|
-
|
|
1630
|
-
sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
|
|
1631
|
-
sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
|
|
1632
|
-
sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
|
|
1209
|
+
for (; index + 5 <= n; index += 5) {
|
|
1210
|
+
__m256i a_f16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(a + index * 3));
|
|
1211
|
+
__m256i b_f16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(b + index * 3));
|
|
1212
|
+
__m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
|
|
1213
|
+
__m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
|
|
1214
|
+
__m512 delta_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
1215
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_f32x16, delta_f32x16, sum_squared_f32x16);
|
|
1633
1216
|
}
|
|
1634
1217
|
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
|
|
1639
|
-
|
|
1640
|
-
|
|
1641
|
-
|
|
1642
|
-
|
|
1643
|
-
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
1644
|
-
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
1645
|
-
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
1646
|
-
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
1647
|
-
|
|
1648
|
-
__m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
|
|
1649
|
-
__m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
|
|
1650
|
-
__m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
|
|
1651
|
-
|
|
1652
|
-
sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
|
|
1653
|
-
sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
|
|
1654
|
-
sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
|
|
1218
|
+
if (index < n) {
|
|
1219
|
+
__mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, (nk_u32_t)((n - index) * 3));
|
|
1220
|
+
__m256i a_f16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(a + index * 3));
|
|
1221
|
+
__m256i b_f16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(b + index * 3));
|
|
1222
|
+
__m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
|
|
1223
|
+
__m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
|
|
1224
|
+
__m512 delta_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
1225
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_f32x16, delta_f32x16, sum_squared_f32x16);
|
|
1655
1226
|
}
|
|
1656
1227
|
|
|
1657
|
-
nk_f32_t
|
|
1658
|
-
|
|
1659
|
-
nk_f32_t total_az = _mm512_reduce_add_ps(sum_a_z_f32x16);
|
|
1660
|
-
nk_f32_t total_bx = _mm512_reduce_add_ps(sum_b_x_f32x16);
|
|
1661
|
-
nk_f32_t total_by = _mm512_reduce_add_ps(sum_b_y_f32x16);
|
|
1662
|
-
nk_f32_t total_bz = _mm512_reduce_add_ps(sum_b_z_f32x16);
|
|
1663
|
-
nk_f32_t total_sq_x = _mm512_reduce_add_ps(sum_squared_x_f32x16);
|
|
1664
|
-
nk_f32_t total_sq_y = _mm512_reduce_add_ps(sum_squared_y_f32x16);
|
|
1665
|
-
nk_f32_t total_sq_z = _mm512_reduce_add_ps(sum_squared_z_f32x16);
|
|
1666
|
-
|
|
1667
|
-
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1668
|
-
nk_f32_t centroid_a_x = total_ax * inv_n;
|
|
1669
|
-
nk_f32_t centroid_a_y = total_ay * inv_n;
|
|
1670
|
-
nk_f32_t centroid_a_z = total_az * inv_n;
|
|
1671
|
-
nk_f32_t centroid_b_x = total_bx * inv_n;
|
|
1672
|
-
nk_f32_t centroid_b_y = total_by * inv_n;
|
|
1673
|
-
nk_f32_t centroid_b_z = total_bz * inv_n;
|
|
1674
|
-
|
|
1675
|
-
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1676
|
-
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1677
|
-
|
|
1678
|
-
nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
1679
|
-
nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
1680
|
-
nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
1681
|
-
nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
|
|
1682
|
-
nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
1683
|
-
|
|
1684
|
-
*result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
1228
|
+
nk_f32_t sum_squared = _mm512_reduce_add_ps(sum_squared_f32x16);
|
|
1229
|
+
*result = nk_f32_sqrt_haswell(sum_squared / (nk_f32_t)n);
|
|
1685
1230
|
}
|
|
1686
1231
|
|
|
1687
1232
|
NK_PUBLIC void nk_rmsd_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
@@ -1690,642 +1235,541 @@ NK_PUBLIC void nk_rmsd_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_s
|
|
|
1690
1235
|
rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
|
|
1691
1236
|
rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
|
|
1692
1237
|
if (scale) *scale = 1.0f;
|
|
1238
|
+
if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
|
|
1239
|
+
if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
|
|
1693
1240
|
|
|
1694
|
-
|
|
1695
|
-
__m512
|
|
1696
|
-
|
|
1697
|
-
__m512 sum_squared_x_f32x16 = zeros_f32x16, sum_squared_y_f32x16 = zeros_f32x16;
|
|
1698
|
-
__m512 sum_squared_z_f32x16 = zeros_f32x16;
|
|
1699
|
-
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
1700
|
-
nk_size_t i = 0;
|
|
1241
|
+
// 15-lane stride-3 layout: mask at the bf16 level so the last chunk stays in-bounds.
|
|
1242
|
+
__m512 sum_squared_f32x16 = _mm512_setzero_ps();
|
|
1243
|
+
nk_size_t index = 0;
|
|
1701
1244
|
|
|
1702
|
-
for (;
|
|
1703
|
-
|
|
1704
|
-
|
|
1705
|
-
|
|
1706
|
-
|
|
1707
|
-
|
|
1708
|
-
|
|
1709
|
-
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
1710
|
-
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
1711
|
-
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
1712
|
-
|
|
1713
|
-
__m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
|
|
1714
|
-
__m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
|
|
1715
|
-
__m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
|
|
1716
|
-
|
|
1717
|
-
sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
|
|
1718
|
-
sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
|
|
1719
|
-
sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
|
|
1245
|
+
for (; index + 5 <= n; index += 5) {
|
|
1246
|
+
__m256i a_bf16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(a + index * 3));
|
|
1247
|
+
__m256i b_bf16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(b + index * 3));
|
|
1248
|
+
__m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
|
|
1249
|
+
__m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
|
|
1250
|
+
__m512 delta_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
1251
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_f32x16, delta_f32x16, sum_squared_f32x16);
|
|
1720
1252
|
}
|
|
1721
1253
|
|
|
1722
|
-
|
|
1723
|
-
|
|
1724
|
-
|
|
1725
|
-
|
|
1726
|
-
|
|
1727
|
-
|
|
1728
|
-
|
|
1729
|
-
|
|
1730
|
-
sum_a_z_f32x16 = _mm512_add_ps(sum_a_z_f32x16, a_z_f32x16);
|
|
1731
|
-
sum_b_x_f32x16 = _mm512_add_ps(sum_b_x_f32x16, b_x_f32x16);
|
|
1732
|
-
sum_b_y_f32x16 = _mm512_add_ps(sum_b_y_f32x16, b_y_f32x16);
|
|
1733
|
-
sum_b_z_f32x16 = _mm512_add_ps(sum_b_z_f32x16, b_z_f32x16);
|
|
1734
|
-
|
|
1735
|
-
__m512 delta_x_f32x16 = _mm512_sub_ps(a_x_f32x16, b_x_f32x16);
|
|
1736
|
-
__m512 delta_y_f32x16 = _mm512_sub_ps(a_y_f32x16, b_y_f32x16);
|
|
1737
|
-
__m512 delta_z_f32x16 = _mm512_sub_ps(a_z_f32x16, b_z_f32x16);
|
|
1738
|
-
|
|
1739
|
-
sum_squared_x_f32x16 = _mm512_fmadd_ps(delta_x_f32x16, delta_x_f32x16, sum_squared_x_f32x16);
|
|
1740
|
-
sum_squared_y_f32x16 = _mm512_fmadd_ps(delta_y_f32x16, delta_y_f32x16, sum_squared_y_f32x16);
|
|
1741
|
-
sum_squared_z_f32x16 = _mm512_fmadd_ps(delta_z_f32x16, delta_z_f32x16, sum_squared_z_f32x16);
|
|
1254
|
+
if (index < n) {
|
|
1255
|
+
__mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, (nk_u32_t)((n - index) * 3));
|
|
1256
|
+
__m256i a_bf16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(a + index * 3));
|
|
1257
|
+
__m256i b_bf16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(b + index * 3));
|
|
1258
|
+
__m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
|
|
1259
|
+
__m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
|
|
1260
|
+
__m512 delta_f32x16 = _mm512_sub_ps(a_f32x16, b_f32x16);
|
|
1261
|
+
sum_squared_f32x16 = _mm512_fmadd_ps(delta_f32x16, delta_f32x16, sum_squared_f32x16);
|
|
1742
1262
|
}
|
|
1743
1263
|
|
|
1744
|
-
nk_f32_t
|
|
1745
|
-
|
|
1746
|
-
nk_f32_t total_az = _mm512_reduce_add_ps(sum_a_z_f32x16);
|
|
1747
|
-
nk_f32_t total_bx = _mm512_reduce_add_ps(sum_b_x_f32x16);
|
|
1748
|
-
nk_f32_t total_by = _mm512_reduce_add_ps(sum_b_y_f32x16);
|
|
1749
|
-
nk_f32_t total_bz = _mm512_reduce_add_ps(sum_b_z_f32x16);
|
|
1750
|
-
nk_f32_t total_sq_x = _mm512_reduce_add_ps(sum_squared_x_f32x16);
|
|
1751
|
-
nk_f32_t total_sq_y = _mm512_reduce_add_ps(sum_squared_y_f32x16);
|
|
1752
|
-
nk_f32_t total_sq_z = _mm512_reduce_add_ps(sum_squared_z_f32x16);
|
|
1753
|
-
|
|
1754
|
-
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1755
|
-
nk_f32_t centroid_a_x = total_ax * inv_n;
|
|
1756
|
-
nk_f32_t centroid_a_y = total_ay * inv_n;
|
|
1757
|
-
nk_f32_t centroid_a_z = total_az * inv_n;
|
|
1758
|
-
nk_f32_t centroid_b_x = total_bx * inv_n;
|
|
1759
|
-
nk_f32_t centroid_b_y = total_by * inv_n;
|
|
1760
|
-
nk_f32_t centroid_b_z = total_bz * inv_n;
|
|
1761
|
-
|
|
1762
|
-
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1763
|
-
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1764
|
-
|
|
1765
|
-
nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
|
|
1766
|
-
nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
|
|
1767
|
-
nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
|
|
1768
|
-
nk_f32_t sum_squared = total_sq_x + total_sq_y + total_sq_z;
|
|
1769
|
-
nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
|
|
1770
|
-
|
|
1771
|
-
*result = nk_f32_sqrt_haswell(sum_squared * inv_n - mean_diff_sq);
|
|
1264
|
+
nk_f32_t sum_squared = _mm512_reduce_add_ps(sum_squared_f32x16);
|
|
1265
|
+
*result = nk_f32_sqrt_haswell(sum_squared / (nk_f32_t)n);
|
|
1772
1266
|
}
|
|
1773
1267
|
|
|
1774
1268
|
NK_PUBLIC void nk_kabsch_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1775
1269
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1776
|
-
|
|
1777
|
-
|
|
1778
|
-
|
|
1779
|
-
|
|
1780
|
-
|
|
1781
|
-
__m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
|
|
1782
|
-
__m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
|
|
1270
|
+
// 15-lane stride-3 layout: one masked epi16 load + widen gives {a_f32x16, b_f32x16} with
|
|
1271
|
+
// channel phase [x,y,z, x,y,z, x,y,z, x,y,z, x,y,z, _] constant across all chunks. The 9
|
|
1272
|
+
// H-cells come from three product accumulators a*b, a*rot1(b), a*rot2(b) demuxed per channel.
|
|
1273
|
+
__m512i const idx_rotation_1_i32x16 = _mm512_setr_epi32(1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9, 13, 14, 12, 15);
|
|
1274
|
+
__m512i const idx_rotation_2_i32x16 = _mm512_setr_epi32(2, 0, 1, 5, 3, 4, 8, 6, 7, 11, 9, 10, 14, 12, 13, 15);
|
|
1783
1275
|
|
|
1784
|
-
|
|
1785
|
-
__m512
|
|
1276
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
1277
|
+
__m512 sum_a_f32x16 = zeros_f32x16, sum_b_f32x16 = zeros_f32x16;
|
|
1278
|
+
__m512 norm_squared_a_f32x16 = zeros_f32x16, norm_squared_b_f32x16 = zeros_f32x16;
|
|
1279
|
+
__m512 product_diagonal_f32x16 = zeros_f32x16;
|
|
1280
|
+
__m512 product_rotation_1_f32x16 = zeros_f32x16;
|
|
1281
|
+
__m512 product_rotation_2_f32x16 = zeros_f32x16;
|
|
1786
1282
|
|
|
1787
|
-
|
|
1788
|
-
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
|
|
1793
|
-
|
|
1794
|
-
|
|
1795
|
-
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
|
|
1799
|
-
|
|
1800
|
-
|
|
1801
|
-
|
|
1802
|
-
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
1803
|
-
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
1804
|
-
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
1805
|
-
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
1806
|
-
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
1283
|
+
nk_size_t index = 0;
|
|
1284
|
+
for (; index + 5 <= n; index += 5) {
|
|
1285
|
+
__m256i a_f16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(a + index * 3));
|
|
1286
|
+
__m256i b_f16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(b + index * 3));
|
|
1287
|
+
__m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
|
|
1288
|
+
__m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
|
|
1289
|
+
__m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
|
|
1290
|
+
__m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
|
|
1291
|
+
sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
|
|
1292
|
+
sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
|
|
1293
|
+
norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
|
|
1294
|
+
norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
|
|
1295
|
+
product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
|
|
1296
|
+
product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
|
|
1297
|
+
product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
|
|
1807
1298
|
}
|
|
1808
1299
|
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
|
|
1812
|
-
|
|
1813
|
-
|
|
1814
|
-
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
|
|
1820
|
-
|
|
1821
|
-
|
|
1822
|
-
|
|
1823
|
-
|
|
1824
|
-
cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
|
|
1825
|
-
cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
|
|
1826
|
-
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
1827
|
-
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
1828
|
-
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
1829
|
-
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
1830
|
-
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
1300
|
+
if (index < n) {
|
|
1301
|
+
__mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, (nk_u32_t)((n - index) * 3));
|
|
1302
|
+
__m256i a_f16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(a + index * 3));
|
|
1303
|
+
__m256i b_f16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(b + index * 3));
|
|
1304
|
+
__m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
|
|
1305
|
+
__m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
|
|
1306
|
+
__m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
|
|
1307
|
+
__m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
|
|
1308
|
+
sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
|
|
1309
|
+
sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
|
|
1310
|
+
norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
|
|
1311
|
+
norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
|
|
1312
|
+
product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
|
|
1313
|
+
product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
|
|
1314
|
+
product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
|
|
1831
1315
|
}
|
|
1832
1316
|
|
|
1833
|
-
|
|
1834
|
-
|
|
1835
|
-
|
|
1836
|
-
|
|
1837
|
-
|
|
1838
|
-
nk_f32_t
|
|
1839
|
-
nk_f32_t
|
|
1840
|
-
nk_f32_t
|
|
1841
|
-
nk_f32_t
|
|
1842
|
-
nk_f32_t
|
|
1843
|
-
nk_f32_t
|
|
1844
|
-
nk_f32_t
|
|
1845
|
-
nk_f32_t
|
|
1846
|
-
|
|
1847
|
-
nk_f32_t
|
|
1317
|
+
// Per-channel demux via mask-reduce on the fp32 accumulators (lane i carries channel i%3).
|
|
1318
|
+
__mmask16 const mask_channel_x_f32 = 0x1249; // lanes {0, 3, 6, 9, 12}
|
|
1319
|
+
__mmask16 const mask_channel_y_f32 = 0x2492; // lanes {1, 4, 7, 10, 13}
|
|
1320
|
+
__mmask16 const mask_channel_z_f32 = 0x4924; // lanes {2, 5, 8, 11, 14}
|
|
1321
|
+
|
|
1322
|
+
nk_f32_t sum_a_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_a_f32x16);
|
|
1323
|
+
nk_f32_t sum_a_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_a_f32x16);
|
|
1324
|
+
nk_f32_t sum_a_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_a_f32x16);
|
|
1325
|
+
nk_f32_t sum_b_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_b_f32x16);
|
|
1326
|
+
nk_f32_t sum_b_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_b_f32x16);
|
|
1327
|
+
nk_f32_t sum_b_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_b_f32x16);
|
|
1328
|
+
nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
|
|
1329
|
+
nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
|
|
1330
|
+
|
|
1331
|
+
nk_f32_t covariance_x_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_diagonal_f32x16);
|
|
1332
|
+
nk_f32_t covariance_x_y = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_1_f32x16);
|
|
1333
|
+
nk_f32_t covariance_x_z = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_2_f32x16);
|
|
1334
|
+
nk_f32_t covariance_y_x = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_2_f32x16);
|
|
1335
|
+
nk_f32_t covariance_y_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_diagonal_f32x16);
|
|
1336
|
+
nk_f32_t covariance_y_z = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_1_f32x16);
|
|
1337
|
+
nk_f32_t covariance_z_x = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_1_f32x16);
|
|
1338
|
+
nk_f32_t covariance_z_y = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_2_f32x16);
|
|
1339
|
+
nk_f32_t covariance_z_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_diagonal_f32x16);
|
|
1848
1340
|
|
|
1849
1341
|
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1850
1342
|
nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
1851
1343
|
nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
1852
|
-
|
|
1853
1344
|
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1854
1345
|
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1855
1346
|
|
|
1856
|
-
|
|
1857
|
-
|
|
1858
|
-
|
|
1859
|
-
|
|
1860
|
-
|
|
1861
|
-
|
|
1862
|
-
|
|
1863
|
-
|
|
1864
|
-
|
|
1865
|
-
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
nk_f32_t
|
|
1869
|
-
nk_svd3x3_f32_(cross_covariance,
|
|
1870
|
-
|
|
1871
|
-
|
|
1872
|
-
|
|
1873
|
-
|
|
1874
|
-
|
|
1875
|
-
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
1876
|
-
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
1877
|
-
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
1878
|
-
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
1879
|
-
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
1880
|
-
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1881
|
-
|
|
1882
|
-
if (nk_det3x3_f32_(r) < 0) {
|
|
1883
|
-
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
1884
|
-
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
1885
|
-
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
1886
|
-
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
1887
|
-
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
1888
|
-
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
1889
|
-
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
1890
|
-
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
1891
|
-
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
1892
|
-
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1347
|
+
// Parallel-axis correction.
|
|
1348
|
+
nk_f32_t cross_covariance[9];
|
|
1349
|
+
cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
1350
|
+
cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
1351
|
+
cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
1352
|
+
cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
1353
|
+
cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
1354
|
+
cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
1355
|
+
cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
1356
|
+
cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
1357
|
+
cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
1358
|
+
|
|
1359
|
+
nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
1360
|
+
nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
1361
|
+
nk_f32_t optimal_rotation[9];
|
|
1362
|
+
nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
|
|
1363
|
+
if (nk_det3x3_f32_(optimal_rotation) < 0) {
|
|
1364
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
1365
|
+
nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
|
|
1893
1366
|
}
|
|
1894
|
-
|
|
1895
1367
|
if (rotation)
|
|
1896
|
-
for (int j = 0; j < 9; ++j) rotation[j] =
|
|
1368
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
1897
1369
|
if (scale) *scale = 1.0f;
|
|
1898
1370
|
|
|
1899
|
-
|
|
1900
|
-
|
|
1371
|
+
// Folded SSD via trace identity:
|
|
1372
|
+
// SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered)
|
|
1373
|
+
// trace(R · H_centered) = Σⱼₖ R[j,k] · H[k,j] (note transpose on H).
|
|
1374
|
+
nk_f32_t centered_norm_squared_a = norm_squared_a -
|
|
1375
|
+
(nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
1376
|
+
centroid_a_z * centroid_a_z);
|
|
1377
|
+
nk_f32_t centered_norm_squared_b = norm_squared_b -
|
|
1378
|
+
(nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
1379
|
+
centroid_b_z * centroid_b_z);
|
|
1380
|
+
if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
|
|
1381
|
+
if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
|
|
1382
|
+
nk_f32_t trace_rotation_covariance =
|
|
1383
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
1384
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
1385
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
1386
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
1387
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
1388
|
+
nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
|
|
1389
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
1901
1390
|
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
1902
1391
|
}
|
|
1903
1392
|
|
|
1904
1393
|
NK_PUBLIC void nk_kabsch_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
1905
1394
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
__m512 cov_yx_f32x16 = zeros_f32x16, cov_yy_f32x16 = zeros_f32x16, cov_yz_f32x16 = zeros_f32x16;
|
|
1912
|
-
__m512 cov_zx_f32x16 = zeros_f32x16, cov_zy_f32x16 = zeros_f32x16, cov_zz_f32x16 = zeros_f32x16;
|
|
1395
|
+
// 15-lane stride-3 layout: one masked epi16 load + widen gives {a_f32x16, b_f32x16} with
|
|
1396
|
+
// channel phase [x,y,z, x,y,z, x,y,z, x,y,z, x,y,z, _] constant across all chunks. The 9
|
|
1397
|
+
// H-cells come from three product accumulators a*b, a*rot1(b), a*rot2(b) demuxed per channel.
|
|
1398
|
+
__m512i const idx_rotation_1_i32x16 = _mm512_setr_epi32(1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9, 13, 14, 12, 15);
|
|
1399
|
+
__m512i const idx_rotation_2_i32x16 = _mm512_setr_epi32(2, 0, 1, 5, 3, 4, 8, 6, 7, 11, 9, 10, 14, 12, 13, 15);
|
|
1913
1400
|
|
|
1914
|
-
|
|
1915
|
-
__m512
|
|
1401
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
1402
|
+
__m512 sum_a_f32x16 = zeros_f32x16, sum_b_f32x16 = zeros_f32x16;
|
|
1403
|
+
__m512 norm_squared_a_f32x16 = zeros_f32x16, norm_squared_b_f32x16 = zeros_f32x16;
|
|
1404
|
+
__m512 product_diagonal_f32x16 = zeros_f32x16;
|
|
1405
|
+
__m512 product_rotation_1_f32x16 = zeros_f32x16;
|
|
1406
|
+
__m512 product_rotation_2_f32x16 = zeros_f32x16;
|
|
1916
1407
|
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1927
|
-
|
|
1928
|
-
|
|
1929
|
-
|
|
1930
|
-
|
|
1931
|
-
|
|
1932
|
-
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
1933
|
-
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
1934
|
-
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
1935
|
-
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
1936
|
-
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
1408
|
+
nk_size_t index = 0;
|
|
1409
|
+
for (; index + 5 <= n; index += 5) {
|
|
1410
|
+
__m256i a_bf16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(a + index * 3));
|
|
1411
|
+
__m256i b_bf16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(b + index * 3));
|
|
1412
|
+
__m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
|
|
1413
|
+
__m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
|
|
1414
|
+
__m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
|
|
1415
|
+
__m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
|
|
1416
|
+
sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
|
|
1417
|
+
sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
|
|
1418
|
+
norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
|
|
1419
|
+
norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
|
|
1420
|
+
product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
|
|
1421
|
+
product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
|
|
1422
|
+
product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
|
|
1937
1423
|
}
|
|
1938
1424
|
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
1948
|
-
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
|
|
1952
|
-
|
|
1953
|
-
|
|
1954
|
-
cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
|
|
1955
|
-
cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
|
|
1956
|
-
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
1957
|
-
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
1958
|
-
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
1959
|
-
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
1960
|
-
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
1425
|
+
if (index < n) {
|
|
1426
|
+
__mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, (nk_u32_t)((n - index) * 3));
|
|
1427
|
+
__m256i a_bf16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(a + index * 3));
|
|
1428
|
+
__m256i b_bf16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(b + index * 3));
|
|
1429
|
+
__m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
|
|
1430
|
+
__m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
|
|
1431
|
+
__m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
|
|
1432
|
+
__m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
|
|
1433
|
+
sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
|
|
1434
|
+
sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
|
|
1435
|
+
norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
|
|
1436
|
+
norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
|
|
1437
|
+
product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
|
|
1438
|
+
product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
|
|
1439
|
+
product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
|
|
1961
1440
|
}
|
|
1962
1441
|
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
|
|
1967
|
-
|
|
1968
|
-
nk_f32_t
|
|
1969
|
-
nk_f32_t
|
|
1970
|
-
nk_f32_t
|
|
1971
|
-
nk_f32_t
|
|
1972
|
-
nk_f32_t
|
|
1973
|
-
nk_f32_t
|
|
1974
|
-
nk_f32_t
|
|
1975
|
-
nk_f32_t
|
|
1976
|
-
|
|
1977
|
-
nk_f32_t
|
|
1442
|
+
// Per-channel demux via mask-reduce on the fp32 accumulators (lane i carries channel i%3).
|
|
1443
|
+
__mmask16 const mask_channel_x_f32 = 0x1249; // lanes {0, 3, 6, 9, 12}
|
|
1444
|
+
__mmask16 const mask_channel_y_f32 = 0x2492; // lanes {1, 4, 7, 10, 13}
|
|
1445
|
+
__mmask16 const mask_channel_z_f32 = 0x4924; // lanes {2, 5, 8, 11, 14}
|
|
1446
|
+
|
|
1447
|
+
nk_f32_t sum_a_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_a_f32x16);
|
|
1448
|
+
nk_f32_t sum_a_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_a_f32x16);
|
|
1449
|
+
nk_f32_t sum_a_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_a_f32x16);
|
|
1450
|
+
nk_f32_t sum_b_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_b_f32x16);
|
|
1451
|
+
nk_f32_t sum_b_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_b_f32x16);
|
|
1452
|
+
nk_f32_t sum_b_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_b_f32x16);
|
|
1453
|
+
nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
|
|
1454
|
+
nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
|
|
1455
|
+
|
|
1456
|
+
nk_f32_t covariance_x_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_diagonal_f32x16);
|
|
1457
|
+
nk_f32_t covariance_x_y = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_1_f32x16);
|
|
1458
|
+
nk_f32_t covariance_x_z = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_2_f32x16);
|
|
1459
|
+
nk_f32_t covariance_y_x = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_2_f32x16);
|
|
1460
|
+
nk_f32_t covariance_y_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_diagonal_f32x16);
|
|
1461
|
+
nk_f32_t covariance_y_z = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_1_f32x16);
|
|
1462
|
+
nk_f32_t covariance_z_x = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_1_f32x16);
|
|
1463
|
+
nk_f32_t covariance_z_y = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_2_f32x16);
|
|
1464
|
+
nk_f32_t covariance_z_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_diagonal_f32x16);
|
|
1978
1465
|
|
|
1979
1466
|
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
1980
1467
|
nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
1981
1468
|
nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
1982
|
-
|
|
1983
1469
|
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
1984
1470
|
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
1985
1471
|
|
|
1986
|
-
|
|
1987
|
-
|
|
1988
|
-
|
|
1989
|
-
|
|
1990
|
-
|
|
1991
|
-
|
|
1992
|
-
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
|
|
1996
|
-
|
|
1997
|
-
|
|
1998
|
-
nk_f32_t
|
|
1999
|
-
nk_svd3x3_f32_(cross_covariance,
|
|
2000
|
-
|
|
2001
|
-
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
|
|
2005
|
-
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2006
|
-
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2007
|
-
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2008
|
-
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2009
|
-
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2010
|
-
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
2011
|
-
|
|
2012
|
-
if (nk_det3x3_f32_(r) < 0) {
|
|
2013
|
-
svd_v[2] = -svd_v[2], svd_v[5] = -svd_v[5], svd_v[8] = -svd_v[8];
|
|
2014
|
-
r[0] = svd_v[0] * svd_u[0] + svd_v[1] * svd_u[1] + svd_v[2] * svd_u[2];
|
|
2015
|
-
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2016
|
-
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2017
|
-
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2018
|
-
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2019
|
-
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2020
|
-
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2021
|
-
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2022
|
-
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1472
|
+
// Parallel-axis correction.
|
|
1473
|
+
nk_f32_t cross_covariance[9];
|
|
1474
|
+
cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
1475
|
+
cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
1476
|
+
cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
1477
|
+
cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
1478
|
+
cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
1479
|
+
cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
1480
|
+
cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
1481
|
+
cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
1482
|
+
cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
1483
|
+
|
|
1484
|
+
nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
1485
|
+
nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
1486
|
+
nk_f32_t optimal_rotation[9];
|
|
1487
|
+
nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
|
|
1488
|
+
if (nk_det3x3_f32_(optimal_rotation) < 0) {
|
|
1489
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
1490
|
+
nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
|
|
2023
1491
|
}
|
|
2024
|
-
|
|
2025
1492
|
if (rotation)
|
|
2026
|
-
for (int j = 0; j < 9; ++j) rotation[j] =
|
|
1493
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
2027
1494
|
if (scale) *scale = 1.0f;
|
|
2028
1495
|
|
|
2029
|
-
|
|
2030
|
-
|
|
1496
|
+
// Folded SSD via trace identity:
|
|
1497
|
+
// SSD = ‖a-ā‖² + ‖b-b̄‖² − 2·trace(R · H_centered)
|
|
1498
|
+
// trace(R · H_centered) = Σⱼₖ R[j,k] · H[k,j] (note transpose on H).
|
|
1499
|
+
nk_f32_t centered_norm_squared_a = norm_squared_a -
|
|
1500
|
+
(nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
1501
|
+
centroid_a_z * centroid_a_z);
|
|
1502
|
+
nk_f32_t centered_norm_squared_b = norm_squared_b -
|
|
1503
|
+
(nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
1504
|
+
centroid_b_z * centroid_b_z);
|
|
1505
|
+
if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
|
|
1506
|
+
if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
|
|
1507
|
+
nk_f32_t trace_rotation_covariance =
|
|
1508
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
1509
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
1510
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
1511
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
1512
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
1513
|
+
nk_f32_t sum_squared = centered_norm_squared_a + centered_norm_squared_b - 2.0f * trace_rotation_covariance;
|
|
1514
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
2031
1515
|
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
2032
1516
|
}
|
|
2033
1517
|
|
|
2034
1518
|
NK_PUBLIC void nk_umeyama_f16_skylake(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
2035
1519
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
2036
|
-
|
|
1520
|
+
// Same 15-lane streaming-stats pattern as kabsch_f16_skylake; adds the Umeyama scale.
|
|
1521
|
+
__m512i const idx_rotation_1_i32x16 = _mm512_setr_epi32(1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9, 13, 14, 12, 15);
|
|
1522
|
+
__m512i const idx_rotation_2_i32x16 = _mm512_setr_epi32(2, 0, 1, 5, 3, 4, 8, 6, 7, 11, 9, 10, 14, 12, 13, 15);
|
|
2037
1523
|
|
|
2038
|
-
__m512
|
|
2039
|
-
__m512
|
|
2040
|
-
__m512
|
|
2041
|
-
__m512
|
|
2042
|
-
__m512
|
|
2043
|
-
__m512
|
|
2044
|
-
|
|
2045
|
-
nk_size_t i = 0;
|
|
2046
|
-
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
1524
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
1525
|
+
__m512 sum_a_f32x16 = zeros_f32x16, sum_b_f32x16 = zeros_f32x16;
|
|
1526
|
+
__m512 norm_squared_a_f32x16 = zeros_f32x16, norm_squared_b_f32x16 = zeros_f32x16;
|
|
1527
|
+
__m512 product_diagonal_f32x16 = zeros_f32x16;
|
|
1528
|
+
__m512 product_rotation_1_f32x16 = zeros_f32x16;
|
|
1529
|
+
__m512 product_rotation_2_f32x16 = zeros_f32x16;
|
|
2047
1530
|
|
|
2048
|
-
|
|
2049
|
-
|
|
2050
|
-
|
|
2051
|
-
|
|
2052
|
-
|
|
2053
|
-
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
|
|
2061
|
-
|
|
2062
|
-
|
|
2063
|
-
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
2064
|
-
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
2065
|
-
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
2066
|
-
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
2067
|
-
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
2068
|
-
|
|
2069
|
-
variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
|
|
2070
|
-
variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
|
|
2071
|
-
variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
|
|
1531
|
+
nk_size_t index = 0;
|
|
1532
|
+
for (; index + 5 <= n; index += 5) {
|
|
1533
|
+
__m256i a_f16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(a + index * 3));
|
|
1534
|
+
__m256i b_f16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(b + index * 3));
|
|
1535
|
+
__m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
|
|
1536
|
+
__m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
|
|
1537
|
+
__m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
|
|
1538
|
+
__m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
|
|
1539
|
+
sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
|
|
1540
|
+
sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
|
|
1541
|
+
norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
|
|
1542
|
+
norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
|
|
1543
|
+
product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
|
|
1544
|
+
product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
|
|
1545
|
+
product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
|
|
2072
1546
|
}
|
|
2073
1547
|
|
|
2074
|
-
|
|
2075
|
-
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2080
|
-
|
|
2081
|
-
|
|
2082
|
-
|
|
2083
|
-
|
|
2084
|
-
|
|
2085
|
-
|
|
2086
|
-
|
|
2087
|
-
|
|
2088
|
-
|
|
2089
|
-
cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
|
|
2090
|
-
cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
|
|
2091
|
-
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
2092
|
-
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
2093
|
-
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
2094
|
-
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
2095
|
-
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
2096
|
-
|
|
2097
|
-
variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
|
|
2098
|
-
variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
|
|
2099
|
-
variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
|
|
1548
|
+
if (index < n) {
|
|
1549
|
+
__mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, (nk_u32_t)((n - index) * 3));
|
|
1550
|
+
__m256i a_f16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(a + index * 3));
|
|
1551
|
+
__m256i b_f16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(b + index * 3));
|
|
1552
|
+
__m512 a_f32x16 = _mm512_cvtph_ps(a_f16x16);
|
|
1553
|
+
__m512 b_f32x16 = _mm512_cvtph_ps(b_f16x16);
|
|
1554
|
+
__m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
|
|
1555
|
+
__m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
|
|
1556
|
+
sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
|
|
1557
|
+
sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
|
|
1558
|
+
norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
|
|
1559
|
+
norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
|
|
1560
|
+
product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
|
|
1561
|
+
product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
|
|
1562
|
+
product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
|
|
2100
1563
|
}
|
|
2101
1564
|
|
|
2102
|
-
|
|
2103
|
-
|
|
2104
|
-
|
|
2105
|
-
|
|
2106
|
-
nk_f32_t
|
|
2107
|
-
nk_f32_t
|
|
2108
|
-
nk_f32_t
|
|
2109
|
-
nk_f32_t
|
|
2110
|
-
nk_f32_t
|
|
2111
|
-
nk_f32_t
|
|
2112
|
-
nk_f32_t
|
|
2113
|
-
nk_f32_t
|
|
2114
|
-
|
|
2115
|
-
nk_f32_t
|
|
2116
|
-
nk_f32_t
|
|
2117
|
-
nk_f32_t
|
|
1565
|
+
__mmask16 const mask_channel_x_f32 = 0x1249;
|
|
1566
|
+
__mmask16 const mask_channel_y_f32 = 0x2492;
|
|
1567
|
+
__mmask16 const mask_channel_z_f32 = 0x4924;
|
|
1568
|
+
|
|
1569
|
+
nk_f32_t sum_a_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_a_f32x16);
|
|
1570
|
+
nk_f32_t sum_a_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_a_f32x16);
|
|
1571
|
+
nk_f32_t sum_a_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_a_f32x16);
|
|
1572
|
+
nk_f32_t sum_b_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_b_f32x16);
|
|
1573
|
+
nk_f32_t sum_b_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_b_f32x16);
|
|
1574
|
+
nk_f32_t sum_b_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_b_f32x16);
|
|
1575
|
+
nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
|
|
1576
|
+
nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
|
|
1577
|
+
|
|
1578
|
+
nk_f32_t covariance_x_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_diagonal_f32x16);
|
|
1579
|
+
nk_f32_t covariance_x_y = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_1_f32x16);
|
|
1580
|
+
nk_f32_t covariance_x_z = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_2_f32x16);
|
|
1581
|
+
nk_f32_t covariance_y_x = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_2_f32x16);
|
|
1582
|
+
nk_f32_t covariance_y_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_diagonal_f32x16);
|
|
1583
|
+
nk_f32_t covariance_y_z = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_1_f32x16);
|
|
1584
|
+
nk_f32_t covariance_z_x = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_1_f32x16);
|
|
1585
|
+
nk_f32_t covariance_z_y = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_2_f32x16);
|
|
1586
|
+
nk_f32_t covariance_z_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_diagonal_f32x16);
|
|
2118
1587
|
|
|
2119
1588
|
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
2120
1589
|
nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
2121
1590
|
nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
2122
|
-
|
|
2123
1591
|
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
2124
1592
|
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
2125
1593
|
|
|
2126
|
-
nk_f32_t
|
|
2127
|
-
|
|
2128
|
-
|
|
2129
|
-
|
|
2130
|
-
|
|
2131
|
-
|
|
2132
|
-
|
|
2133
|
-
|
|
2134
|
-
|
|
2135
|
-
|
|
2136
|
-
|
|
2137
|
-
|
|
2138
|
-
|
|
2139
|
-
|
|
2140
|
-
|
|
2141
|
-
|
|
2142
|
-
|
|
2143
|
-
|
|
2144
|
-
|
|
2145
|
-
|
|
2146
|
-
|
|
2147
|
-
|
|
2148
|
-
|
|
2149
|
-
|
|
2150
|
-
|
|
2151
|
-
|
|
2152
|
-
|
|
2153
|
-
|
|
2154
|
-
|
|
2155
|
-
|
|
2156
|
-
nk_f32_t det = nk_det3x3_f32_(r);
|
|
2157
|
-
nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
|
|
2158
|
-
nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
|
|
2159
|
-
nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
|
|
1594
|
+
nk_f32_t centered_norm_squared_a = norm_squared_a -
|
|
1595
|
+
(nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
1596
|
+
centroid_a_z * centroid_a_z);
|
|
1597
|
+
nk_f32_t centered_norm_squared_b = norm_squared_b -
|
|
1598
|
+
(nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
1599
|
+
centroid_b_z * centroid_b_z);
|
|
1600
|
+
if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
|
|
1601
|
+
if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
|
|
1602
|
+
|
|
1603
|
+
nk_f32_t cross_covariance[9];
|
|
1604
|
+
cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
1605
|
+
cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
1606
|
+
cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
1607
|
+
cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
1608
|
+
cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
1609
|
+
cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
1610
|
+
cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
1611
|
+
cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
1612
|
+
cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
1613
|
+
|
|
1614
|
+
nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
1615
|
+
nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
1616
|
+
nk_f32_t optimal_rotation[9];
|
|
1617
|
+
nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
|
|
1618
|
+
|
|
1619
|
+
// Scale factor: c = trace(D · S) / ‖a-ā‖², with reflection sign via d3.
|
|
1620
|
+
nk_f32_t det = nk_det3x3_f32_(optimal_rotation);
|
|
1621
|
+
nk_f32_t d3 = det < 0.0f ? -1.0f : 1.0f;
|
|
1622
|
+
nk_f32_t trace_ds = nk_sum_three_products_f32_(svd_diagonal[0], 1.0f, svd_diagonal[4], 1.0f, svd_diagonal[8], d3);
|
|
1623
|
+
nk_f32_t c = centered_norm_squared_a > 0.0f ? trace_ds / centered_norm_squared_a : 0.0f;
|
|
2160
1624
|
if (scale) *scale = c;
|
|
2161
1625
|
|
|
2162
|
-
if (det < 0) {
|
|
2163
|
-
|
|
2164
|
-
|
|
2165
|
-
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2166
|
-
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2167
|
-
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2168
|
-
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2169
|
-
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2170
|
-
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2171
|
-
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2172
|
-
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1626
|
+
if (det < 0.0f) {
|
|
1627
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
1628
|
+
nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
|
|
2173
1629
|
}
|
|
2174
|
-
|
|
2175
1630
|
if (rotation)
|
|
2176
|
-
for (int j = 0; j < 9; ++j) rotation[j] =
|
|
2177
|
-
|
|
2178
|
-
|
|
2179
|
-
|
|
1631
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
1632
|
+
|
|
1633
|
+
// Folded SSD with scale:
|
|
1634
|
+
// SSD = c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
|
|
1635
|
+
nk_f32_t trace_rotation_covariance =
|
|
1636
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
1637
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
1638
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
1639
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
1640
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
1641
|
+
nk_f32_t sum_squared = c * c * centered_norm_squared_a + centered_norm_squared_b -
|
|
1642
|
+
2.0f * c * trace_rotation_covariance;
|
|
1643
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
2180
1644
|
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
2181
1645
|
}
|
|
2182
1646
|
|
|
2183
1647
|
NK_PUBLIC void nk_umeyama_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
2184
1648
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
2185
|
-
|
|
1649
|
+
// Same 15-lane streaming-stats pattern as kabsch_bf16_skylake; adds the Umeyama scale.
|
|
1650
|
+
__m512i const idx_rotation_1_i32x16 = _mm512_setr_epi32(1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9, 13, 14, 12, 15);
|
|
1651
|
+
__m512i const idx_rotation_2_i32x16 = _mm512_setr_epi32(2, 0, 1, 5, 3, 4, 8, 6, 7, 11, 9, 10, 14, 12, 13, 15);
|
|
2186
1652
|
|
|
2187
|
-
__m512
|
|
2188
|
-
__m512
|
|
2189
|
-
__m512
|
|
2190
|
-
__m512
|
|
2191
|
-
__m512
|
|
2192
|
-
__m512
|
|
2193
|
-
|
|
2194
|
-
nk_size_t i = 0;
|
|
2195
|
-
__m512 a_x_f32x16, a_y_f32x16, a_z_f32x16, b_x_f32x16, b_y_f32x16, b_z_f32x16;
|
|
1653
|
+
__m512 const zeros_f32x16 = _mm512_setzero_ps();
|
|
1654
|
+
__m512 sum_a_f32x16 = zeros_f32x16, sum_b_f32x16 = zeros_f32x16;
|
|
1655
|
+
__m512 norm_squared_a_f32x16 = zeros_f32x16, norm_squared_b_f32x16 = zeros_f32x16;
|
|
1656
|
+
__m512 product_diagonal_f32x16 = zeros_f32x16;
|
|
1657
|
+
__m512 product_rotation_1_f32x16 = zeros_f32x16;
|
|
1658
|
+
__m512 product_rotation_2_f32x16 = zeros_f32x16;
|
|
2196
1659
|
|
|
2197
|
-
|
|
2198
|
-
|
|
2199
|
-
|
|
2200
|
-
|
|
2201
|
-
|
|
2202
|
-
|
|
2203
|
-
|
|
2204
|
-
|
|
2205
|
-
|
|
2206
|
-
|
|
2207
|
-
|
|
2208
|
-
|
|
2209
|
-
|
|
2210
|
-
|
|
2211
|
-
|
|
2212
|
-
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
2213
|
-
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
2214
|
-
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
2215
|
-
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
2216
|
-
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
2217
|
-
|
|
2218
|
-
variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
|
|
2219
|
-
variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
|
|
2220
|
-
variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
|
|
1660
|
+
nk_size_t index = 0;
|
|
1661
|
+
for (; index + 5 <= n; index += 5) {
|
|
1662
|
+
__m256i a_bf16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(a + index * 3));
|
|
1663
|
+
__m256i b_bf16x16 = _mm256_maskz_loadu_epi16(0x7FFF, (__m256i const *)(b + index * 3));
|
|
1664
|
+
__m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
|
|
1665
|
+
__m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
|
|
1666
|
+
__m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
|
|
1667
|
+
__m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
|
|
1668
|
+
sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
|
|
1669
|
+
sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
|
|
1670
|
+
norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
|
|
1671
|
+
norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
|
|
1672
|
+
product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
|
|
1673
|
+
product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
|
|
1674
|
+
product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
|
|
2221
1675
|
}
|
|
2222
1676
|
|
|
2223
|
-
|
|
2224
|
-
|
|
2225
|
-
|
|
2226
|
-
|
|
2227
|
-
|
|
2228
|
-
|
|
2229
|
-
|
|
2230
|
-
|
|
2231
|
-
|
|
2232
|
-
|
|
2233
|
-
|
|
2234
|
-
|
|
2235
|
-
|
|
2236
|
-
|
|
2237
|
-
|
|
2238
|
-
cov_xz_f32x16 = _mm512_fmadd_ps(a_x_f32x16, b_z_f32x16, cov_xz_f32x16);
|
|
2239
|
-
cov_yx_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_x_f32x16, cov_yx_f32x16);
|
|
2240
|
-
cov_yy_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_y_f32x16, cov_yy_f32x16);
|
|
2241
|
-
cov_yz_f32x16 = _mm512_fmadd_ps(a_y_f32x16, b_z_f32x16, cov_yz_f32x16);
|
|
2242
|
-
cov_zx_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_x_f32x16, cov_zx_f32x16);
|
|
2243
|
-
cov_zy_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_y_f32x16, cov_zy_f32x16);
|
|
2244
|
-
cov_zz_f32x16 = _mm512_fmadd_ps(a_z_f32x16, b_z_f32x16, cov_zz_f32x16);
|
|
2245
|
-
|
|
2246
|
-
variance_a_f32x16 = _mm512_fmadd_ps(a_x_f32x16, a_x_f32x16, variance_a_f32x16);
|
|
2247
|
-
variance_a_f32x16 = _mm512_fmadd_ps(a_y_f32x16, a_y_f32x16, variance_a_f32x16);
|
|
2248
|
-
variance_a_f32x16 = _mm512_fmadd_ps(a_z_f32x16, a_z_f32x16, variance_a_f32x16);
|
|
1677
|
+
if (index < n) {
|
|
1678
|
+
__mmask16 tail_mask = (__mmask16)_bzhi_u32(0x7FFF, (nk_u32_t)((n - index) * 3));
|
|
1679
|
+
__m256i a_bf16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(a + index * 3));
|
|
1680
|
+
__m256i b_bf16x16 = _mm256_maskz_loadu_epi16(tail_mask, (__m256i const *)(b + index * 3));
|
|
1681
|
+
__m512 a_f32x16 = nk_bf16x16_to_f32x16_skylake_(a_bf16x16);
|
|
1682
|
+
__m512 b_f32x16 = nk_bf16x16_to_f32x16_skylake_(b_bf16x16);
|
|
1683
|
+
__m512 b_rotation_1_f32x16 = _mm512_permutexvar_ps(idx_rotation_1_i32x16, b_f32x16);
|
|
1684
|
+
__m512 b_rotation_2_f32x16 = _mm512_permutexvar_ps(idx_rotation_2_i32x16, b_f32x16);
|
|
1685
|
+
sum_a_f32x16 = _mm512_add_ps(sum_a_f32x16, a_f32x16);
|
|
1686
|
+
sum_b_f32x16 = _mm512_add_ps(sum_b_f32x16, b_f32x16);
|
|
1687
|
+
norm_squared_a_f32x16 = _mm512_fmadd_ps(a_f32x16, a_f32x16, norm_squared_a_f32x16);
|
|
1688
|
+
norm_squared_b_f32x16 = _mm512_fmadd_ps(b_f32x16, b_f32x16, norm_squared_b_f32x16);
|
|
1689
|
+
product_diagonal_f32x16 = _mm512_fmadd_ps(a_f32x16, b_f32x16, product_diagonal_f32x16);
|
|
1690
|
+
product_rotation_1_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_1_f32x16, product_rotation_1_f32x16);
|
|
1691
|
+
product_rotation_2_f32x16 = _mm512_fmadd_ps(a_f32x16, b_rotation_2_f32x16, product_rotation_2_f32x16);
|
|
2249
1692
|
}
|
|
2250
1693
|
|
|
2251
|
-
|
|
2252
|
-
|
|
2253
|
-
|
|
2254
|
-
|
|
2255
|
-
nk_f32_t
|
|
2256
|
-
nk_f32_t
|
|
2257
|
-
nk_f32_t
|
|
2258
|
-
nk_f32_t
|
|
2259
|
-
nk_f32_t
|
|
2260
|
-
nk_f32_t
|
|
2261
|
-
nk_f32_t
|
|
2262
|
-
nk_f32_t
|
|
2263
|
-
|
|
2264
|
-
nk_f32_t
|
|
2265
|
-
nk_f32_t
|
|
2266
|
-
nk_f32_t
|
|
1694
|
+
__mmask16 const mask_channel_x_f32 = 0x1249;
|
|
1695
|
+
__mmask16 const mask_channel_y_f32 = 0x2492;
|
|
1696
|
+
__mmask16 const mask_channel_z_f32 = 0x4924;
|
|
1697
|
+
|
|
1698
|
+
nk_f32_t sum_a_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_a_f32x16);
|
|
1699
|
+
nk_f32_t sum_a_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_a_f32x16);
|
|
1700
|
+
nk_f32_t sum_a_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_a_f32x16);
|
|
1701
|
+
nk_f32_t sum_b_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, sum_b_f32x16);
|
|
1702
|
+
nk_f32_t sum_b_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, sum_b_f32x16);
|
|
1703
|
+
nk_f32_t sum_b_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, sum_b_f32x16);
|
|
1704
|
+
nk_f32_t norm_squared_a = _mm512_reduce_add_ps(norm_squared_a_f32x16);
|
|
1705
|
+
nk_f32_t norm_squared_b = _mm512_reduce_add_ps(norm_squared_b_f32x16);
|
|
1706
|
+
|
|
1707
|
+
nk_f32_t covariance_x_x = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_diagonal_f32x16);
|
|
1708
|
+
nk_f32_t covariance_x_y = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_1_f32x16);
|
|
1709
|
+
nk_f32_t covariance_x_z = _mm512_mask_reduce_add_ps(mask_channel_x_f32, product_rotation_2_f32x16);
|
|
1710
|
+
nk_f32_t covariance_y_x = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_2_f32x16);
|
|
1711
|
+
nk_f32_t covariance_y_y = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_diagonal_f32x16);
|
|
1712
|
+
nk_f32_t covariance_y_z = _mm512_mask_reduce_add_ps(mask_channel_y_f32, product_rotation_1_f32x16);
|
|
1713
|
+
nk_f32_t covariance_z_x = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_1_f32x16);
|
|
1714
|
+
nk_f32_t covariance_z_y = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_rotation_2_f32x16);
|
|
1715
|
+
nk_f32_t covariance_z_z = _mm512_mask_reduce_add_ps(mask_channel_z_f32, product_diagonal_f32x16);
|
|
2267
1716
|
|
|
2268
1717
|
nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
|
|
2269
1718
|
nk_f32_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
|
|
2270
1719
|
nk_f32_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
|
|
2271
|
-
|
|
2272
1720
|
if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
|
|
2273
1721
|
if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
|
|
2274
1722
|
|
|
2275
|
-
nk_f32_t
|
|
2276
|
-
|
|
2277
|
-
|
|
2278
|
-
|
|
2279
|
-
|
|
2280
|
-
|
|
2281
|
-
|
|
2282
|
-
|
|
2283
|
-
|
|
2284
|
-
|
|
2285
|
-
|
|
2286
|
-
|
|
2287
|
-
|
|
2288
|
-
|
|
2289
|
-
|
|
2290
|
-
|
|
2291
|
-
|
|
2292
|
-
|
|
2293
|
-
|
|
2294
|
-
|
|
2295
|
-
|
|
2296
|
-
|
|
2297
|
-
|
|
2298
|
-
|
|
2299
|
-
|
|
2300
|
-
|
|
2301
|
-
|
|
2302
|
-
|
|
2303
|
-
|
|
2304
|
-
|
|
2305
|
-
nk_f32_t det = nk_det3x3_f32_(r);
|
|
2306
|
-
nk_f32_t d3 = det < 0 ? -1.0f : 1.0f;
|
|
2307
|
-
nk_f32_t trace_ds = svd_s[0] + svd_s[4] + d3 * svd_s[8];
|
|
2308
|
-
nk_f32_t c = trace_ds / ((nk_f32_t)n * variance_a);
|
|
1723
|
+
nk_f32_t centered_norm_squared_a = norm_squared_a -
|
|
1724
|
+
(nk_f32_t)n * (centroid_a_x * centroid_a_x + centroid_a_y * centroid_a_y +
|
|
1725
|
+
centroid_a_z * centroid_a_z);
|
|
1726
|
+
nk_f32_t centered_norm_squared_b = norm_squared_b -
|
|
1727
|
+
(nk_f32_t)n * (centroid_b_x * centroid_b_x + centroid_b_y * centroid_b_y +
|
|
1728
|
+
centroid_b_z * centroid_b_z);
|
|
1729
|
+
if (centered_norm_squared_a < 0.0f) centered_norm_squared_a = 0.0f;
|
|
1730
|
+
if (centered_norm_squared_b < 0.0f) centered_norm_squared_b = 0.0f;
|
|
1731
|
+
|
|
1732
|
+
nk_f32_t cross_covariance[9];
|
|
1733
|
+
cross_covariance[0] = covariance_x_x - (nk_f32_t)n * centroid_a_x * centroid_b_x;
|
|
1734
|
+
cross_covariance[1] = covariance_x_y - (nk_f32_t)n * centroid_a_x * centroid_b_y;
|
|
1735
|
+
cross_covariance[2] = covariance_x_z - (nk_f32_t)n * centroid_a_x * centroid_b_z;
|
|
1736
|
+
cross_covariance[3] = covariance_y_x - (nk_f32_t)n * centroid_a_y * centroid_b_x;
|
|
1737
|
+
cross_covariance[4] = covariance_y_y - (nk_f32_t)n * centroid_a_y * centroid_b_y;
|
|
1738
|
+
cross_covariance[5] = covariance_y_z - (nk_f32_t)n * centroid_a_y * centroid_b_z;
|
|
1739
|
+
cross_covariance[6] = covariance_z_x - (nk_f32_t)n * centroid_a_z * centroid_b_x;
|
|
1740
|
+
cross_covariance[7] = covariance_z_y - (nk_f32_t)n * centroid_a_z * centroid_b_y;
|
|
1741
|
+
cross_covariance[8] = covariance_z_z - (nk_f32_t)n * centroid_a_z * centroid_b_z;
|
|
1742
|
+
|
|
1743
|
+
nk_f32_t svd_left[9], svd_diagonal[9], svd_right[9];
|
|
1744
|
+
nk_svd3x3_f32_(cross_covariance, svd_left, svd_diagonal, svd_right);
|
|
1745
|
+
nk_f32_t optimal_rotation[9];
|
|
1746
|
+
nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
|
|
1747
|
+
|
|
1748
|
+
// Scale factor: c = trace(D · S) / ‖a-ā‖², with reflection sign via d3.
|
|
1749
|
+
nk_f32_t det = nk_det3x3_f32_(optimal_rotation);
|
|
1750
|
+
nk_f32_t d3 = det < 0.0f ? -1.0f : 1.0f;
|
|
1751
|
+
nk_f32_t trace_ds = nk_sum_three_products_f32_(svd_diagonal[0], 1.0f, svd_diagonal[4], 1.0f, svd_diagonal[8], d3);
|
|
1752
|
+
nk_f32_t c = centered_norm_squared_a > 0.0f ? trace_ds / centered_norm_squared_a : 0.0f;
|
|
2309
1753
|
if (scale) *scale = c;
|
|
2310
1754
|
|
|
2311
|
-
if (det < 0) {
|
|
2312
|
-
|
|
2313
|
-
|
|
2314
|
-
r[1] = svd_v[0] * svd_u[3] + svd_v[1] * svd_u[4] + svd_v[2] * svd_u[5];
|
|
2315
|
-
r[2] = svd_v[0] * svd_u[6] + svd_v[1] * svd_u[7] + svd_v[2] * svd_u[8];
|
|
2316
|
-
r[3] = svd_v[3] * svd_u[0] + svd_v[4] * svd_u[1] + svd_v[5] * svd_u[2];
|
|
2317
|
-
r[4] = svd_v[3] * svd_u[3] + svd_v[4] * svd_u[4] + svd_v[5] * svd_u[5];
|
|
2318
|
-
r[5] = svd_v[3] * svd_u[6] + svd_v[4] * svd_u[7] + svd_v[5] * svd_u[8];
|
|
2319
|
-
r[6] = svd_v[6] * svd_u[0] + svd_v[7] * svd_u[1] + svd_v[8] * svd_u[2];
|
|
2320
|
-
r[7] = svd_v[6] * svd_u[3] + svd_v[7] * svd_u[4] + svd_v[8] * svd_u[5];
|
|
2321
|
-
r[8] = svd_v[6] * svd_u[6] + svd_v[7] * svd_u[7] + svd_v[8] * svd_u[8];
|
|
1755
|
+
if (det < 0.0f) {
|
|
1756
|
+
svd_right[2] = -svd_right[2], svd_right[5] = -svd_right[5], svd_right[8] = -svd_right[8];
|
|
1757
|
+
nk_rotation_from_svd_f32_serial_(svd_left, svd_right, optimal_rotation);
|
|
2322
1758
|
}
|
|
2323
|
-
|
|
2324
1759
|
if (rotation)
|
|
2325
|
-
for (int j = 0; j < 9; ++j) rotation[j] =
|
|
2326
|
-
|
|
2327
|
-
|
|
2328
|
-
|
|
1760
|
+
for (int j = 0; j < 9; ++j) rotation[j] = optimal_rotation[j];
|
|
1761
|
+
|
|
1762
|
+
// Folded SSD with scale:
|
|
1763
|
+
// SSD = c²·‖a-ā‖² + ‖b-b̄‖² − 2c·trace(R · H_centered).
|
|
1764
|
+
nk_f32_t trace_rotation_covariance =
|
|
1765
|
+
optimal_rotation[0] * cross_covariance[0] + optimal_rotation[1] * cross_covariance[3] +
|
|
1766
|
+
optimal_rotation[2] * cross_covariance[6] + optimal_rotation[3] * cross_covariance[1] +
|
|
1767
|
+
optimal_rotation[4] * cross_covariance[4] + optimal_rotation[5] * cross_covariance[7] +
|
|
1768
|
+
optimal_rotation[6] * cross_covariance[2] + optimal_rotation[7] * cross_covariance[5] +
|
|
1769
|
+
optimal_rotation[8] * cross_covariance[8];
|
|
1770
|
+
nk_f32_t sum_squared = c * c * centered_norm_squared_a + centered_norm_squared_b -
|
|
1771
|
+
2.0f * c * trace_rotation_covariance;
|
|
1772
|
+
if (sum_squared < 0.0f) sum_squared = 0.0f;
|
|
2329
1773
|
*result = nk_f32_sqrt_haswell(sum_squared * inv_n);
|
|
2330
1774
|
}
|
|
2331
1775
|
|