numkong 7.4.5 → 7.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -0
- package/binding.gyp +81 -5
- package/c/dispatch_f16.c +23 -0
- package/c/numkong.c +0 -13
- package/include/numkong/attention/sme.h +34 -31
- package/include/numkong/capabilities.h +2 -15
- package/include/numkong/cast/neon.h +15 -0
- package/include/numkong/curved/smef64.h +82 -62
- package/include/numkong/dot/rvvbf16.h +1 -1
- package/include/numkong/dot/rvvhalf.h +1 -1
- package/include/numkong/dot/sve.h +6 -5
- package/include/numkong/dot/svebfdot.h +2 -1
- package/include/numkong/dot/svehalf.h +6 -5
- package/include/numkong/dot/svesdot.h +3 -2
- package/include/numkong/dots/graniteamx.h +733 -0
- package/include/numkong/dots/serial.h +11 -4
- package/include/numkong/dots/sme.h +172 -140
- package/include/numkong/dots/smebi32.h +14 -11
- package/include/numkong/dots/smef64.h +31 -26
- package/include/numkong/dots.h +29 -3
- package/include/numkong/each/serial.h +22 -0
- package/include/numkong/geospatial/haswell.h +1 -1
- package/include/numkong/geospatial/neon.h +1 -1
- package/include/numkong/geospatial/serial.h +1 -1
- package/include/numkong/geospatial/skylake.h +1 -1
- package/include/numkong/maxsim/sme.h +34 -33
- package/include/numkong/mesh/serial.h +22 -0
- package/include/numkong/reduce/neon.h +29 -0
- package/include/numkong/reduce/neonbfdot.h +2 -2
- package/include/numkong/reduce/neonfhm.h +4 -4
- package/include/numkong/reduce/sve.h +52 -0
- package/include/numkong/reduce.h +4 -0
- package/include/numkong/set/sve.h +6 -5
- package/include/numkong/sets/smebi32.h +35 -30
- package/include/numkong/sparse/sve2.h +3 -2
- package/include/numkong/spatial/sve.h +7 -6
- package/include/numkong/spatial/svebfdot.h +7 -4
- package/include/numkong/spatial/svehalf.h +5 -4
- package/include/numkong/spatial/svesdot.h +9 -8
- package/include/numkong/spatials/graniteamx.h +173 -0
- package/include/numkong/spatials/serial.h +22 -0
- package/include/numkong/spatials/sme.h +391 -350
- package/include/numkong/spatials/smef64.h +79 -70
- package/include/numkong/spatials.h +37 -4
- package/include/numkong/types.h +59 -0
- package/javascript/dist/cjs/numkong.js +13 -0
- package/javascript/dist/esm/numkong.js +13 -0
- package/javascript/numkong.c +56 -12
- package/javascript/numkong.ts +13 -0
- package/package.json +7 -7
- package/probes/probe.js +2 -2
- package/wasm/numkong.wasm +0 -0
|
@@ -39,9 +39,9 @@ extern "C" {
|
|
|
39
39
|
* BMOPA gives matching = popcount(XNOR(a,b)).
|
|
40
40
|
* dot(a,b) = popcount(a AND b) = (pop_a + pop_b - depth_bits + matching) / 2
|
|
41
41
|
*/
|
|
42
|
-
|
|
42
|
+
__arm_new("za") static void nk_dots_packed_u1_smebi32_streaming_( //
|
|
43
43
|
nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
44
|
-
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
44
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) NK_STREAMING_ {
|
|
45
45
|
|
|
46
46
|
nk_sets_smebi32_packed_header_t const *header = (nk_sets_smebi32_packed_header_t const *)b_packed;
|
|
47
47
|
nk_size_t const row_tile_count_b = header->row_tile_count;
|
|
@@ -204,20 +204,22 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_u1_smebi32_st
|
|
|
204
204
|
}
|
|
205
205
|
}
|
|
206
206
|
|
|
207
|
-
NK_PUBLIC void nk_dots_packed_u1_smebi32(
|
|
208
|
-
|
|
209
|
-
|
|
207
|
+
NK_PUBLIC void nk_dots_packed_u1_smebi32( //
|
|
208
|
+
nk_u1x8_t const *a, void const *b_packed, nk_u32_t *c, nk_size_t row_count_a, nk_size_t row_count_b,
|
|
209
|
+
nk_size_t depth_bits, nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
210
|
+
nk_sme_start_streaming_();
|
|
210
211
|
nk_dots_packed_u1_smebi32_streaming_(a, b_packed, c, row_count_a, row_count_b, depth_bits, a_stride_in_bytes,
|
|
211
212
|
c_stride_in_bytes);
|
|
213
|
+
nk_sme_stop_streaming_();
|
|
212
214
|
}
|
|
213
215
|
|
|
214
216
|
/**
|
|
215
217
|
* Symmetric u1 dot-product using ZA0 time-sharing + 3-tile fast path.
|
|
216
218
|
* Same ZA transpose pattern as hammings_symmetric, but with dot extraction.
|
|
217
219
|
*/
|
|
218
|
-
|
|
220
|
+
__arm_new("za") static void nk_dots_symmetric_u1_smebi32_streaming_( //
|
|
219
221
|
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
220
|
-
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
222
|
+
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
221
223
|
|
|
222
224
|
nk_size_t const tile_dim = svcntw(); // 16 for 512-bit SVL
|
|
223
225
|
nk_size_t const depth_tile_size = svcntw(); // 16 u32 per depth tile
|
|
@@ -451,12 +453,13 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_u1_smebi32
|
|
|
451
453
|
}
|
|
452
454
|
}
|
|
453
455
|
|
|
454
|
-
NK_PUBLIC void nk_dots_symmetric_u1_smebi32(
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
456
|
+
NK_PUBLIC void nk_dots_symmetric_u1_smebi32( //
|
|
457
|
+
nk_u1x8_t const *vectors, nk_size_t vectors_count, nk_size_t depth_bits, nk_size_t stride_in_bytes,
|
|
458
|
+
nk_u32_t *result, nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
459
|
+
nk_sme_start_streaming_();
|
|
458
460
|
nk_dots_symmetric_u1_smebi32_streaming_(vectors, vectors_count, depth_bits, stride_in_bytes, result,
|
|
459
461
|
result_stride_in_bytes, row_start, row_count);
|
|
462
|
+
nk_sme_stop_streaming_();
|
|
460
463
|
}
|
|
461
464
|
|
|
462
465
|
#if defined(__clang__)
|
|
@@ -153,9 +153,9 @@ NK_PUBLIC void nk_dots_pack_f32_smef64(nk_f32_t const *b, nk_size_t columns, nk_
|
|
|
153
153
|
}
|
|
154
154
|
}
|
|
155
155
|
|
|
156
|
-
|
|
156
|
+
__arm_new("za") static void nk_dots_packed_f32_smef64_streaming_( //
|
|
157
157
|
nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
158
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
158
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
159
159
|
|
|
160
160
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
161
161
|
nk_size_t const column_tile_count = header->column_tile_count;
|
|
@@ -390,14 +390,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f32_smef64_st
|
|
|
390
390
|
}
|
|
391
391
|
}
|
|
392
392
|
|
|
393
|
-
NK_PUBLIC void nk_dots_packed_f32_smef64(
|
|
394
|
-
|
|
395
|
-
|
|
393
|
+
NK_PUBLIC void nk_dots_packed_f32_smef64( //
|
|
394
|
+
nk_f32_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
395
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
396
396
|
|
|
397
397
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f32_t);
|
|
398
398
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
|
|
399
399
|
|
|
400
|
+
nk_sme_start_streaming_();
|
|
400
401
|
nk_dots_packed_f32_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
402
|
+
nk_sme_stop_streaming_();
|
|
401
403
|
}
|
|
402
404
|
|
|
403
405
|
/**
|
|
@@ -406,9 +408,9 @@ NK_PUBLIC void nk_dots_packed_f32_smef64(nk_f32_t const *a, void const *b_packed
|
|
|
406
408
|
* pre-reads A columns into Z registers, then reloads ZA0 with widened B data
|
|
407
409
|
* per column tile. Eliminates all scalar B-packing loops.
|
|
408
410
|
*/
|
|
409
|
-
|
|
411
|
+
__arm_new("za") static void nk_dots_symmetric_f32_smef64_streaming_( //
|
|
410
412
|
nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
|
|
411
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
413
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
412
414
|
|
|
413
415
|
nk_size_t const tile_dimension = svcntd(); // 8 for SVL=512
|
|
414
416
|
nk_size_t const depth_tile_size = svcntw(); // 16 for SVL=512
|
|
@@ -721,15 +723,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f32_smef64
|
|
|
721
723
|
}
|
|
722
724
|
}
|
|
723
725
|
|
|
724
|
-
NK_PUBLIC void nk_dots_symmetric_f32_smef64(
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
nk_size_t row_count) {
|
|
726
|
+
NK_PUBLIC void nk_dots_symmetric_f32_smef64( //
|
|
727
|
+
nk_f32_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f64_t *result,
|
|
728
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
728
729
|
|
|
729
730
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f32_t);
|
|
730
731
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
732
|
+
nk_sme_start_streaming_();
|
|
731
733
|
nk_dots_symmetric_f32_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
732
734
|
result_stride_elements, row_start, row_count);
|
|
735
|
+
nk_sme_stop_streaming_();
|
|
733
736
|
}
|
|
734
737
|
|
|
735
738
|
#pragma endregion F32 Floats
|
|
@@ -783,17 +786,16 @@ NK_PUBLIC void nk_dots_symmetric_f32_smef64(nk_f32_t const *vectors, nk_size_t v
|
|
|
783
786
|
*
|
|
784
787
|
* All slices fit in f32 (24-bit significand). Products: max 19+19 = 38 ≤ 53, exact in f64.
|
|
785
788
|
*/
|
|
786
|
-
NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_19_bits_(void)
|
|
789
|
+
NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_19_bits_(void) {
|
|
787
790
|
return 0xFFFFFFFC00000000ULL; // keep top 19 sig bits
|
|
788
791
|
}
|
|
789
|
-
NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_17_bits_(void)
|
|
792
|
+
NK_PUBLIC nk_u64_t nk_f64_smef64_ozaki_mask_17_bits_(void) {
|
|
790
793
|
return 0xFFFFFFF000000000ULL; // keep top 17 sig bits
|
|
791
794
|
}
|
|
792
795
|
|
|
793
796
|
/* Split a scalar f64 into 3 non-overlapping Ozaki slices (19+17+17 mantissa bits).
|
|
794
797
|
* Each slice fits in f32. Outputs stored via pointers. */
|
|
795
|
-
NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, nk_f64_t *slice_1,
|
|
796
|
-
nk_f64_t *slice_2) NK_STREAMING_ {
|
|
798
|
+
NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, nk_f64_t *slice_1, nk_f64_t *slice_2) {
|
|
797
799
|
nk_fui64_t pun;
|
|
798
800
|
pun.f = val;
|
|
799
801
|
pun.u &= nk_f64_smef64_ozaki_mask_19_bits_();
|
|
@@ -805,9 +807,9 @@ NK_PUBLIC void nk_f64_smef64_ozaki_split_f64_(nk_f64_t val, nk_f64_t *slice_0, n
|
|
|
805
807
|
*slice_2 = residual - *slice_1;
|
|
806
808
|
}
|
|
807
809
|
|
|
808
|
-
|
|
810
|
+
__arm_new("za") static void nk_dots_symmetric_f64_smef64_streaming_( //
|
|
809
811
|
nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_elements, nk_f64_t *result,
|
|
810
|
-
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) {
|
|
812
|
+
nk_size_t result_stride_elements, nk_size_t row_start, nk_size_t row_count) NK_STREAMING_ {
|
|
811
813
|
|
|
812
814
|
nk_size_t const tile_dimension = svcntd();
|
|
813
815
|
nk_size_t const depth_steps_per_batch = tile_dimension;
|
|
@@ -929,15 +931,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_symmetric_f64_smef64
|
|
|
929
931
|
}
|
|
930
932
|
}
|
|
931
933
|
|
|
932
|
-
NK_PUBLIC void nk_dots_symmetric_f64_smef64(
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
nk_size_t row_count) {
|
|
934
|
+
NK_PUBLIC void nk_dots_symmetric_f64_smef64( //
|
|
935
|
+
nk_f64_t const *vectors, nk_size_t vectors_count, nk_size_t depth, nk_size_t stride_in_bytes, nk_f64_t *result,
|
|
936
|
+
nk_size_t result_stride_in_bytes, nk_size_t row_start, nk_size_t row_count) {
|
|
936
937
|
|
|
937
938
|
nk_size_t const stride_elements = stride_in_bytes / sizeof(nk_f64_t);
|
|
938
939
|
nk_size_t const result_stride_elements = result_stride_in_bytes / sizeof(nk_f64_t);
|
|
940
|
+
nk_sme_start_streaming_();
|
|
939
941
|
nk_dots_symmetric_f64_smef64_streaming_(vectors, vectors_count, depth, stride_elements, result,
|
|
940
942
|
result_stride_elements, row_start, row_count);
|
|
943
|
+
nk_sme_stop_streaming_();
|
|
941
944
|
}
|
|
942
945
|
|
|
943
946
|
NK_PUBLIC nk_size_t nk_dots_packed_size_f64_smef64(nk_size_t columns, nk_size_t depth) {
|
|
@@ -1018,9 +1021,9 @@ NK_PUBLIC void nk_dots_pack_f64_smef64(nk_f64_t const *b, nk_size_t columns, nk_
|
|
|
1018
1021
|
}
|
|
1019
1022
|
}
|
|
1020
1023
|
|
|
1021
|
-
|
|
1024
|
+
__arm_new("za") static void nk_dots_packed_f64_smef64_streaming_( //
|
|
1022
1025
|
nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1023
|
-
nk_size_t a_stride_elements, nk_size_t c_stride_elements) {
|
|
1026
|
+
nk_size_t a_stride_elements, nk_size_t c_stride_elements) NK_STREAMING_ {
|
|
1024
1027
|
|
|
1025
1028
|
// Read header
|
|
1026
1029
|
nk_dots_sme_packed_header_t const *header = (nk_dots_sme_packed_header_t const *)b_packed;
|
|
@@ -1296,14 +1299,16 @@ __arm_locally_streaming __arm_new("za") static void nk_dots_packed_f64_smef64_st
|
|
|
1296
1299
|
}
|
|
1297
1300
|
}
|
|
1298
1301
|
|
|
1299
|
-
NK_PUBLIC void nk_dots_packed_f64_smef64(
|
|
1300
|
-
|
|
1301
|
-
|
|
1302
|
+
NK_PUBLIC void nk_dots_packed_f64_smef64( //
|
|
1303
|
+
nk_f64_t const *a, void const *b_packed, nk_f64_t *c, nk_size_t rows, nk_size_t columns, nk_size_t depth,
|
|
1304
|
+
nk_size_t a_stride_in_bytes, nk_size_t c_stride_in_bytes) {
|
|
1302
1305
|
|
|
1303
1306
|
nk_size_t const a_stride_elements = a_stride_in_bytes / sizeof(nk_f64_t);
|
|
1304
1307
|
nk_size_t const c_stride_elements = c_stride_in_bytes / sizeof(nk_f64_t);
|
|
1305
1308
|
|
|
1309
|
+
nk_sme_start_streaming_();
|
|
1306
1310
|
nk_dots_packed_f64_smef64_streaming_(a, b_packed, c, rows, columns, depth, a_stride_elements, c_stride_elements);
|
|
1311
|
+
nk_sme_stop_streaming_();
|
|
1307
1312
|
}
|
|
1308
1313
|
|
|
1309
1314
|
#pragma endregion F64 Floats
|
package/include/numkong/dots.h
CHANGED
|
@@ -681,6 +681,25 @@ NK_PUBLIC void nk_dots_symmetric_u8_sapphireamx(nk_u8_t const *vectors, nk_size_
|
|
|
681
681
|
nk_size_t row_start, nk_size_t row_count);
|
|
682
682
|
#endif // NK_TARGET_SAPPHIREAMX
|
|
683
683
|
|
|
684
|
+
/* Granite Rapids backends using Intel AMX-FP16 (Advanced Matrix Extensions with FP16 support).
|
|
685
|
+
* AMX-FP16 adds TDPFP16PS (FP16×FP16→FP32 tile multiply-accumulate), same tile geometry as BF16.
|
|
686
|
+
* The F32 Ozaki kernel splits F32 inputs into 2 FP16 halves for ~35-40 bit effective precision.
|
|
687
|
+
*/
|
|
688
|
+
#if NK_TARGET_GRANITEAMX
|
|
689
|
+
/** @copydoc nk_dots_packed_size_f16 */
|
|
690
|
+
NK_PUBLIC nk_size_t nk_dots_packed_size_f16_graniteamx(nk_size_t width, nk_size_t depth);
|
|
691
|
+
/** @copydoc nk_dots_pack_f16 */
|
|
692
|
+
NK_PUBLIC void nk_dots_pack_f16_graniteamx(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
693
|
+
void *b_packed);
|
|
694
|
+
/** @copydoc nk_dots_packed_f16 */
|
|
695
|
+
NK_PUBLIC void nk_dots_packed_f16_graniteamx(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
696
|
+
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride);
|
|
697
|
+
/** @copydoc nk_dots_symmetric_f16 */
|
|
698
|
+
NK_PUBLIC void nk_dots_symmetric_f16_graniteamx(nk_f16_t const *vectors, nk_size_t vectors_count, nk_size_t depth,
|
|
699
|
+
nk_size_t stride, nk_f32_t *result, nk_size_t result_stride,
|
|
700
|
+
nk_size_t row_start, nk_size_t row_count);
|
|
701
|
+
#endif // NK_TARGET_GRANITEAMX
|
|
702
|
+
|
|
684
703
|
/* ARM SME backends using Scalable Matrix Extension.
|
|
685
704
|
* SME provides ZA tile registers for outer product operations.
|
|
686
705
|
* F16/BF16/I8/U8/E4M3 use ZA32 tiles, F32/F64 use ZA64 tiles (FEAT_SME_F64F64).
|
|
@@ -1858,6 +1877,7 @@ NK_PUBLIC void nk_dots_symmetric_u1_loongsonasx(nk_u1x8_t const *vectors, nk_siz
|
|
|
1858
1877
|
#include "numkong/dots/genoa.h"
|
|
1859
1878
|
#include "numkong/dots/diamond.h"
|
|
1860
1879
|
#include "numkong/dots/sapphireamx.h"
|
|
1880
|
+
#include "numkong/dots/graniteamx.h"
|
|
1861
1881
|
#include "numkong/dots/neon.h"
|
|
1862
1882
|
#include "numkong/dots/neonsdot.h"
|
|
1863
1883
|
#include "numkong/dots/neonfhm.h"
|
|
@@ -2002,7 +2022,9 @@ NK_PUBLIC void nk_dots_packed_f64(nk_f64_t const *a, void const *b_packed, nk_f6
|
|
|
2002
2022
|
}
|
|
2003
2023
|
|
|
2004
2024
|
NK_PUBLIC nk_size_t nk_dots_packed_size_f16(nk_size_t width, nk_size_t depth) {
|
|
2005
|
-
#if
|
|
2025
|
+
#if NK_TARGET_GRANITEAMX
|
|
2026
|
+
return nk_dots_packed_size_f16_graniteamx(width, depth);
|
|
2027
|
+
#elif NK_TARGET_SME
|
|
2006
2028
|
return nk_dots_packed_size_f16_sme(width, depth);
|
|
2007
2029
|
#elif NK_TARGET_NEONFHM
|
|
2008
2030
|
return nk_dots_packed_size_f16_neonfhm(width, depth);
|
|
@@ -2023,7 +2045,9 @@ NK_PUBLIC nk_size_t nk_dots_packed_size_f16(nk_size_t width, nk_size_t depth) {
|
|
|
2023
2045
|
|
|
2024
2046
|
NK_PUBLIC void nk_dots_pack_f16(nk_f16_t const *b, nk_size_t width, nk_size_t depth, nk_size_t b_stride,
|
|
2025
2047
|
void *b_packed) {
|
|
2026
|
-
#if
|
|
2048
|
+
#if NK_TARGET_GRANITEAMX
|
|
2049
|
+
nk_dots_pack_f16_graniteamx(b, width, depth, b_stride, b_packed);
|
|
2050
|
+
#elif NK_TARGET_SME
|
|
2027
2051
|
nk_dots_pack_f16_sme(b, width, depth, b_stride, b_packed);
|
|
2028
2052
|
#elif NK_TARGET_NEONFHM
|
|
2029
2053
|
nk_dots_pack_f16_neonfhm(b, width, depth, b_stride, b_packed);
|
|
@@ -2044,7 +2068,9 @@ NK_PUBLIC void nk_dots_pack_f16(nk_f16_t const *b, nk_size_t width, nk_size_t de
|
|
|
2044
2068
|
|
|
2045
2069
|
NK_PUBLIC void nk_dots_packed_f16(nk_f16_t const *a, void const *b_packed, nk_f32_t *c, nk_size_t height,
|
|
2046
2070
|
nk_size_t width, nk_size_t depth, nk_size_t a_stride, nk_size_t c_stride) {
|
|
2047
|
-
#if
|
|
2071
|
+
#if NK_TARGET_GRANITEAMX
|
|
2072
|
+
nk_dots_packed_f16_graniteamx(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2073
|
+
#elif NK_TARGET_SME
|
|
2048
2074
|
nk_dots_packed_f16_sme(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
2049
2075
|
#elif NK_TARGET_NEONFHM
|
|
2050
2076
|
nk_dots_packed_f16_neonfhm(a, b_packed, c, height, width, depth, a_stride, c_stride);
|
|
@@ -76,6 +76,18 @@ extern "C" {
|
|
|
76
76
|
} \
|
|
77
77
|
}
|
|
78
78
|
|
|
79
|
+
/* Optimize serial fallbacks for size — see dots/serial.h for rationale. */
|
|
80
|
+
#if defined(NDEBUG)
|
|
81
|
+
#if defined(_MSC_VER)
|
|
82
|
+
#pragma optimize("s", on)
|
|
83
|
+
#elif defined(__clang__)
|
|
84
|
+
#pragma clang attribute push(__attribute__((minsize)), apply_to = function)
|
|
85
|
+
#elif defined(__GNUC__)
|
|
86
|
+
#pragma GCC push_options
|
|
87
|
+
#pragma GCC optimize("Os")
|
|
88
|
+
#endif
|
|
89
|
+
#endif
|
|
90
|
+
|
|
79
91
|
nk_define_each_sum_(f64, f64, nk_assign_from_to_, nk_assign_from_to_) // nk_each_sum_f64_serial
|
|
80
92
|
nk_define_each_sum_(f32, f32, nk_assign_from_to_, nk_assign_from_to_) // nk_each_sum_f32_serial
|
|
81
93
|
nk_define_each_sum_(f16, f32, nk_f16_to_f32_serial, nk_f32_to_f16_serial) // nk_each_sum_f16_serial
|
|
@@ -253,6 +265,16 @@ NK_PUBLIC void nk_each_fma_f64c_serial(nk_f64c_t const *a, nk_f64c_t const *b, n
|
|
|
253
265
|
}
|
|
254
266
|
}
|
|
255
267
|
|
|
268
|
+
#if defined(NDEBUG)
|
|
269
|
+
#if defined(_MSC_VER)
|
|
270
|
+
#pragma optimize("", on)
|
|
271
|
+
#elif defined(__clang__)
|
|
272
|
+
#pragma clang attribute pop
|
|
273
|
+
#elif defined(__GNUC__)
|
|
274
|
+
#pragma GCC pop_options
|
|
275
|
+
#endif
|
|
276
|
+
#endif
|
|
277
|
+
|
|
256
278
|
#if defined(__cplusplus)
|
|
257
279
|
} // extern "C"
|
|
258
280
|
#endif
|
|
@@ -24,7 +24,7 @@
|
|
|
24
24
|
#if NK_TARGET_HASWELL
|
|
25
25
|
|
|
26
26
|
#include "numkong/types.h"
|
|
27
|
-
#include "numkong/trigonometry/haswell.h" // `nk_sin_f64x4_haswell_`, `nk_cos_f64x4_haswell_`, `nk_atan2_f64x4_haswell_
|
|
27
|
+
#include "numkong/trigonometry/haswell.h" // `nk_sin_f64x4_haswell_`, `nk_cos_f64x4_haswell_`, `nk_atan2_f64x4_haswell_`
|
|
28
28
|
|
|
29
29
|
#if defined(__cplusplus)
|
|
30
30
|
extern "C" {
|
|
@@ -21,7 +21,7 @@
|
|
|
21
21
|
#if NK_TARGET_NEON
|
|
22
22
|
|
|
23
23
|
#include "numkong/types.h"
|
|
24
|
-
#include "numkong/trigonometry/neon.h" // `nk_sin_f64x2_neon_`, `nk_cos_f64x2_neon_`, `nk_atan2_f64x2_neon_
|
|
24
|
+
#include "numkong/trigonometry/neon.h" // `nk_sin_f64x2_neon_`, `nk_cos_f64x2_neon_`, `nk_atan2_f64x2_neon_`
|
|
25
25
|
|
|
26
26
|
#if defined(__cplusplus)
|
|
27
27
|
extern "C" {
|
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
|
|
12
12
|
#include "numkong/types.h"
|
|
13
13
|
#include "numkong/spatial/serial.h" // `nk_f64_sqrt_serial`, `nk_f32_sqrt_serial`
|
|
14
|
-
#include "numkong/trigonometry/serial.h" // `nk_f64_sin`, `nk_f64_cos`, `nk_f64_atan2
|
|
14
|
+
#include "numkong/trigonometry/serial.h" // `nk_f64_sin`, `nk_f64_cos`, `nk_f64_atan2`
|
|
15
15
|
|
|
16
16
|
#if defined(__cplusplus)
|
|
17
17
|
extern "C" {
|
|
@@ -24,7 +24,7 @@
|
|
|
24
24
|
#if NK_TARGET_SKYLAKE
|
|
25
25
|
|
|
26
26
|
#include "numkong/types.h"
|
|
27
|
-
#include "numkong/trigonometry/skylake.h" // `nk_sin_f64x8_skylake_`, `nk_cos_f64x8_skylake_`, `nk_atan2_f64x8_skylake_
|
|
27
|
+
#include "numkong/trigonometry/skylake.h" // `nk_sin_f64x8_skylake_`, `nk_cos_f64x8_skylake_`, `nk_atan2_f64x8_skylake_`
|
|
28
28
|
|
|
29
29
|
#if defined(__cplusplus)
|
|
30
30
|
extern "C" {
|
|
@@ -46,7 +46,8 @@
|
|
|
46
46
|
#if NK_TARGET_ARM64_
|
|
47
47
|
#if NK_TARGET_SME
|
|
48
48
|
|
|
49
|
-
#include "numkong/dots/sme.h"
|
|
49
|
+
#include "numkong/dots/sme.h" // `nk_dots_sme_packed_header_t`
|
|
50
|
+
#include "numkong/reduce/sve.h" // `nk_svaddv_f64_`
|
|
50
51
|
|
|
51
52
|
#if defined(__cplusplus)
|
|
52
53
|
extern "C" {
|
|
@@ -90,10 +91,9 @@ NK_STATIC_ASSERT(sizeof(nk_maxsim_sme_packed_header_t) == 64, nk_maxsim_sme_pack
|
|
|
90
91
|
*
|
|
91
92
|
* 1-tile remainder: uses ZA0 only, with predicated loads for partial tiles.
|
|
92
93
|
*/
|
|
93
|
-
|
|
94
|
-
void const *query_packed, void const *document_packed,
|
|
95
|
-
nk_size_t
|
|
96
|
-
nk_size_t depth, nk_f32_t *result) {
|
|
94
|
+
__arm_new("za") static void nk_maxsim_packed_f16_streaming_( //
|
|
95
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
96
|
+
nk_size_t depth, nk_f32_t *result) NK_STREAMING_ {
|
|
97
97
|
|
|
98
98
|
nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
|
|
99
99
|
nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
|
|
@@ -258,18 +258,19 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f16_streami
|
|
|
258
258
|
document_inverse_norms_f32x);
|
|
259
259
|
svfloat32_t angular_distance_f32x = svmax_f32_x(
|
|
260
260
|
row_predicate_b32x, svsub_f32_x(row_predicate_b32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
|
|
261
|
-
total_angular_distance +=
|
|
261
|
+
total_angular_distance += nk_svaddv_f32_(row_predicate_b32x, angular_distance_f32x);
|
|
262
262
|
}
|
|
263
263
|
|
|
264
264
|
*result = total_angular_distance;
|
|
265
265
|
}
|
|
266
266
|
|
|
267
|
-
NK_PUBLIC void nk_maxsim_packed_f16_sme(
|
|
268
|
-
void const *query_packed, void const *document_packed,
|
|
269
|
-
nk_size_t
|
|
270
|
-
nk_f32_t *result) { //
|
|
267
|
+
NK_PUBLIC void nk_maxsim_packed_f16_sme( //
|
|
268
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
269
|
+
nk_size_t depth, nk_f32_t *result) {
|
|
271
270
|
|
|
271
|
+
nk_sme_start_streaming_();
|
|
272
272
|
nk_maxsim_packed_f16_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
|
|
273
|
+
nk_sme_stop_streaming_();
|
|
273
274
|
}
|
|
274
275
|
|
|
275
276
|
/**
|
|
@@ -281,10 +282,9 @@ NK_PUBLIC void nk_maxsim_packed_f16_sme( //
|
|
|
281
282
|
*
|
|
282
283
|
* 1-tile remainder: uses ZA0 only, with predicated loads for partial tiles.
|
|
283
284
|
*/
|
|
284
|
-
|
|
285
|
-
void const *query_packed, void const *document_packed,
|
|
286
|
-
nk_size_t
|
|
287
|
-
nk_size_t depth, nk_f32_t *result) {
|
|
285
|
+
__arm_new("za") static void nk_maxsim_packed_bf16_streaming_( //
|
|
286
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
287
|
+
nk_size_t depth, nk_f32_t *result) NK_STREAMING_ {
|
|
288
288
|
|
|
289
289
|
nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
|
|
290
290
|
nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
|
|
@@ -454,18 +454,19 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_bf16_stream
|
|
|
454
454
|
document_inverse_norms_f32x);
|
|
455
455
|
svfloat32_t angular_distance_f32x = svmax_f32_x(
|
|
456
456
|
row_predicate_b32x, svsub_f32_x(row_predicate_b32x, svdup_f32(1.0f), cosine_f32x), svdup_f32(0.0f));
|
|
457
|
-
total_angular_distance +=
|
|
457
|
+
total_angular_distance += nk_svaddv_f32_(row_predicate_b32x, angular_distance_f32x);
|
|
458
458
|
}
|
|
459
459
|
|
|
460
460
|
*result = total_angular_distance;
|
|
461
461
|
}
|
|
462
462
|
|
|
463
|
-
NK_PUBLIC void nk_maxsim_packed_bf16_sme(
|
|
464
|
-
void const *query_packed, void const *document_packed,
|
|
465
|
-
nk_size_t
|
|
466
|
-
nk_f32_t *result) { //
|
|
463
|
+
NK_PUBLIC void nk_maxsim_packed_bf16_sme( //
|
|
464
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
465
|
+
nk_size_t depth, nk_f32_t *result) {
|
|
467
466
|
|
|
467
|
+
nk_sme_start_streaming_();
|
|
468
468
|
nk_maxsim_packed_bf16_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
|
|
469
|
+
nk_sme_stop_streaming_();
|
|
469
470
|
}
|
|
470
471
|
|
|
471
472
|
NK_PUBLIC nk_size_t nk_maxsim_packed_size_bf16_sme(nk_size_t columns, nk_size_t depth) { //
|
|
@@ -649,7 +650,7 @@ NK_PUBLIC nk_f64_t nk_maxsim_reduce_dot_f32_ssve_( //
|
|
|
649
650
|
svfloat64_t b_odd_f64x = svcvtlt_f64_f32_x(predicate_odd_b64x, b_f32x);
|
|
650
651
|
accumulator_odd_f64x = svmla_f64_m(predicate_odd_b64x, accumulator_odd_f64x, a_odd_f64x, b_odd_f64x);
|
|
651
652
|
}
|
|
652
|
-
return
|
|
653
|
+
return nk_svaddv_f64_(svptrue_b64(), accumulator_even_f64x) + nk_svaddv_f64_(svptrue_b64(), accumulator_odd_f64x);
|
|
653
654
|
}
|
|
654
655
|
|
|
655
656
|
/**
|
|
@@ -687,7 +688,7 @@ NK_PUBLIC nk_f64_t nk_maxsim_angular_from_dots_ssve_(
|
|
|
687
688
|
svfloat64_t angular_distance_f64x = svsub_f64_x(predicate_b64x, svdup_f64(1.0), cosine_f64x);
|
|
688
689
|
angular_distance_f64x = svmax_f64_x(predicate_b64x, angular_distance_f64x, svdup_f64(0.0));
|
|
689
690
|
|
|
690
|
-
total_angular_distance_f64 +=
|
|
691
|
+
total_angular_distance_f64 += nk_svaddv_f64_(predicate_b64x, angular_distance_f64x);
|
|
691
692
|
}
|
|
692
693
|
return total_angular_distance_f64;
|
|
693
694
|
}
|
|
@@ -701,10 +702,9 @@ NK_PUBLIC nk_f64_t nk_maxsim_angular_from_dots_ssve_(
|
|
|
701
702
|
* Refinement: tile-wide interleaved f64 dot products for the winning (query, document) pairs.
|
|
702
703
|
* Angular distance: 1 - dot / sqrt(||q||^2 * ||d||^2), accumulated with f64.
|
|
703
704
|
*/
|
|
704
|
-
|
|
705
|
-
void const *query_packed, void const *document_packed,
|
|
706
|
-
nk_size_t
|
|
707
|
-
nk_f64_t *result) {
|
|
705
|
+
__arm_new("za") static void nk_maxsim_packed_f32_streaming_( //
|
|
706
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
707
|
+
nk_size_t depth, nk_f64_t *result) NK_STREAMING_ {
|
|
708
708
|
|
|
709
709
|
nk_maxsim_sme_packed_header_t const *query_header = (nk_maxsim_sme_packed_header_t const *)query_packed;
|
|
710
710
|
nk_maxsim_sme_packed_header_t const *document_header = (nk_maxsim_sme_packed_header_t const *)document_packed;
|
|
@@ -937,10 +937,10 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
937
937
|
|
|
938
938
|
// Reduce SVE accumulators to scalars and compute angular distances
|
|
939
939
|
nk_f64_t dot_products_f64[4];
|
|
940
|
-
dot_products_f64[0] =
|
|
941
|
-
dot_products_f64[1] =
|
|
942
|
-
dot_products_f64[2] =
|
|
943
|
-
dot_products_f64[3] =
|
|
940
|
+
dot_products_f64[0] = nk_svaddv_f64_(svptrue_b64(), accumulator_0_f64x);
|
|
941
|
+
dot_products_f64[1] = nk_svaddv_f64_(svptrue_b64(), accumulator_1_f64x);
|
|
942
|
+
dot_products_f64[2] = nk_svaddv_f64_(svptrue_b64(), accumulator_2_f64x);
|
|
943
|
+
dot_products_f64[3] = nk_svaddv_f64_(svptrue_b64(), accumulator_3_f64x);
|
|
944
944
|
nk_f64_t batch_query_norms_f64[4], batch_document_norms_f64[4];
|
|
945
945
|
for (nk_size_t i = 0; i < 4; i++) {
|
|
946
946
|
batch_query_norms_f64[i] = (nk_f64_t)query_norms[row_start + row_batch_start + i];
|
|
@@ -969,12 +969,13 @@ __arm_locally_streaming __arm_new("za") static void nk_maxsim_packed_f32_streami
|
|
|
969
969
|
*result = total_angular_distance_f64;
|
|
970
970
|
}
|
|
971
971
|
|
|
972
|
-
NK_PUBLIC void nk_maxsim_packed_f32_sme(
|
|
973
|
-
void const *query_packed, void const *document_packed,
|
|
974
|
-
nk_size_t
|
|
975
|
-
nk_f64_t *result) { //
|
|
972
|
+
NK_PUBLIC void nk_maxsim_packed_f32_sme( //
|
|
973
|
+
void const *query_packed, void const *document_packed, nk_size_t query_count, nk_size_t document_count,
|
|
974
|
+
nk_size_t depth, nk_f64_t *result) {
|
|
976
975
|
|
|
976
|
+
nk_sme_start_streaming_();
|
|
977
977
|
nk_maxsim_packed_f32_streaming_(query_packed, document_packed, query_count, document_count, depth, result);
|
|
978
|
+
nk_sme_stop_streaming_();
|
|
978
979
|
}
|
|
979
980
|
|
|
980
981
|
#if defined(__clang__)
|
|
@@ -289,6 +289,18 @@ extern "C" {
|
|
|
289
289
|
m[2] * (m[3] * m[7] - m[4] * m[6]); \
|
|
290
290
|
}
|
|
291
291
|
|
|
292
|
+
/* Optimize serial fallbacks for size — see dots/serial.h for rationale. */
|
|
293
|
+
#if defined(NDEBUG)
|
|
294
|
+
#if defined(_MSC_VER)
|
|
295
|
+
#pragma optimize("s", on)
|
|
296
|
+
#elif defined(__clang__)
|
|
297
|
+
#pragma clang attribute push(__attribute__((minsize)), apply_to = function)
|
|
298
|
+
#elif defined(__GNUC__)
|
|
299
|
+
#pragma GCC push_options
|
|
300
|
+
#pragma GCC optimize("Os")
|
|
301
|
+
#endif
|
|
302
|
+
#endif
|
|
303
|
+
|
|
292
304
|
NK_INTERNAL nk_f32_t nk_sum_three_products_f32_(nk_f32_t left_0, nk_f32_t right_0, nk_f32_t left_1, nk_f32_t right_1,
|
|
293
305
|
nk_f32_t left_2, nk_f32_t right_2) {
|
|
294
306
|
return left_0 * right_0 + left_1 * right_1 + left_2 * right_2;
|
|
@@ -692,6 +704,16 @@ nk_define_umeyama_(bf16, f32, f32, f32, f32, nk_bf16_to_f32_serial, nk_f32_sqrt_
|
|
|
692
704
|
#undef nk_define_kabsch_
|
|
693
705
|
#undef nk_define_umeyama_
|
|
694
706
|
|
|
707
|
+
#if defined(NDEBUG)
|
|
708
|
+
#if defined(_MSC_VER)
|
|
709
|
+
#pragma optimize("", on)
|
|
710
|
+
#elif defined(__clang__)
|
|
711
|
+
#pragma clang attribute pop
|
|
712
|
+
#elif defined(__GNUC__)
|
|
713
|
+
#pragma GCC pop_options
|
|
714
|
+
#endif
|
|
715
|
+
#endif
|
|
716
|
+
|
|
695
717
|
#if defined(__cplusplus)
|
|
696
718
|
} // extern "C"
|
|
697
719
|
#endif
|
|
@@ -3936,6 +3936,35 @@ NK_PUBLIC void nk_reduce_moments_f16_neon( //
|
|
|
3936
3936
|
else nk_reduce_moments_f16_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3937
3937
|
}
|
|
3938
3938
|
|
|
3939
|
+
NK_INTERNAL void nk_reduce_moments_u1_neon_contiguous_( //
|
|
3940
|
+
nk_u1x8_t const *data_ptr, nk_size_t count, //
|
|
3941
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3942
|
+
nk_size_t byte_count = nk_size_divide_round_up_(count, NK_BITS_PER_BYTE);
|
|
3943
|
+
nk_u64_t sum = 0;
|
|
3944
|
+
nk_size_t idx = 0;
|
|
3945
|
+
// Each vcntq_u8 produces values 0-8 per lane; accumulate at u8 level
|
|
3946
|
+
// for up to 31 iterations (31 × 8 = 248, fits in u8) before widening.
|
|
3947
|
+
while (idx + 16 <= byte_count) {
|
|
3948
|
+
uint8x16_t popcount_u8x16 = vdupq_n_u8(0);
|
|
3949
|
+
for (nk_size_t cycle = 0; cycle < 31 && idx + 16 <= byte_count; ++cycle, idx += 16) {
|
|
3950
|
+
uint8x16_t data_u8x16 = vld1q_u8((nk_u8_t const *)data_ptr + idx);
|
|
3951
|
+
popcount_u8x16 = vaddq_u8(popcount_u8x16, vcntq_u8(data_u8x16));
|
|
3952
|
+
}
|
|
3953
|
+
sum += (nk_u64_t)vaddlvq_u8(popcount_u8x16);
|
|
3954
|
+
}
|
|
3955
|
+
for (; idx < byte_count; ++idx) sum += nk_u1x8_popcount_(((nk_u8_t const *)data_ptr)[idx]);
|
|
3956
|
+
*sum_ptr = sum, *sumsq_ptr = sum;
|
|
3957
|
+
}
|
|
3958
|
+
|
|
3959
|
+
NK_PUBLIC void nk_reduce_moments_u1_neon( //
|
|
3960
|
+
nk_u1x8_t const *data_ptr, nk_size_t count, nk_size_t stride_bytes, //
|
|
3961
|
+
nk_u64_t *sum_ptr, nk_u64_t *sumsq_ptr) {
|
|
3962
|
+
count = nk_size_round_up_to_multiple_(count, 8);
|
|
3963
|
+
if (count == 0) *sum_ptr = 0, *sumsq_ptr = 0;
|
|
3964
|
+
else if (stride_bytes == 1) nk_reduce_moments_u1_neon_contiguous_(data_ptr, count, sum_ptr, sumsq_ptr);
|
|
3965
|
+
else nk_reduce_moments_u1_serial(data_ptr, count, stride_bytes, sum_ptr, sumsq_ptr);
|
|
3966
|
+
}
|
|
3967
|
+
|
|
3939
3968
|
#if defined(__clang__)
|
|
3940
3969
|
#pragma clang attribute pop
|
|
3941
3970
|
#elif defined(__GNUC__)
|
|
@@ -33,7 +33,7 @@ NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_contiguous_( //
|
|
|
33
33
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
34
34
|
|
|
35
35
|
// bf16 representation of 1.0 is 0x3F80 (same as upper 16 bits of f32 1.0)
|
|
36
|
-
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(
|
|
36
|
+
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
|
|
37
37
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
38
38
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
39
39
|
nk_size_t idx = 0;
|
|
@@ -61,7 +61,7 @@ NK_INTERNAL void nk_reduce_moments_bf16_neonbfdot_strided_( //
|
|
|
61
61
|
nk_bf16_t const *data_ptr, nk_size_t count, nk_size_t stride_elements, //
|
|
62
62
|
nk_f32_t *sum_ptr, nk_f32_t *sumsq_ptr) {
|
|
63
63
|
|
|
64
|
-
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(
|
|
64
|
+
bfloat16x8_t ones_bf16x8 = vreinterpretq_bf16_u16(nk_u16x8_splat_(0x3F80));
|
|
65
65
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
66
66
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
67
67
|
nk_size_t idx = 0;
|
|
@@ -34,7 +34,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_contiguous_( //
|
|
|
34
34
|
|
|
35
35
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
36
36
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
37
|
-
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(
|
|
37
|
+
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
|
|
38
38
|
nk_size_t idx = 0;
|
|
39
39
|
|
|
40
40
|
for (; idx + 8 <= count; idx += 8) {
|
|
@@ -67,7 +67,7 @@ NK_INTERNAL void nk_reduce_moments_e4m3_neonfhm_strided_( //
|
|
|
67
67
|
|
|
68
68
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
69
69
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
70
|
-
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(
|
|
70
|
+
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
|
|
71
71
|
nk_size_t idx = 0;
|
|
72
72
|
|
|
73
73
|
if (stride_elements == 2) {
|
|
@@ -159,7 +159,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_contiguous_( //
|
|
|
159
159
|
|
|
160
160
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
161
161
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
162
|
-
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(
|
|
162
|
+
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
|
|
163
163
|
nk_size_t idx = 0;
|
|
164
164
|
|
|
165
165
|
for (; idx + 8 <= count; idx += 8) {
|
|
@@ -192,7 +192,7 @@ NK_INTERNAL void nk_reduce_moments_e5m2_neonfhm_strided_( //
|
|
|
192
192
|
|
|
193
193
|
float32x4_t sum_f32x4 = vdupq_n_f32(0);
|
|
194
194
|
float32x4_t sumsq_f32x4 = vdupq_n_f32(0);
|
|
195
|
-
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(
|
|
195
|
+
float16x8_t ones_f16x8 = vreinterpretq_f16_u16(nk_u16x8_splat_(0x3C00));
|
|
196
196
|
nk_size_t idx = 0;
|
|
197
197
|
|
|
198
198
|
if (stride_elements == 2) {
|