numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,633 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief Reduction kernels: reduce_moments (sum + sum-of-squares), reduce_minmax (min + max with indices).
|
|
3
|
+
* @file include/numkong/reduce.hpp
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date February 5, 2026
|
|
6
|
+
*/
|
|
7
|
+
#ifndef NK_REDUCE_HPP
|
|
8
|
+
#define NK_REDUCE_HPP
|
|
9
|
+
|
|
10
|
+
#include <cstddef> // `std::byte`, `std::size_t`
|
|
11
|
+
#include <cstdint> // `std::uint32_t`
|
|
12
|
+
#include <memory> // `std::allocator_traits`
|
|
13
|
+
#include <type_traits> // `std::is_same_v`
|
|
14
|
+
|
|
15
|
+
#include "numkong/reduce.h"
|
|
16
|
+
|
|
17
|
+
#include "numkong/types.hpp"
|
|
18
|
+
#include "numkong/vector.hpp"
|
|
19
|
+
|
|
20
|
+
namespace ashvardanian::numkong {
|
|
21
|
+
|
|
22
|
+
/**
|
|
23
|
+
* @brief Compute sum and sum-of-squares in a single pass: sum = Sigma data_i, sumsq = Sigma data_i^2
|
|
24
|
+
* @param[in] data Input array
|
|
25
|
+
* @param[in] count Number of elements
|
|
26
|
+
* @param[in] stride_bytes Stride between elements in bytes (use sizeof(in_type_) for contiguous)
|
|
27
|
+
* @param[out] sum Output sum
|
|
28
|
+
* @param[out] sumsq Output sum of squares
|
|
29
|
+
*
|
|
30
|
+
* @tparam in_type_ Input vector element type
|
|
31
|
+
* @tparam sum_type_ Sum accumulator type, defaults to `in_type_::reduce_moments_sum_t` (often widened)
|
|
32
|
+
* @tparam sumsq_type_ Sum-of-squares accumulator type, defaults to `sum_type_`
|
|
33
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
34
|
+
*/
|
|
35
|
+
template <numeric_dtype in_type_, numeric_dtype sum_type_ = typename in_type_::reduce_moments_sum_t,
|
|
36
|
+
numeric_dtype sumsq_type_ = typename in_type_::reduce_moments_sumsq_t,
|
|
37
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
38
|
+
void reduce_moments(in_type_ const *data, std::size_t count, std::size_t stride_bytes, sum_type_ *sum,
|
|
39
|
+
sumsq_type_ *sumsq) noexcept {
|
|
40
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k &&
|
|
41
|
+
std::is_same_v<sum_type_, typename in_type_::reduce_moments_sum_t> &&
|
|
42
|
+
std::is_same_v<sumsq_type_, typename in_type_::reduce_moments_sumsq_t>;
|
|
43
|
+
|
|
44
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd)
|
|
45
|
+
nk_reduce_moments_f64(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
46
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd)
|
|
47
|
+
nk_reduce_moments_f32(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
48
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd)
|
|
49
|
+
nk_reduce_moments_f16(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
50
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
|
|
51
|
+
nk_reduce_moments_bf16(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
52
|
+
else if constexpr (std::is_same_v<in_type_, e4m3_t> && simd)
|
|
53
|
+
nk_reduce_moments_e4m3(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
54
|
+
else if constexpr (std::is_same_v<in_type_, e5m2_t> && simd)
|
|
55
|
+
nk_reduce_moments_e5m2(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
56
|
+
else if constexpr (std::is_same_v<in_type_, e2m3_t> && simd)
|
|
57
|
+
nk_reduce_moments_e2m3(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
58
|
+
else if constexpr (std::is_same_v<in_type_, e3m2_t> && simd)
|
|
59
|
+
nk_reduce_moments_e3m2(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
60
|
+
else if constexpr (std::is_same_v<in_type_, i4x2_t> && simd)
|
|
61
|
+
nk_reduce_moments_i4(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
62
|
+
else if constexpr (std::is_same_v<in_type_, u4x2_t> && simd)
|
|
63
|
+
nk_reduce_moments_u4(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
64
|
+
else if constexpr (std::is_same_v<in_type_, u1x8_t> && simd)
|
|
65
|
+
nk_reduce_moments_u1(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
66
|
+
else if constexpr (std::is_same_v<in_type_, i8_t> && simd)
|
|
67
|
+
nk_reduce_moments_i8(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
68
|
+
else if constexpr (std::is_same_v<in_type_, u8_t> && simd)
|
|
69
|
+
nk_reduce_moments_u8(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
70
|
+
else if constexpr (std::is_same_v<in_type_, i16_t> && simd)
|
|
71
|
+
nk_reduce_moments_i16(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
72
|
+
else if constexpr (std::is_same_v<in_type_, u16_t> && simd)
|
|
73
|
+
nk_reduce_moments_u16(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
74
|
+
else if constexpr (std::is_same_v<in_type_, i32_t> && simd)
|
|
75
|
+
nk_reduce_moments_i32(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
76
|
+
else if constexpr (std::is_same_v<in_type_, u32_t> && simd)
|
|
77
|
+
nk_reduce_moments_u32(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
78
|
+
else if constexpr (std::is_same_v<in_type_, i64_t> && simd)
|
|
79
|
+
nk_reduce_moments_i64(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
80
|
+
else if constexpr (std::is_same_v<in_type_, u64_t> && simd)
|
|
81
|
+
nk_reduce_moments_u64(&data->raw_, count, stride_bytes, &sum->raw_, &sumsq->raw_);
|
|
82
|
+
// Scalar fallback
|
|
83
|
+
else {
|
|
84
|
+
sum_type_ running_sum {};
|
|
85
|
+
sumsq_type_ running_sumsq {};
|
|
86
|
+
vector_view<in_type_> values(reinterpret_cast<char const *>(data), count, stride_bytes);
|
|
87
|
+
for (std::size_t i = 0; i < count; ++i) {
|
|
88
|
+
auto val = values[i];
|
|
89
|
+
running_sum = saturating_add(running_sum, val);
|
|
90
|
+
running_sumsq = saturating_fma(val, val, running_sumsq);
|
|
91
|
+
}
|
|
92
|
+
*sum = running_sum;
|
|
93
|
+
*sumsq = running_sumsq;
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
/**
|
|
98
|
+
* @brief Find minimum and maximum elements with their indices in a single pass.
|
|
99
|
+
* @param[in] data Input array
|
|
100
|
+
* @param[in] count Number of elements
|
|
101
|
+
* @param[in] stride_bytes Stride between elements in bytes (use sizeof(in_type_) for contiguous)
|
|
102
|
+
* @param[out] min_value Output minimum value
|
|
103
|
+
* @param[out] min_index Output index of minimum value
|
|
104
|
+
* @param[out] max_value Output maximum value
|
|
105
|
+
* @param[out] max_index Output index of maximum value
|
|
106
|
+
*
|
|
107
|
+
* @tparam in_type_ Input vector element type
|
|
108
|
+
* @tparam minmax_type_ Result type for min/max values, defaults to `in_type_::reduce_minmax_value_t`
|
|
109
|
+
* @tparam allow_simd_ Enable SIMD kernel dispatch when `prefer_simd_k`
|
|
110
|
+
*/
|
|
111
|
+
template <numeric_dtype in_type_, numeric_dtype minmax_type_ = typename in_type_::reduce_minmax_value_t,
|
|
112
|
+
allow_simd_t allow_simd_ = prefer_simd_k>
|
|
113
|
+
void reduce_minmax(in_type_ const *data, std::size_t count, std::size_t stride_bytes, minmax_type_ *min_value,
|
|
114
|
+
std::size_t *min_index, minmax_type_ *max_value, std::size_t *max_index) noexcept {
|
|
115
|
+
constexpr bool simd = allow_simd_ == prefer_simd_k &&
|
|
116
|
+
std::is_same_v<minmax_type_, typename in_type_::reduce_minmax_value_t>;
|
|
117
|
+
static_assert(sizeof(std::size_t) == sizeof(nk_size_t), "size_t and nk_size_t must have the same width");
|
|
118
|
+
nk_size_t min_offset = 0, max_offset = 0;
|
|
119
|
+
|
|
120
|
+
// For types where minmax_type_ matches the C function output type directly,
|
|
121
|
+
// dispatch to the C kernel and pass raw pointers through.
|
|
122
|
+
if constexpr (std::is_same_v<in_type_, f64_t> && simd)
|
|
123
|
+
nk_reduce_minmax_f64(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
124
|
+
&max_offset);
|
|
125
|
+
else if constexpr (std::is_same_v<in_type_, f32_t> && simd)
|
|
126
|
+
nk_reduce_minmax_f32(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
127
|
+
&max_offset);
|
|
128
|
+
else if constexpr (std::is_same_v<in_type_, i8_t> && simd)
|
|
129
|
+
nk_reduce_minmax_i8(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
130
|
+
&max_offset);
|
|
131
|
+
else if constexpr (std::is_same_v<in_type_, u8_t> && simd)
|
|
132
|
+
nk_reduce_minmax_u8(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
133
|
+
&max_offset);
|
|
134
|
+
else if constexpr (std::is_same_v<in_type_, i16_t> && simd)
|
|
135
|
+
nk_reduce_minmax_i16(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
136
|
+
&max_offset);
|
|
137
|
+
else if constexpr (std::is_same_v<in_type_, u16_t> && simd)
|
|
138
|
+
nk_reduce_minmax_u16(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
139
|
+
&max_offset);
|
|
140
|
+
else if constexpr (std::is_same_v<in_type_, i32_t> && simd)
|
|
141
|
+
nk_reduce_minmax_i32(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
142
|
+
&max_offset);
|
|
143
|
+
else if constexpr (std::is_same_v<in_type_, u32_t> && simd)
|
|
144
|
+
nk_reduce_minmax_u32(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
145
|
+
&max_offset);
|
|
146
|
+
else if constexpr (std::is_same_v<in_type_, i64_t> && simd)
|
|
147
|
+
nk_reduce_minmax_i64(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
148
|
+
&max_offset);
|
|
149
|
+
else if constexpr (std::is_same_v<in_type_, u64_t> && simd)
|
|
150
|
+
nk_reduce_minmax_u64(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
151
|
+
&max_offset);
|
|
152
|
+
else if constexpr (std::is_same_v<in_type_, e2m3_t> && simd)
|
|
153
|
+
nk_reduce_minmax_e2m3(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
154
|
+
&max_offset);
|
|
155
|
+
else if constexpr (std::is_same_v<in_type_, e3m2_t> && simd)
|
|
156
|
+
nk_reduce_minmax_e3m2(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
157
|
+
&max_offset);
|
|
158
|
+
else if constexpr (std::is_same_v<in_type_, f16_t> && simd)
|
|
159
|
+
nk_reduce_minmax_f16(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
160
|
+
&max_offset);
|
|
161
|
+
else if constexpr (std::is_same_v<in_type_, bf16_t> && simd)
|
|
162
|
+
nk_reduce_minmax_bf16(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
163
|
+
&max_offset);
|
|
164
|
+
else if constexpr (std::is_same_v<in_type_, e4m3_t> && simd)
|
|
165
|
+
nk_reduce_minmax_e4m3(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
166
|
+
&max_offset);
|
|
167
|
+
else if constexpr (std::is_same_v<in_type_, e5m2_t> && simd)
|
|
168
|
+
nk_reduce_minmax_e5m2(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
169
|
+
&max_offset);
|
|
170
|
+
else if constexpr (std::is_same_v<in_type_, i4x2_t> && simd)
|
|
171
|
+
nk_reduce_minmax_i4(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
172
|
+
&max_offset);
|
|
173
|
+
else if constexpr (std::is_same_v<in_type_, u4x2_t> && simd)
|
|
174
|
+
nk_reduce_minmax_u4(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
175
|
+
&max_offset);
|
|
176
|
+
else if constexpr (std::is_same_v<in_type_, u1x8_t> && simd)
|
|
177
|
+
nk_reduce_minmax_u1(&data->raw_, count, stride_bytes, &min_value->raw_, &min_offset, &max_value->raw_,
|
|
178
|
+
&max_offset);
|
|
179
|
+
// Scalar fallback
|
|
180
|
+
else {
|
|
181
|
+
minmax_type_ best_min = finite_max<minmax_type_>();
|
|
182
|
+
minmax_type_ best_max = finite_min<minmax_type_>();
|
|
183
|
+
vector_view<in_type_> values(reinterpret_cast<char const *>(data), count, stride_bytes);
|
|
184
|
+
for (nk_size_t i = 0; i < count; ++i) {
|
|
185
|
+
minmax_type_ v = minmax_type_(values[i]);
|
|
186
|
+
if (v < best_min) best_min = v, min_offset = i;
|
|
187
|
+
if (v > best_max) best_max = v, max_offset = i;
|
|
188
|
+
}
|
|
189
|
+
*min_value = best_min, *max_value = best_max;
|
|
190
|
+
}
|
|
191
|
+
if (min_index) *min_index = static_cast<std::size_t>(min_offset);
|
|
192
|
+
if (max_index) *max_index = static_cast<std::size_t>(max_offset);
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
} // namespace ashvardanian::numkong
|
|
196
|
+
|
|
197
|
+
#include "numkong/tensor.hpp"
|
|
198
|
+
|
|
199
|
+
namespace ashvardanian::numkong {
|
|
200
|
+
|
|
201
|
+
#pragma region - Tensor Reduction Helpers
|
|
202
|
+
|
|
203
|
+
template <numeric_dtype value_type_, std::size_t max_rank_>
|
|
204
|
+
bool reduce_rank1_moments_(tensor_view<value_type_, max_rank_> input, typename value_type_::reduce_moments_sum_t &sum,
|
|
205
|
+
typename value_type_::reduce_moments_sumsq_t &sumsq) noexcept {
|
|
206
|
+
using sum_t = typename value_type_::reduce_moments_sum_t;
|
|
207
|
+
using sumsq_t = typename value_type_::reduce_moments_sumsq_t;
|
|
208
|
+
if (input.rank() != 1 || !tensor_layout_supported_(input) || input.byte_data() == nullptr) return false;
|
|
209
|
+
if (can_reduce_rank1_with_kernel_(input)) {
|
|
210
|
+
auto lane = normalize_rank1_lane_(input);
|
|
211
|
+
numkong::reduce_moments<value_type_>(lane.data, lane.count, lane.stride_bytes, &sum, &sumsq);
|
|
212
|
+
return true;
|
|
213
|
+
}
|
|
214
|
+
auto values = input.as_vector();
|
|
215
|
+
sum = sum_t {};
|
|
216
|
+
sumsq = sumsq_t {};
|
|
217
|
+
for (std::size_t i = 0; i < values.size(); ++i) {
|
|
218
|
+
auto value = values[i];
|
|
219
|
+
sum = saturating_add(sum, value);
|
|
220
|
+
sumsq = saturating_fma(value, value, sumsq);
|
|
221
|
+
}
|
|
222
|
+
return true;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
template <numeric_dtype value_type_, std::size_t max_rank_>
|
|
226
|
+
bool reduce_rank1_minmax_(tensor_view<value_type_, max_rank_> input,
|
|
227
|
+
minmax_result<typename value_type_::reduce_minmax_value_t> &result) noexcept {
|
|
228
|
+
using minmax_t = typename value_type_::reduce_minmax_value_t;
|
|
229
|
+
if (input.rank() != 1 || !tensor_layout_supported_(input) || input.byte_data() == nullptr) return false;
|
|
230
|
+
if (can_reduce_rank1_with_kernel_(input)) {
|
|
231
|
+
auto lane = normalize_rank1_lane_(input);
|
|
232
|
+
numkong::reduce_minmax<value_type_>(lane.data, lane.count, lane.stride_bytes, &result.min_value,
|
|
233
|
+
&result.min_index, &result.max_value, &result.max_index);
|
|
234
|
+
if (lane.reversed) {
|
|
235
|
+
result.min_index = lane.count - 1 - result.min_index;
|
|
236
|
+
result.max_index = lane.count - 1 - result.max_index;
|
|
237
|
+
}
|
|
238
|
+
return true;
|
|
239
|
+
}
|
|
240
|
+
auto values = input.as_vector();
|
|
241
|
+
result.min_value = finite_max<minmax_t>();
|
|
242
|
+
result.max_value = finite_min<minmax_t>();
|
|
243
|
+
result.min_index = 0;
|
|
244
|
+
result.max_index = 0;
|
|
245
|
+
for (std::size_t i = 0; i < values.size(); ++i) {
|
|
246
|
+
minmax_t value = minmax_t(values[i]);
|
|
247
|
+
if (value < result.min_value) result.min_value = value, result.min_index = i;
|
|
248
|
+
if (value > result.max_value) result.max_value = value, result.max_index = i;
|
|
249
|
+
}
|
|
250
|
+
return true;
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
template <numeric_dtype value_type_, std::size_t max_rank_>
|
|
254
|
+
bool accumulate_moments_tensor_(tensor_view<value_type_, max_rank_> input,
|
|
255
|
+
tensor_span<typename value_type_::reduce_moments_sum_t, max_rank_> sums,
|
|
256
|
+
tensor_span<typename value_type_::reduce_moments_sumsq_t, max_rank_> sumsqs) noexcept {
|
|
257
|
+
using sum_t = typename value_type_::reduce_moments_sum_t;
|
|
258
|
+
using sumsq_t = typename value_type_::reduce_moments_sumsq_t;
|
|
259
|
+
if (!tensor_layout_supported_(input) || !shapes_match_out_(input, sums) || !shapes_match_out_(input, sumsqs))
|
|
260
|
+
return false;
|
|
261
|
+
if (input.rank() == 1) {
|
|
262
|
+
auto src = input.as_vector();
|
|
263
|
+
auto dst_sum = sums.as_vector();
|
|
264
|
+
auto dst_sumsq = sumsqs.as_vector();
|
|
265
|
+
for (std::size_t i = 0; i < src.size(); ++i) {
|
|
266
|
+
auto value = src[i];
|
|
267
|
+
dst_sum[i] = saturating_add(dst_sum[i], sum_t(value));
|
|
268
|
+
dst_sumsq[i] = saturating_fma(value, value, sumsq_t(dst_sumsq[i]));
|
|
269
|
+
}
|
|
270
|
+
return true;
|
|
271
|
+
}
|
|
272
|
+
for (std::size_t i = 0; i < input.extent(0); ++i) {
|
|
273
|
+
if (!accumulate_moments_tensor_(input.slice_leading(i), sums.slice_leading(i), sumsqs.slice_leading(i)))
|
|
274
|
+
return false;
|
|
275
|
+
}
|
|
276
|
+
return true;
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
template <numeric_dtype value_type_, std::size_t max_rank_>
|
|
280
|
+
bool update_minmax_tensor_(tensor_view<value_type_, max_rank_> input,
|
|
281
|
+
tensor_span<typename value_type_::reduce_minmax_value_t, max_rank_> mins,
|
|
282
|
+
tensor_span<typename value_type_::reduce_minmax_value_t, max_rank_> maxs) noexcept {
|
|
283
|
+
using minmax_t = typename value_type_::reduce_minmax_value_t;
|
|
284
|
+
if (!tensor_layout_supported_(input) || !shapes_match_out_(input, mins) || !shapes_match_out_(input, maxs))
|
|
285
|
+
return false;
|
|
286
|
+
if (input.rank() == 1) {
|
|
287
|
+
auto src = input.as_vector();
|
|
288
|
+
auto dst_min = mins.as_vector();
|
|
289
|
+
auto dst_max = maxs.as_vector();
|
|
290
|
+
for (std::size_t i = 0; i < src.size(); ++i) {
|
|
291
|
+
minmax_t value = minmax_t(src[i]);
|
|
292
|
+
if (value < dst_min[i]) dst_min[i] = value;
|
|
293
|
+
if (value > dst_max[i]) dst_max[i] = value;
|
|
294
|
+
}
|
|
295
|
+
return true;
|
|
296
|
+
}
|
|
297
|
+
for (std::size_t i = 0; i < input.extent(0); ++i) {
|
|
298
|
+
if (!update_minmax_tensor_(input.slice_leading(i), mins.slice_leading(i), maxs.slice_leading(i))) return false;
|
|
299
|
+
}
|
|
300
|
+
return true;
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
template <numeric_dtype value_type_, std::size_t max_rank_>
|
|
304
|
+
bool reduce_moments_axis_(tensor_view<value_type_, max_rank_> input, std::size_t axis,
|
|
305
|
+
typename value_type_::reduce_moments_sum_t *sums,
|
|
306
|
+
typename value_type_::reduce_moments_sumsq_t *sumsqs) noexcept {
|
|
307
|
+
return for_each_axis_lane_(input, axis,
|
|
308
|
+
[&](tensor_view<value_type_, max_rank_> lane, std::size_t output_index) noexcept {
|
|
309
|
+
typename value_type_::reduce_moments_sum_t sum {};
|
|
310
|
+
typename value_type_::reduce_moments_sumsq_t sumsq {};
|
|
311
|
+
if (!reduce_rank1_moments_(lane, sum, sumsq)) return false;
|
|
312
|
+
if (sums) sums[output_index] = sum;
|
|
313
|
+
if (sumsqs) sumsqs[output_index] = sumsq;
|
|
314
|
+
return true;
|
|
315
|
+
});
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
template <numeric_dtype value_type_, std::size_t max_rank_>
|
|
319
|
+
bool reduce_moments_axis_packed_(tensor_view<value_type_, max_rank_> input, std::size_t axis,
|
|
320
|
+
tensor_span<typename value_type_::reduce_moments_sum_t, max_rank_> sums,
|
|
321
|
+
tensor_span<typename value_type_::reduce_moments_sumsq_t, max_rank_> sumsqs,
|
|
322
|
+
keep_dims_t keep_dims) noexcept {
|
|
323
|
+
using sum_t = typename value_type_::reduce_moments_sum_t;
|
|
324
|
+
using sumsq_t = typename value_type_::reduce_moments_sumsq_t;
|
|
325
|
+
if (!tensor_layout_supported_(input) || axis >= input.rank()) return false;
|
|
326
|
+
if (axis == 0) {
|
|
327
|
+
auto sum_target = keep_dims ? sums.slice_leading(0) : sums;
|
|
328
|
+
auto sumsq_target = keep_dims ? sumsqs.slice_leading(0) : sumsqs;
|
|
329
|
+
if (input.rank() == 1) {
|
|
330
|
+
sum_t sum {};
|
|
331
|
+
sumsq_t sumsq {};
|
|
332
|
+
if (!reduce_rank1_moments_(input, sum, sumsq)) return false;
|
|
333
|
+
sum_target.scalar_ref() = sum;
|
|
334
|
+
sumsq_target.scalar_ref() = sumsq;
|
|
335
|
+
return true;
|
|
336
|
+
}
|
|
337
|
+
for (std::size_t i = 0; i < input.extent(0); ++i)
|
|
338
|
+
if (!accumulate_moments_tensor_(input.slice_leading(i), sum_target, sumsq_target)) return false;
|
|
339
|
+
return true;
|
|
340
|
+
}
|
|
341
|
+
if (input.rank() == 1) return false;
|
|
342
|
+
for (std::size_t i = 0; i < input.extent(0); ++i)
|
|
343
|
+
if (!reduce_moments_axis_packed_(input.slice_leading(i), axis - 1, sums.slice_leading(i),
|
|
344
|
+
sumsqs.slice_leading(i), keep_dims))
|
|
345
|
+
return false;
|
|
346
|
+
return true;
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
template <numeric_dtype value_type_, std::size_t max_rank_>
|
|
350
|
+
bool reduce_minmax_axis_(tensor_view<value_type_, max_rank_> input, std::size_t axis,
|
|
351
|
+
typename value_type_::reduce_minmax_value_t *mins, std::size_t *argmins,
|
|
352
|
+
typename value_type_::reduce_minmax_value_t *maxs, std::size_t *argmaxs) noexcept {
|
|
353
|
+
return for_each_axis_lane_(input, axis,
|
|
354
|
+
[&](tensor_view<value_type_, max_rank_> lane, std::size_t output_index) noexcept {
|
|
355
|
+
minmax_result<typename value_type_::reduce_minmax_value_t> result {};
|
|
356
|
+
if (!reduce_rank1_minmax_(lane, result)) return false;
|
|
357
|
+
if (mins) mins[output_index] = result.min_value;
|
|
358
|
+
if (argmins) argmins[output_index] = result.min_index;
|
|
359
|
+
if (maxs) maxs[output_index] = result.max_value;
|
|
360
|
+
if (argmaxs) argmaxs[output_index] = result.max_index;
|
|
361
|
+
return true;
|
|
362
|
+
});
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
template <numeric_dtype value_type_, std::size_t max_rank_>
|
|
366
|
+
bool reduce_minmax_axis_packed_(tensor_view<value_type_, max_rank_> input, std::size_t axis,
|
|
367
|
+
tensor_span<typename value_type_::reduce_minmax_value_t, max_rank_> mins,
|
|
368
|
+
tensor_span<typename value_type_::reduce_minmax_value_t, max_rank_> maxs,
|
|
369
|
+
keep_dims_t keep_dims) noexcept {
|
|
370
|
+
using minmax_t = typename value_type_::reduce_minmax_value_t;
|
|
371
|
+
if (!tensor_layout_supported_(input) || axis >= input.rank()) return false;
|
|
372
|
+
if (axis == 0) {
|
|
373
|
+
auto min_target = keep_dims ? mins.slice_leading(0) : mins;
|
|
374
|
+
auto max_target = keep_dims ? maxs.slice_leading(0) : maxs;
|
|
375
|
+
if (input.rank() == 1) {
|
|
376
|
+
minmax_result<minmax_t> result {};
|
|
377
|
+
if (!reduce_rank1_minmax_(input, result)) return false;
|
|
378
|
+
min_target.scalar_ref() = result.min_value;
|
|
379
|
+
max_target.scalar_ref() = result.max_value;
|
|
380
|
+
return true;
|
|
381
|
+
}
|
|
382
|
+
for (std::size_t i = 0; i < input.extent(0); ++i)
|
|
383
|
+
if (!update_minmax_tensor_(input.slice_leading(i), min_target, max_target)) return false;
|
|
384
|
+
return true;
|
|
385
|
+
}
|
|
386
|
+
if (input.rank() == 1) return false;
|
|
387
|
+
for (std::size_t i = 0; i < input.extent(0); ++i)
|
|
388
|
+
if (!reduce_minmax_axis_packed_(input.slice_leading(i), axis - 1, mins.slice_leading(i), maxs.slice_leading(i),
|
|
389
|
+
keep_dims))
|
|
390
|
+
return false;
|
|
391
|
+
return true;
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
#pragma endregion - Tensor Reduction Helpers
|
|
395
|
+
|
|
396
|
+
#pragma region - Scalar Reductions
|
|
397
|
+
|
|
398
|
+
/** @brief Compute Σxᵢ and Σxᵢ² in a single pass. Returns zeroed result for empty tensors. */
|
|
399
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
|
|
400
|
+
moments_result<typename value_type_::reduce_moments_sum_t, typename value_type_::reduce_moments_sumsq_t> moments(
|
|
401
|
+
tensor_view<value_type_, max_rank_> input) noexcept {
|
|
402
|
+
using sum_t = typename value_type_::reduce_moments_sum_t;
|
|
403
|
+
using sumsq_t = typename value_type_::reduce_moments_sumsq_t;
|
|
404
|
+
moments_result<sum_t, sumsq_t> result {};
|
|
405
|
+
if (input.empty() || input.numel() == 0 || !tensor_layout_supported_(input)) return result;
|
|
406
|
+
if (input.is_contiguous()) {
|
|
407
|
+
numkong::reduce_moments<value_type_>(input.data(), input.numel(), sizeof(value_type_), &result.sum,
|
|
408
|
+
&result.sumsq);
|
|
409
|
+
return result;
|
|
410
|
+
}
|
|
411
|
+
if (input.rank() == 1) {
|
|
412
|
+
reduce_rank1_moments_(input, result.sum, result.sumsq);
|
|
413
|
+
return result;
|
|
414
|
+
}
|
|
415
|
+
for (std::size_t i = 0; i < input.extent(0); ++i) {
|
|
416
|
+
auto slice_result = moments<value_type_, max_rank_>(input.slice_leading(static_cast<std::ptrdiff_t>(i)));
|
|
417
|
+
result.sum = saturating_add(result.sum, slice_result.sum);
|
|
418
|
+
result.sumsq = saturating_add(result.sumsq, slice_result.sumsq);
|
|
419
|
+
}
|
|
420
|
+
return result;
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
/** @brief Find min and max values with their flat indices. */
|
|
424
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
|
|
425
|
+
minmax_result<typename value_type_::reduce_minmax_value_t> minmax(tensor_view<value_type_, max_rank_> input) noexcept {
|
|
426
|
+
using minmax_t = typename value_type_::reduce_minmax_value_t;
|
|
427
|
+
minmax_result<minmax_t> result {};
|
|
428
|
+
if (input.empty() || input.numel() == 0 || !tensor_layout_supported_(input)) return result;
|
|
429
|
+
if (input.is_contiguous()) {
|
|
430
|
+
numkong::reduce_minmax<value_type_>(input.data(), input.numel(), sizeof(value_type_), &result.min_value,
|
|
431
|
+
&result.min_index, &result.max_value, &result.max_index);
|
|
432
|
+
return result;
|
|
433
|
+
}
|
|
434
|
+
if (input.rank() == 1) {
|
|
435
|
+
reduce_rank1_minmax_(input, result);
|
|
436
|
+
return result;
|
|
437
|
+
}
|
|
438
|
+
result.min_value = finite_max<minmax_t>();
|
|
439
|
+
result.max_value = finite_min<minmax_t>();
|
|
440
|
+
std::size_t base = 0;
|
|
441
|
+
for (std::size_t i = 0; i < input.extent(0); ++i) {
|
|
442
|
+
auto slice = input.slice_leading(static_cast<std::ptrdiff_t>(i));
|
|
443
|
+
auto slice_result = minmax<value_type_, max_rank_>(slice);
|
|
444
|
+
if (slice_result.min_value < result.min_value) {
|
|
445
|
+
result.min_value = slice_result.min_value;
|
|
446
|
+
result.min_index = base + slice_result.min_index;
|
|
447
|
+
}
|
|
448
|
+
if (slice_result.max_value > result.max_value) {
|
|
449
|
+
result.max_value = slice_result.max_value;
|
|
450
|
+
result.max_index = base + slice_result.max_index;
|
|
451
|
+
}
|
|
452
|
+
base += slice.numel();
|
|
453
|
+
}
|
|
454
|
+
return result;
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
/** @brief Σ of all elements. */
|
|
458
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
|
|
459
|
+
typename value_type_::reduce_moments_sum_t sum(tensor_view<value_type_, max_rank_> input) noexcept {
|
|
460
|
+
return moments(input).sum;
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
/** @brief Find the minimum element value. */
|
|
464
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
|
|
465
|
+
typename value_type_::reduce_minmax_value_t min(tensor_view<value_type_, max_rank_> input) noexcept {
|
|
466
|
+
return minmax(input).min_value;
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
/** @brief Find the maximum element value. */
|
|
470
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
|
|
471
|
+
typename value_type_::reduce_minmax_value_t max(tensor_view<value_type_, max_rank_> input) noexcept {
|
|
472
|
+
return minmax(input).max_value;
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
/** @brief Index of the minimum element (flat). */
|
|
476
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
|
|
477
|
+
std::size_t argmin(tensor_view<value_type_, max_rank_> input) noexcept {
|
|
478
|
+
return minmax(input).min_index;
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
/** @brief Index of the maximum element (flat). */
|
|
482
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8>
|
|
483
|
+
std::size_t argmax(tensor_view<value_type_, max_rank_> input) noexcept {
|
|
484
|
+
return minmax(input).max_index;
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
#pragma endregion - Scalar Reductions
|
|
488
|
+
|
|
489
|
+
#pragma region - Axis Reductions
|
|
490
|
+
|
|
491
|
+
/** @brief Σ along a single axis. Returns empty tensor on failure. */
|
|
492
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8,
|
|
493
|
+
typename allocator_type_ = aligned_allocator<typename value_type_::reduce_moments_sum_t>>
|
|
494
|
+
tensor<typename value_type_::reduce_moments_sum_t, allocator_type_, max_rank_> try_sum(
|
|
495
|
+
tensor_view<value_type_, max_rank_> input, std::size_t axis, keep_dims_t keep_dims = collapse_dims_k) noexcept {
|
|
496
|
+
using sum_t = typename value_type_::reduce_moments_sum_t;
|
|
497
|
+
using sum_tensor_t = tensor<sum_t, allocator_type_, max_rank_>;
|
|
498
|
+
|
|
499
|
+
if (input.empty() || axis >= input.rank() || !tensor_layout_supported_(input)) return sum_tensor_t {};
|
|
500
|
+
|
|
501
|
+
auto out_shape = reduced_shape_<sum_t>(input.shape(), axis, keep_dims);
|
|
502
|
+
auto sums = sum_tensor_t::try_zeros(out_shape.extents, out_shape.rank);
|
|
503
|
+
if (sums.empty() || !shape_matches_(out_shape, sums.span())) return sum_tensor_t {};
|
|
504
|
+
if constexpr (dimensions_per_value<value_type_>() > 1) {
|
|
505
|
+
using sumsq_t = typename value_type_::reduce_moments_sumsq_t;
|
|
506
|
+
using sumsq_alloc_t = typename std::allocator_traits<allocator_type_>::template rebind_alloc<sumsq_t>;
|
|
507
|
+
using sumsq_tensor_t = tensor<sumsq_t, sumsq_alloc_t, max_rank_>;
|
|
508
|
+
auto scratch = sumsq_tensor_t::try_zeros(out_shape.extents, out_shape.rank);
|
|
509
|
+
if (scratch.empty() || !shape_matches_(reduced_shape_<sumsq_t>(input.shape(), axis, keep_dims), scratch.span()))
|
|
510
|
+
return sum_tensor_t {};
|
|
511
|
+
if (!reduce_moments_axis_packed_(input, axis, sums.span(), scratch.span(), keep_dims)) return sum_tensor_t {};
|
|
512
|
+
}
|
|
513
|
+
else if (!reduce_moments_axis_(input, axis, sums.data(), nullptr)) return sum_tensor_t {};
|
|
514
|
+
return sums;
|
|
515
|
+
}
|
|
516
|
+
|
|
517
|
+
/** @brief Moments along an axis (Σxᵢ and Σxᵢ² per slice). */
|
|
518
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8,
|
|
519
|
+
typename allocator_type_ = aligned_allocator<typename value_type_::reduce_moments_sum_t>>
|
|
520
|
+
moments_result<tensor<typename value_type_::reduce_moments_sum_t, allocator_type_, max_rank_>,
|
|
521
|
+
tensor<typename value_type_::reduce_moments_sumsq_t,
|
|
522
|
+
typename std::allocator_traits<allocator_type_>::template rebind_alloc<
|
|
523
|
+
typename value_type_::reduce_moments_sumsq_t>,
|
|
524
|
+
max_rank_>>
|
|
525
|
+
try_moments(tensor_view<value_type_, max_rank_> input, std::size_t axis,
|
|
526
|
+
keep_dims_t keep_dims = collapse_dims_k) noexcept {
|
|
527
|
+
using sum_t = typename value_type_::reduce_moments_sum_t;
|
|
528
|
+
using sumsq_t = typename value_type_::reduce_moments_sumsq_t;
|
|
529
|
+
using sum_tensor_t = tensor<sum_t, allocator_type_, max_rank_>;
|
|
530
|
+
using sumsq_alloc_t = typename std::allocator_traits<allocator_type_>::template rebind_alloc<sumsq_t>;
|
|
531
|
+
using sumsq_tensor_t = tensor<sumsq_t, sumsq_alloc_t, max_rank_>;
|
|
532
|
+
|
|
533
|
+
if (input.empty() || axis >= input.rank() || !tensor_layout_supported_(input))
|
|
534
|
+
return {sum_tensor_t {}, sumsq_tensor_t {}};
|
|
535
|
+
|
|
536
|
+
auto out_shape_sum = reduced_shape_<sum_t>(input.shape(), axis, keep_dims);
|
|
537
|
+
auto out_shape_sq = reduced_shape_<sumsq_t>(input.shape(), axis, keep_dims);
|
|
538
|
+
|
|
539
|
+
auto sums = sum_tensor_t::try_zeros(out_shape_sum.extents, out_shape_sum.rank);
|
|
540
|
+
auto sumsqs = sumsq_tensor_t::try_zeros(out_shape_sq.extents, out_shape_sq.rank);
|
|
541
|
+
if (sums.empty() || sumsqs.empty() || !shape_matches_(out_shape_sum, sums.span()) ||
|
|
542
|
+
!shape_matches_(out_shape_sq, sumsqs.span()))
|
|
543
|
+
return {sum_tensor_t {}, sumsq_tensor_t {}};
|
|
544
|
+
|
|
545
|
+
if constexpr (dimensions_per_value<value_type_>() > 1) {
|
|
546
|
+
if (!reduce_moments_axis_packed_(input, axis, sums.span(), sumsqs.span(), keep_dims))
|
|
547
|
+
return {sum_tensor_t {}, sumsq_tensor_t {}};
|
|
548
|
+
}
|
|
549
|
+
else if (!reduce_moments_axis_(input, axis, sums.data(), sumsqs.data()))
|
|
550
|
+
return {sum_tensor_t {}, sumsq_tensor_t {}};
|
|
551
|
+
|
|
552
|
+
return {std::move(sums), std::move(sumsqs)};
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
/** @brief Min and max along an axis. */
|
|
556
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8,
|
|
557
|
+
typename allocator_type_ = aligned_allocator<typename value_type_::reduce_minmax_value_t>>
|
|
558
|
+
minmax_result<tensor<typename value_type_::reduce_minmax_value_t, allocator_type_, max_rank_>> try_minmax(
|
|
559
|
+
tensor_view<value_type_, max_rank_> input, std::size_t axis, keep_dims_t keep_dims = collapse_dims_k) noexcept {
|
|
560
|
+
using minmax_t = typename value_type_::reduce_minmax_value_t;
|
|
561
|
+
using out_tensor_t = tensor<minmax_t, allocator_type_, max_rank_>;
|
|
562
|
+
if (input.empty() || axis >= input.rank() || !tensor_layout_supported_(input))
|
|
563
|
+
return {out_tensor_t {}, 0, out_tensor_t {}, 0};
|
|
564
|
+
|
|
565
|
+
auto out_shape = reduced_shape_<minmax_t>(input.shape(), axis, keep_dims);
|
|
566
|
+
auto mins = out_tensor_t::try_full(out_shape.extents, out_shape.rank, finite_max<minmax_t>());
|
|
567
|
+
auto maxs = out_tensor_t::try_full(out_shape.extents, out_shape.rank, finite_min<minmax_t>());
|
|
568
|
+
if (mins.empty() || maxs.empty() || !shape_matches_(out_shape, mins.span()) ||
|
|
569
|
+
!shape_matches_(out_shape, maxs.span()))
|
|
570
|
+
return {out_tensor_t {}, 0, out_tensor_t {}, 0};
|
|
571
|
+
|
|
572
|
+
if constexpr (dimensions_per_value<value_type_>() > 1) {
|
|
573
|
+
if (!reduce_minmax_axis_packed_(input, axis, mins.span(), maxs.span(), keep_dims))
|
|
574
|
+
return {out_tensor_t {}, 0, out_tensor_t {}, 0};
|
|
575
|
+
}
|
|
576
|
+
else if (!reduce_minmax_axis_(input, axis, mins.data(), nullptr, maxs.data(), nullptr))
|
|
577
|
+
return {out_tensor_t {}, 0, out_tensor_t {}, 0};
|
|
578
|
+
return {std::move(mins), 0, std::move(maxs), 0};
|
|
579
|
+
}
|
|
580
|
+
|
|
581
|
+
/** @brief Argmin along an axis. */
|
|
582
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8,
|
|
583
|
+
typename allocator_type_ = aligned_allocator<std::size_t>>
|
|
584
|
+
tensor<std::size_t, allocator_type_, max_rank_> try_argmin(tensor_view<value_type_, max_rank_> input, std::size_t axis,
|
|
585
|
+
keep_dims_t keep_dims = collapse_dims_k) noexcept {
|
|
586
|
+
using out_tensor_t = tensor<std::size_t, allocator_type_, max_rank_>;
|
|
587
|
+
if (input.empty() || axis >= input.rank() || !tensor_layout_supported_(input)) return out_tensor_t {};
|
|
588
|
+
if constexpr (dimensions_per_value<value_type_>() > 1) return out_tensor_t {};
|
|
589
|
+
|
|
590
|
+
auto out_shape = reduced_shape_<std::size_t>(input.shape(), axis, keep_dims);
|
|
591
|
+
auto indices = out_tensor_t::try_zeros(out_shape.extents, out_shape.rank);
|
|
592
|
+
if (indices.empty() || !shape_matches_(out_shape, indices.span())) return out_tensor_t {};
|
|
593
|
+
if (!reduce_minmax_axis_(input, axis, nullptr, indices.data(), nullptr, nullptr)) return out_tensor_t {};
|
|
594
|
+
return indices;
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
/** @brief Argmax along an axis. */
|
|
598
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8,
|
|
599
|
+
typename allocator_type_ = aligned_allocator<std::size_t>>
|
|
600
|
+
tensor<std::size_t, allocator_type_, max_rank_> try_argmax(tensor_view<value_type_, max_rank_> input, std::size_t axis,
|
|
601
|
+
keep_dims_t keep_dims = collapse_dims_k) noexcept {
|
|
602
|
+
using out_tensor_t = tensor<std::size_t, allocator_type_, max_rank_>;
|
|
603
|
+
if (input.empty() || axis >= input.rank() || !tensor_layout_supported_(input)) return out_tensor_t {};
|
|
604
|
+
if constexpr (dimensions_per_value<value_type_>() > 1) return out_tensor_t {};
|
|
605
|
+
|
|
606
|
+
auto out_shape = reduced_shape_<std::size_t>(input.shape(), axis, keep_dims);
|
|
607
|
+
auto indices = out_tensor_t::try_zeros(out_shape.extents, out_shape.rank);
|
|
608
|
+
if (indices.empty() || !shape_matches_(out_shape, indices.span())) return out_tensor_t {};
|
|
609
|
+
if (!reduce_minmax_axis_(input, axis, nullptr, nullptr, nullptr, indices.data())) return out_tensor_t {};
|
|
610
|
+
return indices;
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
/** @brief Min along an axis. */
|
|
614
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8,
|
|
615
|
+
typename allocator_type_ = aligned_allocator<typename value_type_::reduce_minmax_value_t>>
|
|
616
|
+
tensor<typename value_type_::reduce_minmax_value_t, allocator_type_, max_rank_> try_min(
|
|
617
|
+
tensor_view<value_type_, max_rank_> input, std::size_t axis, keep_dims_t keep_dims = collapse_dims_k) noexcept {
|
|
618
|
+
return try_minmax<value_type_, max_rank_, allocator_type_>(input, axis, keep_dims).min_value;
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
/** @brief Max along an axis. */
|
|
622
|
+
template <numeric_dtype value_type_, std::size_t max_rank_ = 8,
|
|
623
|
+
typename allocator_type_ = aligned_allocator<typename value_type_::reduce_minmax_value_t>>
|
|
624
|
+
tensor<typename value_type_::reduce_minmax_value_t, allocator_type_, max_rank_> try_max(
|
|
625
|
+
tensor_view<value_type_, max_rank_> input, std::size_t axis, keep_dims_t keep_dims = collapse_dims_k) noexcept {
|
|
626
|
+
return try_minmax<value_type_, max_rank_, allocator_type_>(input, axis, keep_dims).max_value;
|
|
627
|
+
}
|
|
628
|
+
|
|
629
|
+
#pragma endregion - Axis Reductions
|
|
630
|
+
|
|
631
|
+
} // namespace ashvardanian::numkong
|
|
632
|
+
|
|
633
|
+
#endif // NK_REDUCE_HPP
|