numkong 7.4.3 → 7.4.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -353,18 +353,14 @@ NK_INTERNAL nk_f64_t nk_transformed_ssd_f64_neon_(nk_f64_t const *a, nk_f64_t co
353
353
 
354
354
  NK_PUBLIC void nk_rmsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
355
355
  nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f64_t *result) {
356
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
357
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
356
358
  if (rotation)
357
359
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
358
360
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
359
361
  if (scale) *scale = 1.0f;
360
362
 
361
363
  float64x2_t zero_f64x2 = vdupq_n_f64(0.0);
362
- float64x2_t sum_a_x_low_f64x2 = zero_f64x2, sum_a_x_high_f64x2 = zero_f64x2;
363
- float64x2_t sum_a_y_low_f64x2 = zero_f64x2, sum_a_y_high_f64x2 = zero_f64x2;
364
- float64x2_t sum_a_z_low_f64x2 = zero_f64x2, sum_a_z_high_f64x2 = zero_f64x2;
365
- float64x2_t sum_b_x_low_f64x2 = zero_f64x2, sum_b_x_high_f64x2 = zero_f64x2;
366
- float64x2_t sum_b_y_low_f64x2 = zero_f64x2, sum_b_y_high_f64x2 = zero_f64x2;
367
- float64x2_t sum_b_z_low_f64x2 = zero_f64x2, sum_b_z_high_f64x2 = zero_f64x2;
368
364
  float64x2_t sum_squared_x_low_f64x2 = zero_f64x2, sum_squared_x_high_f64x2 = zero_f64x2;
369
365
  float64x2_t sum_squared_y_low_f64x2 = zero_f64x2, sum_squared_y_high_f64x2 = zero_f64x2;
370
366
  float64x2_t sum_squared_z_low_f64x2 = zero_f64x2, sum_squared_z_high_f64x2 = zero_f64x2;
@@ -375,38 +371,15 @@ NK_PUBLIC void nk_rmsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t
375
371
  nk_deinterleave_f32x4_neon_(a + index * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4),
376
372
  nk_deinterleave_f32x4_neon_(b + index * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
377
373
 
378
- float64x2_t a_x_low_f64x2 = vcvt_f64_f32(vget_low_f32(a_x_f32x4));
379
- float64x2_t a_x_high_f64x2 = vcvt_high_f64_f32(a_x_f32x4);
380
- float64x2_t a_y_low_f64x2 = vcvt_f64_f32(vget_low_f32(a_y_f32x4));
381
- float64x2_t a_y_high_f64x2 = vcvt_high_f64_f32(a_y_f32x4);
382
- float64x2_t a_z_low_f64x2 = vcvt_f64_f32(vget_low_f32(a_z_f32x4));
383
- float64x2_t a_z_high_f64x2 = vcvt_high_f64_f32(a_z_f32x4);
384
- float64x2_t b_x_low_f64x2 = vcvt_f64_f32(vget_low_f32(b_x_f32x4));
385
- float64x2_t b_x_high_f64x2 = vcvt_high_f64_f32(b_x_f32x4);
386
- float64x2_t b_y_low_f64x2 = vcvt_f64_f32(vget_low_f32(b_y_f32x4));
387
- float64x2_t b_y_high_f64x2 = vcvt_high_f64_f32(b_y_f32x4);
388
- float64x2_t b_z_low_f64x2 = vcvt_f64_f32(vget_low_f32(b_z_f32x4));
389
- float64x2_t b_z_high_f64x2 = vcvt_high_f64_f32(b_z_f32x4);
390
-
391
- sum_a_x_low_f64x2 = vaddq_f64(sum_a_x_low_f64x2, a_x_low_f64x2),
392
- sum_a_x_high_f64x2 = vaddq_f64(sum_a_x_high_f64x2, a_x_high_f64x2);
393
- sum_a_y_low_f64x2 = vaddq_f64(sum_a_y_low_f64x2, a_y_low_f64x2),
394
- sum_a_y_high_f64x2 = vaddq_f64(sum_a_y_high_f64x2, a_y_high_f64x2);
395
- sum_a_z_low_f64x2 = vaddq_f64(sum_a_z_low_f64x2, a_z_low_f64x2),
396
- sum_a_z_high_f64x2 = vaddq_f64(sum_a_z_high_f64x2, a_z_high_f64x2);
397
- sum_b_x_low_f64x2 = vaddq_f64(sum_b_x_low_f64x2, b_x_low_f64x2),
398
- sum_b_x_high_f64x2 = vaddq_f64(sum_b_x_high_f64x2, b_x_high_f64x2);
399
- sum_b_y_low_f64x2 = vaddq_f64(sum_b_y_low_f64x2, b_y_low_f64x2),
400
- sum_b_y_high_f64x2 = vaddq_f64(sum_b_y_high_f64x2, b_y_high_f64x2);
401
- sum_b_z_low_f64x2 = vaddq_f64(sum_b_z_low_f64x2, b_z_low_f64x2),
402
- sum_b_z_high_f64x2 = vaddq_f64(sum_b_z_high_f64x2, b_z_high_f64x2);
403
-
404
- float64x2_t delta_x_low_f64x2 = vsubq_f64(a_x_low_f64x2, b_x_low_f64x2);
405
- float64x2_t delta_x_high_f64x2 = vsubq_f64(a_x_high_f64x2, b_x_high_f64x2);
406
- float64x2_t delta_y_low_f64x2 = vsubq_f64(a_y_low_f64x2, b_y_low_f64x2);
407
- float64x2_t delta_y_high_f64x2 = vsubq_f64(a_y_high_f64x2, b_y_high_f64x2);
408
- float64x2_t delta_z_low_f64x2 = vsubq_f64(a_z_low_f64x2, b_z_low_f64x2);
409
- float64x2_t delta_z_high_f64x2 = vsubq_f64(a_z_high_f64x2, b_z_high_f64x2);
374
+ float64x2_t delta_x_low_f64x2 = vsubq_f64(vcvt_f64_f32(vget_low_f32(a_x_f32x4)),
375
+ vcvt_f64_f32(vget_low_f32(b_x_f32x4)));
376
+ float64x2_t delta_x_high_f64x2 = vsubq_f64(vcvt_high_f64_f32(a_x_f32x4), vcvt_high_f64_f32(b_x_f32x4));
377
+ float64x2_t delta_y_low_f64x2 = vsubq_f64(vcvt_f64_f32(vget_low_f32(a_y_f32x4)),
378
+ vcvt_f64_f32(vget_low_f32(b_y_f32x4)));
379
+ float64x2_t delta_y_high_f64x2 = vsubq_f64(vcvt_high_f64_f32(a_y_f32x4), vcvt_high_f64_f32(b_y_f32x4));
380
+ float64x2_t delta_z_low_f64x2 = vsubq_f64(vcvt_f64_f32(vget_low_f32(a_z_f32x4)),
381
+ vcvt_f64_f32(vget_low_f32(b_z_f32x4)));
382
+ float64x2_t delta_z_high_f64x2 = vsubq_f64(vcvt_high_f64_f32(a_z_f32x4), vcvt_high_f64_f32(b_z_f32x4));
410
383
 
411
384
  sum_squared_x_low_f64x2 = vfmaq_f64(sum_squared_x_low_f64x2, delta_x_low_f64x2, delta_x_low_f64x2),
412
385
  sum_squared_x_high_f64x2 = vfmaq_f64(sum_squared_x_high_f64x2, delta_x_high_f64x2, delta_x_high_f64x2);
@@ -416,71 +389,39 @@ NK_PUBLIC void nk_rmsd_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t
416
389
  sum_squared_z_high_f64x2 = vfmaq_f64(sum_squared_z_high_f64x2, delta_z_high_f64x2, delta_z_high_f64x2);
417
390
  }
418
391
 
419
- nk_f64_t sum_a_x = vaddvq_f64(vaddq_f64(sum_a_x_low_f64x2, sum_a_x_high_f64x2));
420
- nk_f64_t sum_a_y = vaddvq_f64(vaddq_f64(sum_a_y_low_f64x2, sum_a_y_high_f64x2));
421
- nk_f64_t sum_a_z = vaddvq_f64(vaddq_f64(sum_a_z_low_f64x2, sum_a_z_high_f64x2));
422
- nk_f64_t sum_b_x = vaddvq_f64(vaddq_f64(sum_b_x_low_f64x2, sum_b_x_high_f64x2));
423
- nk_f64_t sum_b_y = vaddvq_f64(vaddq_f64(sum_b_y_low_f64x2, sum_b_y_high_f64x2));
424
- nk_f64_t sum_b_z = vaddvq_f64(vaddq_f64(sum_b_z_low_f64x2, sum_b_z_high_f64x2));
425
392
  nk_f64_t sum_squared_x = vaddvq_f64(vaddq_f64(sum_squared_x_low_f64x2, sum_squared_x_high_f64x2));
426
393
  nk_f64_t sum_squared_y = vaddvq_f64(vaddq_f64(sum_squared_y_low_f64x2, sum_squared_y_high_f64x2));
427
394
  nk_f64_t sum_squared_z = vaddvq_f64(vaddq_f64(sum_squared_z_low_f64x2, sum_squared_z_high_f64x2));
428
395
 
429
396
  for (; index < n; ++index) {
430
- nk_f64_t a_x = a[index * 3 + 0], a_y = a[index * 3 + 1], a_z = a[index * 3 + 2];
431
- nk_f64_t b_x = b[index * 3 + 0], b_y = b[index * 3 + 1], b_z = b[index * 3 + 2];
432
- sum_a_x += a_x, sum_a_y += a_y, sum_a_z += a_z;
433
- sum_b_x += b_x, sum_b_y += b_y, sum_b_z += b_z;
434
- nk_f64_t delta_x = a_x - b_x, delta_y = a_y - b_y, delta_z = a_z - b_z;
397
+ nk_f64_t delta_x = (nk_f64_t)a[index * 3 + 0] - (nk_f64_t)b[index * 3 + 0];
398
+ nk_f64_t delta_y = (nk_f64_t)a[index * 3 + 1] - (nk_f64_t)b[index * 3 + 1];
399
+ nk_f64_t delta_z = (nk_f64_t)a[index * 3 + 2] - (nk_f64_t)b[index * 3 + 2];
435
400
  sum_squared_x += delta_x * delta_x, sum_squared_y += delta_y * delta_y, sum_squared_z += delta_z * delta_z;
436
401
  }
437
402
 
438
- nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
439
- nk_f64_t centroid_a_x = sum_a_x * inv_n, centroid_a_y = sum_a_y * inv_n, centroid_a_z = sum_a_z * inv_n;
440
- nk_f64_t centroid_b_x = sum_b_x * inv_n, centroid_b_y = sum_b_y * inv_n, centroid_b_z = sum_b_z * inv_n;
441
- if (a_centroid)
442
- a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
443
- a_centroid[2] = (nk_f32_t)centroid_a_z;
444
- if (b_centroid)
445
- b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
446
- b_centroid[2] = (nk_f32_t)centroid_b_z;
447
-
448
- nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x, mean_diff_y = centroid_a_y - centroid_b_y,
449
- mean_diff_z = centroid_a_z - centroid_b_z;
450
- nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
451
- *result = nk_f64_sqrt_neon((sum_squared_x + sum_squared_y + sum_squared_z) * inv_n - mean_diff_sq);
403
+ *result = nk_f64_sqrt_neon((sum_squared_x + sum_squared_y + sum_squared_z) / (nk_f64_t)n);
452
404
  }
453
405
 
454
406
  NK_PUBLIC void nk_rmsd_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t n, nk_f64_t *a_centroid,
455
407
  nk_f64_t *b_centroid, nk_f64_t *rotation, nk_f64_t *scale, nk_f64_t *result) {
456
- // RMSD uses identity rotation and scale=1.0.
408
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
409
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
457
410
  if (rotation)
458
411
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
459
412
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
460
413
  if (scale) *scale = 1.0;
461
414
 
462
415
  float64x2_t const zeros_f64x2 = vdupq_n_f64(0);
463
-
464
- // Accumulators for centroids and squared differences
465
- float64x2_t sum_a_x_f64x2 = zeros_f64x2, sum_a_y_f64x2 = zeros_f64x2, sum_a_z_f64x2 = zeros_f64x2;
466
- float64x2_t sum_b_x_f64x2 = zeros_f64x2, sum_b_y_f64x2 = zeros_f64x2, sum_b_z_f64x2 = zeros_f64x2;
467
416
  float64x2_t sum_squared_x_f64x2 = zeros_f64x2, sum_squared_y_f64x2 = zeros_f64x2, sum_squared_z_f64x2 = zeros_f64x2;
468
417
 
469
418
  float64x2_t a_x_f64x2, a_y_f64x2, a_z_f64x2, b_x_f64x2, b_y_f64x2, b_z_f64x2;
470
419
  nk_size_t i = 0;
471
420
 
472
- // Main loop processing 2 points at a time
473
421
  for (; i + 2 <= n; i += 2) {
474
422
  nk_deinterleave_f64x2_neon_(a + i * 3, &a_x_f64x2, &a_y_f64x2, &a_z_f64x2);
475
423
  nk_deinterleave_f64x2_neon_(b + i * 3, &b_x_f64x2, &b_y_f64x2, &b_z_f64x2);
476
424
 
477
- sum_a_x_f64x2 = vaddq_f64(sum_a_x_f64x2, a_x_f64x2);
478
- sum_a_y_f64x2 = vaddq_f64(sum_a_y_f64x2, a_y_f64x2);
479
- sum_a_z_f64x2 = vaddq_f64(sum_a_z_f64x2, a_z_f64x2);
480
- sum_b_x_f64x2 = vaddq_f64(sum_b_x_f64x2, b_x_f64x2);
481
- sum_b_y_f64x2 = vaddq_f64(sum_b_y_f64x2, b_y_f64x2);
482
- sum_b_z_f64x2 = vaddq_f64(sum_b_z_f64x2, b_z_f64x2);
483
-
484
425
  float64x2_t delta_x_f64x2 = vsubq_f64(a_x_f64x2, b_x_f64x2);
485
426
  float64x2_t delta_y_f64x2 = vsubq_f64(a_y_f64x2, b_y_f64x2);
486
427
  float64x2_t delta_z_f64x2 = vsubq_f64(a_z_f64x2, b_z_f64x2);
@@ -490,53 +431,23 @@ NK_PUBLIC void nk_rmsd_f64_neon(nk_f64_t const *a, nk_f64_t const *b, nk_size_t
490
431
  sum_squared_z_f64x2 = vfmaq_f64(sum_squared_z_f64x2, delta_z_f64x2, delta_z_f64x2);
491
432
  }
492
433
 
493
- // Reduce vectors to scalars.
494
- nk_f64_t total_ax = nk_reduce_stable_f64x2_neon_(sum_a_x_f64x2), total_ax_compensation = 0.0;
495
- nk_f64_t total_ay = nk_reduce_stable_f64x2_neon_(sum_a_y_f64x2), total_ay_compensation = 0.0;
496
- nk_f64_t total_az = nk_reduce_stable_f64x2_neon_(sum_a_z_f64x2), total_az_compensation = 0.0;
497
- nk_f64_t total_bx = nk_reduce_stable_f64x2_neon_(sum_b_x_f64x2), total_bx_compensation = 0.0;
498
- nk_f64_t total_by = nk_reduce_stable_f64x2_neon_(sum_b_y_f64x2), total_by_compensation = 0.0;
499
- nk_f64_t total_bz = nk_reduce_stable_f64x2_neon_(sum_b_z_f64x2), total_bz_compensation = 0.0;
500
434
  nk_f64_t total_squared_x = nk_reduce_stable_f64x2_neon_(sum_squared_x_f64x2), total_squared_x_compensation = 0.0;
501
435
  nk_f64_t total_squared_y = nk_reduce_stable_f64x2_neon_(sum_squared_y_f64x2), total_squared_y_compensation = 0.0;
502
436
  nk_f64_t total_squared_z = nk_reduce_stable_f64x2_neon_(sum_squared_z_f64x2), total_squared_z_compensation = 0.0;
503
437
 
504
- // Scalar tail
505
438
  for (; i < n; ++i) {
506
- nk_f64_t ax = a[i * 3 + 0], ay = a[i * 3 + 1], az = a[i * 3 + 2];
507
- nk_f64_t bx = b[i * 3 + 0], by = b[i * 3 + 1], bz = b[i * 3 + 2];
508
- nk_accumulate_sum_f64_(&total_ax, &total_ax_compensation, ax);
509
- nk_accumulate_sum_f64_(&total_ay, &total_ay_compensation, ay);
510
- nk_accumulate_sum_f64_(&total_az, &total_az_compensation, az);
511
- nk_accumulate_sum_f64_(&total_bx, &total_bx_compensation, bx);
512
- nk_accumulate_sum_f64_(&total_by, &total_by_compensation, by);
513
- nk_accumulate_sum_f64_(&total_bz, &total_bz_compensation, bz);
514
- nk_f64_t delta_x = ax - bx, delta_y = ay - by, delta_z = az - bz;
439
+ nk_f64_t delta_x = a[i * 3 + 0] - b[i * 3 + 0];
440
+ nk_f64_t delta_y = a[i * 3 + 1] - b[i * 3 + 1];
441
+ nk_f64_t delta_z = a[i * 3 + 2] - b[i * 3 + 2];
515
442
  nk_accumulate_square_f64_(&total_squared_x, &total_squared_x_compensation, delta_x);
516
443
  nk_accumulate_square_f64_(&total_squared_y, &total_squared_y_compensation, delta_y);
517
444
  nk_accumulate_square_f64_(&total_squared_z, &total_squared_z_compensation, delta_z);
518
445
  }
519
446
 
520
- total_ax += total_ax_compensation, total_ay += total_ay_compensation, total_az += total_az_compensation;
521
- total_bx += total_bx_compensation, total_by += total_by_compensation, total_bz += total_bz_compensation;
522
447
  total_squared_x += total_squared_x_compensation, total_squared_y += total_squared_y_compensation,
523
448
  total_squared_z += total_squared_z_compensation;
524
449
 
525
- // Compute centroids
526
- nk_f64_t inv_n = 1.0 / (nk_f64_t)n;
527
- nk_f64_t centroid_a_x = total_ax * inv_n, centroid_a_y = total_ay * inv_n, centroid_a_z = total_az * inv_n;
528
- nk_f64_t centroid_b_x = total_bx * inv_n, centroid_b_y = total_by * inv_n, centroid_b_z = total_bz * inv_n;
529
- if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
530
- if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
531
-
532
- // Compute RMSD
533
- nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x;
534
- nk_f64_t mean_diff_y = centroid_a_y - centroid_b_y;
535
- nk_f64_t mean_diff_z = centroid_a_z - centroid_b_z;
536
- nk_f64_t sum_squared = total_squared_x + total_squared_y + total_squared_z;
537
- nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
538
-
539
- *result = nk_f64_sqrt_neon(sum_squared * inv_n - mean_diff_sq);
450
+ *result = nk_f64_sqrt_neon((total_squared_x + total_squared_y + total_squared_z) / (nk_f64_t)n);
540
451
  }
541
452
 
542
453
  NK_PUBLIC void nk_kabsch_f32_neon(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
@@ -267,12 +267,12 @@ NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk
267
267
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
268
268
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
269
269
  if (scale) *scale = 1.0f;
270
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
271
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
270
272
 
271
273
  float32x4_t const zeros_f32x4 = vdupq_n_f32(0);
272
274
 
273
- // Accumulators for centroids and squared differences
274
- float32x4_t sum_a_x_f32x4 = zeros_f32x4, sum_a_y_f32x4 = zeros_f32x4, sum_a_z_f32x4 = zeros_f32x4;
275
- float32x4_t sum_b_x_f32x4 = zeros_f32x4, sum_b_y_f32x4 = zeros_f32x4, sum_b_z_f32x4 = zeros_f32x4;
275
+ // Accumulators for squared differences
276
276
  float32x4_t sum_squared_x_f32x4 = zeros_f32x4, sum_squared_y_f32x4 = zeros_f32x4, sum_squared_z_f32x4 = zeros_f32x4;
277
277
 
278
278
  float32x4_t a_x_f32x4, a_y_f32x4, a_z_f32x4, b_x_f32x4, b_y_f32x4, b_z_f32x4;
@@ -283,13 +283,6 @@ NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk
283
283
  nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(a + i * 3, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
284
284
  nk_deinterleave_bf16x4_to_f32x4_neonbfdot_(b + i * 3, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
285
285
 
286
- sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
287
- sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
288
- sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
289
- sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
290
- sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
291
- sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
292
-
293
286
  float32x4_t delta_x_f32x4 = vsubq_f32(a_x_f32x4, b_x_f32x4);
294
287
  float32x4_t delta_y_f32x4 = vsubq_f32(a_y_f32x4, b_y_f32x4);
295
288
  float32x4_t delta_z_f32x4 = vsubq_f32(a_z_f32x4, b_z_f32x4);
@@ -305,13 +298,6 @@ NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk
305
298
  nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(a + i * 3, n - i, &a_x_f32x4, &a_y_f32x4, &a_z_f32x4);
306
299
  nk_partial_deinterleave_bf16_to_f32x4_neonbfdot_(b + i * 3, n - i, &b_x_f32x4, &b_y_f32x4, &b_z_f32x4);
307
300
 
308
- sum_a_x_f32x4 = vaddq_f32(sum_a_x_f32x4, a_x_f32x4);
309
- sum_a_y_f32x4 = vaddq_f32(sum_a_y_f32x4, a_y_f32x4);
310
- sum_a_z_f32x4 = vaddq_f32(sum_a_z_f32x4, a_z_f32x4);
311
- sum_b_x_f32x4 = vaddq_f32(sum_b_x_f32x4, b_x_f32x4);
312
- sum_b_y_f32x4 = vaddq_f32(sum_b_y_f32x4, b_y_f32x4);
313
- sum_b_z_f32x4 = vaddq_f32(sum_b_z_f32x4, b_z_f32x4);
314
-
315
301
  float32x4_t delta_x_f32x4 = vsubq_f32(a_x_f32x4, b_x_f32x4);
316
302
  float32x4_t delta_y_f32x4 = vsubq_f32(a_y_f32x4, b_y_f32x4);
317
303
  float32x4_t delta_z_f32x4 = vsubq_f32(a_z_f32x4, b_z_f32x4);
@@ -322,36 +308,11 @@ NK_PUBLIC void nk_rmsd_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk
322
308
  }
323
309
 
324
310
  // Reduce vectors to scalars
325
- nk_f32_t total_ax = vaddvq_f32(sum_a_x_f32x4);
326
- nk_f32_t total_ay = vaddvq_f32(sum_a_y_f32x4);
327
- nk_f32_t total_az = vaddvq_f32(sum_a_z_f32x4);
328
- nk_f32_t total_bx = vaddvq_f32(sum_b_x_f32x4);
329
- nk_f32_t total_by = vaddvq_f32(sum_b_y_f32x4);
330
- nk_f32_t total_bz = vaddvq_f32(sum_b_z_f32x4);
331
311
  nk_f32_t total_squared_x = vaddvq_f32(sum_squared_x_f32x4);
332
312
  nk_f32_t total_squared_y = vaddvq_f32(sum_squared_y_f32x4);
333
313
  nk_f32_t total_squared_z = vaddvq_f32(sum_squared_z_f32x4);
334
314
 
335
- // Compute centroids
336
- nk_f32_t inv_n = 1.0f / (nk_f32_t)n;
337
- nk_f32_t centroid_a_x = total_ax * inv_n;
338
- nk_f32_t centroid_a_y = total_ay * inv_n;
339
- nk_f32_t centroid_a_z = total_az * inv_n;
340
- nk_f32_t centroid_b_x = total_bx * inv_n;
341
- nk_f32_t centroid_b_y = total_by * inv_n;
342
- nk_f32_t centroid_b_z = total_bz * inv_n;
343
-
344
- if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
345
- if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
346
-
347
- // Compute RMSD
348
- nk_f32_t mean_diff_x = centroid_a_x - centroid_b_x;
349
- nk_f32_t mean_diff_y = centroid_a_y - centroid_b_y;
350
- nk_f32_t mean_diff_z = centroid_a_z - centroid_b_z;
351
- nk_f32_t sum_squared = total_squared_x + total_squared_y + total_squared_z;
352
- nk_f32_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
353
-
354
- *result = nk_f32_sqrt_neon(sum_squared * inv_n - mean_diff_sq);
315
+ *result = nk_f32_sqrt_neon((total_squared_x + total_squared_y + total_squared_z) / (nk_f32_t)n);
355
316
  }
356
317
 
357
318
  NK_PUBLIC void nk_kabsch_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
@@ -701,16 +701,10 @@ NK_PUBLIC void nk_rmsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t p
701
701
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
702
702
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
703
703
  if (scale) *scale = 1.0f;
704
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
705
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
704
706
 
705
- // Fused single-pass: accumulate centroids and squared differences simultaneously.
706
- // RMSD = √(E[(a−b)²] − (ā − b̄)²)
707
707
  nk_size_t max_vector_length = __riscv_vsetvlmax_e64m2();
708
- vfloat64m2_t sum_a_x_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
709
- vfloat64m2_t sum_a_y_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
710
- vfloat64m2_t sum_a_z_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
711
- vfloat64m2_t sum_b_x_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
712
- vfloat64m2_t sum_b_y_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
713
- vfloat64m2_t sum_b_z_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
714
708
  vfloat64m2_t sum_squared_f64m2 = __riscv_vfmv_v_f_f64m2(0.0, max_vector_length);
715
709
  nk_f32_t const *a_ptr = a, *b_ptr = b;
716
710
  nk_size_t remaining = points_count;
@@ -725,15 +719,7 @@ NK_PUBLIC void nk_rmsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t p
725
719
  vfloat32m1_t b_x_f32m1 = __riscv_vget_v_f32m1x3_f32m1(b_f32m1x3, 0);
726
720
  vfloat32m1_t b_y_f32m1 = __riscv_vget_v_f32m1x3_f32m1(b_f32m1x3, 1);
727
721
  vfloat32m1_t b_z_f32m1 = __riscv_vget_v_f32m1x3_f32m1(b_f32m1x3, 2);
728
- // Accumulate centroids in f64.
729
- sum_a_x_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_a_x_f64m2, sum_a_x_f64m2, a_x_f32m1, vector_length);
730
- sum_a_y_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_a_y_f64m2, sum_a_y_f64m2, a_y_f32m1, vector_length);
731
- sum_a_z_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_a_z_f64m2, sum_a_z_f64m2, a_z_f32m1, vector_length);
732
- sum_b_x_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_b_x_f64m2, sum_b_x_f64m2, b_x_f32m1, vector_length);
733
- sum_b_y_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_b_y_f64m2, sum_b_y_f64m2, b_y_f32m1, vector_length);
734
- sum_b_z_f64m2 = __riscv_vfwadd_wv_f64m2_tu(sum_b_z_f64m2, sum_b_z_f64m2, b_z_f32m1, vector_length);
735
- // Accumulate (a−b)² per component. Widen a,b to f64 before subtracting to avoid f32
736
- // cancellation in the single-pass formula RMSD = √(E[(a−b)²] − (ā − b̄)²).
722
+ // Accumulate (a−b)² per component, widening to f64.
737
723
  vfloat64m2_t a_x_f64m2 = __riscv_vfwcvt_f_f_v_f64m2(a_x_f32m1, vector_length);
738
724
  vfloat64m2_t b_x_f64m2 = __riscv_vfwcvt_f_f_v_f64m2(b_x_f32m1, vector_length);
739
725
  vfloat64m2_t a_y_f64m2 = __riscv_vfwcvt_f_f_v_f64m2(a_y_f32m1, vector_length);
@@ -748,38 +734,9 @@ NK_PUBLIC void nk_rmsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t p
748
734
  sum_squared_f64m2 = __riscv_vfmacc_vv_f64m2_tu(sum_squared_f64m2, delta_z_f64m2, delta_z_f64m2, vector_length);
749
735
  }
750
736
  vfloat64m1_t zero_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, 1);
751
- nk_f64_t inv_points_count = 1.0 / (nk_f64_t)points_count;
752
- nk_f64_t centroid_a_x = __riscv_vfmv_f_s_f64m1_f64(
753
- __riscv_vfredusum_vs_f64m2_f64m1(sum_a_x_f64m2, zero_f64m1, max_vector_length)) *
754
- inv_points_count;
755
- nk_f64_t centroid_a_y = __riscv_vfmv_f_s_f64m1_f64(
756
- __riscv_vfredusum_vs_f64m2_f64m1(sum_a_y_f64m2, zero_f64m1, max_vector_length)) *
757
- inv_points_count;
758
- nk_f64_t centroid_a_z = __riscv_vfmv_f_s_f64m1_f64(
759
- __riscv_vfredusum_vs_f64m2_f64m1(sum_a_z_f64m2, zero_f64m1, max_vector_length)) *
760
- inv_points_count;
761
- nk_f64_t centroid_b_x = __riscv_vfmv_f_s_f64m1_f64(
762
- __riscv_vfredusum_vs_f64m2_f64m1(sum_b_x_f64m2, zero_f64m1, max_vector_length)) *
763
- inv_points_count;
764
- nk_f64_t centroid_b_y = __riscv_vfmv_f_s_f64m1_f64(
765
- __riscv_vfredusum_vs_f64m2_f64m1(sum_b_y_f64m2, zero_f64m1, max_vector_length)) *
766
- inv_points_count;
767
- nk_f64_t centroid_b_z = __riscv_vfmv_f_s_f64m1_f64(
768
- __riscv_vfredusum_vs_f64m2_f64m1(sum_b_z_f64m2, zero_f64m1, max_vector_length)) *
769
- inv_points_count;
770
- if (a_centroid)
771
- a_centroid[0] = (nk_f32_t)centroid_a_x, a_centroid[1] = (nk_f32_t)centroid_a_y,
772
- a_centroid[2] = (nk_f32_t)centroid_a_z;
773
- if (b_centroid)
774
- b_centroid[0] = (nk_f32_t)centroid_b_x, b_centroid[1] = (nk_f32_t)centroid_b_y,
775
- b_centroid[2] = (nk_f32_t)centroid_b_z;
776
-
777
737
  nk_f64_t sum_squared = __riscv_vfmv_f_s_f64m1_f64(
778
738
  __riscv_vfredusum_vs_f64m2_f64m1(sum_squared_f64m2, zero_f64m1, max_vector_length));
779
- nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x, mean_diff_y = centroid_a_y - centroid_b_y,
780
- mean_diff_z = centroid_a_z - centroid_b_z;
781
- nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
782
- *result = nk_f64_sqrt_rvv(sum_squared * inv_points_count - mean_diff_sq);
739
+ *result = nk_f64_sqrt_rvv(sum_squared / (nk_f64_t)points_count);
783
740
  }
784
741
 
785
742
  NK_PUBLIC void nk_rmsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t points_count, nk_f64_t *a_centroid,
@@ -788,22 +745,10 @@ NK_PUBLIC void nk_rmsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t p
788
745
  rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0,
789
746
  rotation[6] = 0, rotation[7] = 0, rotation[8] = 1;
790
747
  if (scale) *scale = 1.0;
748
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0;
749
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0;
791
750
 
792
- // Fused single-pass: accumulate centroids and squared differences simultaneously.
793
- // RMSD = √(E[(a−b)²] − (ā − b̄)²)
794
751
  nk_size_t max_vector_length = __riscv_vsetvlmax_e64m1();
795
- vfloat64m1_t sum_a_x_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
796
- vfloat64m1_t sum_a_y_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
797
- vfloat64m1_t sum_a_z_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
798
- vfloat64m1_t sum_b_x_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
799
- vfloat64m1_t sum_b_y_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
800
- vfloat64m1_t sum_b_z_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
801
- vfloat64m1_t compensation_a_x_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
802
- vfloat64m1_t compensation_a_y_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
803
- vfloat64m1_t compensation_a_z_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
804
- vfloat64m1_t compensation_b_x_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
805
- vfloat64m1_t compensation_b_y_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
806
- vfloat64m1_t compensation_b_z_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
807
752
  vfloat64m1_t sum_squared_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
808
753
  vfloat64m1_t compensation_squared_f64m1 = __riscv_vfmv_v_f_f64m1(0.0, max_vector_length);
809
754
  nk_f64_t const *a_ptr = a, *b_ptr = b;
@@ -819,13 +764,6 @@ NK_PUBLIC void nk_rmsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t p
819
764
  vfloat64m1_t b_x_f64m1 = __riscv_vget_v_f64m1x3_f64m1(b_f64m1x3, 0);
820
765
  vfloat64m1_t b_y_f64m1 = __riscv_vget_v_f64m1x3_f64m1(b_f64m1x3, 1);
821
766
  vfloat64m1_t b_z_f64m1 = __riscv_vget_v_f64m1x3_f64m1(b_f64m1x3, 2);
822
- // Accumulate centroids with Kahan compensation.
823
- nk_accumulate_sum_f64m1_rvv_(&sum_a_x_f64m1, &compensation_a_x_f64m1, a_x_f64m1, vector_length);
824
- nk_accumulate_sum_f64m1_rvv_(&sum_a_y_f64m1, &compensation_a_y_f64m1, a_y_f64m1, vector_length);
825
- nk_accumulate_sum_f64m1_rvv_(&sum_a_z_f64m1, &compensation_a_z_f64m1, a_z_f64m1, vector_length);
826
- nk_accumulate_sum_f64m1_rvv_(&sum_b_x_f64m1, &compensation_b_x_f64m1, b_x_f64m1, vector_length);
827
- nk_accumulate_sum_f64m1_rvv_(&sum_b_y_f64m1, &compensation_b_y_f64m1, b_y_f64m1, vector_length);
828
- nk_accumulate_sum_f64m1_rvv_(&sum_b_z_f64m1, &compensation_b_z_f64m1, b_z_f64m1, vector_length);
829
767
  // Accumulate (a-b)^2 per component.
830
768
  vfloat64m1_t delta_x_f64m1 = __riscv_vfsub_vv_f64m1(a_x_f64m1, b_x_f64m1, vector_length);
831
769
  vfloat64m1_t delta_y_f64m1 = __riscv_vfsub_vv_f64m1(a_y_f64m1, b_y_f64m1, vector_length);
@@ -835,21 +773,8 @@ NK_PUBLIC void nk_rmsd_f64_rvv(nk_f64_t const *a, nk_f64_t const *b, nk_size_t p
835
773
  dist_sq_f64m1 = __riscv_vfmacc_vv_f64m1(dist_sq_f64m1, delta_z_f64m1, delta_z_f64m1, vector_length);
836
774
  nk_accumulate_sum_f64m1_rvv_(&sum_squared_f64m1, &compensation_squared_f64m1, dist_sq_f64m1, vector_length);
837
775
  }
838
- nk_f64_t inv_points_count = 1.0 / (nk_f64_t)points_count;
839
- nk_f64_t centroid_a_x = nk_dot_stable_sum_f64m1_rvv_(sum_a_x_f64m1, compensation_a_x_f64m1) * inv_points_count;
840
- nk_f64_t centroid_a_y = nk_dot_stable_sum_f64m1_rvv_(sum_a_y_f64m1, compensation_a_y_f64m1) * inv_points_count;
841
- nk_f64_t centroid_a_z = nk_dot_stable_sum_f64m1_rvv_(sum_a_z_f64m1, compensation_a_z_f64m1) * inv_points_count;
842
- nk_f64_t centroid_b_x = nk_dot_stable_sum_f64m1_rvv_(sum_b_x_f64m1, compensation_b_x_f64m1) * inv_points_count;
843
- nk_f64_t centroid_b_y = nk_dot_stable_sum_f64m1_rvv_(sum_b_y_f64m1, compensation_b_y_f64m1) * inv_points_count;
844
- nk_f64_t centroid_b_z = nk_dot_stable_sum_f64m1_rvv_(sum_b_z_f64m1, compensation_b_z_f64m1) * inv_points_count;
845
- if (a_centroid) a_centroid[0] = centroid_a_x, a_centroid[1] = centroid_a_y, a_centroid[2] = centroid_a_z;
846
- if (b_centroid) b_centroid[0] = centroid_b_x, b_centroid[1] = centroid_b_y, b_centroid[2] = centroid_b_z;
847
-
848
776
  nk_f64_t sum_squared = nk_dot_stable_sum_f64m1_rvv_(sum_squared_f64m1, compensation_squared_f64m1);
849
- nk_f64_t mean_diff_x = centroid_a_x - centroid_b_x, mean_diff_y = centroid_a_y - centroid_b_y,
850
- mean_diff_z = centroid_a_z - centroid_b_z;
851
- nk_f64_t mean_diff_sq = mean_diff_x * mean_diff_x + mean_diff_y * mean_diff_y + mean_diff_z * mean_diff_z;
852
- *result = nk_f64_sqrt_rvv(sum_squared * inv_points_count - mean_diff_sq);
777
+ *result = nk_f64_sqrt_rvv(sum_squared / (nk_f64_t)points_count);
853
778
  }
854
779
 
855
780
  NK_PUBLIC void nk_kabsch_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t points_count, nk_f32_t *a_centroid,
@@ -392,59 +392,32 @@ nk_define_det3x3_(f64)
392
392
  /* RMSD (Root Mean Square Deviation) without optimal superposition.
393
393
  * Simply computes the RMS of distances between corresponding points.
394
394
  */
395
- #define nk_define_rmsd_(input_type, accumulator_type, output_type, result_type, load_and_convert, compute_sqrt) \
396
- NK_PUBLIC void nk_rmsd_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
397
- nk_size_t n, nk_##output_type##_t *a_centroid, \
398
- nk_##output_type##_t *b_centroid, nk_##output_type##_t *rotation, \
399
- nk_##output_type##_t *scale, nk_##result_type##_t *result) { \
400
- nk_##accumulator_type##_t sum_a_x = 0, sum_a_y = 0, sum_a_z = 0; \
401
- nk_##accumulator_type##_t sum_b_x = 0, sum_b_y = 0, sum_b_z = 0; \
402
- nk_##accumulator_type##_t sum_a_x_compensation = 0, sum_a_y_compensation = 0, sum_a_z_compensation = 0; \
403
- nk_##accumulator_type##_t sum_b_x_compensation = 0, sum_b_y_compensation = 0, sum_b_z_compensation = 0; \
404
- nk_##accumulator_type##_t val_a_x, val_a_y, val_a_z, val_b_x, val_b_y, val_b_z; \
405
- for (nk_size_t i = 0; i < n; ++i) { \
406
- load_and_convert(a + i * 3 + 0, &val_a_x), load_and_convert(a + i * 3 + 1, &val_a_y); \
407
- load_and_convert(a + i * 3 + 2, &val_a_z), load_and_convert(b + i * 3 + 0, &val_b_x); \
408
- load_and_convert(b + i * 3 + 1, &val_b_y), load_and_convert(b + i * 3 + 2, &val_b_z); \
409
- nk_accumulate_sum_##accumulator_type##_(&sum_a_x, &sum_a_x_compensation, val_a_x); \
410
- nk_accumulate_sum_##accumulator_type##_(&sum_a_y, &sum_a_y_compensation, val_a_y); \
411
- nk_accumulate_sum_##accumulator_type##_(&sum_a_z, &sum_a_z_compensation, val_a_z); \
412
- nk_accumulate_sum_##accumulator_type##_(&sum_b_x, &sum_b_x_compensation, val_b_x); \
413
- nk_accumulate_sum_##accumulator_type##_(&sum_b_y, &sum_b_y_compensation, val_b_y); \
414
- nk_accumulate_sum_##accumulator_type##_(&sum_b_z, &sum_b_z_compensation, val_b_z); \
415
- } \
416
- nk_##accumulator_type##_t inv_n = (nk_##accumulator_type##_t)1.0 / (nk_##accumulator_type##_t)n; \
417
- nk_##accumulator_type##_t centroid_a_x = (sum_a_x + sum_a_x_compensation) * inv_n; \
418
- nk_##accumulator_type##_t centroid_a_y = (sum_a_y + sum_a_y_compensation) * inv_n; \
419
- nk_##accumulator_type##_t centroid_a_z = (sum_a_z + sum_a_z_compensation) * inv_n; \
420
- nk_##accumulator_type##_t centroid_b_x = (sum_b_x + sum_b_x_compensation) * inv_n; \
421
- nk_##accumulator_type##_t centroid_b_y = (sum_b_y + sum_b_y_compensation) * inv_n; \
422
- nk_##accumulator_type##_t centroid_b_z = (sum_b_z + sum_b_z_compensation) * inv_n; \
423
- if (a_centroid) \
424
- a_centroid[0] = (nk_##output_type##_t)centroid_a_x, a_centroid[1] = (nk_##output_type##_t)centroid_a_y, \
425
- a_centroid[2] = (nk_##output_type##_t)centroid_a_z; \
426
- if (b_centroid) \
427
- b_centroid[0] = (nk_##output_type##_t)centroid_b_x, b_centroid[1] = (nk_##output_type##_t)centroid_b_y, \
428
- b_centroid[2] = (nk_##output_type##_t)centroid_b_z; \
429
- /* RMSD uses identity rotation and scale=1.0 */ \
430
- if (rotation) \
431
- rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0, \
432
- rotation[6] = 0, rotation[7] = 0, rotation[8] = 1; \
433
- if (scale) *scale = (nk_##output_type##_t)1; \
434
- nk_##accumulator_type##_t sum_squared = 0, sum_squared_compensation = 0; \
435
- for (nk_size_t i = 0; i < n; ++i) { \
436
- load_and_convert(a + i * 3 + 0, &val_a_x), load_and_convert(b + i * 3 + 0, &val_b_x); \
437
- load_and_convert(a + i * 3 + 1, &val_a_y), load_and_convert(b + i * 3 + 1, &val_b_y); \
438
- load_and_convert(a + i * 3 + 2, &val_a_z), load_and_convert(b + i * 3 + 2, &val_b_z); \
439
- nk_##accumulator_type##_t dx = (val_a_x - centroid_a_x) - (val_b_x - centroid_b_x); \
440
- nk_##accumulator_type##_t dy = (val_a_y - centroid_a_y) - (val_b_y - centroid_b_y); \
441
- nk_##accumulator_type##_t dz = (val_a_z - centroid_a_z) - (val_b_z - centroid_b_z); \
442
- nk_accumulate_square_##accumulator_type##_(&sum_squared, &sum_squared_compensation, dx); \
443
- nk_accumulate_square_##accumulator_type##_(&sum_squared, &sum_squared_compensation, dy); \
444
- nk_accumulate_square_##accumulator_type##_(&sum_squared, &sum_squared_compensation, dz); \
445
- } \
446
- nk_##accumulator_type##_t msd = (sum_squared + sum_squared_compensation) * inv_n; \
447
- *result = msd > 0 ? (nk_##result_type##_t)compute_sqrt(msd) : 0; \
395
+ #define nk_define_rmsd_(input_type, accumulator_type, output_type, result_type, load_and_convert, compute_sqrt) \
396
+ NK_PUBLIC void nk_rmsd_##input_type##_serial(nk_##input_type##_t const *a, nk_##input_type##_t const *b, \
397
+ nk_size_t n, nk_##output_type##_t *a_centroid, \
398
+ nk_##output_type##_t *b_centroid, nk_##output_type##_t *rotation, \
399
+ nk_##output_type##_t *scale, nk_##result_type##_t *result) { \
400
+ if (a_centroid) a_centroid[0] = 0, a_centroid[1] = 0, a_centroid[2] = 0; \
401
+ if (b_centroid) b_centroid[0] = 0, b_centroid[1] = 0, b_centroid[2] = 0; \
402
+ if (rotation) \
403
+ rotation[0] = 1, rotation[1] = 0, rotation[2] = 0, rotation[3] = 0, rotation[4] = 1, rotation[5] = 0, \
404
+ rotation[6] = 0, rotation[7] = 0, rotation[8] = 1; \
405
+ if (scale) *scale = (nk_##output_type##_t)1; \
406
+ nk_##accumulator_type##_t sum_squared = 0, sum_squared_compensation = 0; \
407
+ nk_##accumulator_type##_t val_a_x, val_a_y, val_a_z, val_b_x, val_b_y, val_b_z; \
408
+ for (nk_size_t i = 0; i < n; ++i) { \
409
+ load_and_convert(a + i * 3 + 0, &val_a_x), load_and_convert(b + i * 3 + 0, &val_b_x); \
410
+ load_and_convert(a + i * 3 + 1, &val_a_y), load_and_convert(b + i * 3 + 1, &val_b_y); \
411
+ load_and_convert(a + i * 3 + 2, &val_a_z), load_and_convert(b + i * 3 + 2, &val_b_z); \
412
+ nk_##accumulator_type##_t dx = val_a_x - val_b_x; \
413
+ nk_##accumulator_type##_t dy = val_a_y - val_b_y; \
414
+ nk_##accumulator_type##_t dz = val_a_z - val_b_z; \
415
+ nk_accumulate_square_##accumulator_type##_(&sum_squared, &sum_squared_compensation, dx); \
416
+ nk_accumulate_square_##accumulator_type##_(&sum_squared, &sum_squared_compensation, dy); \
417
+ nk_accumulate_square_##accumulator_type##_(&sum_squared, &sum_squared_compensation, dz); \
418
+ } \
419
+ nk_##accumulator_type##_t msd = (sum_squared + sum_squared_compensation) / (nk_##accumulator_type##_t)n; \
420
+ *result = msd > 0 ? (nk_##result_type##_t)compute_sqrt(msd) : 0; \
448
421
  }
449
422
 
450
423
  /* Kabsch algorithm for optimal rigid body superposition.