numkong 7.4.4 → 7.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -0
- package/binding.gyp +81 -5
- package/c/dispatch_f16.c +23 -0
- package/c/numkong.c +0 -13
- package/include/numkong/attention/sme.h +34 -31
- package/include/numkong/capabilities.h +2 -15
- package/include/numkong/cast/neon.h +15 -0
- package/include/numkong/curved/smef64.h +82 -62
- package/include/numkong/dot/rvvbf16.h +1 -1
- package/include/numkong/dot/rvvhalf.h +1 -1
- package/include/numkong/dot/sve.h +6 -5
- package/include/numkong/dot/svebfdot.h +2 -1
- package/include/numkong/dot/svehalf.h +6 -5
- package/include/numkong/dot/svesdot.h +3 -2
- package/include/numkong/dots/graniteamx.h +733 -0
- package/include/numkong/dots/serial.h +11 -4
- package/include/numkong/dots/sme.h +172 -140
- package/include/numkong/dots/smebi32.h +14 -11
- package/include/numkong/dots/smef64.h +31 -26
- package/include/numkong/dots.h +29 -3
- package/include/numkong/each/serial.h +22 -0
- package/include/numkong/geospatial/haswell.h +1 -1
- package/include/numkong/geospatial/neon.h +1 -1
- package/include/numkong/geospatial/serial.h +1 -1
- package/include/numkong/geospatial/skylake.h +1 -1
- package/include/numkong/maxsim/sme.h +94 -55
- package/include/numkong/mesh/README.md +13 -27
- package/include/numkong/mesh/haswell.h +25 -122
- package/include/numkong/mesh/neon.h +21 -110
- package/include/numkong/mesh/neonbfdot.h +4 -43
- package/include/numkong/mesh/rvv.h +7 -82
- package/include/numkong/mesh/serial.h +48 -53
- package/include/numkong/mesh/skylake.h +7 -123
- package/include/numkong/mesh/v128relaxed.h +9 -93
- package/include/numkong/mesh.h +2 -2
- package/include/numkong/mesh.hpp +35 -96
- package/include/numkong/reduce/neon.h +29 -0
- package/include/numkong/reduce/neonbfdot.h +2 -2
- package/include/numkong/reduce/neonfhm.h +4 -4
- package/include/numkong/reduce/sve.h +52 -0
- package/include/numkong/reduce.h +4 -0
- package/include/numkong/set/sve.h +6 -5
- package/include/numkong/sets/smebi32.h +35 -30
- package/include/numkong/sparse/sve2.h +3 -2
- package/include/numkong/spatial/sve.h +7 -6
- package/include/numkong/spatial/svebfdot.h +7 -4
- package/include/numkong/spatial/svehalf.h +5 -4
- package/include/numkong/spatial/svesdot.h +9 -8
- package/include/numkong/spatials/graniteamx.h +173 -0
- package/include/numkong/spatials/serial.h +22 -0
- package/include/numkong/spatials/sme.h +391 -350
- package/include/numkong/spatials/smef64.h +79 -70
- package/include/numkong/spatials.h +37 -4
- package/include/numkong/types.h +59 -0
- package/javascript/dist/cjs/numkong.js +13 -0
- package/javascript/dist/esm/numkong.js +13 -0
- package/javascript/numkong.c +56 -12
- package/javascript/numkong.ts +13 -0
- package/package.json +7 -7
- package/probes/probe.js +2 -2
- package/wasm/numkong.wasm +0 -0
|
@@ -41,8 +41,9 @@
|
|
|
41
41
|
#include "numkong/types.h"
|
|
42
42
|
#include "numkong/set/serial.h"
|
|
43
43
|
#include "numkong/sets/serial.h"
|
|
44
|
-
#include "numkong/
|
|
45
|
-
#include "numkong/reduce.h"
|
|
44
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
45
|
+
#include "numkong/reduce/neon.h" // `nk_reduce_moments_u1_neon`
|
|
46
|
+
#include "numkong/dots/sme.h" // `nk_sme_zero_za32_*`
|
|
46
47
|
|
|
47
48
|
#if defined(__cplusplus)
|
|
48
49
|
extern "C" {
|
|
@@ -100,7 +101,7 @@ NK_PUBLIC nk_u32_t nk_sets_reduce_sumsq_u1_streaming_(nk_u1x8_t const *data, nk_
|
|
|
100
101
|
svbool_t predicate_b8x = svwhilelt_b8_u64(offset, n_bytes);
|
|
101
102
|
acc_u32x = svdot_u32(acc_u32x, svcnt_u8_z(predicate_b8x, svld1_u8(predicate_b8x, data + offset)), ones_u8x);
|
|
102
103
|
}
|
|
103
|
-
return (nk_u32_t)
|
|
104
|
+
return (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), acc_u32x);
|
|
104
105
|
}
|
|
105
106
|
|
|
106
107
|
#pragma region Hamming Distance
|
|
@@ -187,11 +188,9 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
|
|
|
187
188
|
// Compute per-row population counts
|
|
188
189
|
for (nk_size_t row = 0; row < row_count; row++) {
|
|
189
190
|
nk_u1x8_t const *src_row = (nk_u1x8_t const *)((char const *)b + row * b_stride_in_bytes);
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
norms_ptr[row] = (nk_u32_t)nk_local_sum_;
|
|
194
|
-
}
|
|
191
|
+
nk_u64_t nk_local_sum_, nk_local_sumsq_;
|
|
192
|
+
nk_reduce_moments_u1_neon(src_row, depth_bytes * 8, sizeof(nk_u1x8_t), &nk_local_sum_, &nk_local_sumsq_);
|
|
193
|
+
norms_ptr[row] = (nk_u32_t)nk_local_sum_;
|
|
195
194
|
}
|
|
196
195
|
}
|
|
197
196
|
|
|
@@ -203,9 +202,9 @@ NK_PUBLIC void nk_dots_pack_u1_smebi32(nk_u1x8_t const *b, nk_size_t row_count,
|
|
|
203
202
|
* Each ZA0.S batch covers 16 depth u32 steps (one full depth tile).
|
|
204
203
|
* BMOPA expansion=1 for u32: each u32 contributes 32 bits via XNOR+POPCNT.
|
|
205
204
|
*/
|
|
206
|
-
|
|
205
|
+
__arm_new("za") static void nk_hammings_packed_u1_smebi32_streaming_( //
|
|
207
206
|
nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
208
|
-
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
207
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
|
|
209
208
|
|
|
210
209
|
nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
|
|
211
210
|
nk_size_t const row_tile_count_b = header->row_tile_count;
|
|
@@ -344,11 +343,13 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_packed_u1_smebi3
|
|
|
344
343
|
}
|
|
345
344
|
}
|
|
346
345
|
|
|
347
|
-
NK_PUBLIC void nk_hammings_packed_u1_smebi32(
|
|
348
|
-
|
|
349
|
-
|
|
346
|
+
NK_PUBLIC void nk_hammings_packed_u1_smebi32( //
|
|
347
|
+
nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
348
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
349
|
+
nk_sme_start_streaming_();
|
|
350
350
|
nk_hammings_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
|
|
351
351
|
c_stride_in_bytes);
|
|
352
|
+
nk_sme_stop_streaming_();
|
|
352
353
|
}
|
|
353
354
|
|
|
354
355
|
/**
|
|
@@ -357,9 +358,9 @@ NK_PUBLIC void nk_hammings_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
|
|
|
357
358
|
* ZA1-3.S = BMOPA accumulators (3 B column tiles in fast path).
|
|
358
359
|
* Mirrors the unpacked kernel nk_hammings_packed_u1_smebi32_streaming_ pattern.
|
|
359
360
|
*/
|
|
360
|
-
|
|
361
|
+
__arm_new("za") static void nk_hammings_symmetric_u1_smebi32_streaming_( //
|
|
361
362
|
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
362
|
-
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
363
|
+
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
363
364
|
|
|
364
365
|
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
365
366
|
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
@@ -545,12 +546,13 @@ __arm_locally_streaming __arm_new("za") static void nk_hammings_symmetric_u1_sme
|
|
|
545
546
|
}
|
|
546
547
|
}
|
|
547
548
|
|
|
548
|
-
NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
549
|
+
NK_PUBLIC void nk_hammings_symmetric_u1_smebi32( //
|
|
550
|
+
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
551
|
+
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
552
|
+
nk_sme_start_streaming_();
|
|
552
553
|
nk_hammings_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
|
|
553
554
|
result_stride_in_bytes, row_start, row_count);
|
|
555
|
+
nk_sme_stop_streaming_();
|
|
554
556
|
}
|
|
555
557
|
|
|
556
558
|
#pragma endregion Hamming Distance
|
|
@@ -581,9 +583,9 @@ NK_PUBLIC void nk_hammings_symmetric_u1_smebi32(nk_u1x8_t const *vectors, nk_siz
|
|
|
581
583
|
* union = (norm_a + norm_b + hamming) / 2
|
|
582
584
|
* jaccard = 1 - intersection / union (1.0 when union == 0)
|
|
583
585
|
*/
|
|
584
|
-
|
|
586
|
+
__arm_new("za") static void nk_jaccards_packed_u1_smebi32_streaming_( //
|
|
585
587
|
nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
586
|
-
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
588
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
|
|
587
589
|
|
|
588
590
|
nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
|
|
589
591
|
nk_size_t const row_tile_count_b = header->row_tile_count;
|
|
@@ -796,11 +798,13 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_packed_u1_smebi3
|
|
|
796
798
|
}
|
|
797
799
|
}
|
|
798
800
|
|
|
799
|
-
NK_PUBLIC void nk_jaccards_packed_u1_smebi32(
|
|
800
|
-
|
|
801
|
-
|
|
801
|
+
NK_PUBLIC void nk_jaccards_packed_u1_smebi32( //
|
|
802
|
+
nk_u1x8_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
803
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
804
|
+
nk_sme_start_streaming_();
|
|
802
805
|
nk_jaccards_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
|
|
803
806
|
c_stride_in_bytes);
|
|
807
|
+
nk_sme_stop_streaming_();
|
|
804
808
|
}
|
|
805
809
|
|
|
806
810
|
/**
|
|
@@ -808,9 +812,9 @@ NK_PUBLIC void nk_jaccards_packed_u1_smebi32(nk_u1x8_t const *a, void const *b_p
|
|
|
808
812
|
* Fills upper triangle only (column_tile >= row_tile); caller sees result[i][j] for j >= i.
|
|
809
813
|
* Norms computed on-the-fly using streaming SVE popcount.
|
|
810
814
|
*/
|
|
811
|
-
|
|
815
|
+
__arm_new("za") static void nk_jaccards_symmetric_u1_smebi32_streaming_( //
|
|
812
816
|
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
813
|
-
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
817
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
814
818
|
|
|
815
819
|
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
816
820
|
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
@@ -1104,12 +1108,13 @@ __arm_locally_streaming __arm_new("za") static void nk_jaccards_symmetric_u1_sme
|
|
|
1104
1108
|
}
|
|
1105
1109
|
}
|
|
1106
1110
|
|
|
1107
|
-
NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32(
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1111
|
+
NK_PUBLIC void nk_jaccards_symmetric_u1_smebi32( //
|
|
1112
|
+
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
1113
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
1114
|
+
nk_sme_start_streaming_();
|
|
1111
1115
|
nk_jaccards_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
|
|
1112
1116
|
result_stride_in_bytes, row_start, row_count);
|
|
1117
|
+
nk_sme_stop_streaming_();
|
|
1113
1118
|
}
|
|
1114
1119
|
|
|
1115
1120
|
#pragma endregion Jaccard Distance
|
|
@@ -12,6 +12,7 @@
|
|
|
12
12
|
#if NK_TARGET_ARM64_
|
|
13
13
|
|
|
14
14
|
#include "numkong/types.h"
|
|
15
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
15
16
|
|
|
16
17
|
#if defined(__cplusplus)
|
|
17
18
|
extern "C" {
|
|
@@ -395,7 +396,7 @@ NK_PUBLIC void nk_sparse_dot_u32f32_sve2( //
|
|
|
395
396
|
a_idx += a_step;
|
|
396
397
|
b_idx += b_step;
|
|
397
398
|
}
|
|
398
|
-
*product =
|
|
399
|
+
*product = nk_svaddv_f64_(predicate_all_b64x, product_f64x);
|
|
399
400
|
}
|
|
400
401
|
|
|
401
402
|
#if defined(__clang__)
|
|
@@ -485,7 +486,7 @@ NK_PUBLIC void nk_sparse_dot_u16bf16_sve2( //
|
|
|
485
486
|
a_idx += a_step;
|
|
486
487
|
b_idx += b_step;
|
|
487
488
|
}
|
|
488
|
-
*product =
|
|
489
|
+
*product = nk_svaddv_f32_(svptrue_b32(), product_f32x);
|
|
489
490
|
}
|
|
490
491
|
|
|
491
492
|
#if defined(__clang__)
|
|
@@ -36,6 +36,7 @@
|
|
|
36
36
|
#if NK_TARGET_SVE
|
|
37
37
|
|
|
38
38
|
#include "numkong/types.h"
|
|
39
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
39
40
|
#include "numkong/spatial/neon.h" // `nk_f64_sqrt_neon`
|
|
40
41
|
#include "numkong/dot/sve.h" // `nk_dot_stable_sum_f64_sve_`
|
|
41
42
|
|
|
@@ -113,7 +114,7 @@ NK_PUBLIC void nk_sqeuclidean_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_s
|
|
|
113
114
|
svfloat64_t diff_odd_f64x = svsub_f64_x(pred_odd_b64x, a_odd_f64x, b_odd_f64x);
|
|
114
115
|
dist_sq_f64x = svmla_f64_m(pred_odd_b64x, dist_sq_f64x, diff_odd_f64x, diff_odd_f64x);
|
|
115
116
|
}
|
|
116
|
-
nk_f64_t dist_sq_f64 =
|
|
117
|
+
nk_f64_t dist_sq_f64 = nk_svaddv_f64_(svptrue_b64(), dist_sq_f64x);
|
|
117
118
|
*result = dist_sq_f64;
|
|
118
119
|
}
|
|
119
120
|
|
|
@@ -149,9 +150,9 @@ NK_PUBLIC void nk_angular_f32_sve(nk_f32_t const *a, nk_f32_t const *b, nk_size_
|
|
|
149
150
|
b2_f64x = svmla_f64_m(pred_odd_b64x, b2_f64x, b_odd_f64x, b_odd_f64x);
|
|
150
151
|
}
|
|
151
152
|
|
|
152
|
-
nk_f64_t ab_f64 =
|
|
153
|
-
nk_f64_t a2_f64 =
|
|
154
|
-
nk_f64_t b2_f64 =
|
|
153
|
+
nk_f64_t ab_f64 = nk_svaddv_f64_(svptrue_b64(), ab_f64x);
|
|
154
|
+
nk_f64_t a2_f64 = nk_svaddv_f64_(svptrue_b64(), a2_f64x);
|
|
155
|
+
nk_f64_t b2_f64 = nk_svaddv_f64_(svptrue_b64(), b2_f64x);
|
|
155
156
|
*result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
|
|
156
157
|
}
|
|
157
158
|
|
|
@@ -225,8 +226,8 @@ NK_PUBLIC void nk_angular_f64_sve(nk_f64_t const *a, nk_f64_t const *b, nk_size_
|
|
|
225
226
|
} while (i < n);
|
|
226
227
|
|
|
227
228
|
nk_f64_t ab_f64 = nk_dot_stable_sum_f64_sve_(predicate_all_b64x, ab_sum_f64x, ab_compensation_f64x);
|
|
228
|
-
nk_f64_t a2_f64 =
|
|
229
|
-
nk_f64_t b2_f64 =
|
|
229
|
+
nk_f64_t a2_f64 = nk_svaddv_f64_(predicate_all_b64x, a2_f64x);
|
|
230
|
+
nk_f64_t b2_f64 = nk_svaddv_f64_(predicate_all_b64x, b2_f64x);
|
|
230
231
|
*result = nk_angular_normalize_f64_neon_(ab_f64, a2_f64, b2_f64);
|
|
231
232
|
}
|
|
232
233
|
|
|
@@ -36,6 +36,7 @@
|
|
|
36
36
|
#if NK_TARGET_SVEBFDOT
|
|
37
37
|
|
|
38
38
|
#include "numkong/types.h"
|
|
39
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
39
40
|
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
|
|
40
41
|
|
|
41
42
|
#if defined(__cplusplus)
|
|
@@ -75,7 +76,9 @@ NK_PUBLIC void nk_sqeuclidean_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t c
|
|
|
75
76
|
d2_high_f32x = svmla_f32_m(predicate_high_b32x, d2_high_f32x, a_minus_b_high_f32x, a_minus_b_high_f32x);
|
|
76
77
|
i += svcnth();
|
|
77
78
|
} while (i < n);
|
|
78
|
-
nk_f32_t
|
|
79
|
+
nk_f32_t d2_low = nk_svaddv_f32_(svptrue_b32(), d2_low_f32x);
|
|
80
|
+
nk_f32_t d2_high = nk_svaddv_f32_(svptrue_b32(), d2_high_f32x);
|
|
81
|
+
nk_f32_t d2 = d2_low + d2_high;
|
|
79
82
|
*result = d2;
|
|
80
83
|
}
|
|
81
84
|
NK_PUBLIC void nk_euclidean_bf16_svebfdot(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -101,9 +104,9 @@ NK_PUBLIC void nk_angular_bf16_svebfdot(nk_bf16_t const *a_enum, nk_bf16_t const
|
|
|
101
104
|
i += svcnth();
|
|
102
105
|
} while (i < n);
|
|
103
106
|
|
|
104
|
-
nk_f32_t ab =
|
|
105
|
-
nk_f32_t a2 =
|
|
106
|
-
nk_f32_t b2 =
|
|
107
|
+
nk_f32_t ab = nk_svaddv_f32_(svptrue_b32(), ab_f32x);
|
|
108
|
+
nk_f32_t a2 = nk_svaddv_f32_(svptrue_b32(), a2_f32x);
|
|
109
|
+
nk_f32_t b2 = nk_svaddv_f32_(svptrue_b32(), b2_f32x);
|
|
107
110
|
*result = nk_angular_normalize_f32_neon_(ab, a2, b2);
|
|
108
111
|
}
|
|
109
112
|
|
|
@@ -32,6 +32,7 @@
|
|
|
32
32
|
#if NK_TARGET_SVEHALF
|
|
33
33
|
|
|
34
34
|
#include "numkong/types.h"
|
|
35
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
35
36
|
#include "numkong/spatial/neon.h" // `nk_f32_sqrt_neon`
|
|
36
37
|
|
|
37
38
|
#if defined(__cplusplus)
|
|
@@ -74,7 +75,7 @@ NK_PUBLIC void nk_sqeuclidean_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const
|
|
|
74
75
|
|
|
75
76
|
i += svcnth();
|
|
76
77
|
} while (i < n);
|
|
77
|
-
*result =
|
|
78
|
+
*result = nk_svaddv_f32_(svptrue_b32(), d2_f32x);
|
|
78
79
|
}
|
|
79
80
|
|
|
80
81
|
NK_PUBLIC void nk_euclidean_f16_svehalf(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -114,9 +115,9 @@ NK_PUBLIC void nk_angular_f16_svehalf(nk_f16_t const *a_enum, nk_f16_t const *b_
|
|
|
114
115
|
i += svcnth();
|
|
115
116
|
} while (i < n);
|
|
116
117
|
|
|
117
|
-
nk_f32_t ab_f32 =
|
|
118
|
-
nk_f32_t a2_f32 =
|
|
119
|
-
nk_f32_t b2_f32 =
|
|
118
|
+
nk_f32_t ab_f32 = nk_svaddv_f32_(svptrue_b32(), ab_f32x);
|
|
119
|
+
nk_f32_t a2_f32 = nk_svaddv_f32_(svptrue_b32(), a2_f32x);
|
|
120
|
+
nk_f32_t b2_f32 = nk_svaddv_f32_(svptrue_b32(), b2_f32x);
|
|
120
121
|
*result = nk_angular_normalize_f32_neon_(ab_f32, a2_f32, b2_f32);
|
|
121
122
|
}
|
|
122
123
|
|
|
@@ -34,6 +34,7 @@
|
|
|
34
34
|
#if NK_TARGET_SVESDOT
|
|
35
35
|
|
|
36
36
|
#include "numkong/types.h"
|
|
37
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
37
38
|
#include "numkong/spatial/neon.h" // `nk_angular_normalize_f32_neon_`, `nk_f32_sqrt_neon`
|
|
38
39
|
|
|
39
40
|
#if defined(__cplusplus)
|
|
@@ -58,7 +59,7 @@ NK_PUBLIC void nk_sqeuclidean_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_
|
|
|
58
59
|
distance_sq_u32x = svdot_u32(distance_sq_u32x, diff_u8x, diff_u8x);
|
|
59
60
|
i += svcntb();
|
|
60
61
|
} while (i < n);
|
|
61
|
-
*result = (nk_u32_t)
|
|
62
|
+
*result = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), distance_sq_u32x);
|
|
62
63
|
}
|
|
63
64
|
NK_PUBLIC void nk_euclidean_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
64
65
|
nk_u32_t distance_sq_u32;
|
|
@@ -81,9 +82,9 @@ NK_PUBLIC void nk_angular_i8_svesdot(nk_i8_t const *a, nk_i8_t const *b, nk_size
|
|
|
81
82
|
i += svcntb();
|
|
82
83
|
} while (i < n);
|
|
83
84
|
|
|
84
|
-
nk_i32_t ab = (nk_i32_t)
|
|
85
|
-
nk_i32_t a2 = (nk_i32_t)
|
|
86
|
-
nk_i32_t b2 = (nk_i32_t)
|
|
85
|
+
nk_i32_t ab = (nk_i32_t)nk_svaddv_s32_(svptrue_b32(), ab_i32x);
|
|
86
|
+
nk_i32_t a2 = (nk_i32_t)nk_svaddv_s32_(svptrue_b32(), a2_i32x);
|
|
87
|
+
nk_i32_t b2 = (nk_i32_t)nk_svaddv_s32_(svptrue_b32(), b2_i32x);
|
|
87
88
|
*result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
|
|
88
89
|
}
|
|
89
90
|
|
|
@@ -98,7 +99,7 @@ NK_PUBLIC void nk_sqeuclidean_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_
|
|
|
98
99
|
distance_sq_u32x = svdot_u32(distance_sq_u32x, diff_u8x, diff_u8x);
|
|
99
100
|
i += svcntb();
|
|
100
101
|
} while (i < n);
|
|
101
|
-
*result = (nk_u32_t)
|
|
102
|
+
*result = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), distance_sq_u32x);
|
|
102
103
|
}
|
|
103
104
|
NK_PUBLIC void nk_euclidean_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
104
105
|
nk_u32_t distance_sq_u32;
|
|
@@ -121,9 +122,9 @@ NK_PUBLIC void nk_angular_u8_svesdot(nk_u8_t const *a, nk_u8_t const *b, nk_size
|
|
|
121
122
|
i += svcntb();
|
|
122
123
|
} while (i < n);
|
|
123
124
|
|
|
124
|
-
nk_u32_t ab = (nk_u32_t)
|
|
125
|
-
nk_u32_t a2 = (nk_u32_t)
|
|
126
|
-
nk_u32_t b2 = (nk_u32_t)
|
|
125
|
+
nk_u32_t ab = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), ab_u32x);
|
|
126
|
+
nk_u32_t a2 = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), a2_u32x);
|
|
127
|
+
nk_u32_t b2 = (nk_u32_t)nk_svaddv_u32_(svptrue_b32(), b2_u32x);
|
|
127
128
|
*result = nk_angular_normalize_f32_neon_((nk_f32_t)ab, (nk_f32_t)a2, (nk_f32_t)b2);
|
|
128
129
|
}
|
|
129
130
|
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Batched Spatial Distances for Granite Rapids (AMX-FP16) with AVX-512 Finalization.
|
|
3
|
+
* @file include/numkong/spatials/graniteamx.h
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date April 9, 2026
|
|
6
|
+
*
|
|
7
|
+
* @sa include/numkong/spatials.h
|
|
8
|
+
*/
|
|
9
|
+
#ifndef NK_SPATIALS_GRANITEAMX_H
|
|
10
|
+
#define NK_SPATIALS_GRANITEAMX_H
|
|
11
|
+
|
|
12
|
+
#if NK_TARGET_X8664_
|
|
13
|
+
#if NK_TARGET_GRANITEAMX
|
|
14
|
+
|
|
15
|
+
#include "numkong/spatial/skylake.h"
|
|
16
|
+
#include "numkong/spatial/serial.h"
|
|
17
|
+
#include "numkong/dots/graniteamx.h"
|
|
18
|
+
|
|
19
|
+
#if defined(__cplusplus)
|
|
20
|
+
extern "C" {
|
|
21
|
+
#endif
|
|
22
|
+
|
|
23
|
+
#if defined(__clang__)
|
|
24
|
+
#pragma clang attribute push( \
|
|
25
|
+
__attribute__((target( \
|
|
26
|
+
"avx2,avx512f,avx512vl,avx512bw,avx512dq,avx512fp16,avx512vbmi,f16c,fma,bmi,bmi2,amx-tile,amx-bf16,amx-int8,amx-fp16"))), \
|
|
27
|
+
apply_to = function)
|
|
28
|
+
#elif defined(__GNUC__)
|
|
29
|
+
#pragma GCC push_options
|
|
30
|
+
#pragma GCC target("avx2", "avx512f", "avx512vl", "avx512bw", "avx512dq", "avx512fp16", "avx512vbmi", "f16c", "fma", \
|
|
31
|
+
"bmi", "bmi2", "amx-tile", "amx-bf16", "amx-int8", "amx-fp16")
|
|
32
|
+
#endif
|
|
33
|
+
|
|
34
|
+
#pragma region F16 Packed
|
|
35
|
+
|
|
36
|
+
NK_INTERNAL void nk_angulars_packed_f16_graniteamx_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
37
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
38
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
39
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
40
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
41
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
42
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_f16_(a + row * a_stride_elements, depth);
|
|
43
|
+
nk_angulars_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
NK_PUBLIC void nk_angulars_packed_f16_graniteamx( //
|
|
48
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
49
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
50
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
51
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
|
|
52
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
53
|
+
nk_dots_packed_f16_graniteamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
54
|
+
nk_angulars_packed_f16_graniteamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
55
|
+
c_stride_elements);
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
NK_INTERNAL void nk_euclideans_packed_f16_graniteamx_finalize_(nk_f16_t const *a, void const *b_packed, nk_f32_t *c,
|
|
59
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
60
|
+
nk_size_t a_stride_elements,
|
|
61
|
+
nk_size_t c_stride_elements) {
|
|
62
|
+
nk_dots_amx_packed_header_t const *header = (nk_dots_amx_packed_header_t const *)b_packed;
|
|
63
|
+
nk_f32_t const *b_norms = (nk_f32_t const *)((char const *)b_packed + header->norms_byte_offset);
|
|
64
|
+
for (nk_size_t row = 0; row < rows; row++) {
|
|
65
|
+
nk_f32_t query_norm_sq = nk_dots_reduce_sumsq_f16_(a + row * a_stride_elements, depth);
|
|
66
|
+
nk_euclideans_row_f32dots_sapphireamx_(c + row * c_stride_elements, b_norms, query_norm_sq, columns);
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
NK_PUBLIC void nk_euclideans_packed_f16_graniteamx( //
|
|
71
|
+
nk_f16_t const *a, void const *b_packed, nk_f32_t *c, //
|
|
72
|
+
nk_size_t rows, nk_size_t columns, nk_size_t depth, //
|
|
73
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
74
|
+
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f16_t);
|
|
75
|
+
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f32_t);
|
|
76
|
+
nk_dots_packed_f16_graniteamx(a, b_packed, c, rows, columns, depth, a_stride_in_bytes, c_stride_in_bytes);
|
|
77
|
+
nk_euclideans_packed_f16_graniteamx_finalize_(a, b_packed, c, rows, columns, depth, a_stride_elements,
|
|
78
|
+
c_stride_elements);
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
#pragma endregion F16 Packed
|
|
82
|
+
|
|
83
|
+
#pragma region F16 Symmetric
|
|
84
|
+
|
|
85
|
+
NK_INTERNAL void nk_angulars_symmetric_f16_graniteamx_finalize_(nk_f16_t const *vectors, nk_size_t vectors_count,
|
|
86
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
87
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
88
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
89
|
+
|
|
90
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
91
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_f16_(vectors + row * stride_elements, depth);
|
|
92
|
+
|
|
93
|
+
nk_f32_t column_norms_cache[256];
|
|
94
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
95
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
96
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
97
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
|
|
98
|
+
|
|
99
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
100
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
101
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
102
|
+
if (col_start >= chunk_end) continue;
|
|
103
|
+
nk_angulars_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
104
|
+
r_row[row], chunk_end - col_start);
|
|
105
|
+
}
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
NK_PUBLIC void nk_angulars_symmetric_f16_graniteamx( //
|
|
112
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
113
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
114
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
|
|
115
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
116
|
+
nk_dots_symmetric_f16_graniteamx(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
117
|
+
row_start, row_count);
|
|
118
|
+
nk_angulars_symmetric_f16_graniteamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
119
|
+
result_stride_elements, row_start, row_count);
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
NK_INTERNAL void nk_euclideans_symmetric_f16_graniteamx_finalize_(nk_f16_t const *vectors, nk_size_t vectors_count,
|
|
123
|
+
nk_size_t depth, nk_size_t stride_elements,
|
|
124
|
+
nk_f32_t *result, nk_size_t result_stride_elements,
|
|
125
|
+
nk_size_t row_start, nk_size_t row_count) {
|
|
126
|
+
|
|
127
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++)
|
|
128
|
+
result[row * result_stride_elements + row] = nk_dots_reduce_sumsq_f16_(vectors + row * stride_elements, depth);
|
|
129
|
+
|
|
130
|
+
nk_f32_t column_norms_cache[256];
|
|
131
|
+
for (nk_size_t chunk_start = 0; chunk_start < vectors_count; chunk_start += 256) {
|
|
132
|
+
nk_size_t chunk_end = chunk_start + 256 < vectors_count ? chunk_start + 256 : vectors_count;
|
|
133
|
+
for (nk_size_t col = chunk_start; col < chunk_end; col++)
|
|
134
|
+
column_norms_cache[col - chunk_start] = nk_dots_reduce_sumsq_f16_(vectors + col * stride_elements, depth);
|
|
135
|
+
|
|
136
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) {
|
|
137
|
+
nk_f32_t *r_row = result + row * result_stride_elements;
|
|
138
|
+
nk_size_t col_start = chunk_start > row + 1 ? chunk_start : row + 1;
|
|
139
|
+
if (col_start >= chunk_end) continue;
|
|
140
|
+
nk_euclideans_row_f32dots_sapphireamx_(r_row + col_start, column_norms_cache + col_start - chunk_start,
|
|
141
|
+
r_row[row], chunk_end - col_start);
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
for (nk_size_t row = row_start; row < row_start + row_count; row++) result[row * result_stride_elements + row] = 0;
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
NK_PUBLIC void nk_euclideans_symmetric_f16_graniteamx( //
|
|
149
|
+
nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, //
|
|
150
|
+
nk_f32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
151
|
+
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f16_t);
|
|
152
|
+
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f32_t);
|
|
153
|
+
nk_dots_symmetric_f16_graniteamx(vectors, vectors_count, depth, stride_in_bytes, result, result_stride_in_bytes,
|
|
154
|
+
row_start, row_count);
|
|
155
|
+
nk_euclideans_symmetric_f16_graniteamx_finalize_(vectors, vectors_count, depth, stride_elements, result,
|
|
156
|
+
result_stride_elements, row_start, row_count);
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
#pragma endregion F16 Symmetric
|
|
160
|
+
|
|
161
|
+
#if defined(__clang__)
|
|
162
|
+
#pragma clang attribute pop
|
|
163
|
+
#elif defined(__GNUC__)
|
|
164
|
+
#pragma GCC pop_options
|
|
165
|
+
#endif
|
|
166
|
+
|
|
167
|
+
#if defined(__cplusplus)
|
|
168
|
+
} // extern "C"
|
|
169
|
+
#endif
|
|
170
|
+
|
|
171
|
+
#endif // NK_TARGET_GRANITEAMX
|
|
172
|
+
#endif // NK_TARGET_X8664_
|
|
173
|
+
#endif // NK_SPATIALS_GRANITEAMX_H
|
|
@@ -15,6 +15,18 @@
|
|
|
15
15
|
extern "C" {
|
|
16
16
|
#endif
|
|
17
17
|
|
|
18
|
+
/* Optimize serial fallbacks for size — see dots/serial.h for rationale. */
|
|
19
|
+
#if defined(NDEBUG)
|
|
20
|
+
#if defined(_MSC_VER)
|
|
21
|
+
#pragma optimize("s", on)
|
|
22
|
+
#elif defined(__clang__)
|
|
23
|
+
#pragma clang attribute push(__attribute__((minsize)), apply_to = function)
|
|
24
|
+
#elif defined(__GNUC__)
|
|
25
|
+
#pragma GCC push_options
|
|
26
|
+
#pragma GCC optimize("Os")
|
|
27
|
+
#endif
|
|
28
|
+
#endif
|
|
29
|
+
|
|
18
30
|
nk_define_cross_normalized_packed_(angular, f64, serial, f64, f64, f64, /*norm_value_type=*/f64, f64, nk_b256_vec_t,
|
|
19
31
|
nk_dots_packed_f64_serial, nk_angular_through_f64_from_dot_serial_,
|
|
20
32
|
nk_dots_reduce_sumsq_f64_, nk_load_b256_serial_, nk_partial_load_b64x4_serial_,
|
|
@@ -219,6 +231,16 @@ nk_define_cross_normalized_symmetric_(euclidean, u4, serial, u4x2, u32, /*norm_v
|
|
|
219
231
|
nk_dots_reduce_sumsq_u4_, nk_load_b128_serial_, nk_partial_load_b32x4_serial_,
|
|
220
232
|
nk_store_b128_serial_, nk_partial_store_b32x4_serial_, 2)
|
|
221
233
|
|
|
234
|
+
#if defined(NDEBUG)
|
|
235
|
+
#if defined(_MSC_VER)
|
|
236
|
+
#pragma optimize("", on)
|
|
237
|
+
#elif defined(__clang__)
|
|
238
|
+
#pragma clang attribute pop
|
|
239
|
+
#elif defined(__GNUC__)
|
|
240
|
+
#pragma GCC pop_options
|
|
241
|
+
#endif
|
|
242
|
+
#endif
|
|
243
|
+
|
|
222
244
|
#if defined(__cplusplus)
|
|
223
245
|
} // extern "C"
|
|
224
246
|
#endif
|