numkong 7.4.5 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -0
- package/binding.gyp +99 -5
- package/c/dispatch_e5m2.c +23 -3
- package/c/dispatch_f16.c +23 -0
- package/c/numkong.c +0 -13
- package/include/numkong/attention/sme.h +34 -31
- package/include/numkong/capabilities.h +2 -15
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/neon.h +15 -0
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/curved/smef64.h +82 -62
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/rvvbf16.h +1 -1
- package/include/numkong/dot/rvvhalf.h +1 -1
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dot/sve.h +6 -5
- package/include/numkong/dot/svebfdot.h +2 -1
- package/include/numkong/dot/svehalf.h +6 -5
- package/include/numkong/dot/svesdot.h +3 -2
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +1167 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +33 -11
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots/sme.h +172 -140
- package/include/numkong/dots/smebi32.h +14 -11
- package/include/numkong/dots/smef64.h +31 -26
- package/include/numkong/dots.h +41 -3
- package/include/numkong/each/serial.h +39 -0
- package/include/numkong/geospatial/haswell.h +1 -1
- package/include/numkong/geospatial/neon.h +1 -1
- package/include/numkong/geospatial/serial.h +15 -4
- package/include/numkong/geospatial/skylake.h +1 -1
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/maxsim/sme.h +34 -33
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +225 -161
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/neon.h +29 -0
- package/include/numkong/reduce/neonbfdot.h +2 -2
- package/include/numkong/reduce/neonfhm.h +4 -4
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/reduce/sve.h +52 -0
- package/include/numkong/reduce.h +4 -0
- package/include/numkong/set/sve.h +6 -5
- package/include/numkong/sets/smebi32.h +35 -30
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/sparse/sve2.h +3 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial/sve.h +7 -6
- package/include/numkong/spatial/svebfdot.h +7 -4
- package/include/numkong/spatial/svehalf.h +5 -4
- package/include/numkong/spatial/svesdot.h +9 -8
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +301 -0
- package/include/numkong/spatials/serial.h +39 -0
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials/sme.h +391 -350
- package/include/numkong/spatials/smef64.h +79 -70
- package/include/numkong/spatials.h +54 -4
- package/include/numkong/tensor.hpp +107 -23
- package/include/numkong/types.h +59 -0
- package/javascript/dist/cjs/numkong.js +13 -0
- package/javascript/dist/esm/numkong.js +13 -0
- package/javascript/numkong.c +59 -14
- package/javascript/numkong.ts +13 -0
- package/package.json +7 -7
- package/probes/probe.js +2 -2
- package/wasm/numkong.wasm +0 -0
|
@@ -39,9 +39,9 @@ extern "C" {
|
|
|
39
39
|
* BMOPA gives matching = popcount(XNOR(a,b)).
|
|
40
40
|
* dot(a,b) = popcount(a AND b) = (pop_a + pop_b - depth_bits + matching) / 2
|
|
41
41
|
*/
|
|
42
|
-
|
|
42
|
+
__arm_new("za") static void nk_dots_packed_u1_smebi32_streaming_( //
|
|
43
43
|
nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
44
|
-
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
44
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
|
|
45
45
|
|
|
46
46
|
nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
|
|
47
47
|
nk_size_t const row_tile_count_b = header->row_tile_count;
|
|
@@ -204,20 +204,22 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_st
|
|
|
204
204
|
}
|
|
205
205
|
}
|
|
206
206
|
|
|
207
|
-
NK_PUBLIC void nk_dots_packed_u1_smebi32(
|
|
208
|
-
|
|
209
|
-
|
|
207
|
+
NK_PUBLIC void nk_dots_packed_u1_smebi32( //
|
|
208
|
+
nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
209
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
210
|
+
nk_sme_start_streaming_();
|
|
210
211
|
nk_dots_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
|
|
211
212
|
c_stride_in_bytes);
|
|
213
|
+
nk_sme_stop_streaming_();
|
|
212
214
|
}
|
|
213
215
|
|
|
214
216
|
/**
|
|
215
217
|
* Symmetric u1 dot-product using ZA0 time-sharing + 3-tile fast path.
|
|
216
218
|
* Same ZA transpose pattern as hammings_symmetric, but with dot extraction.
|
|
217
219
|
*/
|
|
218
|
-
|
|
220
|
+
__arm_new("za") static void nk_dots_symmetric_u1_smebi32_streaming_( //
|
|
219
221
|
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
220
|
-
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
222
|
+
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
221
223
|
|
|
222
224
|
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
223
225
|
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
@@ -451,12 +453,13 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32
|
|
|
451
453
|
}
|
|
452
454
|
}
|
|
453
455
|
|
|
454
|
-
NK_PUBLIC void nk_dots_symmetric_u1_smebi32(
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
456
|
+
NK_PUBLIC void nk_dots_symmetric_u1_smebi32( //
|
|
457
|
+
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
458
|
+
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
459
|
+
nk_sme_start_streaming_();
|
|
458
460
|
nk_dots_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
|
|
459
461
|
result_stride_in_bytes, row_start, row_count);
|
|
462
|
+
nk_sme_stop_streaming_();
|
|
460
463
|
}
|
|
461
464
|
|
|
462
465
|
#if defined(__clang__)
|
|
@@ -153,9 +153,9 @@ NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t columns, nk_
|
|
|
153
153
|
}
|
|
154
154
|
}
|
|
155
155
|
|
|
156
|
-
|
|
156
|
+
__arm_new("za") static void nk_dots_packed_f32_smef64_streaming_( //
|
|
157
157
|
nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
158
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
158
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
159
159
|
|
|
160
160
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
161
161
|
nk_size_t const column_tile_count = header->column_tile_count;
|
|
@@ -390,14 +390,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
|
|
|
390
390
|
}
|
|
391
391
|
}
|
|
392
392
|
|
|
393
|
-
NK_PUBLIC void nk_dots_packed_f32_smef64(
|
|
394
|
-
|
|
395
|
-
|
|
393
|
+
NK_PUBLIC void nk_dots_packed_f32_smef64( //
|
|
394
|
+
nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
395
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
396
396
|
|
|
397
397
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f32_t);
|
|
398
398
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
|
|
399
399
|
|
|
400
|
+
nk_sme_start_streaming_();
|
|
400
401
|
nk_dots_packed_f32_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
402
|
+
nk_sme_stop_streaming_();
|
|
401
403
|
}
|
|
402
404
|
|
|
403
405
|
/**
|
|
@@ -406,9 +408,9 @@ NK_PUBLIC void nk_dots_packed_f32_smef64(nk_f32_t const *a, void const *b_packed
|
|
|
406
408
|
* pre-reads A columns into Z registers, then reloads ZA0 with widened B data
|
|
407
409
|
* per column tile. Eliminates all scalar B-packing loops.
|
|
408
410
|
*/
|
|
409
|
-
|
|
411
|
+
__arm_new("za") static void nk_dots_symmetric_f32_smef64_streaming_( //
|
|
410
412
|
nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
|
|
411
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
413
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
412
414
|
|
|
413
415
|
nk_size_t const tile_dimension = svcntd(); // 8 for SVL=512
|
|
414
416
|
nk_size_t const depth_tile_size = svcntw(); // 16 for SVL=512
|
|
@@ -721,15 +723,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64
|
|
|
721
723
|
}
|
|
722
724
|
}
|
|
723
725
|
|
|
724
|
-
NK_PUBLIC void nk_dots_symmetric_f32_smef64(
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
nk_size_t row_count) {
|
|
726
|
+
NK_PUBLIC void nk_dots_symmetric_f32_smef64( //
|
|
727
|
+
nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f64_t *result,
|
|
728
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
728
729
|
|
|
729
730
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
|
|
730
731
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
732
|
+
nk_sme_start_streaming_();
|
|
731
733
|
nk_dots_symmetric_f32_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
732
734
|
result_stride_elements, row_start, row_count);
|
|
735
|
+
nk_sme_stop_streaming_();
|
|
733
736
|
}
|
|
734
737
|
|
|
735
738
|
#pragma endregion F32 Floats
|
|
@@ -783,17 +786,16 @@ NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t v
|
|
|
783
786
|
*
|
|
784
787
|
* All slices fit in f32 (24-bit significand). Products: max 19+19 = 38 ≤ 53, exact in f64.
|
|
785
788
|
*/
|
|
786
|
-
NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_19_bits_(void)
|
|
789
|
+
NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_19_bits_(void) {
|
|
787
790
|
return 0xFFFFFFFC00000000ULL; // keep top 19 sig bits
|
|
788
791
|
}
|
|
789
|
-
NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_17_bits_(void)
|
|
792
|
+
NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_17_bits_(void) {
|
|
790
793
|
return 0xFFFFFFF000000000ULL; // keep top 17 sig bits
|
|
791
794
|
}
|
|
792
795
|
|
|
793
796
|
/* Split a scalar f64 into 3 non-overlapping Ozaki slices (19+17+17 mantissa bits).
|
|
794
797
|
* Each slice fits in f32. Outputs stored via pointers. */
|
|
795
|
-
NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, nk_f64_t *slice_1,
|
|
796
|
-
nk_f64_t *slice_2) NK_STREAMING_ {
|
|
798
|
+
NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, nk_f64_t *slice_1, nk_f64_t *slice_2) {
|
|
797
799
|
nk_fui64_t pun;
|
|
798
800
|
pun.f = val;
|
|
799
801
|
pun.u &= nk_f64_smef64_ozaki_mask_19_bits_();
|
|
@@ -805,9 +807,9 @@ NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, n
|
|
|
805
807
|
*slice_2 = residual - *slice_1;
|
|
806
808
|
}
|
|
807
809
|
|
|
808
|
-
|
|
810
|
+
__arm_new("za") static void nk_dots_symmetric_f64_smef64_streaming_( //
|
|
809
811
|
nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
|
|
810
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
812
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
811
813
|
|
|
812
814
|
nk_size_t const tile_dimension = svcntd();
|
|
813
815
|
nk_size_t const depth_steps_per_batch = tile_dimension;
|
|
@@ -929,15 +931,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f64_smef64
|
|
|
929
931
|
}
|
|
930
932
|
}
|
|
931
933
|
|
|
932
|
-
NK_PUBLIC void nk_dots_symmetric_f64_smef64(
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
nk_size_t row_count) {
|
|
934
|
+
NK_PUBLIC void nk_dots_symmetric_f64_smef64( //
|
|
935
|
+
nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f64_t *result,
|
|
936
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
936
937
|
|
|
937
938
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
|
|
938
939
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
940
|
+
nk_sme_start_streaming_();
|
|
939
941
|
nk_dots_symmetric_f64_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
940
942
|
result_stride_elements, row_start, row_count);
|
|
943
|
+
nk_sme_stop_streaming_();
|
|
941
944
|
}
|
|
942
945
|
|
|
943
946
|
NK_PUBLIC nk_size_t nk_dots_packed_size_f64_smef64(nk_size_t columns, nk_size_t depth) {
|
|
@@ -1018,9 +1021,9 @@ NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t columns, nk_
|
|
|
1018
1021
|
}
|
|
1019
1022
|
}
|
|
1020
1023
|
|
|
1021
|
-
|
|
1024
|
+
__arm_new("za") static void nk_dots_packed_f64_smef64_streaming_( //
|
|
1022
1025
|
nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1023
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1026
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1024
1027
|
|
|
1025
1028
|
// Read header
|
|
1026
1029
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
@@ -1296,14 +1299,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
|
|
|
1296
1299
|
}
|
|
1297
1300
|
}
|
|
1298
1301
|
|
|
1299
|
-
NK_PUBLIC void nk_dots_packed_f64_smef64(
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
+
NK_PUBLIC void nk_dots_packed_f64_smef64( //
|
|
1303
|
+
nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1304
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1302
1305
|
|
|
1303
1306
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f64_t);
|
|
1304
1307
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
|
|
1305
1308
|
|
|
1309
|
+
nk_sme_start_streaming_();
|
|
1306
1310
|
nk_dots_packed_f64_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1311
|
+
nk_sme_stop_streaming_();
|
|
1307
1312
|
}
|
|
1308
1313
|
|
|
1309
1314
|
#pragma endregion F64 Floats
|
package/include/numkong/dots.h
CHANGED
|
@@ -681,6 +681,37 @@ NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx(nk_u8_t const *vectors, nk_size_
|
|
|
681
681
|
nk_size_t row_start, nk_size_t row_count);
|
|
682
682
|
#endif // NK_TARGET_SAPPHIREAMX
|
|
683
683
|
|
|
684
|
+
/* Granite Rapids backends using Intel AMX-FP16 (Advanced Matrix Extensions with FP16 support).
|
|
685
|
+
* AMX-FP16 adds TDPFP16PS (FP16×FP16→FP32 tile multiply-accumulate), same tile geometry as BF16.
|
|
686
|
+
* The F32 Ozaki kernel splits F32 inputs into 2 FP16 halves for ~35-40 bit effective precision.
|
|
687
|
+
*/
|
|
688
|
+
#if NK_TARGET_GRANITEAMX
|
|
689
|
+
/** @copydoc nk_dots_packed_size_f16 */
|
|
690
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_graniteamx(nk_size_t width, nk_size_t depth);
|
|
691
|
+
/** @copydoc nk_dots_pack_f16 */
|
|
692
|
+
NK_PUBLIC void nk_dots_pack_f16_graniteamx(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
693
|
+
void *b_packed);
|
|
694
|
+
/** @copydoc nk_dots_packed_f16 */
|
|
695
|
+
NK_PUBLIC void nk_dots_packed_f16_graniteamx(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
696
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
697
|
+
/** @copydoc nk_dots_symmetric_f16 */
|
|
698
|
+
NK_PUBLIC void nk_dots_symmetric_f16_graniteamx(nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
699
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
700
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
701
|
+
/** @copydoc nk_dots_packed_size_f16 */
|
|
702
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_e5m2_graniteamx(nk_size_t width, nk_size_t depth);
|
|
703
|
+
/** @copydoc nk_dots_pack_f16 */
|
|
704
|
+
NK_PUBLIC void nk_dots_pack_e5m2_graniteamx(nk_e5m2_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
705
|
+
void *b_packed);
|
|
706
|
+
/** @copydoc nk_dots_packed_f16 */
|
|
707
|
+
NK_PUBLIC void nk_dots_packed_e5m2_graniteamx(nk_e5m2_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
708
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
709
|
+
/** @copydoc nk_dots_symmetric_f16 */
|
|
710
|
+
NK_PUBLIC void nk_dots_symmetric_e5m2_graniteamx(nk_e5m2_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
711
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
712
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
713
|
+
#endif // NK_TARGET_GRANITEAMX
|
|
714
|
+
|
|
684
715
|
/* ARM SME backends using Scalable Matrix Extension.
|
|
685
716
|
* SME provides ZA tile registers for outer product operations.
|
|
686
717
|
* F16/BF16/I8/U8/E4M3 use ZA32 tiles, F32/F64 use ZA64 tiles (FEAT_SME_F64F64).
|
|
@@ -1858,6 +1889,7 @@ NK_PUBLIC void nk_dots_symmetric_u1_loongsonasx(nk_u1x8_t const *vectors, nk_siz
|
|
|
1858
1889
|
#include "numkong/dots/genoa.h"
|
|
1859
1890
|
#include "numkong/dots/diamond.h"
|
|
1860
1891
|
#include "numkong/dots/sapphireamx.h"
|
|
1892
|
+
#include "numkong/dots/graniteamx.h"
|
|
1861
1893
|
#include "numkong/dots/neon.h"
|
|
1862
1894
|
#include "numkong/dots/neonsdot.h"
|
|
1863
1895
|
#include "numkong/dots/neonfhm.h"
|
|
@@ -2002,7 +2034,9 @@ NK_PUBLIC void nk_dots_packed_f64(nk_f64_t const *a, void const *b_packed, nk_f6
|
|
|
2002
2034
|
}
|
|
2003
2035
|
|
|
2004
2036
|
NK_PUBLIC nk_size_t nk_dots_packed_size_f16(nk_size_t width, nk_size_t depth) {
|
|
2005
|
-
#if
|
|
2037
|
+
#if NK_TARGET_GRANITEAMX
|
|
2038
|
+
return nk_dots_packed_size_f16_graniteamx(width, depth);
|
|
2039
|
+
#elif NK_TARGET_SME
|
|
2006
2040
|
return nk_dots_packed_size_f16_sme(width, depth);
|
|
2007
2041
|
#elif NK_TARGET_NEONFHM
|
|
2008
2042
|
return nk_dots_packed_size_f16_neonfhm(width, depth);
|
|
@@ -2023,7 +2057,9 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_f16(nk_size_t width, nk_size_t depth) {
|
|
|
2023
2057
|
|
|
2024
2058
|
NK_PUBLIC void nk_dots_pack_f16(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
2025
2059
|
void *b_packed) {
|
|
2026
|
-
#if
|
|
2060
|
+
#if NK_TARGET_GRANITEAMX
|
|
2061
|
+
nk_dots_pack_f16_graniteamx(b, width, depth, b_stride, b_packed);
|
|
2062
|
+
#elif NK_TARGET_SME
|
|
2027
2063
|
nk_dots_pack_f16_sme(b, width, depth, b_stride, b_packed);
|
|
2028
2064
|
#elif NK_TARGET_NEONFHM
|
|
2029
2065
|
nk_dots_pack_f16_neonfhm(b, width, depth, b_stride, b_packed);
|
|
@@ -2044,7 +2080,9 @@ NK_PUBLIC void nk_dots_pack_f16(nk_f16_t const *b, nk_size_t width, nk_size_t de
|
|
|
2044
2080
|
|
|
2045
2081
|
NK_PUBLIC void nk_dots_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
2046
2082
|
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2047
|
-
#if
|
|
2083
|
+
#if NK_TARGET_GRANITEAMX
|
|
2084
|
+
nk_dots_packed_f16_graniteamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2085
|
+
#elif NK_TARGET_SME
|
|
2048
2086
|
nk_dots_packed_f16_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2049
2087
|
#elif NK_TARGET_NEONFHM
|
|
2050
2088
|
nk_dots_packed_f16_neonfhm(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
@@ -76,6 +76,29 @@ extern "C" {
|
|
|
76
76
|
} \
|
|
77
77
|
}
|
|
78
78
|
|
|
79
|
+
/* Keep the serial instantiations below actually scalar, regardless of build type.
|
|
80
|
+
* Without this, -O3 + LTO can vectorize or clone the serial kernels under AVX-512
|
|
81
|
+
* callers in dispatch_*.c, which wastes binary and breaks the nk_*_serial-as-scalar-oracle
|
|
82
|
+
* contract. See dots/serial.h. */
|
|
83
|
+
#if defined(__clang__)
|
|
84
|
+
#pragma clang attribute push(__attribute__((noinline)), apply_to = function)
|
|
85
|
+
#elif defined(__GNUC__)
|
|
86
|
+
#pragma GCC push_options
|
|
87
|
+
#pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
|
|
88
|
+
#endif
|
|
89
|
+
|
|
90
|
+
/* Size bias for release. Gated on NDEBUG so Debug builds keep -O0 for stepping. */
|
|
91
|
+
#if defined(NDEBUG)
|
|
92
|
+
#if defined(_MSC_VER)
|
|
93
|
+
#pragma optimize("s", on)
|
|
94
|
+
#elif defined(__clang__)
|
|
95
|
+
#pragma clang attribute push(__attribute__((minsize)), apply_to = function)
|
|
96
|
+
#elif defined(__GNUC__)
|
|
97
|
+
#pragma GCC push_options
|
|
98
|
+
#pragma GCC optimize("Os")
|
|
99
|
+
#endif
|
|
100
|
+
#endif
|
|
101
|
+
|
|
79
102
|
nk_define_each_sum_(f64, f64, nk_assign_from_to_, nk_assign_from_to_) // nk_each_sum_f64_serial
|
|
80
103
|
nk_define_each_sum_(f32, f32, nk_assign_from_to_, nk_assign_from_to_) // nk_each_sum_f32_serial
|
|
81
104
|
nk_define_each_sum_(f16, f32, nk_f16_to_f32_serial, nk_f32_to_f16_serial) // nk_each_sum_f16_serial
|
|
@@ -253,6 +276,22 @@ NK_PUBLIC void nk_each_fma_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, n
|
|
|
253
276
|
}
|
|
254
277
|
}
|
|
255
278
|
|
|
279
|
+
#if defined(NDEBUG)
|
|
280
|
+
#if defined(_MSC_VER)
|
|
281
|
+
#pragma optimize("", on)
|
|
282
|
+
#elif defined(__clang__)
|
|
283
|
+
#pragma clang attribute pop
|
|
284
|
+
#elif defined(__GNUC__)
|
|
285
|
+
#pragma GCC pop_options
|
|
286
|
+
#endif
|
|
287
|
+
#endif
|
|
288
|
+
|
|
289
|
+
#if defined(__clang__)
|
|
290
|
+
#pragma clang attribute pop
|
|
291
|
+
#elif defined(__GNUC__)
|
|
292
|
+
#pragma GCC pop_options
|
|
293
|
+
#endif
|
|
294
|
+
|
|
256
295
|
#if defined(__cplusplus)
|
|
257
296
|
} // extern "C"
|
|
258
297
|
#endif
|
|
@@ -24,7 +24,7 @@
|
|
|
24
24
|
#if NK_TARGET_HASWELL
|
|
25
25
|
|
|
26
26
|
#include "numkong/types.h"
|
|
27
|
-
#include "numkong/trigonometry/haswell.h" // `nk_sin_f64x4_haswell_`, `nk_cos_f64x4_haswell_`, `nk_atan2_f64x4_haswell_
|
|
27
|
+
#include "numkong/trigonometry/haswell.h" // `nk_sin_f64x4_haswell_`, `nk_cos_f64x4_haswell_`, `nk_atan2_f64x4_haswell_`
|
|
28
28
|
|
|
29
29
|
#if defined(__cplusplus)
|
|
30
30
|
extern "C" {
|
|
@@ -21,7 +21,7 @@
|
|
|
21
21
|
#if NK_TARGET_NEON
|
|
22
22
|
|
|
23
23
|
#include "numkong/types.h"
|
|
24
|
-
#include "numkong/trigonometry/neon.h" // `nk_sin_f64x2_neon_`, `nk_cos_f64x2_neon_`, `nk_atan2_f64x2_neon_
|
|
24
|
+
#include "numkong/trigonometry/neon.h" // `nk_sin_f64x2_neon_`, `nk_cos_f64x2_neon_`, `nk_atan2_f64x2_neon_`
|
|
25
25
|
|
|
26
26
|
#if defined(__cplusplus)
|
|
27
27
|
extern "C" {
|
|
@@ -11,15 +11,20 @@
|
|
|
11
11
|
|
|
12
12
|
#include "numkong/types.h"
|
|
13
13
|
#include "numkong/spatial/serial.h" // `nk_f64_sqrt_serial`, `nk_f32_sqrt_serial`
|
|
14
|
-
#include "numkong/trigonometry/serial.h" // `nk_f64_sin`, `nk_f64_cos`, `nk_f64_atan2
|
|
14
|
+
#include "numkong/trigonometry/serial.h" // `nk_f64_sin`, `nk_f64_cos`, `nk_f64_atan2`
|
|
15
15
|
|
|
16
16
|
#if defined(__cplusplus)
|
|
17
17
|
extern "C" {
|
|
18
18
|
#endif
|
|
19
19
|
|
|
20
|
-
/*
|
|
21
|
-
*
|
|
22
|
-
|
|
20
|
+
/* Keep the serial instantiations below actually scalar, regardless of build type.
|
|
21
|
+
* See dots/serial.h for rationale. */
|
|
22
|
+
#if defined(__clang__)
|
|
23
|
+
#pragma clang attribute push(__attribute__((noinline)), apply_to = function)
|
|
24
|
+
#elif defined(__GNUC__)
|
|
25
|
+
#pragma GCC push_options
|
|
26
|
+
#pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
|
|
27
|
+
#endif
|
|
23
28
|
|
|
24
29
|
NK_PUBLIC void nk_haversine_f64_serial( //
|
|
25
30
|
nk_f64_t const *a_lats, nk_f64_t const *a_lons, //
|
|
@@ -302,6 +307,12 @@ NK_PUBLIC void nk_vincenty_f32_serial( //
|
|
|
302
307
|
}
|
|
303
308
|
}
|
|
304
309
|
|
|
310
|
+
#if defined(__clang__)
|
|
311
|
+
#pragma clang attribute pop
|
|
312
|
+
#elif defined(__GNUC__)
|
|
313
|
+
#pragma GCC pop_options
|
|
314
|
+
#endif
|
|
315
|
+
|
|
305
316
|
#if defined(__cplusplus)
|
|
306
317
|
} // extern "C"
|
|
307
318
|
#endif
|
|
@@ -24,7 +24,7 @@
|
|
|
24
24
|
#if NK_TARGET_SKYLAKE
|
|
25
25
|
|
|
26
26
|
#include "numkong/types.h"
|
|
27
|
-
#include "numkong/trigonometry/skylake.h" // `nk_sin_f64x8_skylake_`, `nk_cos_f64x8_skylake_`, `nk_atan2_f64x8_skylake_
|
|
27
|
+
#include "numkong/trigonometry/skylake.h" // `nk_sin_f64x8_skylake_`, `nk_cos_f64x8_skylake_`, `nk_atan2_f64x8_skylake_`
|
|
28
28
|
|
|
29
29
|
#if defined(__cplusplus)
|
|
30
30
|
extern "C" {
|
|
@@ -71,6 +71,15 @@ NK_STATIC_ASSERT(sizeof(nk_maxsim_vector_metadata_t) == 12, nk_maxsim_vector_met
|
|
|
71
71
|
*/
|
|
72
72
|
typedef void (*nk_maxsim_to_f32_t)(void const *source, nk_f32_t *destination);
|
|
73
73
|
|
|
74
|
+
/* Keep the serial instantiations below actually scalar, regardless of build type.
|
|
75
|
+
* See dots/serial.h for rationale. */
|
|
76
|
+
#if defined(__clang__)
|
|
77
|
+
#pragma clang attribute push(__attribute__((noinline)), apply_to = function)
|
|
78
|
+
#elif defined(__GNUC__)
|
|
79
|
+
#pragma GCC push_options
|
|
80
|
+
#pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
|
|
81
|
+
#endif
|
|
82
|
+
|
|
74
83
|
/** @brief Identity conversion for f32 sources — just a typed memcpy. */
|
|
75
84
|
NK_INTERNAL void nk_f32_to_f32_(void const *source, nk_f32_t *destination) { *destination = *(nk_f32_t const *)source; }
|
|
76
85
|
|
|
@@ -483,6 +492,12 @@ NK_PUBLIC void nk_maxsim_packed_f16_serial( //
|
|
|
483
492
|
*result = (nk_f32_t)total_angular_distance;
|
|
484
493
|
}
|
|
485
494
|
|
|
495
|
+
#if defined(__clang__)
|
|
496
|
+
#pragma clang attribute pop
|
|
497
|
+
#elif defined(__GNUC__)
|
|
498
|
+
#pragma GCC pop_options
|
|
499
|
+
#endif
|
|
500
|
+
|
|
486
501
|
#if defined(__cplusplus)
|
|
487
502
|
} // extern "C"
|
|
488
503
|
#endif
|
|
@@ -46,7 +46,8 @@
|
|
|
46
46
|
#if NK_TARGET_ARM64_
|
|
47
47
|
#if NK_TARGET_SME
|
|
48
48
|
|
|
49
|
-
#include "numkong/dots/sme.h"
|
|
49
|
+
#include "numkong/dots/sme.h" // `nk_dots_sme_packed_header_t`
|
|
50
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
50
51
|
|
|
51
52
|
#if defined(__cplusplus)
|
|
52
53
|
extern "C" {
|
|
@@ -90,10 +91,9 @@ NK_STATIC_ASSERT(sizeof(nk_maxsim_sme_packed_header_t) == 64, nk_maxsim_sme_pack
|
|
|
90
91
|
*
|
|
91
92
|
* 1-tile remainder: uses ZA0 only, with predicated loads for partial tiles.
|
|
92
93
|
*/
|
|
93
|
-
|
|
94
|
-
void const *query_packed, void const *document_packed,
|
|
95
|
-
nk_size_t
|
|
96
|
-
nk_size_t depth, nk_f32_t *result) {
|
|
94
|
+
__arm_new("za") static void nk_maxsim_packed_f16_streaming_( //
|
|
95
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
96
|
+
nk_size_t depth, nk_f32_t *result) NK_STREAMING_ {
|
|
97
97
|
|
|
98
98
|
nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
|
|
99
99
|
nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
|
|
@@ -258,18 +258,19 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
|
|
|
258
258
|
document_inverse_norms_f32x);
|
|
259
259
|
svfloat32_t angular_distance_f32x = svmax_f32_x(
|
|
260
260
|
row_predicate_b32x, svsub_f32_x(row_predicate_b32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
|
|
261
|
-
total_angular_distance +=
|
|
261
|
+
total_angular_distance += nk_svaddv_f32_(row_predicate_b32x, angular_distance_f32x);
|
|
262
262
|
}
|
|
263
263
|
|
|
264
264
|
*result = total_angular_distance;
|
|
265
265
|
}
|
|
266
266
|
|
|
267
|
-
NK_PUBLIC void nk_maxsim_packed_f16_sme(
|
|
268
|
-
void const *query_packed, void const *document_packed,
|
|
269
|
-
nk_size_t
|
|
270
|
-
nk_f32_t *result) { //
|
|
267
|
+
NK_PUBLIC void nk_maxsim_packed_f16_sme( //
|
|
268
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
269
|
+
nk_size_t depth, nk_f32_t *result) {
|
|
271
270
|
|
|
271
|
+
nk_sme_start_streaming_();
|
|
272
272
|
nk_maxsim_packed_f16_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
|
|
273
|
+
nk_sme_stop_streaming_();
|
|
273
274
|
}
|
|
274
275
|
|
|
275
276
|
/**
|
|
@@ -281,10 +282,9 @@ NK_PUBLIC void nk_maxsim_packed_f16_sme( //
|
|
|
281
282
|
*
|
|
282
283
|
* 1-tile remainder: uses ZA0 only, with predicated loads for partial tiles.
|
|
283
284
|
*/
|
|
284
|
-
|
|
285
|
-
void const *query_packed, void const *document_packed,
|
|
286
|
-
nk_size_t
|
|
287
|
-
nk_size_t depth, nk_f32_t *result) {
|
|
285
|
+
__arm_new("za") static void nk_maxsim_packed_bf16_streaming_( //
|
|
286
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
287
|
+
nk_size_t depth, nk_f32_t *result) NK_STREAMING_ {
|
|
288
288
|
|
|
289
289
|
nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
|
|
290
290
|
nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
|
|
@@ -454,18 +454,19 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
|
|
|
454
454
|
document_inverse_norms_f32x);
|
|
455
455
|
svfloat32_t angular_distance_f32x = svmax_f32_x(
|
|
456
456
|
row_predicate_b32x, svsub_f32_x(row_predicate_b32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
|
|
457
|
-
total_angular_distance +=
|
|
457
|
+
total_angular_distance += nk_svaddv_f32_(row_predicate_b32x, angular_distance_f32x);
|
|
458
458
|
}
|
|
459
459
|
|
|
460
460
|
*result = total_angular_distance;
|
|
461
461
|
}
|
|
462
462
|
|
|
463
|
-
NK_PUBLIC void nk_maxsim_packed_bf16_sme(
|
|
464
|
-
void const *query_packed, void const *document_packed,
|
|
465
|
-
nk_size_t
|
|
466
|
-
nk_f32_t *result) { //
|
|
463
|
+
NK_PUBLIC void nk_maxsim_packed_bf16_sme( //
|
|
464
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
465
|
+
nk_size_t depth, nk_f32_t *result) {
|
|
467
466
|
|
|
467
|
+
nk_sme_start_streaming_();
|
|
468
468
|
nk_maxsim_packed_bf16_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
|
|
469
|
+
nk_sme_stop_streaming_();
|
|
469
470
|
}
|
|
470
471
|
|
|
471
472
|
NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sme(nk_size_t columns, nk_size_t depth) { //
|
|
@@ -649,7 +650,7 @@ NK_PUBLIC nk_f64_t nk_maxsim_reduce_dot_f32_ssve_( //
|
|
|
649
650
|
svfloat64_t b_odd_f64x = svcvtlt_f64_f32_x(predicate_odd_b64x, b_f32x);
|
|
650
651
|
accumulator_odd_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_odd_f64x, a_odd_f64x, b_odd_f64x);
|
|
651
652
|
}
|
|
652
|
-
return
|
|
653
|
+
return nk_svaddv_f64_(svptrue_b64(), accumulator_even_f64x) + nk_svaddv_f64_(svptrue_b64(), accumulator_odd_f64x);
|
|
653
654
|
}
|
|
654
655
|
|
|
655
656
|
/**
|
|
@@ -687,7 +688,7 @@ NK_PUBLIC nk_f64_t nk_maxsim_angular_from_dots_ssve_(
|
|
|
687
688
|
svfloat64_t angular_distance_f64x = svsub_f64_x(predicate_b64x, svdup_f64(1.0), cosine_f64x);
|
|
688
689
|
angular_distance_f64x = svmax_f64_x(predicate_b64x, angular_distance_f64x, svdup_f64(0.0));
|
|
689
690
|
|
|
690
|
-
total_angular_distance_f64 +=
|
|
691
|
+
total_angular_distance_f64 += nk_svaddv_f64_(predicate_b64x, angular_distance_f64x);
|
|
691
692
|
}
|
|
692
693
|
return total_angular_distance_f64;
|
|
693
694
|
}
|
|
@@ -701,10 +702,9 @@ NK_PUBLIC nk_f64_t nk_maxsim_angular_from_dots_ssve_(
|
|
|
701
702
|
* Refinement: tile-wide interleaved f64 dot products for the winning (query, document) pairs.
|
|
702
703
|
* Angular distance: 1 - dot / sqrt(||q||^2 * ||d||^2), accumulated with f64.
|
|
703
704
|
*/
|
|
704
|
-
|
|
705
|
-
void const *query_packed, void const *document_packed,
|
|
706
|
-
nk_size_t
|
|
707
|
-
nk_f64_t *result) {
|
|
705
|
+
__arm_new("za") static void nk_maxsim_packed_f32_streaming_( //
|
|
706
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
707
|
+
nk_size_t depth, nk_f64_t *result) NK_STREAMING_ {
|
|
708
708
|
|
|
709
709
|
nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
|
|
710
710
|
nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
|
|
@@ -937,10 +937,10 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
937
937
|
|
|
938
938
|
// Reduce SVE accumulators to scalars and compute angular distances
|
|
939
939
|
nk_f64_t dot_products_f64[4];
|
|
940
|
-
dot_products_f64[0] =
|
|
941
|
-
dot_products_f64[1] =
|
|
942
|
-
dot_products_f64[2] =
|
|
943
|
-
dot_products_f64[3] =
|
|
940
|
+
dot_products_f64[0] = nk_svaddv_f64_(svptrue_b64(), accumulator_0_f64x);
|
|
941
|
+
dot_products_f64[1] = nk_svaddv_f64_(svptrue_b64(), accumulator_1_f64x);
|
|
942
|
+
dot_products_f64[2] = nk_svaddv_f64_(svptrue_b64(), accumulator_2_f64x);
|
|
943
|
+
dot_products_f64[3] = nk_svaddv_f64_(svptrue_b64(), accumulator_3_f64x);
|
|
944
944
|
nk_f64_t batch_query_norms_f64[4], batch_document_norms_f64[4];
|
|
945
945
|
for (nk_size_t i = 0; i < 4; i++) {
|
|
946
946
|
batch_query_norms_f64[i] = (nk_f64_t)query_norms[row_start + row_batch_start + i];
|
|
@@ -969,12 +969,13 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
969
969
|
*result = total_angular_distance_f64;
|
|
970
970
|
}
|
|
971
971
|
|
|
972
|
-
NK_PUBLIC void nk_maxsim_packed_f32_sme(
|
|
973
|
-
void const *query_packed, void const *document_packed,
|
|
974
|
-
nk_size_t
|
|
975
|
-
nk_f64_t *result) { //
|
|
972
|
+
NK_PUBLIC void nk_maxsim_packed_f32_sme( //
|
|
973
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
974
|
+
nk_size_t depth, nk_f64_t *result) {
|
|
976
975
|
|
|
976
|
+
nk_sme_start_streaming_();
|
|
977
977
|
nk_maxsim_packed_f32_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
|
|
978
|
+
nk_sme_stop_streaming_();
|
|
978
979
|
}
|
|
979
980
|
|
|
980
981
|
#if defined(__clang__)
|