numkong 7.5.0 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/binding.gyp +18 -0
- package/c/dispatch_e5m2.c +23 -3
- package/include/numkong/capabilities.h +1 -1
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +434 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +23 -8
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots.h +12 -0
- package/include/numkong/each/serial.h +18 -1
- package/include/numkong/geospatial/serial.h +14 -3
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +204 -162
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +128 -0
- package/include/numkong/spatials/serial.h +18 -1
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials.h +17 -0
- package/include/numkong/tensor.hpp +107 -23
- package/javascript/numkong.c +3 -2
- package/package.json +7 -7
- package/wasm/numkong.wasm +0 -0
|
@@ -346,28 +346,36 @@ nk_angular_f16_skylake_cycle:
|
|
|
346
346
|
}
|
|
347
347
|
|
|
348
348
|
NK_PUBLIC void nk_sqeuclidean_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
349
|
-
|
|
350
|
-
|
|
349
|
+
// E4M3 has no free widen shift (its 4-bit exponent doesn't line up with F16's 5-bit
|
|
350
|
+
// at bit 10), so we call the Giesen-based 16-lane cast helper twice per iter and
|
|
351
|
+
// run with two F32 accumulators to break the FMA dependency chain.
|
|
352
|
+
__m512 first_acc_f32x16 = _mm512_setzero_ps();
|
|
353
|
+
__m512 second_acc_f32x16 = _mm512_setzero_ps();
|
|
354
|
+
__m256i a_u8x32, b_u8x32;
|
|
351
355
|
|
|
352
356
|
nk_sqeuclidean_e4m3_skylake_cycle:
|
|
353
|
-
if (n <
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
+
if (n < 32) {
|
|
358
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
359
|
+
a_u8x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
360
|
+
b_u8x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
357
361
|
n = 0;
|
|
358
362
|
}
|
|
359
363
|
else {
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
a +=
|
|
364
|
+
a_u8x32 = _mm256_loadu_si256((__m256i const *)a);
|
|
365
|
+
b_u8x32 = _mm256_loadu_si256((__m256i const *)b);
|
|
366
|
+
a += 32, b += 32, n -= 32;
|
|
363
367
|
}
|
|
364
|
-
__m512
|
|
365
|
-
__m512
|
|
366
|
-
__m512
|
|
367
|
-
|
|
368
|
+
__m512 a_low_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_castsi256_si128(a_u8x32));
|
|
369
|
+
__m512 a_high_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_extracti128_si256(a_u8x32, 1));
|
|
370
|
+
__m512 b_low_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_castsi256_si128(b_u8x32));
|
|
371
|
+
__m512 b_high_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_extracti128_si256(b_u8x32, 1));
|
|
372
|
+
__m512 diff_low_f32x16 = _mm512_sub_ps(a_low_f32x16, b_low_f32x16);
|
|
373
|
+
__m512 diff_high_f32x16 = _mm512_sub_ps(a_high_f32x16, b_high_f32x16);
|
|
374
|
+
first_acc_f32x16 = _mm512_fmadd_ps(diff_low_f32x16, diff_low_f32x16, first_acc_f32x16);
|
|
375
|
+
second_acc_f32x16 = _mm512_fmadd_ps(diff_high_f32x16, diff_high_f32x16, second_acc_f32x16);
|
|
368
376
|
if (n) goto nk_sqeuclidean_e4m3_skylake_cycle;
|
|
369
377
|
|
|
370
|
-
*result = nk_reduce_add_f32x16_skylake_(
|
|
378
|
+
*result = nk_reduce_add_f32x16_skylake_(_mm512_add_ps(first_acc_f32x16, second_acc_f32x16));
|
|
371
379
|
}
|
|
372
380
|
|
|
373
381
|
NK_PUBLIC void nk_euclidean_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -379,25 +387,30 @@ NK_PUBLIC void nk_angular_e4m3_skylake(nk_e4m3_t const *a, nk_e4m3_t const *b, n
|
|
|
379
387
|
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
380
388
|
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
381
389
|
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
382
|
-
|
|
390
|
+
__m256i a_u8x32, b_u8x32;
|
|
383
391
|
|
|
384
392
|
nk_angular_e4m3_skylake_cycle:
|
|
385
|
-
if (n <
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
393
|
+
if (n < 32) {
|
|
394
|
+
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, (unsigned int)n);
|
|
395
|
+
a_u8x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
396
|
+
b_u8x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
389
397
|
n = 0;
|
|
390
398
|
}
|
|
391
399
|
else {
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
a +=
|
|
400
|
+
a_u8x32 = _mm256_loadu_si256((__m256i const *)a);
|
|
401
|
+
b_u8x32 = _mm256_loadu_si256((__m256i const *)b);
|
|
402
|
+
a += 32, b += 32, n -= 32;
|
|
395
403
|
}
|
|
396
|
-
__m512
|
|
397
|
-
__m512
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
404
|
+
__m512 a_low_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_castsi256_si128(a_u8x32));
|
|
405
|
+
__m512 a_high_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_extracti128_si256(a_u8x32, 1));
|
|
406
|
+
__m512 b_low_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_castsi256_si128(b_u8x32));
|
|
407
|
+
__m512 b_high_f32x16 = nk_e4m3x16_to_f32x16_skylake_(_mm256_extracti128_si256(b_u8x32, 1));
|
|
408
|
+
dot_f32x16 = _mm512_fmadd_ps(a_low_f32x16, b_low_f32x16, dot_f32x16);
|
|
409
|
+
dot_f32x16 = _mm512_fmadd_ps(a_high_f32x16, b_high_f32x16, dot_f32x16);
|
|
410
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_low_f32x16, a_low_f32x16, a_norm_sq_f32x16);
|
|
411
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_high_f32x16, a_high_f32x16, a_norm_sq_f32x16);
|
|
412
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_low_f32x16, b_low_f32x16, b_norm_sq_f32x16);
|
|
413
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_high_f32x16, b_high_f32x16, b_norm_sq_f32x16);
|
|
401
414
|
if (n) goto nk_angular_e4m3_skylake_cycle;
|
|
402
415
|
|
|
403
416
|
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
@@ -407,28 +420,53 @@ nk_angular_e4m3_skylake_cycle:
|
|
|
407
420
|
}
|
|
408
421
|
|
|
409
422
|
NK_PUBLIC void nk_sqeuclidean_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
410
|
-
|
|
411
|
-
|
|
423
|
+
// E5M2 shares F16's exponent bias (15): `byte << 8` equals the matching F16 bit-pattern
|
|
424
|
+
// for normals, subnormals, zero, Inf, and NaN. We expose that shift for free by unpacking
|
|
425
|
+
// against zero — the zero byte lands in the low half of each 16-bit lane, the E5M2 byte
|
|
426
|
+
// in the high half. `vpunpck*bw` is per-128-bit-lane so the F32 outputs are lane-scrambled
|
|
427
|
+
// across 512 bits, but the commutative sum reduction is invariant under that.
|
|
428
|
+
__m512 first_acc_f32x16 = _mm512_setzero_ps();
|
|
429
|
+
__m512 second_acc_f32x16 = _mm512_setzero_ps();
|
|
430
|
+
__m512i const zero_u8x64 = _mm512_setzero_si512();
|
|
431
|
+
__m512i a_u8x64, b_u8x64;
|
|
412
432
|
|
|
413
433
|
nk_sqeuclidean_e5m2_skylake_cycle:
|
|
414
|
-
if (n <
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
434
|
+
if (n < 64) {
|
|
435
|
+
__mmask64 mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)n);
|
|
436
|
+
a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
437
|
+
b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
418
438
|
n = 0;
|
|
419
439
|
}
|
|
420
440
|
else {
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
a +=
|
|
441
|
+
a_u8x64 = _mm512_loadu_si512((__m512i const *)a);
|
|
442
|
+
b_u8x64 = _mm512_loadu_si512((__m512i const *)b);
|
|
443
|
+
a += 64, b += 64, n -= 64;
|
|
424
444
|
}
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
445
|
+
__m512i a_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, a_u8x64);
|
|
446
|
+
__m512i a_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, a_u8x64);
|
|
447
|
+
__m512i b_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, b_u8x64);
|
|
448
|
+
__m512i b_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, b_u8x64);
|
|
449
|
+
|
|
450
|
+
__m512 a_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_even_f16x32));
|
|
451
|
+
__m512 a_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_even_f16x32, 1));
|
|
452
|
+
__m512 a_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_odd_f16x32));
|
|
453
|
+
__m512 a_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_odd_f16x32, 1));
|
|
454
|
+
__m512 b_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_even_f16x32));
|
|
455
|
+
__m512 b_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_even_f16x32, 1));
|
|
456
|
+
__m512 b_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_odd_f16x32));
|
|
457
|
+
__m512 b_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_odd_f16x32, 1));
|
|
458
|
+
|
|
459
|
+
__m512 diff_first_f32x16 = _mm512_sub_ps(a_first_f32x16, b_first_f32x16);
|
|
460
|
+
__m512 diff_second_f32x16 = _mm512_sub_ps(a_second_f32x16, b_second_f32x16);
|
|
461
|
+
__m512 diff_third_f32x16 = _mm512_sub_ps(a_third_f32x16, b_third_f32x16);
|
|
462
|
+
__m512 diff_fourth_f32x16 = _mm512_sub_ps(a_fourth_f32x16, b_fourth_f32x16);
|
|
463
|
+
first_acc_f32x16 = _mm512_fmadd_ps(diff_first_f32x16, diff_first_f32x16, first_acc_f32x16);
|
|
464
|
+
second_acc_f32x16 = _mm512_fmadd_ps(diff_second_f32x16, diff_second_f32x16, second_acc_f32x16);
|
|
465
|
+
first_acc_f32x16 = _mm512_fmadd_ps(diff_third_f32x16, diff_third_f32x16, first_acc_f32x16);
|
|
466
|
+
second_acc_f32x16 = _mm512_fmadd_ps(diff_fourth_f32x16, diff_fourth_f32x16, second_acc_f32x16);
|
|
429
467
|
if (n) goto nk_sqeuclidean_e5m2_skylake_cycle;
|
|
430
468
|
|
|
431
|
-
*result = nk_reduce_add_f32x16_skylake_(
|
|
469
|
+
*result = nk_reduce_add_f32x16_skylake_(_mm512_add_ps(first_acc_f32x16, second_acc_f32x16));
|
|
432
470
|
}
|
|
433
471
|
|
|
434
472
|
NK_PUBLIC void nk_euclidean_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -440,25 +478,47 @@ NK_PUBLIC void nk_angular_e5m2_skylake(nk_e5m2_t const *a, nk_e5m2_t const *b, n
|
|
|
440
478
|
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
441
479
|
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
442
480
|
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
443
|
-
|
|
481
|
+
__m512i const zero_u8x64 = _mm512_setzero_si512();
|
|
482
|
+
__m512i a_u8x64, b_u8x64;
|
|
444
483
|
|
|
445
484
|
nk_angular_e5m2_skylake_cycle:
|
|
446
|
-
if (n <
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
485
|
+
if (n < 64) {
|
|
486
|
+
__mmask64 mask = _bzhi_u64(0xFFFFFFFFFFFFFFFFULL, (unsigned int)n);
|
|
487
|
+
a_u8x64 = _mm512_maskz_loadu_epi8(mask, a);
|
|
488
|
+
b_u8x64 = _mm512_maskz_loadu_epi8(mask, b);
|
|
450
489
|
n = 0;
|
|
451
490
|
}
|
|
452
491
|
else {
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
a +=
|
|
492
|
+
a_u8x64 = _mm512_loadu_si512((__m512i const *)a);
|
|
493
|
+
b_u8x64 = _mm512_loadu_si512((__m512i const *)b);
|
|
494
|
+
a += 64, b += 64, n -= 64;
|
|
456
495
|
}
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
496
|
+
__m512i a_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, a_u8x64);
|
|
497
|
+
__m512i a_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, a_u8x64);
|
|
498
|
+
__m512i b_even_f16x32 = _mm512_unpacklo_epi8(zero_u8x64, b_u8x64);
|
|
499
|
+
__m512i b_odd_f16x32 = _mm512_unpackhi_epi8(zero_u8x64, b_u8x64);
|
|
500
|
+
|
|
501
|
+
__m512 a_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_even_f16x32));
|
|
502
|
+
__m512 a_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_even_f16x32, 1));
|
|
503
|
+
__m512 a_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(a_odd_f16x32));
|
|
504
|
+
__m512 a_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(a_odd_f16x32, 1));
|
|
505
|
+
__m512 b_first_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_even_f16x32));
|
|
506
|
+
__m512 b_second_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_even_f16x32, 1));
|
|
507
|
+
__m512 b_third_f32x16 = _mm512_cvtph_ps(_mm512_castsi512_si256(b_odd_f16x32));
|
|
508
|
+
__m512 b_fourth_f32x16 = _mm512_cvtph_ps(_mm512_extracti64x4_epi64(b_odd_f16x32, 1));
|
|
509
|
+
|
|
510
|
+
dot_f32x16 = _mm512_fmadd_ps(a_first_f32x16, b_first_f32x16, dot_f32x16);
|
|
511
|
+
dot_f32x16 = _mm512_fmadd_ps(a_second_f32x16, b_second_f32x16, dot_f32x16);
|
|
512
|
+
dot_f32x16 = _mm512_fmadd_ps(a_third_f32x16, b_third_f32x16, dot_f32x16);
|
|
513
|
+
dot_f32x16 = _mm512_fmadd_ps(a_fourth_f32x16, b_fourth_f32x16, dot_f32x16);
|
|
514
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_first_f32x16, a_first_f32x16, a_norm_sq_f32x16);
|
|
515
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_second_f32x16, a_second_f32x16, a_norm_sq_f32x16);
|
|
516
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_third_f32x16, a_third_f32x16, a_norm_sq_f32x16);
|
|
517
|
+
a_norm_sq_f32x16 = _mm512_fmadd_ps(a_fourth_f32x16, a_fourth_f32x16, a_norm_sq_f32x16);
|
|
518
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_first_f32x16, b_first_f32x16, b_norm_sq_f32x16);
|
|
519
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_second_f32x16, b_second_f32x16, b_norm_sq_f32x16);
|
|
520
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_third_f32x16, b_third_f32x16, b_norm_sq_f32x16);
|
|
521
|
+
b_norm_sq_f32x16 = _mm512_fmadd_ps(b_fourth_f32x16, b_fourth_f32x16, b_norm_sq_f32x16);
|
|
462
522
|
if (n) goto nk_angular_e5m2_skylake_cycle;
|
|
463
523
|
|
|
464
524
|
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
@@ -604,12 +604,6 @@ NK_PUBLIC void nk_euclidean_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
604
604
|
NK_PUBLIC void nk_sqeuclidean_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
605
605
|
/** @copydoc nk_angular_f64 */
|
|
606
606
|
NK_PUBLIC void nk_angular_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result);
|
|
607
|
-
/** @copydoc nk_euclidean_f64 */
|
|
608
|
-
NK_PUBLIC void nk_euclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
609
|
-
/** @copydoc nk_sqeuclidean_f64 */
|
|
610
|
-
NK_PUBLIC void nk_sqeuclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
611
|
-
/** @copydoc nk_angular_f64 */
|
|
612
|
-
NK_PUBLIC void nk_angular_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result);
|
|
613
607
|
#endif // NK_TARGET_GENOA
|
|
614
608
|
|
|
615
609
|
#if NK_TARGET_DIAMOND
|
|
@@ -1263,8 +1257,6 @@ NK_PUBLIC void nk_euclidean_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size
|
|
|
1263
1257
|
nk_euclidean_e5m2_neonfp8(a, b, n, result);
|
|
1264
1258
|
#elif NK_TARGET_DIAMOND
|
|
1265
1259
|
nk_euclidean_e5m2_diamond(a, b, n, result);
|
|
1266
|
-
#elif NK_TARGET_GENOA
|
|
1267
|
-
nk_euclidean_e5m2_genoa(a, b, n, result);
|
|
1268
1260
|
#elif NK_TARGET_SKYLAKE
|
|
1269
1261
|
nk_euclidean_e5m2_skylake(a, b, n, result);
|
|
1270
1262
|
#elif NK_TARGET_RVV
|
|
@@ -1281,8 +1273,6 @@ NK_PUBLIC void nk_sqeuclidean_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_si
|
|
|
1281
1273
|
nk_sqeuclidean_e5m2_neonfp8(a, b, n, result);
|
|
1282
1274
|
#elif NK_TARGET_DIAMOND
|
|
1283
1275
|
nk_sqeuclidean_e5m2_diamond(a, b, n, result);
|
|
1284
|
-
#elif NK_TARGET_GENOA
|
|
1285
|
-
nk_sqeuclidean_e5m2_genoa(a, b, n, result);
|
|
1286
1276
|
#elif NK_TARGET_SKYLAKE
|
|
1287
1277
|
nk_sqeuclidean_e5m2_skylake(a, b, n, result);
|
|
1288
1278
|
#elif NK_TARGET_RVV
|
|
@@ -1299,8 +1289,6 @@ NK_PUBLIC void nk_angular_e5m2(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t
|
|
|
1299
1289
|
nk_angular_e5m2_neonfp8(a, b, n, result);
|
|
1300
1290
|
#elif NK_TARGET_DIAMOND
|
|
1301
1291
|
nk_angular_e5m2_diamond(a, b, n, result);
|
|
1302
|
-
#elif NK_TARGET_GENOA
|
|
1303
|
-
nk_angular_e5m2_genoa(a, b, n, result);
|
|
1304
1292
|
#elif NK_TARGET_SKYLAKE
|
|
1305
1293
|
nk_angular_e5m2_skylake(a, b, n, result);
|
|
1306
1294
|
#elif NK_TARGET_RVV
|
|
@@ -158,6 +158,134 @@ NK_PUBLIC void nk_euclideans_symmetric_f16_graniteamx(
|
|
|
158
158
|
|
|
159
159
|
#pragma endregion F16 Symmetric
|
|
160
160
|
|
|
161
|
+
#pragma region E5M2 Packed
|
|
162
|
+
|
|
163
|
+
NK_INTERNAL void nk_angulars_packed_e5m2_graniteamx_finalize_(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
164
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
165
|
+
nk_size_t a_stride_elements,
|
|
166
|
+
nk_size_t c_stride_elements) {
|
|
167
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
168
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
169
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
170
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_e5m2_(a + row * a_stride_elements, depth);
|
|
171
|
+
nk_angulars_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
NK_PUBLIC void nk_angulars_packed_e5m2_graniteamx( //
|
|
176
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
177
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
178
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
179
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes;
|
|
180
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
181
|
+
nk_dots_packed_e5m2_graniteamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
182
|
+
nk_angulars_packed_e5m2_graniteamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
183
|
+
c_stride_elements);
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
NK_INTERNAL void nk_euclideans_packed_e5m2_graniteamx_finalize_(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c,
|
|
187
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
188
|
+
nk_size_t a_stride_elements,
|
|
189
|
+
nk_size_t c_stride_elements) {
|
|
190
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
191
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
192
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
193
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_e5m2_(a + row * a_stride_elements, depth);
|
|
194
|
+
nk_euclideans_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
195
|
+
}
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
NK_PUBLIC void nk_euclideans_packed_e5m2_graniteamx( //
|
|
199
|
+
nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
200
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
201
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
202
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes;
|
|
203
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
204
|
+
nk_dots_packed_e5m2_graniteamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
205
|
+
nk_euclideans_packed_e5m2_graniteamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
206
|
+
c_stride_elements);
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
#pragma endregion E5M2 Packed
|
|
210
|
+
|
|
211
|
+
#pragma region E5M2 Symmetric
|
|
212
|
+
|
|
213
|
+
NK_INTERNAL void nk_angulars_symmetric_e5m2_graniteamx_finalize_(nk_e5m2_t const *vectors, nk_size_t vectors_count,
|
|
214
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
215
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
216
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
217
|
+
|
|
218
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
219
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_e5m2_(vectors + row * stride_elements, depth);
|
|
220
|
+
|
|
221
|
+
nk_f32_t column_norms_cache[256];
|
|
222
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
223
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
224
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
225
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
|
|
226
|
+
|
|
227
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
228
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
229
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
230
|
+
if (col_start >= chunk_end) continue;
|
|
231
|
+
nk_angulars_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
232
|
+
r_row[row], chunk_end - col_start);
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
NK_PUBLIC void nk_angulars_symmetric_e5m2_graniteamx( //
|
|
240
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
241
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
242
|
+
nk_size_t const stride_elements = stride_in_bytes;
|
|
243
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
244
|
+
nk_dots_symmetric_e5m2_graniteamx(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
245
|
+
row_start, row_count);
|
|
246
|
+
nk_angulars_symmetric_e5m2_graniteamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
247
|
+
result_stride_elements, row_start, row_count);
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
NK_INTERNAL void nk_euclideans_symmetric_e5m2_graniteamx_finalize_(nk_e5m2_t const *vectors, nk_size_t vectors_count,
|
|
251
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
252
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
253
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
254
|
+
|
|
255
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
256
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_e5m2_(vectors + row * stride_elements, depth);
|
|
257
|
+
|
|
258
|
+
nk_f32_t column_norms_cache[256];
|
|
259
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
260
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
261
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
262
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_e5m2_(vectors + col * stride_elements, depth);
|
|
263
|
+
|
|
264
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
265
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
266
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
267
|
+
if (col_start >= chunk_end) continue;
|
|
268
|
+
nk_euclideans_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
269
|
+
r_row[row], chunk_end - col_start);
|
|
270
|
+
}
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
NK_PUBLIC void nk_euclideans_symmetric_e5m2_graniteamx( //
|
|
277
|
+
nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
278
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
279
|
+
nk_size_t const stride_elements = stride_in_bytes;
|
|
280
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
281
|
+
nk_dots_symmetric_e5m2_graniteamx(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
282
|
+
row_start, row_count);
|
|
283
|
+
nk_euclideans_symmetric_e5m2_graniteamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
284
|
+
result_stride_elements, row_start, row_count);
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
#pragma endregion E5M2 Symmetric
|
|
288
|
+
|
|
161
289
|
#if defined(__clang__)
|
|
162
290
|
#pragma clang attribute pop
|
|
163
291
|
#elif defined(__GNUC__)
|
|
@@ -15,7 +15,18 @@
|
|
|
15
15
|
extern "C" {
|
|
16
16
|
#endif
|
|
17
17
|
|
|
18
|
-
/*
|
|
18
|
+
/* Keep the serial instantiations below actually scalar, regardless of build type.
|
|
19
|
+
* Without this, -O3 + LTO can vectorize or clone the serial kernels under AVX-512
|
|
20
|
+
* callers in dispatch_*.c, which wastes binary and breaks the nk_*_serial-as-scalar-oracle
|
|
21
|
+
* contract that tests and numerical-stability docs rely on. See dots/serial.h. */
|
|
22
|
+
#if defined(__clang__)
|
|
23
|
+
#pragma clang attribute push(__attribute__((noinline)), apply_to = function)
|
|
24
|
+
#elif defined(__GNUC__)
|
|
25
|
+
#pragma GCC push_options
|
|
26
|
+
#pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
|
|
27
|
+
#endif
|
|
28
|
+
|
|
29
|
+
/* Size bias for release. Gated on NDEBUG so Debug builds keep -O0 for stepping. */
|
|
19
30
|
#if defined(NDEBUG)
|
|
20
31
|
#if defined(_MSC_VER)
|
|
21
32
|
#pragma optimize("s", on)
|
|
@@ -241,6 +252,12 @@ nk_define_cross_normalized_symmetric_(euclidean, u4, serial, u4x2, u32, /*norm_v
|
|
|
241
252
|
#endif
|
|
242
253
|
#endif
|
|
243
254
|
|
|
255
|
+
#if defined(__clang__)
|
|
256
|
+
#pragma clang attribute pop
|
|
257
|
+
#elif defined(__GNUC__)
|
|
258
|
+
#pragma GCC pop_options
|
|
259
|
+
#endif
|
|
260
|
+
|
|
244
261
|
#if defined(__cplusplus)
|
|
245
262
|
} // extern "C"
|
|
246
263
|
#endif
|
|
@@ -97,11 +97,11 @@ nk_define_cross_normalized_symmetric_(euclidean, bf16, skylake, bf16, f32, /*nor
|
|
|
97
97
|
nk_dots_reduce_sumsq_bf16_, nk_load_b128_haswell_, nk_partial_load_b32x4_skylake_,
|
|
98
98
|
nk_store_b128_haswell_, nk_partial_store_b32x4_skylake_, 1)
|
|
99
99
|
|
|
100
|
-
nk_define_cross_normalized_packed_(angular, e4m3, skylake, e4m3,
|
|
100
|
+
nk_define_cross_normalized_packed_(angular, e4m3, skylake, e4m3, f16, f32, /*norm_value_type=*/f32, f32, nk_b128_vec_t,
|
|
101
101
|
nk_dots_packed_e4m3_skylake, nk_angular_through_f32_from_dot_haswell_,
|
|
102
102
|
nk_dots_reduce_sumsq_e4m3_, nk_load_b128_haswell_, nk_partial_load_b32x4_skylake_,
|
|
103
103
|
nk_store_b128_haswell_, nk_partial_store_b32x4_skylake_, 1)
|
|
104
|
-
nk_define_cross_normalized_packed_(euclidean, e4m3, skylake, e4m3,
|
|
104
|
+
nk_define_cross_normalized_packed_(euclidean, e4m3, skylake, e4m3, f16, f32, /*norm_value_type=*/f32, f32,
|
|
105
105
|
nk_b128_vec_t, nk_dots_packed_e4m3_skylake,
|
|
106
106
|
nk_euclidean_through_f32_from_dot_haswell_, nk_dots_reduce_sumsq_e4m3_,
|
|
107
107
|
nk_load_b128_haswell_, nk_partial_load_b32x4_skylake_, nk_store_b128_haswell_,
|
|
@@ -759,6 +759,23 @@ NK_PUBLIC void nk_euclideans_packed_f16_graniteamx(nk_f16_t const *a, void const
|
|
|
759
759
|
NK_PUBLIC void nk_euclideans_symmetric_f16_graniteamx(nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
760
760
|
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
761
761
|
nk_size_t row_start, nk_size_t row_count);
|
|
762
|
+
/** @copydoc nk_angulars_packed_f16 */
|
|
763
|
+
NK_PUBLIC void nk_angulars_packed_e5m2_graniteamx(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
|
|
764
|
+
nk_size_t rows, nk_size_t cols, nk_size_t depth,
|
|
765
|
+
nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
|
|
766
|
+
/** @copydoc nk_angulars_symmetric_f16 */
|
|
767
|
+
NK_PUBLIC void nk_angulars_symmetric_e5m2_graniteamx(nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
768
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
769
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
770
|
+
/** @copydoc nk_euclideans_packed_f16 */
|
|
771
|
+
NK_PUBLIC void nk_euclideans_packed_e5m2_graniteamx(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *result,
|
|
772
|
+
nk_size_t rows, nk_size_t cols, nk_size_t depth,
|
|
773
|
+
nk_size_t a_stride_in_bytes, nk_size_t r_stride_in_bytes);
|
|
774
|
+
/** @copydoc nk_euclideans_symmetric_f16 */
|
|
775
|
+
NK_PUBLIC void nk_euclideans_symmetric_e5m2_graniteamx(nk_e5m2_t const *vectors, nk_size_t vectors_count,
|
|
776
|
+
nk_size_t depth, nk_size_t stride, nk_f32_t *result,
|
|
777
|
+
nk_size_t result_stride, nk_size_t row_start,
|
|
778
|
+
nk_size_t row_count);
|
|
762
779
|
#endif // NK_TARGET_GRANITEAMX
|
|
763
780
|
|
|
764
781
|
/* ARM SME backends using Scalable Matrix Extension.
|