numkong 7.5.0 → 7.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/binding.gyp +18 -0
- package/c/dispatch_e5m2.c +23 -3
- package/include/numkong/capabilities.h +1 -1
- package/include/numkong/cast/README.md +3 -0
- package/include/numkong/cast/haswell.h +28 -64
- package/include/numkong/cast/serial.h +17 -0
- package/include/numkong/cast/skylake.h +67 -52
- package/include/numkong/cast.h +1 -0
- package/include/numkong/dot/README.md +1 -0
- package/include/numkong/dot/haswell.h +92 -13
- package/include/numkong/dot/serial.h +15 -0
- package/include/numkong/dot/skylake.h +61 -14
- package/include/numkong/dots/README.md +2 -0
- package/include/numkong/dots/graniteamx.h +434 -0
- package/include/numkong/dots/haswell.h +28 -28
- package/include/numkong/dots/sapphireamx.h +1 -1
- package/include/numkong/dots/serial.h +23 -8
- package/include/numkong/dots/skylake.h +28 -23
- package/include/numkong/dots.h +12 -0
- package/include/numkong/each/serial.h +18 -1
- package/include/numkong/geospatial/serial.h +14 -3
- package/include/numkong/maxsim/serial.h +15 -0
- package/include/numkong/mesh/README.md +50 -44
- package/include/numkong/mesh/genoa.h +462 -0
- package/include/numkong/mesh/haswell.h +806 -933
- package/include/numkong/mesh/neon.h +871 -943
- package/include/numkong/mesh/neonbfdot.h +382 -522
- package/include/numkong/mesh/neonfhm.h +676 -0
- package/include/numkong/mesh/rvv.h +404 -319
- package/include/numkong/mesh/serial.h +204 -162
- package/include/numkong/mesh/skylake.h +1029 -1585
- package/include/numkong/mesh/v128relaxed.h +403 -377
- package/include/numkong/mesh.h +38 -0
- package/include/numkong/reduce/serial.h +15 -1
- package/include/numkong/sparse/serial.h +17 -2
- package/include/numkong/spatial/genoa.h +0 -68
- package/include/numkong/spatial/haswell.h +98 -56
- package/include/numkong/spatial/serial.h +15 -0
- package/include/numkong/spatial/skylake.h +114 -54
- package/include/numkong/spatial.h +0 -12
- package/include/numkong/spatials/graniteamx.h +128 -0
- package/include/numkong/spatials/serial.h +18 -1
- package/include/numkong/spatials/skylake.h +2 -2
- package/include/numkong/spatials.h +17 -0
- package/include/numkong/tensor.hpp +107 -23
- package/javascript/numkong.c +3 -2
- package/package.json +7 -7
- package/wasm/numkong.wasm +0 -0
package/include/numkong/mesh.h
CHANGED
|
@@ -266,6 +266,20 @@ NK_PUBLIC void nk_umeyama_bf16_skylake(nk_bf16_t const *a, nk_bf16_t const *b, n
|
|
|
266
266
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
267
267
|
#endif // NK_TARGET_SKYLAKE
|
|
268
268
|
|
|
269
|
+
/* SIMD-powered backends for AVX512-BF16 CPUs of AMD Genoa / Intel Sapphire Rapids generation and newer.
|
|
270
|
+
*/
|
|
271
|
+
#if NK_TARGET_GENOA
|
|
272
|
+
/** @copydoc nk_rmsd_bf16 */
|
|
273
|
+
NK_PUBLIC void nk_rmsd_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
274
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
275
|
+
/** @copydoc nk_kabsch_bf16 */
|
|
276
|
+
NK_PUBLIC void nk_kabsch_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
277
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
278
|
+
/** @copydoc nk_umeyama_bf16 */
|
|
279
|
+
NK_PUBLIC void nk_umeyama_bf16_genoa(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
280
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
281
|
+
#endif // NK_TARGET_GENOA
|
|
282
|
+
|
|
269
283
|
/* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer.
|
|
270
284
|
*/
|
|
271
285
|
#if NK_TARGET_HASWELL
|
|
@@ -357,6 +371,20 @@ NK_PUBLIC void nk_umeyama_bf16_neonbfdot(nk_bf16_t const *a, nk_bf16_t const *b,
|
|
|
357
371
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
358
372
|
#endif // NK_TARGET_NEONBFDOT
|
|
359
373
|
|
|
374
|
+
/* SIMD-powered backends for Arm NEON FHM (FP16 widening FMA) CPUs.
|
|
375
|
+
*/
|
|
376
|
+
#if NK_TARGET_NEONFHM
|
|
377
|
+
/** @copydoc nk_rmsd_f16 */
|
|
378
|
+
NK_PUBLIC void nk_rmsd_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
379
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
380
|
+
/** @copydoc nk_kabsch_f16 */
|
|
381
|
+
NK_PUBLIC void nk_kabsch_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
382
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
383
|
+
/** @copydoc nk_umeyama_f16 */
|
|
384
|
+
NK_PUBLIC void nk_umeyama_f16_neonfhm(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
385
|
+
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result);
|
|
386
|
+
#endif // NK_TARGET_NEONFHM
|
|
387
|
+
|
|
360
388
|
#if NK_TARGET_RVV
|
|
361
389
|
/** @copydoc nk_rmsd_f32 */
|
|
362
390
|
NK_PUBLIC void nk_rmsd_f32_rvv(nk_f32_t const *a, nk_f32_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
@@ -454,8 +482,10 @@ NK_INTERNAL nk_dtype_t nk_mesh_transform_dtype(nk_dtype_t dtype) {
|
|
|
454
482
|
#include "numkong/mesh/serial.h"
|
|
455
483
|
#include "numkong/mesh/neon.h"
|
|
456
484
|
#include "numkong/mesh/neonbfdot.h"
|
|
485
|
+
#include "numkong/mesh/neonfhm.h"
|
|
457
486
|
#include "numkong/mesh/haswell.h"
|
|
458
487
|
#include "numkong/mesh/skylake.h"
|
|
488
|
+
#include "numkong/mesh/genoa.h"
|
|
459
489
|
#include "numkong/mesh/rvv.h"
|
|
460
490
|
#include "numkong/mesh/v128relaxed.h"
|
|
461
491
|
|
|
@@ -505,6 +535,8 @@ NK_PUBLIC void nk_rmsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk
|
|
|
505
535
|
nk_rmsd_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
506
536
|
#elif NK_TARGET_HASWELL
|
|
507
537
|
nk_rmsd_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
538
|
+
#elif NK_TARGET_NEONFHM
|
|
539
|
+
nk_rmsd_f16_neonfhm(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
508
540
|
#elif NK_TARGET_NEON
|
|
509
541
|
nk_rmsd_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
510
542
|
#elif NK_TARGET_RVV
|
|
@@ -517,6 +549,8 @@ NK_PUBLIC void nk_rmsd_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n, nk
|
|
|
517
549
|
NK_PUBLIC void nk_rmsd_bf16(nk_bf16_t const *a, nk_bf16_t const *b, nk_size_t n, nk_f32_t *a_centroid,
|
|
518
550
|
nk_f32_t *b_centroid, nk_f32_t *rotation, nk_f32_t *scale, nk_f32_t *result) {
|
|
519
551
|
#if NK_TARGET_SKYLAKE
|
|
552
|
+
// Skylake f32-widen path wins on Intel where VDPBF16PS throughput matches FMA; on AMD Zen4+
|
|
553
|
+
// where VDPBF16PS is faster than FMA, users can call `nk_rmsd_bf16_genoa` directly.
|
|
520
554
|
nk_rmsd_bf16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
521
555
|
#elif NK_TARGET_HASWELL
|
|
522
556
|
nk_rmsd_bf16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
@@ -569,6 +603,8 @@ NK_PUBLIC void nk_kabsch_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
|
|
|
569
603
|
nk_kabsch_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
570
604
|
#elif NK_TARGET_HASWELL
|
|
571
605
|
nk_kabsch_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
606
|
+
#elif NK_TARGET_NEONFHM
|
|
607
|
+
nk_kabsch_f16_neonfhm(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
572
608
|
#elif NK_TARGET_NEON
|
|
573
609
|
nk_kabsch_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
574
610
|
#elif NK_TARGET_RVV
|
|
@@ -633,6 +669,8 @@ NK_PUBLIC void nk_umeyama_f16(nk_f16_t const *a, nk_f16_t const *b, nk_size_t n,
|
|
|
633
669
|
nk_umeyama_f16_skylake(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
634
670
|
#elif NK_TARGET_HASWELL
|
|
635
671
|
nk_umeyama_f16_haswell(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
672
|
+
#elif NK_TARGET_NEONFHM
|
|
673
|
+
nk_umeyama_f16_neonfhm(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
636
674
|
#elif NK_TARGET_NEON
|
|
637
675
|
nk_umeyama_f16_neon(a, b, n, a_centroid, b_centroid, rotation, scale, result);
|
|
638
676
|
#elif NK_TARGET_RVV
|
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
#define NK_REDUCE_SERIAL_H
|
|
15
15
|
|
|
16
16
|
#include "numkong/types.h"
|
|
17
|
-
#include "numkong/scalar/serial.h"
|
|
18
17
|
#include "numkong/cast/serial.h"
|
|
19
18
|
#include "numkong/scalar/serial.h"
|
|
20
19
|
|
|
@@ -22,6 +21,15 @@
|
|
|
22
21
|
extern "C" {
|
|
23
22
|
#endif
|
|
24
23
|
|
|
24
|
+
/* Keep the serial instantiations below actually scalar, regardless of build type.
|
|
25
|
+
* See dots/serial.h for rationale. */
|
|
26
|
+
#if defined(__clang__)
|
|
27
|
+
#pragma clang attribute push(__attribute__((noinline)), apply_to = function)
|
|
28
|
+
#elif defined(__GNUC__)
|
|
29
|
+
#pragma GCC push_options
|
|
30
|
+
#pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
|
|
31
|
+
#endif
|
|
32
|
+
|
|
25
33
|
NK_INTERNAL nk_f64_t nk_reduce_sum_f64_serial_(nk_f64_t const *values, nk_f64_t const *compensations, int count) {
|
|
26
34
|
nk_f64_t running_sum = 0, accumulated_error = 0;
|
|
27
35
|
for (int i = 0; i < count; i++) {
|
|
@@ -746,6 +754,12 @@ NK_PUBLIC void nk_reduce_minmax_u1_serial( //
|
|
|
746
754
|
*max_value_ptr = max_value, *max_index_ptr = max_idx;
|
|
747
755
|
}
|
|
748
756
|
|
|
757
|
+
#if defined(__clang__)
|
|
758
|
+
#pragma clang attribute pop
|
|
759
|
+
#elif defined(__GNUC__)
|
|
760
|
+
#pragma GCC pop_options
|
|
761
|
+
#endif
|
|
762
|
+
|
|
749
763
|
#if defined(__cplusplus)
|
|
750
764
|
} // extern "C"
|
|
751
765
|
#endif
|
|
@@ -17,7 +17,7 @@ extern "C" {
|
|
|
17
17
|
#endif
|
|
18
18
|
|
|
19
19
|
#define nk_define_sparse_intersect_(input_type) \
|
|
20
|
-
|
|
20
|
+
NK_INTERNAL nk_size_t nk_sparse_intersect_##input_type##_galloping_search_( \
|
|
21
21
|
nk_##input_type##_t const *array, nk_size_t start, nk_size_t length, nk_##input_type##_t val) { \
|
|
22
22
|
nk_size_t low = start; \
|
|
23
23
|
nk_size_t high = start + 1; \
|
|
@@ -32,7 +32,7 @@ extern "C" {
|
|
|
32
32
|
} \
|
|
33
33
|
return low; \
|
|
34
34
|
} \
|
|
35
|
-
|
|
35
|
+
NK_INTERNAL nk_size_t nk_sparse_intersect_##input_type##_linear_scan_( \
|
|
36
36
|
nk_##input_type##_t const *a, nk_##input_type##_t const *b, nk_size_t a_length, nk_size_t b_length, \
|
|
37
37
|
nk_##input_type##_t *result) { \
|
|
38
38
|
nk_size_t intersection_size = 0; \
|
|
@@ -103,6 +103,15 @@ extern "C" {
|
|
|
103
103
|
*product = weights_product; \
|
|
104
104
|
}
|
|
105
105
|
|
|
106
|
+
/* Keep the serial instantiations below actually scalar, regardless of build type.
|
|
107
|
+
* See dots/serial.h for rationale. */
|
|
108
|
+
#if defined(__clang__)
|
|
109
|
+
#pragma clang attribute push(__attribute__((noinline)), apply_to = function)
|
|
110
|
+
#elif defined(__GNUC__)
|
|
111
|
+
#pragma GCC push_options
|
|
112
|
+
#pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
|
|
113
|
+
#endif
|
|
114
|
+
|
|
106
115
|
nk_define_sparse_intersect_(u16) // nk_sparse_intersect_u16_serial
|
|
107
116
|
nk_define_sparse_intersect_(u32) // nk_sparse_intersect_u32_serial
|
|
108
117
|
nk_define_sparse_intersect_(u64) // nk_sparse_intersect_u64_serial
|
|
@@ -110,6 +119,12 @@ nk_define_sparse_intersect_(u64) // nk_sparse_intersect_u64_serial
|
|
|
110
119
|
nk_define_sparse_dot_(u16, bf16, f32, nk_bf16_to_f32_serial) // nk_sparse_dot_u16bf16_serial
|
|
111
120
|
nk_define_sparse_dot_(u32, f32, f64, nk_assign_from_to_) // nk_sparse_dot_u32f32_serial
|
|
112
121
|
|
|
122
|
+
#if defined(__clang__)
|
|
123
|
+
#pragma clang attribute pop
|
|
124
|
+
#elif defined(__GNUC__)
|
|
125
|
+
#pragma GCC pop_options
|
|
126
|
+
#endif
|
|
127
|
+
|
|
113
128
|
#if defined(__cplusplus)
|
|
114
129
|
} // extern "C"
|
|
115
130
|
#endif
|
|
@@ -139,74 +139,6 @@ nk_angular_bf16_genoa_cycle:
|
|
|
139
139
|
*result = nk_angular_normalize_f32_haswell_(dot_product_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
140
140
|
}
|
|
141
141
|
|
|
142
|
-
NK_PUBLIC void nk_sqeuclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
143
|
-
__m512 a_sq_f32x16 = _mm512_setzero_ps();
|
|
144
|
-
__m512 b_sq_f32x16 = _mm512_setzero_ps();
|
|
145
|
-
__m512 ab_f32x16 = _mm512_setzero_ps();
|
|
146
|
-
__m256i a_e5m2x32, b_e5m2x32;
|
|
147
|
-
|
|
148
|
-
nk_sqeuclidean_e5m2_genoa_cycle:
|
|
149
|
-
if (n < 32) {
|
|
150
|
-
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
151
|
-
a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
152
|
-
b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
153
|
-
n = 0;
|
|
154
|
-
}
|
|
155
|
-
else {
|
|
156
|
-
a_e5m2x32 = _mm256_loadu_epi8(a);
|
|
157
|
-
b_e5m2x32 = _mm256_loadu_epi8(b);
|
|
158
|
-
a += 32, b += 32, n -= 32;
|
|
159
|
-
}
|
|
160
|
-
__m512i a_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(a_e5m2x32);
|
|
161
|
-
__m512i b_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(b_e5m2x32);
|
|
162
|
-
a_sq_f32x16 = _mm512_dpbf16_ps(a_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(a_bf16x32));
|
|
163
|
-
b_sq_f32x16 = _mm512_dpbf16_ps(b_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
|
|
164
|
-
ab_f32x16 = _mm512_dpbf16_ps(ab_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
|
|
165
|
-
if (n) goto nk_sqeuclidean_e5m2_genoa_cycle;
|
|
166
|
-
|
|
167
|
-
// (a-b)² = a² + b² - 2ab
|
|
168
|
-
__m512 sum_sq_f32x16 = _mm512_add_ps(a_sq_f32x16, b_sq_f32x16);
|
|
169
|
-
*result = nk_reduce_add_f32x16_skylake_(_mm512_fnmadd_ps(_mm512_set1_ps(2.0f), ab_f32x16, sum_sq_f32x16));
|
|
170
|
-
}
|
|
171
|
-
|
|
172
|
-
NK_PUBLIC void nk_euclidean_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
173
|
-
nk_sqeuclidean_e5m2_genoa(a, b, n, result);
|
|
174
|
-
*result = nk_f32_sqrt_haswell(*result);
|
|
175
|
-
}
|
|
176
|
-
|
|
177
|
-
NK_PUBLIC void nk_angular_e5m2_genoa(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
178
|
-
__m512 dot_f32x16 = _mm512_setzero_ps();
|
|
179
|
-
__m512 a_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
180
|
-
__m512 b_norm_sq_f32x16 = _mm512_setzero_ps();
|
|
181
|
-
__m256i a_e5m2x32, b_e5m2x32;
|
|
182
|
-
|
|
183
|
-
nk_angular_e5m2_genoa_cycle:
|
|
184
|
-
if (n < 32) {
|
|
185
|
-
__mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n);
|
|
186
|
-
a_e5m2x32 = _mm256_maskz_loadu_epi8(mask, a);
|
|
187
|
-
b_e5m2x32 = _mm256_maskz_loadu_epi8(mask, b);
|
|
188
|
-
n = 0;
|
|
189
|
-
}
|
|
190
|
-
else {
|
|
191
|
-
a_e5m2x32 = _mm256_loadu_epi8(a);
|
|
192
|
-
b_e5m2x32 = _mm256_loadu_epi8(b);
|
|
193
|
-
a += 32, b += 32, n -= 32;
|
|
194
|
-
}
|
|
195
|
-
__m512i a_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(a_e5m2x32);
|
|
196
|
-
__m512i b_bf16x32 = nk_e5m2x32_to_bf16x32_icelake_(b_e5m2x32);
|
|
197
|
-
dot_f32x16 = _mm512_dpbf16_ps(dot_f32x16, nk_m512bh_from_m512i_(a_bf16x32), nk_m512bh_from_m512i_(b_bf16x32));
|
|
198
|
-
a_norm_sq_f32x16 = _mm512_dpbf16_ps(a_norm_sq_f32x16, nk_m512bh_from_m512i_(a_bf16x32),
|
|
199
|
-
nk_m512bh_from_m512i_(a_bf16x32));
|
|
200
|
-
b_norm_sq_f32x16 = _mm512_dpbf16_ps(b_norm_sq_f32x16, nk_m512bh_from_m512i_(b_bf16x32),
|
|
201
|
-
nk_m512bh_from_m512i_(b_bf16x32));
|
|
202
|
-
if (n) goto nk_angular_e5m2_genoa_cycle;
|
|
203
|
-
|
|
204
|
-
nk_f32_t dot_f32 = nk_reduce_add_f32x16_skylake_(dot_f32x16);
|
|
205
|
-
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(a_norm_sq_f32x16);
|
|
206
|
-
nk_f32_t b_norm_sq_f32 = nk_reduce_add_f32x16_skylake_(b_norm_sq_f32x16);
|
|
207
|
-
*result = nk_angular_normalize_f32_haswell_(dot_f32, a_norm_sq_f32, b_norm_sq_f32);
|
|
208
|
-
}
|
|
209
|
-
|
|
210
142
|
#if defined(__clang__)
|
|
211
143
|
#pragma clang attribute pop
|
|
212
144
|
#elif defined(__GNUC__)
|
|
@@ -840,28 +840,37 @@ nk_angular_e3m2_haswell_cycle:
|
|
|
840
840
|
}
|
|
841
841
|
|
|
842
842
|
NK_PUBLIC void nk_sqeuclidean_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
843
|
-
|
|
843
|
+
// E4M3 has no free widen shift, so we call the Giesen-based 8-lane cast helper
|
|
844
|
+
// twice per 16-lane iter and run with two F32 accumulators to break the FMA chain.
|
|
845
|
+
__m256 first_acc_f32x8 = _mm256_setzero_ps();
|
|
846
|
+
__m256 second_acc_f32x8 = _mm256_setzero_ps();
|
|
847
|
+
__m128i a_u8x16, b_u8x16;
|
|
844
848
|
|
|
845
849
|
nk_sqeuclidean_e4m3_haswell_cycle:
|
|
846
|
-
if (n <
|
|
850
|
+
if (n < 16) {
|
|
847
851
|
nk_b128_vec_t a_vec, b_vec;
|
|
848
852
|
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
849
853
|
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
854
|
+
a_u8x16 = a_vec.xmm;
|
|
855
|
+
b_u8x16 = b_vec.xmm;
|
|
856
|
+
n = 0;
|
|
854
857
|
}
|
|
855
858
|
else {
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
860
|
-
n -= 8, a += 8, b += 8;
|
|
861
|
-
goto nk_sqeuclidean_e4m3_haswell_cycle;
|
|
859
|
+
a_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
860
|
+
b_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
861
|
+
a += 16, b += 16, n -= 16;
|
|
862
862
|
}
|
|
863
|
+
__m256 a_low_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_u8x16);
|
|
864
|
+
__m256 a_high_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(a_u8x16, a_u8x16));
|
|
865
|
+
__m256 b_low_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_u8x16);
|
|
866
|
+
__m256 b_high_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(b_u8x16, b_u8x16));
|
|
867
|
+
__m256 diff_low_f32x8 = _mm256_sub_ps(a_low_f32x8, b_low_f32x8);
|
|
868
|
+
__m256 diff_high_f32x8 = _mm256_sub_ps(a_high_f32x8, b_high_f32x8);
|
|
869
|
+
first_acc_f32x8 = _mm256_fmadd_ps(diff_low_f32x8, diff_low_f32x8, first_acc_f32x8);
|
|
870
|
+
second_acc_f32x8 = _mm256_fmadd_ps(diff_high_f32x8, diff_high_f32x8, second_acc_f32x8);
|
|
871
|
+
if (n) goto nk_sqeuclidean_e4m3_haswell_cycle;
|
|
863
872
|
|
|
864
|
-
*result = nk_reduce_add_f32x8_haswell_(
|
|
873
|
+
*result = nk_reduce_add_f32x8_haswell_(_mm256_add_ps(first_acc_f32x8, second_acc_f32x8));
|
|
865
874
|
}
|
|
866
875
|
|
|
867
876
|
NK_PUBLIC void nk_euclidean_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -873,27 +882,33 @@ NK_PUBLIC void nk_angular_e4m3_haswell(nk_e4m3_t const *a, nk_e4m3_t const *b, n
|
|
|
873
882
|
__m256 dot_product_f32x8 = _mm256_setzero_ps();
|
|
874
883
|
__m256 a_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
875
884
|
__m256 b_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
885
|
+
__m128i a_u8x16, b_u8x16;
|
|
876
886
|
|
|
877
887
|
nk_angular_e4m3_haswell_cycle:
|
|
878
|
-
if (n <
|
|
888
|
+
if (n < 16) {
|
|
879
889
|
nk_b128_vec_t a_vec, b_vec;
|
|
880
890
|
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
881
891
|
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
886
|
-
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
892
|
+
a_u8x16 = a_vec.xmm;
|
|
893
|
+
b_u8x16 = b_vec.xmm;
|
|
894
|
+
n = 0;
|
|
887
895
|
}
|
|
888
896
|
else {
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
+
a_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
898
|
+
b_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
899
|
+
a += 16, b += 16, n -= 16;
|
|
900
|
+
}
|
|
901
|
+
__m256 a_low_f32x8 = nk_e4m3x8_to_f32x8_haswell_(a_u8x16);
|
|
902
|
+
__m256 a_high_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(a_u8x16, a_u8x16));
|
|
903
|
+
__m256 b_low_f32x8 = nk_e4m3x8_to_f32x8_haswell_(b_u8x16);
|
|
904
|
+
__m256 b_high_f32x8 = nk_e4m3x8_to_f32x8_haswell_(_mm_unpackhi_epi64(b_u8x16, b_u8x16));
|
|
905
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_low_f32x8, b_low_f32x8, dot_product_f32x8);
|
|
906
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_high_f32x8, b_high_f32x8, dot_product_f32x8);
|
|
907
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_low_f32x8, a_low_f32x8, a_norm_sq_f32x8);
|
|
908
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_high_f32x8, a_high_f32x8, a_norm_sq_f32x8);
|
|
909
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_low_f32x8, b_low_f32x8, b_norm_sq_f32x8);
|
|
910
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_high_f32x8, b_high_f32x8, b_norm_sq_f32x8);
|
|
911
|
+
if (n) goto nk_angular_e4m3_haswell_cycle;
|
|
897
912
|
|
|
898
913
|
nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
|
|
899
914
|
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(a_norm_sq_f32x8);
|
|
@@ -902,28 +917,44 @@ nk_angular_e4m3_haswell_cycle:
|
|
|
902
917
|
}
|
|
903
918
|
|
|
904
919
|
NK_PUBLIC void nk_sqeuclidean_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
905
|
-
|
|
920
|
+
// E5M2 shares F16's exponent bias (15): `byte << 8` equals the matching F16 encoding.
|
|
921
|
+
// `vpunpck*bw` against zero is the free widen+shift: zero byte in low half of each
|
|
922
|
+
// 16-bit lane, E5M2 byte in high half. Per-128-bit-lane scrambled; commutative sum
|
|
923
|
+
// reduction is invariant under that.
|
|
924
|
+
__m256 first_acc_f32x8 = _mm256_setzero_ps();
|
|
925
|
+
__m256 second_acc_f32x8 = _mm256_setzero_ps();
|
|
926
|
+
__m128i const zero_u8x16 = _mm_setzero_si128();
|
|
927
|
+
__m128i a_u8x16, b_u8x16;
|
|
906
928
|
|
|
907
929
|
nk_sqeuclidean_e5m2_haswell_cycle:
|
|
908
|
-
if (n <
|
|
930
|
+
if (n < 16) {
|
|
909
931
|
nk_b128_vec_t a_vec, b_vec;
|
|
910
932
|
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
911
933
|
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
distance_sq_f32x8 = _mm256_fmadd_ps(diff_f32x8, diff_f32x8, distance_sq_f32x8);
|
|
934
|
+
a_u8x16 = a_vec.xmm;
|
|
935
|
+
b_u8x16 = b_vec.xmm;
|
|
936
|
+
n = 0;
|
|
916
937
|
}
|
|
917
938
|
else {
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
939
|
+
a_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
940
|
+
b_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
941
|
+
a += 16, b += 16, n -= 16;
|
|
942
|
+
}
|
|
943
|
+
__m128i a_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, a_u8x16);
|
|
944
|
+
__m128i a_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, a_u8x16);
|
|
945
|
+
__m128i b_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, b_u8x16);
|
|
946
|
+
__m128i b_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, b_u8x16);
|
|
947
|
+
__m256 a_first_f32x8 = _mm256_cvtph_ps(a_even_f16x8);
|
|
948
|
+
__m256 a_second_f32x8 = _mm256_cvtph_ps(a_odd_f16x8);
|
|
949
|
+
__m256 b_first_f32x8 = _mm256_cvtph_ps(b_even_f16x8);
|
|
950
|
+
__m256 b_second_f32x8 = _mm256_cvtph_ps(b_odd_f16x8);
|
|
951
|
+
__m256 diff_first_f32x8 = _mm256_sub_ps(a_first_f32x8, b_first_f32x8);
|
|
952
|
+
__m256 diff_second_f32x8 = _mm256_sub_ps(a_second_f32x8, b_second_f32x8);
|
|
953
|
+
first_acc_f32x8 = _mm256_fmadd_ps(diff_first_f32x8, diff_first_f32x8, first_acc_f32x8);
|
|
954
|
+
second_acc_f32x8 = _mm256_fmadd_ps(diff_second_f32x8, diff_second_f32x8, second_acc_f32x8);
|
|
955
|
+
if (n) goto nk_sqeuclidean_e5m2_haswell_cycle;
|
|
956
|
+
|
|
957
|
+
*result = nk_reduce_add_f32x8_haswell_(_mm256_add_ps(first_acc_f32x8, second_acc_f32x8));
|
|
927
958
|
}
|
|
928
959
|
|
|
929
960
|
NK_PUBLIC void nk_euclidean_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, nk_size_t n, nk_f32_t *result) {
|
|
@@ -935,27 +966,38 @@ NK_PUBLIC void nk_angular_e5m2_haswell(nk_e5m2_t const *a, nk_e5m2_t const *b, n
|
|
|
935
966
|
__m256 dot_product_f32x8 = _mm256_setzero_ps();
|
|
936
967
|
__m256 a_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
937
968
|
__m256 b_norm_sq_f32x8 = _mm256_setzero_ps();
|
|
969
|
+
__m128i const zero_u8x16 = _mm_setzero_si128();
|
|
970
|
+
__m128i a_u8x16, b_u8x16;
|
|
938
971
|
|
|
939
972
|
nk_angular_e5m2_haswell_cycle:
|
|
940
|
-
if (n <
|
|
973
|
+
if (n < 16) {
|
|
941
974
|
nk_b128_vec_t a_vec, b_vec;
|
|
942
975
|
nk_partial_load_b8x16_serial_(a, &a_vec, n);
|
|
943
976
|
nk_partial_load_b8x16_serial_(b, &b_vec, n);
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_f32x8, a_f32x8, a_norm_sq_f32x8);
|
|
948
|
-
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_f32x8, b_f32x8, b_norm_sq_f32x8);
|
|
977
|
+
a_u8x16 = a_vec.xmm;
|
|
978
|
+
b_u8x16 = b_vec.xmm;
|
|
979
|
+
n = 0;
|
|
949
980
|
}
|
|
950
981
|
else {
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
982
|
+
a_u8x16 = _mm_loadu_si128((__m128i const *)a);
|
|
983
|
+
b_u8x16 = _mm_loadu_si128((__m128i const *)b);
|
|
984
|
+
a += 16, b += 16, n -= 16;
|
|
985
|
+
}
|
|
986
|
+
__m128i a_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, a_u8x16);
|
|
987
|
+
__m128i a_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, a_u8x16);
|
|
988
|
+
__m128i b_even_f16x8 = _mm_unpacklo_epi8(zero_u8x16, b_u8x16);
|
|
989
|
+
__m128i b_odd_f16x8 = _mm_unpackhi_epi8(zero_u8x16, b_u8x16);
|
|
990
|
+
__m256 a_first_f32x8 = _mm256_cvtph_ps(a_even_f16x8);
|
|
991
|
+
__m256 a_second_f32x8 = _mm256_cvtph_ps(a_odd_f16x8);
|
|
992
|
+
__m256 b_first_f32x8 = _mm256_cvtph_ps(b_even_f16x8);
|
|
993
|
+
__m256 b_second_f32x8 = _mm256_cvtph_ps(b_odd_f16x8);
|
|
994
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_first_f32x8, b_first_f32x8, dot_product_f32x8);
|
|
995
|
+
dot_product_f32x8 = _mm256_fmadd_ps(a_second_f32x8, b_second_f32x8, dot_product_f32x8);
|
|
996
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_first_f32x8, a_first_f32x8, a_norm_sq_f32x8);
|
|
997
|
+
a_norm_sq_f32x8 = _mm256_fmadd_ps(a_second_f32x8, a_second_f32x8, a_norm_sq_f32x8);
|
|
998
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_first_f32x8, b_first_f32x8, b_norm_sq_f32x8);
|
|
999
|
+
b_norm_sq_f32x8 = _mm256_fmadd_ps(b_second_f32x8, b_second_f32x8, b_norm_sq_f32x8);
|
|
1000
|
+
if (n) goto nk_angular_e5m2_haswell_cycle;
|
|
959
1001
|
|
|
960
1002
|
nk_f32_t dot_product_f32 = nk_reduce_add_f32x8_haswell_(dot_product_f32x8);
|
|
961
1003
|
nk_f32_t a_norm_sq_f32 = nk_reduce_add_f32x8_haswell_(a_norm_sq_f32x8);
|
|
@@ -108,6 +108,15 @@ extern "C" {
|
|
|
108
108
|
} \
|
|
109
109
|
}
|
|
110
110
|
|
|
111
|
+
/* Keep the serial instantiations below actually scalar, regardless of build type.
|
|
112
|
+
* See dots/serial.h for rationale. */
|
|
113
|
+
#if defined(__clang__)
|
|
114
|
+
#pragma clang attribute push(__attribute__((noinline)), apply_to = function)
|
|
115
|
+
#elif defined(__GNUC__)
|
|
116
|
+
#pragma GCC push_options
|
|
117
|
+
#pragma GCC optimize("no-tree-vectorize", "no-tree-slp-vectorize", "no-ipa-cp-clone", "no-inline")
|
|
118
|
+
#endif
|
|
119
|
+
|
|
111
120
|
nk_define_angular_(f64, f64, f64, nk_assign_from_to_, nk_f64_rsqrt_serial) // nk_angular_f64_serial
|
|
112
121
|
nk_define_sqeuclidean_(f64, f64, f64, nk_assign_from_to_) // nk_sqeuclidean_f64_serial
|
|
113
122
|
nk_define_euclidean_(f64, f64, f64, f64, nk_assign_from_to_, nk_f64_sqrt_serial) // nk_euclidean_f64_serial
|
|
@@ -340,6 +349,12 @@ NK_INTERNAL void nk_euclidean_through_u32_from_dot_serial_(nk_b128_vec_t dots, n
|
|
|
340
349
|
}
|
|
341
350
|
}
|
|
342
351
|
|
|
352
|
+
#if defined(__clang__)
|
|
353
|
+
#pragma clang attribute pop
|
|
354
|
+
#elif defined(__GNUC__)
|
|
355
|
+
#pragma GCC pop_options
|
|
356
|
+
#endif
|
|
357
|
+
|
|
343
358
|
#if defined(__cplusplus)
|
|
344
359
|
} // extern "C"
|
|
345
360
|
#endif
|