numkong 7.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +201 -0
- package/README.md +495 -0
- package/binding.gyp +540 -0
- package/c/dispatch.h +512 -0
- package/c/dispatch_bf16.c +389 -0
- package/c/dispatch_bf16c.c +52 -0
- package/c/dispatch_e2m3.c +263 -0
- package/c/dispatch_e3m2.c +243 -0
- package/c/dispatch_e4m3.c +276 -0
- package/c/dispatch_e5m2.c +272 -0
- package/c/dispatch_f16.c +376 -0
- package/c/dispatch_f16c.c +58 -0
- package/c/dispatch_f32.c +378 -0
- package/c/dispatch_f32c.c +99 -0
- package/c/dispatch_f64.c +296 -0
- package/c/dispatch_f64c.c +98 -0
- package/c/dispatch_i16.c +96 -0
- package/c/dispatch_i32.c +89 -0
- package/c/dispatch_i4.c +150 -0
- package/c/dispatch_i64.c +86 -0
- package/c/dispatch_i8.c +289 -0
- package/c/dispatch_other.c +330 -0
- package/c/dispatch_u1.c +148 -0
- package/c/dispatch_u16.c +124 -0
- package/c/dispatch_u32.c +118 -0
- package/c/dispatch_u4.c +150 -0
- package/c/dispatch_u64.c +102 -0
- package/c/dispatch_u8.c +303 -0
- package/c/numkong.c +950 -0
- package/include/README.md +573 -0
- package/include/module.modulemap +129 -0
- package/include/numkong/attention/sapphireamx.h +1361 -0
- package/include/numkong/attention/sme.h +2066 -0
- package/include/numkong/attention.h +49 -0
- package/include/numkong/capabilities.h +748 -0
- package/include/numkong/cast/README.md +262 -0
- package/include/numkong/cast/haswell.h +975 -0
- package/include/numkong/cast/icelake.h +470 -0
- package/include/numkong/cast/neon.h +1192 -0
- package/include/numkong/cast/rvv.h +1021 -0
- package/include/numkong/cast/sapphire.h +262 -0
- package/include/numkong/cast/serial.h +2262 -0
- package/include/numkong/cast/skylake.h +856 -0
- package/include/numkong/cast/v128relaxed.h +180 -0
- package/include/numkong/cast.h +230 -0
- package/include/numkong/curved/README.md +223 -0
- package/include/numkong/curved/genoa.h +182 -0
- package/include/numkong/curved/haswell.h +276 -0
- package/include/numkong/curved/neon.h +205 -0
- package/include/numkong/curved/neonbfdot.h +212 -0
- package/include/numkong/curved/neonhalf.h +212 -0
- package/include/numkong/curved/rvv.h +305 -0
- package/include/numkong/curved/serial.h +207 -0
- package/include/numkong/curved/skylake.h +457 -0
- package/include/numkong/curved/smef64.h +506 -0
- package/include/numkong/curved.h +517 -0
- package/include/numkong/curved.hpp +144 -0
- package/include/numkong/dot/README.md +425 -0
- package/include/numkong/dot/alder.h +563 -0
- package/include/numkong/dot/genoa.h +315 -0
- package/include/numkong/dot/haswell.h +1688 -0
- package/include/numkong/dot/icelake.h +883 -0
- package/include/numkong/dot/neon.h +818 -0
- package/include/numkong/dot/neonbfdot.h +244 -0
- package/include/numkong/dot/neonfhm.h +360 -0
- package/include/numkong/dot/neonhalf.h +198 -0
- package/include/numkong/dot/neonsdot.h +508 -0
- package/include/numkong/dot/rvv.h +714 -0
- package/include/numkong/dot/rvvbb.h +72 -0
- package/include/numkong/dot/rvvbf16.h +123 -0
- package/include/numkong/dot/rvvhalf.h +129 -0
- package/include/numkong/dot/sapphire.h +141 -0
- package/include/numkong/dot/serial.h +838 -0
- package/include/numkong/dot/sierra.h +405 -0
- package/include/numkong/dot/skylake.h +1084 -0
- package/include/numkong/dot/sve.h +379 -0
- package/include/numkong/dot/svebfdot.h +74 -0
- package/include/numkong/dot/svehalf.h +123 -0
- package/include/numkong/dot/v128relaxed.h +1258 -0
- package/include/numkong/dot.h +1070 -0
- package/include/numkong/dot.hpp +94 -0
- package/include/numkong/dots/README.md +496 -0
- package/include/numkong/dots/alder.h +114 -0
- package/include/numkong/dots/genoa.h +94 -0
- package/include/numkong/dots/haswell.h +295 -0
- package/include/numkong/dots/icelake.h +171 -0
- package/include/numkong/dots/neon.h +120 -0
- package/include/numkong/dots/neonbfdot.h +58 -0
- package/include/numkong/dots/neonfhm.h +94 -0
- package/include/numkong/dots/neonhalf.h +57 -0
- package/include/numkong/dots/neonsdot.h +108 -0
- package/include/numkong/dots/rvv.h +2486 -0
- package/include/numkong/dots/sapphireamx.h +3973 -0
- package/include/numkong/dots/serial.h +2844 -0
- package/include/numkong/dots/sierra.h +97 -0
- package/include/numkong/dots/skylake.h +196 -0
- package/include/numkong/dots/sme.h +5372 -0
- package/include/numkong/dots/smebi32.h +461 -0
- package/include/numkong/dots/smef64.h +1318 -0
- package/include/numkong/dots/smehalf.h +47 -0
- package/include/numkong/dots/v128relaxed.h +294 -0
- package/include/numkong/dots.h +2804 -0
- package/include/numkong/dots.hpp +639 -0
- package/include/numkong/each/README.md +469 -0
- package/include/numkong/each/haswell.h +1658 -0
- package/include/numkong/each/icelake.h +272 -0
- package/include/numkong/each/neon.h +1104 -0
- package/include/numkong/each/neonbfdot.h +212 -0
- package/include/numkong/each/neonhalf.h +410 -0
- package/include/numkong/each/rvv.h +1121 -0
- package/include/numkong/each/sapphire.h +477 -0
- package/include/numkong/each/serial.h +260 -0
- package/include/numkong/each/skylake.h +1562 -0
- package/include/numkong/each.h +2146 -0
- package/include/numkong/each.hpp +434 -0
- package/include/numkong/geospatial/README.md +147 -0
- package/include/numkong/geospatial/haswell.h +593 -0
- package/include/numkong/geospatial/neon.h +571 -0
- package/include/numkong/geospatial/rvv.h +701 -0
- package/include/numkong/geospatial/serial.h +309 -0
- package/include/numkong/geospatial/skylake.h +577 -0
- package/include/numkong/geospatial/v128relaxed.h +613 -0
- package/include/numkong/geospatial.h +453 -0
- package/include/numkong/geospatial.hpp +235 -0
- package/include/numkong/matrix.hpp +336 -0
- package/include/numkong/maxsim/README.md +187 -0
- package/include/numkong/maxsim/alder.h +511 -0
- package/include/numkong/maxsim/genoa.h +115 -0
- package/include/numkong/maxsim/haswell.h +553 -0
- package/include/numkong/maxsim/icelake.h +480 -0
- package/include/numkong/maxsim/neonsdot.h +394 -0
- package/include/numkong/maxsim/sapphireamx.h +877 -0
- package/include/numkong/maxsim/serial.h +490 -0
- package/include/numkong/maxsim/sme.h +929 -0
- package/include/numkong/maxsim/v128relaxed.h +280 -0
- package/include/numkong/maxsim.h +571 -0
- package/include/numkong/maxsim.hpp +133 -0
- package/include/numkong/mesh/README.md +227 -0
- package/include/numkong/mesh/haswell.h +2235 -0
- package/include/numkong/mesh/neon.h +1329 -0
- package/include/numkong/mesh/neonbfdot.h +842 -0
- package/include/numkong/mesh/neonhalf.h +616 -0
- package/include/numkong/mesh/rvv.h +916 -0
- package/include/numkong/mesh/serial.h +742 -0
- package/include/numkong/mesh/skylake.h +1135 -0
- package/include/numkong/mesh/v128relaxed.h +1052 -0
- package/include/numkong/mesh.h +652 -0
- package/include/numkong/mesh.hpp +762 -0
- package/include/numkong/numkong.h +78 -0
- package/include/numkong/numkong.hpp +57 -0
- package/include/numkong/probability/README.md +173 -0
- package/include/numkong/probability/haswell.h +267 -0
- package/include/numkong/probability/neon.h +225 -0
- package/include/numkong/probability/rvv.h +409 -0
- package/include/numkong/probability/serial.h +169 -0
- package/include/numkong/probability/skylake.h +324 -0
- package/include/numkong/probability.h +383 -0
- package/include/numkong/probability.hpp +120 -0
- package/include/numkong/random.h +50 -0
- package/include/numkong/random.hpp +285 -0
- package/include/numkong/reduce/README.md +547 -0
- package/include/numkong/reduce/alder.h +632 -0
- package/include/numkong/reduce/genoa.h +201 -0
- package/include/numkong/reduce/haswell.h +3783 -0
- package/include/numkong/reduce/icelake.h +549 -0
- package/include/numkong/reduce/neon.h +3841 -0
- package/include/numkong/reduce/neonbfdot.h +353 -0
- package/include/numkong/reduce/neonfhm.h +665 -0
- package/include/numkong/reduce/neonhalf.h +157 -0
- package/include/numkong/reduce/neonsdot.h +357 -0
- package/include/numkong/reduce/rvv.h +3407 -0
- package/include/numkong/reduce/serial.h +757 -0
- package/include/numkong/reduce/sierra.h +338 -0
- package/include/numkong/reduce/skylake.h +3792 -0
- package/include/numkong/reduce/v128relaxed.h +2302 -0
- package/include/numkong/reduce.h +1597 -0
- package/include/numkong/reduce.hpp +633 -0
- package/include/numkong/scalar/README.md +89 -0
- package/include/numkong/scalar/haswell.h +113 -0
- package/include/numkong/scalar/neon.h +122 -0
- package/include/numkong/scalar/neonhalf.h +70 -0
- package/include/numkong/scalar/rvv.h +211 -0
- package/include/numkong/scalar/sapphire.h +63 -0
- package/include/numkong/scalar/serial.h +332 -0
- package/include/numkong/scalar/v128relaxed.h +56 -0
- package/include/numkong/scalar.h +683 -0
- package/include/numkong/set/README.md +179 -0
- package/include/numkong/set/haswell.h +334 -0
- package/include/numkong/set/icelake.h +485 -0
- package/include/numkong/set/neon.h +364 -0
- package/include/numkong/set/rvv.h +226 -0
- package/include/numkong/set/rvvbb.h +117 -0
- package/include/numkong/set/serial.h +174 -0
- package/include/numkong/set/sve.h +185 -0
- package/include/numkong/set/v128relaxed.h +240 -0
- package/include/numkong/set.h +457 -0
- package/include/numkong/set.hpp +114 -0
- package/include/numkong/sets/README.md +149 -0
- package/include/numkong/sets/haswell.h +63 -0
- package/include/numkong/sets/icelake.h +66 -0
- package/include/numkong/sets/neon.h +61 -0
- package/include/numkong/sets/serial.h +43 -0
- package/include/numkong/sets/smebi32.h +1099 -0
- package/include/numkong/sets/v128relaxed.h +58 -0
- package/include/numkong/sets.h +339 -0
- package/include/numkong/sparse/README.md +156 -0
- package/include/numkong/sparse/icelake.h +463 -0
- package/include/numkong/sparse/neon.h +288 -0
- package/include/numkong/sparse/serial.h +117 -0
- package/include/numkong/sparse/sve2.h +507 -0
- package/include/numkong/sparse/turin.h +322 -0
- package/include/numkong/sparse.h +363 -0
- package/include/numkong/sparse.hpp +113 -0
- package/include/numkong/spatial/README.md +435 -0
- package/include/numkong/spatial/alder.h +607 -0
- package/include/numkong/spatial/genoa.h +290 -0
- package/include/numkong/spatial/haswell.h +960 -0
- package/include/numkong/spatial/icelake.h +586 -0
- package/include/numkong/spatial/neon.h +773 -0
- package/include/numkong/spatial/neonbfdot.h +165 -0
- package/include/numkong/spatial/neonhalf.h +118 -0
- package/include/numkong/spatial/neonsdot.h +261 -0
- package/include/numkong/spatial/rvv.h +984 -0
- package/include/numkong/spatial/rvvbf16.h +123 -0
- package/include/numkong/spatial/rvvhalf.h +117 -0
- package/include/numkong/spatial/sapphire.h +343 -0
- package/include/numkong/spatial/serial.h +346 -0
- package/include/numkong/spatial/sierra.h +323 -0
- package/include/numkong/spatial/skylake.h +606 -0
- package/include/numkong/spatial/sve.h +224 -0
- package/include/numkong/spatial/svebfdot.h +122 -0
- package/include/numkong/spatial/svehalf.h +109 -0
- package/include/numkong/spatial/v128relaxed.h +717 -0
- package/include/numkong/spatial.h +1425 -0
- package/include/numkong/spatial.hpp +183 -0
- package/include/numkong/spatials/README.md +580 -0
- package/include/numkong/spatials/alder.h +94 -0
- package/include/numkong/spatials/genoa.h +94 -0
- package/include/numkong/spatials/haswell.h +219 -0
- package/include/numkong/spatials/icelake.h +113 -0
- package/include/numkong/spatials/neon.h +109 -0
- package/include/numkong/spatials/neonbfdot.h +60 -0
- package/include/numkong/spatials/neonfhm.h +92 -0
- package/include/numkong/spatials/neonhalf.h +58 -0
- package/include/numkong/spatials/neonsdot.h +109 -0
- package/include/numkong/spatials/rvv.h +1960 -0
- package/include/numkong/spatials/sapphireamx.h +1149 -0
- package/include/numkong/spatials/serial.h +226 -0
- package/include/numkong/spatials/sierra.h +96 -0
- package/include/numkong/spatials/skylake.h +184 -0
- package/include/numkong/spatials/sme.h +1901 -0
- package/include/numkong/spatials/smef64.h +465 -0
- package/include/numkong/spatials/v128relaxed.h +240 -0
- package/include/numkong/spatials.h +3021 -0
- package/include/numkong/spatials.hpp +508 -0
- package/include/numkong/tensor.hpp +1592 -0
- package/include/numkong/trigonometry/README.md +184 -0
- package/include/numkong/trigonometry/haswell.h +652 -0
- package/include/numkong/trigonometry/neon.h +639 -0
- package/include/numkong/trigonometry/rvv.h +699 -0
- package/include/numkong/trigonometry/serial.h +703 -0
- package/include/numkong/trigonometry/skylake.h +721 -0
- package/include/numkong/trigonometry/v128relaxed.h +666 -0
- package/include/numkong/trigonometry.h +467 -0
- package/include/numkong/trigonometry.hpp +166 -0
- package/include/numkong/types.h +1384 -0
- package/include/numkong/types.hpp +5603 -0
- package/include/numkong/vector.hpp +698 -0
- package/javascript/README.md +246 -0
- package/javascript/dist/cjs/numkong-wasm.d.ts +166 -0
- package/javascript/dist/cjs/numkong-wasm.js +617 -0
- package/javascript/dist/cjs/numkong.d.ts +343 -0
- package/javascript/dist/cjs/numkong.js +523 -0
- package/javascript/dist/cjs/package.json +3 -0
- package/javascript/dist/cjs/types.d.ts +284 -0
- package/javascript/dist/cjs/types.js +653 -0
- package/javascript/dist/esm/numkong-wasm.d.ts +166 -0
- package/javascript/dist/esm/numkong-wasm.js +595 -0
- package/javascript/dist/esm/numkong.d.ts +343 -0
- package/javascript/dist/esm/numkong.js +452 -0
- package/javascript/dist/esm/package.json +3 -0
- package/javascript/dist/esm/types.d.ts +284 -0
- package/javascript/dist/esm/types.js +630 -0
- package/javascript/dist-package-cjs.json +3 -0
- package/javascript/dist-package-esm.json +3 -0
- package/javascript/node-gyp-build.d.ts +1 -0
- package/javascript/numkong-wasm.ts +756 -0
- package/javascript/numkong.c +689 -0
- package/javascript/numkong.ts +575 -0
- package/javascript/tsconfig-base.json +39 -0
- package/javascript/tsconfig-cjs.json +8 -0
- package/javascript/tsconfig-esm.json +8 -0
- package/javascript/types.ts +674 -0
- package/package.json +87 -0
|
@@ -0,0 +1,1592 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @brief NumKong Tensor types and tensor-level operations for C++23 and newer.
|
|
3
|
+
* @file include/numkong/tensor.hpp
|
|
4
|
+
* @author Ash Vardanian
|
|
5
|
+
* @date March 2026
|
|
6
|
+
*
|
|
7
|
+
* Provides owning and non-owning N-dimensional tensor types:
|
|
8
|
+
*
|
|
9
|
+
* - `nk::tensor<T, A, max_rank>`: Owning, non-resizable
|
|
10
|
+
* - `nk::tensor_view<T, max_rank>`: Non-owning, const
|
|
11
|
+
* - `nk::tensor_span<T, max_rank>`: Non-owning, mutable
|
|
12
|
+
* - `nk::matrix` / `nk::matrix_view` / `nk::matrix_span`: 2D aliases
|
|
13
|
+
*
|
|
14
|
+
* Tensor-level free functions:
|
|
15
|
+
* - Non-allocating scalar results: `sum(view)`, `min(view)`, `max(view)`, etc.
|
|
16
|
+
* - Allocating ops: `try_add(a, b)`, `try_sum(view, axis)`, etc.
|
|
17
|
+
* - In-place into pre-allocated output: `add_into(a, b, out)`, etc.
|
|
18
|
+
*
|
|
19
|
+
* Features:
|
|
20
|
+
* - Signed strides (ptrdiff_t) for reversed/transposed views
|
|
21
|
+
* - Signed indexing (negative = from end)
|
|
22
|
+
* - C++23 variadic `operator[]` for flat access, exact access, and trailing `slice`
|
|
23
|
+
* - Axis iteration (rows_views(), rows_spans(), axis_iterator)
|
|
24
|
+
* - Conversion to vector_view/vector_span for rank-1 tensors
|
|
25
|
+
*/
|
|
26
|
+
|
|
27
|
+
#ifndef NK_TENSOR_HPP
|
|
28
|
+
#define NK_TENSOR_HPP
|
|
29
|
+
|
|
30
|
+
#include <array> // `std::array`
|
|
31
|
+
#include <cstdio> // `std::fprintf`, `stderr`
|
|
32
|
+
#include <cstdlib> // `std::abort`
|
|
33
|
+
#include <cstring> // `std::memset`
|
|
34
|
+
#include <span> // `std::span`
|
|
35
|
+
#include <tuple> // `std::tuple_element_t`
|
|
36
|
+
#include <type_traits>
|
|
37
|
+
|
|
38
|
+
#include "vector.hpp" // `aligned_allocator`
|
|
39
|
+
|
|
40
|
+
namespace ashvardanian::numkong {
|
|
41
|
+
|
|
42
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
43
|
+
struct tensor_view;
|
|
44
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
45
|
+
struct tensor_span;
|
|
46
|
+
template <typename value_type_, typename allocator_type_, std::size_t max_rank_>
|
|
47
|
+
struct tensor;
|
|
48
|
+
|
|
49
|
+
struct tensor_slice_t {};
|
|
50
|
+
inline constexpr tensor_slice_t slice {};
|
|
51
|
+
|
|
52
|
+
template <typename... arg_types_>
|
|
53
|
+
struct trailing_tensor_slice_args_ : std::false_type {};
|
|
54
|
+
|
|
55
|
+
template <>
|
|
56
|
+
struct trailing_tensor_slice_args_<tensor_slice_t> : std::true_type {};
|
|
57
|
+
|
|
58
|
+
template <std::integral index_type_, typename... rest_types_>
|
|
59
|
+
struct trailing_tensor_slice_args_<index_type_, rest_types_...> : trailing_tensor_slice_args_<rest_types_...> {};
|
|
60
|
+
|
|
61
|
+
template <typename... rest_types_>
|
|
62
|
+
struct trailing_tensor_slice_args_<all_t, rest_types_...> : trailing_tensor_slice_args_<rest_types_...> {};
|
|
63
|
+
|
|
64
|
+
template <typename... rest_types_>
|
|
65
|
+
struct trailing_tensor_slice_args_<range, rest_types_...> : trailing_tensor_slice_args_<rest_types_...> {};
|
|
66
|
+
|
|
67
|
+
template <typename... arg_types_>
|
|
68
|
+
inline constexpr bool trailing_tensor_slice_args_v =
|
|
69
|
+
trailing_tensor_slice_args_<std::remove_cvref_t<arg_types_>...>::value;
|
|
70
|
+
|
|
71
|
+
#if defined(NDEBUG)
|
|
72
|
+
#define nk_assert_(expr) ((void)0)
|
|
73
|
+
#else
|
|
74
|
+
extern "C" [[noreturn]] inline void nk_assert_failure(char const *expr, char const *file, int line) noexcept {
|
|
75
|
+
std::fprintf(stderr, "NumKong assertion failed: %s (%s:%d)\n", expr, file, line);
|
|
76
|
+
std::abort();
|
|
77
|
+
}
|
|
78
|
+
#define nk_assert_(expr) ((expr) ? (void)0 : nk_assert_failure(#expr, __FILE__, __LINE__))
|
|
79
|
+
#endif
|
|
80
|
+
|
|
81
|
+
#pragma region - Shape Storage
|
|
82
|
+
|
|
83
|
+
/**
|
|
84
|
+
* @brief Inline fixed-capacity shape descriptor.
|
|
85
|
+
* @tparam max_rank_ Maximum number of dimensions supported.
|
|
86
|
+
*
|
|
87
|
+
* Stores extents and signed strides for up to `max_rank_` dimensions.
|
|
88
|
+
* For `max_rank_=2` (matrix), this is only 40 bytes.
|
|
89
|
+
* For `max_rank_=64`, this is 1032 bytes.
|
|
90
|
+
*/
|
|
91
|
+
template <std::size_t max_rank_>
|
|
92
|
+
struct shape_storage_ {
|
|
93
|
+
std::size_t extents[max_rank_] = {};
|
|
94
|
+
std::ptrdiff_t strides[max_rank_] = {};
|
|
95
|
+
std::size_t rank = 0;
|
|
96
|
+
|
|
97
|
+
/** @brief Total number of elements. */
|
|
98
|
+
constexpr std::size_t numel() const noexcept {
|
|
99
|
+
std::size_t n = 1;
|
|
100
|
+
for (std::size_t i = 0; i < rank; ++i) n *= extents[i];
|
|
101
|
+
return n;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
/** @brief Linearize multi-dimensional coordinates to a byte offset. */
|
|
105
|
+
constexpr std::ptrdiff_t linearize(std::size_t const *coords) const noexcept {
|
|
106
|
+
std::ptrdiff_t offset = 0;
|
|
107
|
+
for (std::size_t i = 0; i < rank; ++i) offset += static_cast<std::ptrdiff_t>(coords[i]) * strides[i];
|
|
108
|
+
return offset;
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
/** @brief Create contiguous (row-major) shape storage. */
|
|
112
|
+
static constexpr shape_storage_ contiguous(std::size_t const *exts, std::size_t rank_val,
|
|
113
|
+
std::size_t elem_bytes) noexcept {
|
|
114
|
+
shape_storage_ s;
|
|
115
|
+
s.rank = rank_val;
|
|
116
|
+
auto stride = static_cast<std::ptrdiff_t>(elem_bytes);
|
|
117
|
+
for (std::size_t i = rank_val; i > 0; --i) {
|
|
118
|
+
s.extents[i - 1] = exts[i - 1];
|
|
119
|
+
s.strides[i - 1] = stride;
|
|
120
|
+
stride *= static_cast<std::ptrdiff_t>(exts[i - 1]);
|
|
121
|
+
}
|
|
122
|
+
return s;
|
|
123
|
+
}
|
|
124
|
+
};
|
|
125
|
+
|
|
126
|
+
template <typename value_type_>
|
|
127
|
+
constexpr std::size_t dims_to_values_(std::size_t dims) noexcept {
|
|
128
|
+
return divide_round_up(dims, static_cast<std::size_t>(dimensions_per_value<value_type_>()));
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
132
|
+
constexpr std::size_t storage_values_for_shape_(shape_storage_<max_rank_> const &shape) noexcept {
|
|
133
|
+
if (shape.rank == 0) return 1;
|
|
134
|
+
std::size_t values = 1;
|
|
135
|
+
for (std::size_t i = 0; i < shape.rank; ++i) {
|
|
136
|
+
bool const is_last = i + 1 == shape.rank;
|
|
137
|
+
values *= is_last ? dims_to_values_<value_type_>(shape.extents[i]) : shape.extents[i];
|
|
138
|
+
}
|
|
139
|
+
return values;
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
143
|
+
constexpr shape_storage_<max_rank_> make_contiguous_shape_(std::size_t const *exts, std::size_t rank_val) noexcept {
|
|
144
|
+
shape_storage_<max_rank_> s;
|
|
145
|
+
s.rank = rank_val;
|
|
146
|
+
auto stride = static_cast<std::ptrdiff_t>(sizeof(value_type_));
|
|
147
|
+
for (std::size_t i = rank_val; i > 0; --i) {
|
|
148
|
+
s.extents[i - 1] = exts[i - 1];
|
|
149
|
+
s.strides[i - 1] = stride;
|
|
150
|
+
auto const extent_factor = i == rank_val ? dims_to_values_<value_type_>(exts[i - 1])
|
|
151
|
+
: static_cast<std::size_t>(exts[i - 1]);
|
|
152
|
+
stride *= static_cast<std::ptrdiff_t>(extent_factor);
|
|
153
|
+
}
|
|
154
|
+
return s;
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
158
|
+
constexpr bool is_tensor_contiguous_(shape_storage_<max_rank_> const &shape) noexcept {
|
|
159
|
+
if (shape.rank == 0) return true;
|
|
160
|
+
auto expected = static_cast<std::ptrdiff_t>(sizeof(value_type_));
|
|
161
|
+
for (std::size_t i = shape.rank; i > 0; --i) {
|
|
162
|
+
if (shape.strides[i - 1] != expected) return false;
|
|
163
|
+
auto const extent_factor = i == shape.rank ? dims_to_values_<value_type_>(shape.extents[i - 1])
|
|
164
|
+
: static_cast<std::size_t>(shape.extents[i - 1]);
|
|
165
|
+
expected *= static_cast<std::ptrdiff_t>(extent_factor);
|
|
166
|
+
}
|
|
167
|
+
return true;
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
171
|
+
constexpr bool packed_tensor_layout_supported_(shape_storage_<max_rank_> const &shape) noexcept {
|
|
172
|
+
if constexpr (dimensions_per_value<value_type_>() == 1) return true;
|
|
173
|
+
else return is_tensor_contiguous_<value_type_>(shape);
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
template <typename value_type_, std::size_t max_rank_, std::size_t... indices_, typename... index_types_>
|
|
177
|
+
constexpr std::array<std::size_t, sizeof...(indices_)> resolve_tensor_indices_(shape_storage_<max_rank_> const &shape,
|
|
178
|
+
std::index_sequence<indices_...>,
|
|
179
|
+
index_types_... idxs) noexcept {
|
|
180
|
+
return {resolve_index_(idxs, shape.extents[indices_])...};
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
template <typename value_type_, std::size_t max_rank_, std::size_t extent_>
|
|
184
|
+
decltype(auto) tensor_lookup_resolved_(tensor_view<value_type_, max_rank_> input,
|
|
185
|
+
std::span<std::size_t const, extent_> coords) noexcept;
|
|
186
|
+
|
|
187
|
+
template <typename value_type_, std::size_t max_rank_, std::size_t extent_>
|
|
188
|
+
decltype(auto) tensor_lookup_resolved_(tensor_span<value_type_, max_rank_> input,
|
|
189
|
+
std::span<std::size_t const, extent_> coords) noexcept;
|
|
190
|
+
|
|
191
|
+
template <typename value_type_, std::size_t max_rank_, typename index_type_>
|
|
192
|
+
decltype(auto) tensor_flat_lookup_(tensor_view<value_type_, max_rank_> input, index_type_ idx) noexcept;
|
|
193
|
+
|
|
194
|
+
template <typename value_type_, std::size_t max_rank_, typename index_type_>
|
|
195
|
+
decltype(auto) tensor_flat_lookup_(tensor_span<value_type_, max_rank_> input, index_type_ idx) noexcept;
|
|
196
|
+
|
|
197
|
+
template <typename tensor_type_>
|
|
198
|
+
tensor_type_ tensor_slice_suffix_(tensor_type_ input, tensor_slice_t) noexcept;
|
|
199
|
+
|
|
200
|
+
template <typename tensor_type_, std::integral index_type_, typename... rest_types_>
|
|
201
|
+
tensor_type_ tensor_slice_suffix_(tensor_type_ input, index_type_ idx, rest_types_... rest) noexcept;
|
|
202
|
+
|
|
203
|
+
template <typename tensor_type_, typename... rest_types_>
|
|
204
|
+
tensor_type_ tensor_slice_suffix_(tensor_type_ input, all_t, rest_types_... rest) noexcept;
|
|
205
|
+
|
|
206
|
+
template <typename tensor_type_, typename... rest_types_>
|
|
207
|
+
tensor_type_ tensor_slice_suffix_(tensor_type_ input, range r, rest_types_... rest) noexcept;
|
|
208
|
+
|
|
209
|
+
#pragma endregion - Shape Storage
|
|
210
|
+
|
|
211
|
+
#pragma region - Tensor View
|
|
212
|
+
|
|
213
|
+
template <typename view_type_>
|
|
214
|
+
class axis_iterator;
|
|
215
|
+
|
|
216
|
+
/**
|
|
217
|
+
* @brief Non-owning, immutable, N-dimensional view.
|
|
218
|
+
* @tparam value_type_ Element type.
|
|
219
|
+
* @tparam max_rank_ Maximum number of dimensions.
|
|
220
|
+
*/
|
|
221
|
+
template <typename value_type_, std::size_t max_rank_ = 8>
|
|
222
|
+
struct tensor_view {
|
|
223
|
+
using value_type = value_type_;
|
|
224
|
+
using size_type = std::size_t;
|
|
225
|
+
using difference_type = std::ptrdiff_t;
|
|
226
|
+
|
|
227
|
+
private:
|
|
228
|
+
char const *data_ = nullptr;
|
|
229
|
+
shape_storage_<max_rank_> shape_;
|
|
230
|
+
|
|
231
|
+
public:
|
|
232
|
+
constexpr tensor_view() noexcept = default;
|
|
233
|
+
|
|
234
|
+
constexpr tensor_view(char const *data, shape_storage_<max_rank_> const &shape) noexcept
|
|
235
|
+
: data_(data), shape_(shape) {}
|
|
236
|
+
|
|
237
|
+
/** @brief Convenience constructor for rank-2 views from typed pointer, rows, and cols. */
|
|
238
|
+
tensor_view(value_type const *data, size_type rows, size_type cols) noexcept
|
|
239
|
+
requires(max_rank_ >= 2)
|
|
240
|
+
: data_(reinterpret_cast<char const *>(data)) {
|
|
241
|
+
std::size_t extents[2] = {rows, cols};
|
|
242
|
+
shape_ = make_contiguous_shape_<value_type_, max_rank_>(extents, 2);
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
/** @brief Number of dimensions. */
|
|
246
|
+
constexpr size_type rank() const noexcept { return shape_.rank; }
|
|
247
|
+
|
|
248
|
+
/** @brief Extent along the i-th dimension. */
|
|
249
|
+
constexpr size_type extent(size_type i) const noexcept { return shape_.extents[i]; }
|
|
250
|
+
|
|
251
|
+
/** @brief Stride in bytes along the i-th dimension (signed). */
|
|
252
|
+
constexpr difference_type stride_bytes(size_type i) const noexcept { return shape_.strides[i]; }
|
|
253
|
+
|
|
254
|
+
/** @brief Total number of elements. */
|
|
255
|
+
constexpr size_type numel() const noexcept { return shape_.numel(); }
|
|
256
|
+
|
|
257
|
+
/** @brief True if empty. */
|
|
258
|
+
constexpr bool empty() const noexcept { return data_ == nullptr || shape_.numel() == 0; }
|
|
259
|
+
|
|
260
|
+
/** @brief Raw byte pointer. */
|
|
261
|
+
constexpr char const *byte_data() const noexcept { return data_; }
|
|
262
|
+
|
|
263
|
+
/** @brief Typed pointer (assumes data is contiguous from this pointer). */
|
|
264
|
+
constexpr value_type const *data() const noexcept { return reinterpret_cast<value_type const *>(data_); }
|
|
265
|
+
|
|
266
|
+
/** @brief Access the shape storage. */
|
|
267
|
+
constexpr shape_storage_<max_rank_> const &shape() const noexcept { return shape_; }
|
|
268
|
+
|
|
269
|
+
/** @brief Slice along the leading dimension. */
|
|
270
|
+
template <std::integral index_type_>
|
|
271
|
+
tensor_view<value_type_, max_rank_> slice_leading(index_type_ idx) const noexcept {
|
|
272
|
+
nk_assert_(shape_.rank >= 1);
|
|
273
|
+
if (shape_.rank == 0) return {};
|
|
274
|
+
auto i = resolve_index_(idx, shape_.extents[0]);
|
|
275
|
+
auto offset = static_cast<difference_type>(i) * shape_.strides[0];
|
|
276
|
+
shape_storage_<max_rank_> sub;
|
|
277
|
+
sub.rank = shape_.rank - 1;
|
|
278
|
+
for (size_type d = 0; d < sub.rank; ++d) {
|
|
279
|
+
sub.extents[d] = shape_.extents[d + 1];
|
|
280
|
+
sub.strides[d] = shape_.strides[d + 1];
|
|
281
|
+
}
|
|
282
|
+
return {data_ + offset, sub};
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
/** @brief Row access (alias for slice_leading). */
|
|
286
|
+
template <std::integral index_type_>
|
|
287
|
+
tensor_view<value_type_, max_rank_> row(index_type_ i) const noexcept {
|
|
288
|
+
return slice_leading(i);
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
/** @brief Rank-0 scalar access. */
|
|
292
|
+
template <std::integral index_type_>
|
|
293
|
+
decltype(auto) operator[](index_type_ idx) const noexcept {
|
|
294
|
+
return tensor_flat_lookup_(*this, idx);
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
/** @brief Exact multi-dimensional scalar lookup. */
|
|
298
|
+
template <std::integral... index_types_>
|
|
299
|
+
requires(sizeof...(index_types_) >= 2)
|
|
300
|
+
decltype(auto) operator[](index_types_... idxs) const noexcept {
|
|
301
|
+
nk_assert_(shape_.rank == sizeof...(index_types_));
|
|
302
|
+
auto coords = resolve_tensor_indices_<value_type_>(shape_, std::index_sequence_for<index_types_...> {},
|
|
303
|
+
idxs...);
|
|
304
|
+
return tensor_lookup_resolved_(*this, std::span<std::size_t const, sizeof...(index_types_)>(coords));
|
|
305
|
+
}
|
|
306
|
+
|
|
307
|
+
/** @brief Trailing `slice` returns the same view. */
|
|
308
|
+
constexpr tensor_view operator[](tensor_slice_t) const noexcept { return *this; }
|
|
309
|
+
|
|
310
|
+
/** @brief Prefix leading-axis slicing with a trailing `slice` marker. */
|
|
311
|
+
template <typename first_type_, typename second_type_, typename... rest_types_>
|
|
312
|
+
requires(trailing_tensor_slice_args_v<first_type_, second_type_, rest_types_...>)
|
|
313
|
+
tensor_view operator[](first_type_ first, second_type_ second, rest_types_... rest) const noexcept {
|
|
314
|
+
return tensor_slice_suffix_(*this, first, second, rest...);
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
/** @brief Rank-0 scalar access. */
|
|
318
|
+
decltype(auto) scalar() const noexcept {
|
|
319
|
+
nk_assert_(shape_.rank == 0);
|
|
320
|
+
nk_assert_(data_ != nullptr);
|
|
321
|
+
return *reinterpret_cast<value_type_ const *>(data_);
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
/** @brief Convert to vector_view (requires rank == 1). */
|
|
325
|
+
vector_view<value_type> as_vector() const noexcept {
|
|
326
|
+
nk_assert_(shape_.rank == 1);
|
|
327
|
+
if (shape_.rank != 1) return {};
|
|
328
|
+
return {data_, shape_.extents[0], shape_.strides[0]};
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
/** @brief Reinterpret as a 2D matrix view. Requires rank >= 2. */
|
|
332
|
+
tensor_view<value_type_, 2> as_matrix() const noexcept {
|
|
333
|
+
nk_assert_(shape_.rank >= 2);
|
|
334
|
+
if (shape_.rank < 2) return {};
|
|
335
|
+
shape_storage_<2> matrix_shape;
|
|
336
|
+
matrix_shape.rank = 2;
|
|
337
|
+
matrix_shape.extents[0] = shape_.extents[0];
|
|
338
|
+
matrix_shape.extents[1] = shape_.extents[1];
|
|
339
|
+
matrix_shape.strides[0] = shape_.strides[0];
|
|
340
|
+
matrix_shape.strides[1] = shape_.strides[1];
|
|
341
|
+
return {data_, matrix_shape};
|
|
342
|
+
}
|
|
343
|
+
|
|
344
|
+
/** @brief Check if the tensor is contiguous in memory. */
|
|
345
|
+
constexpr bool is_contiguous() const noexcept { return is_tensor_contiguous_<value_type>(shape_); }
|
|
346
|
+
|
|
347
|
+
/** @brief Transpose: reverse the order of all dimensions (swap extents and strides). */
|
|
348
|
+
constexpr tensor_view transpose() const noexcept {
|
|
349
|
+
if constexpr (dimensions_per_value<value_type>() > 1) {
|
|
350
|
+
if (shape_.rank >= 2) return {};
|
|
351
|
+
}
|
|
352
|
+
if (shape_.rank < 2) return *this;
|
|
353
|
+
auto transposed = shape_;
|
|
354
|
+
for (size_type i = 0; i < transposed.rank / 2; ++i) {
|
|
355
|
+
std::swap(transposed.extents[i], transposed.extents[transposed.rank - 1 - i]);
|
|
356
|
+
std::swap(transposed.strides[i], transposed.strides[transposed.rank - 1 - i]);
|
|
357
|
+
}
|
|
358
|
+
return {data_, transposed};
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
/** @brief Reshape to new extents (requires contiguous layout and matching element count).
|
|
362
|
+
* Returns an empty view if the tensor is not contiguous or element counts don't match. */
|
|
363
|
+
tensor_view reshape(std::initializer_list<size_type> new_extents) const noexcept {
|
|
364
|
+
auto new_rank = new_extents.size();
|
|
365
|
+
if (!is_contiguous() || new_rank > max_rank_ || new_rank == 0) return {};
|
|
366
|
+
auto new_shape = make_contiguous_shape_<value_type, max_rank_>(new_extents.begin(), new_rank);
|
|
367
|
+
if (storage_values_for_shape_<value_type>(new_shape) != storage_values_for_shape_<value_type>(shape_))
|
|
368
|
+
return {};
|
|
369
|
+
return {data_, new_shape};
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
/** @brief Range of sub-views along the leading dimension. */
|
|
373
|
+
struct rows_views_t {
|
|
374
|
+
tensor_view parent;
|
|
375
|
+
axis_iterator<tensor_view> begin() const noexcept { return {parent, 0}; }
|
|
376
|
+
axis_iterator<tensor_view> end() const noexcept { return {parent, parent.extent(0)}; }
|
|
377
|
+
};
|
|
378
|
+
|
|
379
|
+
rows_views_t rows() const noexcept { return {*this}; }
|
|
380
|
+
|
|
381
|
+
/** @brief Flatten to 1D view (requires contiguous layout). Returns empty view if not contiguous. */
|
|
382
|
+
tensor_view flatten() const noexcept { return reshape({numel()}); }
|
|
383
|
+
|
|
384
|
+
/** @brief Remove dimensions of size 1. */
|
|
385
|
+
tensor_view squeeze() const noexcept {
|
|
386
|
+
auto result = shape_;
|
|
387
|
+
size_type new_rank = 0;
|
|
388
|
+
for (size_type i = 0; i < shape_.rank; ++i) {
|
|
389
|
+
if (shape_.extents[i] != 1) {
|
|
390
|
+
result.extents[new_rank] = shape_.extents[i];
|
|
391
|
+
result.strides[new_rank] = shape_.strides[i];
|
|
392
|
+
++new_rank;
|
|
393
|
+
}
|
|
394
|
+
}
|
|
395
|
+
if (new_rank == 0) {
|
|
396
|
+
new_rank = 1;
|
|
397
|
+
result.extents[0] = 1;
|
|
398
|
+
result.strides[0] = static_cast<difference_type>(sizeof(value_type));
|
|
399
|
+
}
|
|
400
|
+
result.rank = new_rank;
|
|
401
|
+
return {data_, result};
|
|
402
|
+
}
|
|
403
|
+
};
|
|
404
|
+
|
|
405
|
+
#pragma endregion - Tensor View
|
|
406
|
+
|
|
407
|
+
#pragma region - Tensor Span
|
|
408
|
+
|
|
409
|
+
/**
|
|
410
|
+
* @brief Non-owning, mutable, N-dimensional view.
|
|
411
|
+
* @tparam value_type_ Element type.
|
|
412
|
+
* @tparam max_rank_ Maximum number of dimensions.
|
|
413
|
+
*/
|
|
414
|
+
template <typename value_type_, std::size_t max_rank_ = 8>
|
|
415
|
+
struct tensor_span {
|
|
416
|
+
using value_type = value_type_;
|
|
417
|
+
using size_type = std::size_t;
|
|
418
|
+
using difference_type = std::ptrdiff_t;
|
|
419
|
+
|
|
420
|
+
private:
|
|
421
|
+
char *data_ = nullptr;
|
|
422
|
+
shape_storage_<max_rank_> shape_;
|
|
423
|
+
|
|
424
|
+
public:
|
|
425
|
+
constexpr tensor_span() noexcept = default;
|
|
426
|
+
|
|
427
|
+
constexpr tensor_span(char *data, shape_storage_<max_rank_> const &shape) noexcept : data_(data), shape_(shape) {}
|
|
428
|
+
|
|
429
|
+
/** @brief Convenience constructor for rank-2 spans from typed pointer, rows, and cols. */
|
|
430
|
+
tensor_span(value_type *data, size_type rows, size_type cols) noexcept
|
|
431
|
+
requires(max_rank_ >= 2)
|
|
432
|
+
: data_(reinterpret_cast<char *>(data)) {
|
|
433
|
+
std::size_t extents[2] = {rows, cols};
|
|
434
|
+
shape_ = make_contiguous_shape_<value_type_, max_rank_>(extents, 2);
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
/** @brief Number of dimensions. */
|
|
438
|
+
constexpr size_type rank() const noexcept { return shape_.rank; }
|
|
439
|
+
/** @brief Extent along the i-th dimension. */
|
|
440
|
+
constexpr size_type extent(size_type i) const noexcept { return shape_.extents[i]; }
|
|
441
|
+
/** @brief Stride in bytes along the i-th dimension (signed). */
|
|
442
|
+
constexpr difference_type stride_bytes(size_type i) const noexcept { return shape_.strides[i]; }
|
|
443
|
+
/** @brief Total number of elements. */
|
|
444
|
+
constexpr size_type numel() const noexcept { return shape_.numel(); }
|
|
445
|
+
/** @brief True if empty. */
|
|
446
|
+
constexpr bool empty() const noexcept { return data_ == nullptr || shape_.numel() == 0; }
|
|
447
|
+
|
|
448
|
+
/** @brief Raw byte pointer. */
|
|
449
|
+
constexpr char *byte_data() noexcept { return data_; }
|
|
450
|
+
constexpr char const *byte_data() const noexcept { return data_; }
|
|
451
|
+
/** @brief Typed pointer (assumes data is contiguous from this pointer). */
|
|
452
|
+
constexpr value_type *data() noexcept { return reinterpret_cast<value_type *>(data_); }
|
|
453
|
+
constexpr value_type const *data() const noexcept { return reinterpret_cast<value_type const *>(data_); }
|
|
454
|
+
/** @brief Access the shape storage. */
|
|
455
|
+
constexpr shape_storage_<max_rank_> const &shape() const noexcept { return shape_; }
|
|
456
|
+
|
|
457
|
+
/** @brief Implicit conversion to const view. */
|
|
458
|
+
constexpr operator tensor_view<value_type_, max_rank_>() const noexcept {
|
|
459
|
+
return {static_cast<char const *>(data_), shape_};
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
/** @brief Slice along leading dimension. */
|
|
463
|
+
template <std::integral index_type_>
|
|
464
|
+
tensor_span slice_leading(index_type_ idx) const noexcept {
|
|
465
|
+
nk_assert_(shape_.rank >= 1);
|
|
466
|
+
if (shape_.rank == 0) return {};
|
|
467
|
+
auto i = resolve_index_(idx, shape_.extents[0]);
|
|
468
|
+
auto offset = static_cast<difference_type>(i) * shape_.strides[0];
|
|
469
|
+
shape_storage_<max_rank_> sub;
|
|
470
|
+
sub.rank = shape_.rank - 1;
|
|
471
|
+
for (size_type d = 0; d < sub.rank; ++d) {
|
|
472
|
+
sub.extents[d] = shape_.extents[d + 1];
|
|
473
|
+
sub.strides[d] = shape_.strides[d + 1];
|
|
474
|
+
}
|
|
475
|
+
return {data_ + offset, sub};
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
/** @brief Mutable row access (alias for slice_leading). */
|
|
479
|
+
template <std::integral index_type_>
|
|
480
|
+
tensor_span row(index_type_ i) const noexcept {
|
|
481
|
+
return slice_leading(i);
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
/** @brief Flat logical scalar access. */
|
|
485
|
+
template <std::integral index_type_>
|
|
486
|
+
decltype(auto) operator[](index_type_ idx) noexcept {
|
|
487
|
+
return tensor_flat_lookup_(*this, idx);
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
/** @brief Const flat logical scalar access. */
|
|
491
|
+
template <std::integral index_type_>
|
|
492
|
+
decltype(auto) operator[](index_type_ idx) const noexcept {
|
|
493
|
+
return tensor_flat_lookup_(static_cast<tensor_view<value_type_, max_rank_>>(*this), idx);
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
/** @brief Exact multi-dimensional scalar lookup. */
|
|
497
|
+
template <std::integral... index_types_>
|
|
498
|
+
requires(sizeof...(index_types_) >= 2)
|
|
499
|
+
decltype(auto) operator[](index_types_... idxs) noexcept {
|
|
500
|
+
nk_assert_(shape_.rank == sizeof...(index_types_));
|
|
501
|
+
auto coords = resolve_tensor_indices_<value_type_>(shape_, std::index_sequence_for<index_types_...> {},
|
|
502
|
+
idxs...);
|
|
503
|
+
return tensor_lookup_resolved_(*this, std::span<std::size_t const, sizeof...(index_types_)>(coords));
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
/** @brief Const full-coordinate lookup. */
|
|
507
|
+
template <std::integral... index_types_>
|
|
508
|
+
requires(sizeof...(index_types_) >= 2)
|
|
509
|
+
decltype(auto) operator[](index_types_... idxs) const noexcept {
|
|
510
|
+
return static_cast<tensor_view<value_type_, max_rank_>>(*this)[idxs...];
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
/** @brief Trailing `slice` returns the same span. */
|
|
514
|
+
constexpr tensor_span operator[](tensor_slice_t) noexcept { return *this; }
|
|
515
|
+
constexpr tensor_view<value_type_, max_rank_> operator[](tensor_slice_t) const noexcept {
|
|
516
|
+
return static_cast<tensor_view<value_type_, max_rank_>>(*this);
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
/** @brief Prefix leading-axis slicing with a trailing `slice` marker. */
|
|
520
|
+
template <typename first_type_, typename second_type_, typename... rest_types_>
|
|
521
|
+
requires(trailing_tensor_slice_args_v<first_type_, second_type_, rest_types_...>)
|
|
522
|
+
tensor_span operator[](first_type_ first, second_type_ second, rest_types_... rest) noexcept {
|
|
523
|
+
return tensor_slice_suffix_(*this, first, second, rest...);
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
/** @brief Const prefix leading-axis slicing with a trailing `slice` marker. */
|
|
527
|
+
template <typename first_type_, typename second_type_, typename... rest_types_>
|
|
528
|
+
requires(trailing_tensor_slice_args_v<first_type_, second_type_, rest_types_...>)
|
|
529
|
+
tensor_view<value_type_, max_rank_> operator[](first_type_ first, second_type_ second,
|
|
530
|
+
rest_types_... rest) const noexcept {
|
|
531
|
+
return tensor_slice_suffix_(static_cast<tensor_view<value_type_, max_rank_>>(*this), first, second, rest...);
|
|
532
|
+
}
|
|
533
|
+
|
|
534
|
+
/** @brief Rank-0 mutable scalar access. */
|
|
535
|
+
decltype(auto) scalar_ref() noexcept {
|
|
536
|
+
nk_assert_(shape_.rank == 0);
|
|
537
|
+
nk_assert_(data_ != nullptr);
|
|
538
|
+
return *reinterpret_cast<value_type_ *>(data_);
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
/** @brief Rank-0 const scalar access. */
|
|
542
|
+
decltype(auto) scalar() const noexcept { return static_cast<tensor_view<value_type_, max_rank_>>(*this).scalar(); }
|
|
543
|
+
|
|
544
|
+
/** @brief Convert to vector_span (requires rank == 1). */
|
|
545
|
+
vector_span<value_type> as_vector() noexcept {
|
|
546
|
+
nk_assert_(shape_.rank == 1);
|
|
547
|
+
if (shape_.rank != 1) return {};
|
|
548
|
+
return {data_, shape_.extents[0], shape_.strides[0]};
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
/** @brief Convert to vector_view (requires rank == 1). */
|
|
552
|
+
vector_view<value_type> as_vector() const noexcept {
|
|
553
|
+
nk_assert_(shape_.rank == 1);
|
|
554
|
+
if (shape_.rank != 1) return {};
|
|
555
|
+
return {static_cast<char const *>(data_), shape_.extents[0], shape_.strides[0]};
|
|
556
|
+
}
|
|
557
|
+
|
|
558
|
+
/** @brief Reinterpret as a 2D matrix span. Requires rank >= 2. */
|
|
559
|
+
tensor_span<value_type_, 2> as_matrix() noexcept {
|
|
560
|
+
nk_assert_(shape_.rank >= 2);
|
|
561
|
+
if (shape_.rank < 2) return {};
|
|
562
|
+
shape_storage_<2> matrix_shape;
|
|
563
|
+
matrix_shape.rank = 2;
|
|
564
|
+
matrix_shape.extents[0] = shape_.extents[0];
|
|
565
|
+
matrix_shape.extents[1] = shape_.extents[1];
|
|
566
|
+
matrix_shape.strides[0] = shape_.strides[0];
|
|
567
|
+
matrix_shape.strides[1] = shape_.strides[1];
|
|
568
|
+
return {data_, matrix_shape};
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
/** @brief Reinterpret as a 2D const matrix view. Requires rank >= 2. */
|
|
572
|
+
tensor_view<value_type_, 2> as_matrix() const noexcept {
|
|
573
|
+
nk_assert_(shape_.rank >= 2);
|
|
574
|
+
if (shape_.rank < 2) return {};
|
|
575
|
+
shape_storage_<2> matrix_shape;
|
|
576
|
+
matrix_shape.rank = 2;
|
|
577
|
+
matrix_shape.extents[0] = shape_.extents[0];
|
|
578
|
+
matrix_shape.extents[1] = shape_.extents[1];
|
|
579
|
+
matrix_shape.strides[0] = shape_.strides[0];
|
|
580
|
+
matrix_shape.strides[1] = shape_.strides[1];
|
|
581
|
+
return {static_cast<char const *>(data_), matrix_shape};
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
/** @brief Check if contiguous in memory. */
|
|
585
|
+
constexpr bool is_contiguous() const noexcept { return is_tensor_contiguous_<value_type>(shape_); }
|
|
586
|
+
|
|
587
|
+
/** @brief Transpose: reverse the order of all dimensions (swap extents and strides). */
|
|
588
|
+
constexpr tensor_span transpose() noexcept {
|
|
589
|
+
if constexpr (dimensions_per_value<value_type>() > 1) {
|
|
590
|
+
if (shape_.rank >= 2) return {};
|
|
591
|
+
}
|
|
592
|
+
if (shape_.rank < 2) return *this;
|
|
593
|
+
auto transposed = shape_;
|
|
594
|
+
for (size_type i = 0; i < transposed.rank / 2; ++i) {
|
|
595
|
+
std::swap(transposed.extents[i], transposed.extents[transposed.rank - 1 - i]);
|
|
596
|
+
std::swap(transposed.strides[i], transposed.strides[transposed.rank - 1 - i]);
|
|
597
|
+
}
|
|
598
|
+
return {data_, transposed};
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
/** @brief Reshape to new extents (requires contiguous layout and matching element count).
|
|
602
|
+
* Returns an empty span if not contiguous or element counts don't match. */
|
|
603
|
+
tensor_span reshape(std::initializer_list<size_type> new_extents) noexcept {
|
|
604
|
+
auto new_rank = new_extents.size();
|
|
605
|
+
if (!is_contiguous() || new_rank > max_rank_ || new_rank == 0) return {};
|
|
606
|
+
auto new_shape = make_contiguous_shape_<value_type, max_rank_>(new_extents.begin(), new_rank);
|
|
607
|
+
if (storage_values_for_shape_<value_type>(new_shape) != storage_values_for_shape_<value_type>(shape_))
|
|
608
|
+
return {};
|
|
609
|
+
return {data_, new_shape};
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
/** @brief Range of mutable sub-spans along the leading dimension. */
|
|
613
|
+
struct rows_spans_t {
|
|
614
|
+
tensor_span parent;
|
|
615
|
+
axis_iterator<tensor_span> begin() noexcept { return {parent, 0}; }
|
|
616
|
+
axis_iterator<tensor_span> end() noexcept { return {parent, parent.extent(0)}; }
|
|
617
|
+
};
|
|
618
|
+
|
|
619
|
+
rows_spans_t rows() noexcept { return {*this}; }
|
|
620
|
+
|
|
621
|
+
/** @brief Range of immutable sub-views along the leading dimension. */
|
|
622
|
+
struct rows_views_t {
|
|
623
|
+
tensor_view<value_type_, max_rank_> parent;
|
|
624
|
+
axis_iterator<tensor_view<value_type_, max_rank_>> begin() const noexcept { return {parent, 0}; }
|
|
625
|
+
axis_iterator<tensor_view<value_type_, max_rank_>> end() const noexcept { return {parent, parent.extent(0)}; }
|
|
626
|
+
};
|
|
627
|
+
|
|
628
|
+
rows_views_t rows() const noexcept {
|
|
629
|
+
tensor_view<value_type_, max_rank_> v = *this;
|
|
630
|
+
return {v};
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
/** @brief Flatten to 1D span (requires contiguous layout). Returns empty span if not contiguous. */
|
|
634
|
+
tensor_span flatten() noexcept { return reshape({numel()}); }
|
|
635
|
+
|
|
636
|
+
/** @brief Remove dimensions of size 1. */
|
|
637
|
+
tensor_span squeeze() noexcept {
|
|
638
|
+
auto result = shape_;
|
|
639
|
+
size_type new_rank = 0;
|
|
640
|
+
for (size_type i = 0; i < shape_.rank; ++i) {
|
|
641
|
+
if (shape_.extents[i] != 1) {
|
|
642
|
+
result.extents[new_rank] = shape_.extents[i];
|
|
643
|
+
result.strides[new_rank] = shape_.strides[i];
|
|
644
|
+
++new_rank;
|
|
645
|
+
}
|
|
646
|
+
}
|
|
647
|
+
if (new_rank == 0) {
|
|
648
|
+
new_rank = 1;
|
|
649
|
+
result.extents[0] = 1;
|
|
650
|
+
result.strides[0] = static_cast<difference_type>(sizeof(value_type));
|
|
651
|
+
}
|
|
652
|
+
result.rank = new_rank;
|
|
653
|
+
return {data_, result};
|
|
654
|
+
}
|
|
655
|
+
};
|
|
656
|
+
|
|
657
|
+
#pragma endregion - Tensor Span
|
|
658
|
+
|
|
659
|
+
template <typename value_type_, std::size_t max_rank_, std::size_t extent_>
|
|
660
|
+
decltype(auto) tensor_lookup_resolved_(tensor_view<value_type_, max_rank_> input,
|
|
661
|
+
std::span<std::size_t const, extent_> coords) noexcept {
|
|
662
|
+
nk_assert_(input.byte_data() != nullptr);
|
|
663
|
+
nk_assert_(coords.size() == input.rank());
|
|
664
|
+
if constexpr (dimensions_per_value<value_type_>() > 1) {
|
|
665
|
+
nk_assert_(packed_tensor_layout_supported_<value_type_>(input.shape()));
|
|
666
|
+
auto offset = std::ptrdiff_t {};
|
|
667
|
+
for (std::size_t i = 0; i + 1 < input.rank(); ++i)
|
|
668
|
+
offset += static_cast<std::ptrdiff_t>(coords[i]) * input.stride_bytes(i);
|
|
669
|
+
constexpr auto dims_per_value = dimensions_per_value<value_type_>();
|
|
670
|
+
auto last_index = coords[input.rank() - 1];
|
|
671
|
+
auto value_index = last_index / dims_per_value;
|
|
672
|
+
auto sub_index = last_index % dims_per_value;
|
|
673
|
+
using raw_type = typename raw_pod_type<value_type_>::type;
|
|
674
|
+
auto *base = const_cast<raw_type *>(reinterpret_cast<raw_type const *>(
|
|
675
|
+
input.byte_data() + offset +
|
|
676
|
+
static_cast<std::ptrdiff_t>(value_index) * input.stride_bytes(input.rank() - 1)));
|
|
677
|
+
return sub_byte_ref<value_type_>(base, sub_index).get();
|
|
678
|
+
}
|
|
679
|
+
else {
|
|
680
|
+
auto offset = input.shape().linearize(coords.data());
|
|
681
|
+
return *reinterpret_cast<value_type_ const *>(input.byte_data() + offset);
|
|
682
|
+
}
|
|
683
|
+
}
|
|
684
|
+
|
|
685
|
+
template <typename value_type_, std::size_t max_rank_, std::size_t extent_>
|
|
686
|
+
decltype(auto) tensor_lookup_resolved_(tensor_span<value_type_, max_rank_> input,
|
|
687
|
+
std::span<std::size_t const, extent_> coords) noexcept {
|
|
688
|
+
nk_assert_(input.byte_data() != nullptr);
|
|
689
|
+
nk_assert_(coords.size() == input.rank());
|
|
690
|
+
if constexpr (dimensions_per_value<value_type_>() > 1) {
|
|
691
|
+
nk_assert_(packed_tensor_layout_supported_<value_type_>(input.shape()));
|
|
692
|
+
auto offset = std::ptrdiff_t {};
|
|
693
|
+
for (std::size_t i = 0; i + 1 < input.rank(); ++i)
|
|
694
|
+
offset += static_cast<std::ptrdiff_t>(coords[i]) * input.stride_bytes(i);
|
|
695
|
+
constexpr auto dims_per_value = dimensions_per_value<value_type_>();
|
|
696
|
+
auto last_index = coords[input.rank() - 1];
|
|
697
|
+
auto value_index = last_index / dims_per_value;
|
|
698
|
+
auto sub_index = last_index % dims_per_value;
|
|
699
|
+
using raw_type = typename raw_pod_type<value_type_>::type;
|
|
700
|
+
auto *base = reinterpret_cast<raw_type *>(input.byte_data() + offset +
|
|
701
|
+
static_cast<std::ptrdiff_t>(value_index) *
|
|
702
|
+
input.stride_bytes(input.rank() - 1));
|
|
703
|
+
return sub_byte_ref<value_type_>(base, sub_index);
|
|
704
|
+
}
|
|
705
|
+
else {
|
|
706
|
+
auto offset = input.shape().linearize(coords.data());
|
|
707
|
+
return *reinterpret_cast<value_type_ *>(input.byte_data() + offset);
|
|
708
|
+
}
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
template <typename value_type_, std::size_t max_rank_, typename index_type_>
|
|
712
|
+
decltype(auto) tensor_flat_lookup_(tensor_view<value_type_, max_rank_> input, index_type_ idx) noexcept {
|
|
713
|
+
nk_assert_(input.byte_data() != nullptr);
|
|
714
|
+
if constexpr (dimensions_per_value<value_type_>() > 1) nk_assert_(input.rank() > 0);
|
|
715
|
+
auto flat = resolve_index_(idx, input.numel());
|
|
716
|
+
if constexpr (dimensions_per_value<value_type_>() == 1) {
|
|
717
|
+
if (input.rank() == 0) return input.scalar();
|
|
718
|
+
}
|
|
719
|
+
|
|
720
|
+
std::array<std::size_t, max_rank_> coords {};
|
|
721
|
+
for (std::size_t dim = input.rank(); dim > 0; --dim) {
|
|
722
|
+
auto axis = dim - 1;
|
|
723
|
+
auto extent = input.extent(axis);
|
|
724
|
+
coords[axis] = flat % extent;
|
|
725
|
+
flat /= extent;
|
|
726
|
+
}
|
|
727
|
+
return tensor_lookup_resolved_(input, std::span<std::size_t const>(coords.data(), input.rank()));
|
|
728
|
+
}
|
|
729
|
+
|
|
730
|
+
template <typename value_type_, std::size_t max_rank_, typename index_type_>
|
|
731
|
+
decltype(auto) tensor_flat_lookup_(tensor_span<value_type_, max_rank_> input, index_type_ idx) noexcept {
|
|
732
|
+
nk_assert_(input.byte_data() != nullptr);
|
|
733
|
+
if constexpr (dimensions_per_value<value_type_>() > 1) nk_assert_(input.rank() > 0);
|
|
734
|
+
auto flat = resolve_index_(idx, input.numel());
|
|
735
|
+
if constexpr (dimensions_per_value<value_type_>() == 1) {
|
|
736
|
+
if (input.rank() == 0) return input.scalar_ref();
|
|
737
|
+
}
|
|
738
|
+
|
|
739
|
+
std::array<std::size_t, max_rank_> coords {};
|
|
740
|
+
for (std::size_t dim = input.rank(); dim > 0; --dim) {
|
|
741
|
+
auto axis = dim - 1;
|
|
742
|
+
auto extent = input.extent(axis);
|
|
743
|
+
coords[axis] = flat % extent;
|
|
744
|
+
flat /= extent;
|
|
745
|
+
}
|
|
746
|
+
return tensor_lookup_resolved_(input, std::span<std::size_t const>(coords.data(), input.rank()));
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
template <typename tensor_type_>
|
|
750
|
+
tensor_type_ tensor_slice_suffix_(tensor_type_ input, tensor_slice_t) noexcept {
|
|
751
|
+
return input;
|
|
752
|
+
}
|
|
753
|
+
|
|
754
|
+
template <typename tensor_type_, std::integral index_type_, typename... rest_types_>
|
|
755
|
+
tensor_type_ tensor_slice_suffix_(tensor_type_ input, index_type_ idx, rest_types_... rest) noexcept {
|
|
756
|
+
if constexpr (dimensions_per_value<typename tensor_type_::value_type>() > 1) {
|
|
757
|
+
if constexpr (sizeof...(rest_types_) == 1)
|
|
758
|
+
if (input.rank() <= 1) return {};
|
|
759
|
+
}
|
|
760
|
+
if (input.rank() == 0) return {};
|
|
761
|
+
return tensor_slice_suffix_(input.slice_leading(idx), rest...);
|
|
762
|
+
}
|
|
763
|
+
|
|
764
|
+
template <typename tensor_type_, typename... rest_types_>
|
|
765
|
+
tensor_type_ tensor_slice_suffix_(tensor_type_ input, all_t, rest_types_... rest) noexcept {
|
|
766
|
+
// `all` keeps the leading dimension intact — apply remaining args to inner dimensions.
|
|
767
|
+
if (input.rank() == 0) return {};
|
|
768
|
+
using size_type = typename tensor_type_::size_type;
|
|
769
|
+
using difference_type = typename tensor_type_::difference_type;
|
|
770
|
+
using shape_type = std::remove_cvref_t<decltype(input.shape())>;
|
|
771
|
+
|
|
772
|
+
auto leading_extent = input.extent(0);
|
|
773
|
+
auto leading_stride = input.stride_bytes(0);
|
|
774
|
+
|
|
775
|
+
// Slice the first row to discover the resulting sub-shape.
|
|
776
|
+
auto first_row = input.slice_leading(static_cast<size_type>(0));
|
|
777
|
+
auto inner = tensor_slice_suffix_(first_row, rest...);
|
|
778
|
+
|
|
779
|
+
// Build the output shape: leading dimension + inner dimensions.
|
|
780
|
+
shape_type result_shape;
|
|
781
|
+
result_shape.rank = 1 + inner.rank();
|
|
782
|
+
result_shape.extents[0] = leading_extent;
|
|
783
|
+
result_shape.strides[0] = leading_stride;
|
|
784
|
+
for (size_type d = 0; d < inner.rank(); ++d) {
|
|
785
|
+
result_shape.extents[1 + d] = inner.extent(d);
|
|
786
|
+
result_shape.strides[1 + d] = inner.stride_bytes(d);
|
|
787
|
+
}
|
|
788
|
+
|
|
789
|
+
// The data pointer is the inner slice's offset relative to the first row,
|
|
790
|
+
// applied to the original data pointer.
|
|
791
|
+
using byte_ptr = decltype(input.byte_data());
|
|
792
|
+
auto inner_byte_offset = inner.byte_data() - first_row.byte_data();
|
|
793
|
+
return {const_cast<byte_ptr>(input.byte_data() + inner_byte_offset), result_shape};
|
|
794
|
+
}
|
|
795
|
+
|
|
796
|
+
template <typename tensor_type_, typename... rest_types_>
|
|
797
|
+
tensor_type_ tensor_slice_suffix_(tensor_type_ input, range r, rest_types_... rest) noexcept {
|
|
798
|
+
if (input.rank() == 0) return {};
|
|
799
|
+
using size_type = typename tensor_type_::size_type;
|
|
800
|
+
using difference_type = typename tensor_type_::difference_type;
|
|
801
|
+
using shape_type = std::remove_cvref_t<decltype(input.shape())>;
|
|
802
|
+
|
|
803
|
+
auto leading_extent = input.extent(0);
|
|
804
|
+
auto leading_stride = input.stride_bytes(0);
|
|
805
|
+
auto start = resolve_index_(r.start, leading_extent);
|
|
806
|
+
auto stop = resolve_index_(r.stop, leading_extent);
|
|
807
|
+
auto step = r.step;
|
|
808
|
+
if (start >= stop || step <= 0) return {};
|
|
809
|
+
|
|
810
|
+
auto range_extent = static_cast<size_type>((stop - start + static_cast<size_type>(step) - 1) /
|
|
811
|
+
static_cast<size_type>(step));
|
|
812
|
+
auto range_stride = leading_stride * static_cast<difference_type>(step);
|
|
813
|
+
auto data_offset = static_cast<difference_type>(start) * leading_stride;
|
|
814
|
+
|
|
815
|
+
if constexpr (sizeof...(rest_types_) == 1 &&
|
|
816
|
+
std::is_same_v<std::tuple_element_t<0, std::tuple<std::remove_cvref_t<rest_types_>...>>,
|
|
817
|
+
tensor_slice_t>) {
|
|
818
|
+
// Fast path: range followed by just `slice` — no inner recursion needed.
|
|
819
|
+
shape_type result_shape;
|
|
820
|
+
result_shape.rank = input.rank();
|
|
821
|
+
result_shape.extents[0] = range_extent;
|
|
822
|
+
result_shape.strides[0] = range_stride;
|
|
823
|
+
for (size_type d = 1; d < input.rank(); ++d) {
|
|
824
|
+
result_shape.extents[d] = input.extent(d);
|
|
825
|
+
result_shape.strides[d] = input.stride_bytes(d);
|
|
826
|
+
}
|
|
827
|
+
using byte_ptr = decltype(input.byte_data());
|
|
828
|
+
return {const_cast<byte_ptr>(input.byte_data() + data_offset), result_shape};
|
|
829
|
+
}
|
|
830
|
+
else {
|
|
831
|
+
// General path: recurse into inner dimensions (like `all_t` but with narrowed leading).
|
|
832
|
+
auto first_row = input.slice_leading(static_cast<size_type>(start));
|
|
833
|
+
auto inner = tensor_slice_suffix_(first_row, rest...);
|
|
834
|
+
|
|
835
|
+
shape_type result_shape;
|
|
836
|
+
result_shape.rank = 1 + inner.rank();
|
|
837
|
+
result_shape.extents[0] = range_extent;
|
|
838
|
+
result_shape.strides[0] = range_stride;
|
|
839
|
+
for (size_type d = 0; d < inner.rank(); ++d) {
|
|
840
|
+
result_shape.extents[1 + d] = inner.extent(d);
|
|
841
|
+
result_shape.strides[1 + d] = inner.stride_bytes(d);
|
|
842
|
+
}
|
|
843
|
+
|
|
844
|
+
using byte_ptr = decltype(input.byte_data());
|
|
845
|
+
auto inner_byte_offset = inner.byte_data() - first_row.byte_data();
|
|
846
|
+
return {const_cast<byte_ptr>(input.byte_data() + data_offset + inner_byte_offset), result_shape};
|
|
847
|
+
}
|
|
848
|
+
}
|
|
849
|
+
|
|
850
|
+
#pragma region - Axis Iterator
|
|
851
|
+
|
|
852
|
+
/**
|
|
853
|
+
* @brief Random-access iterator over slices along the leading dimension.
|
|
854
|
+
* @tparam view_type_ Either `tensor_view` or `tensor_span`.
|
|
855
|
+
*
|
|
856
|
+
* For a rank-2 matrix, iterating yields rank-1 row views/spans.
|
|
857
|
+
* Dereference calls `parent_.slice_leading(index_)` to produce each sub-view.
|
|
858
|
+
*/
|
|
859
|
+
template <typename view_type_>
|
|
860
|
+
class axis_iterator {
|
|
861
|
+
using value_type = typename view_type_::value_type;
|
|
862
|
+
using difference_type = std::ptrdiff_t;
|
|
863
|
+
|
|
864
|
+
char const *data_ = nullptr;
|
|
865
|
+
difference_type stride_ = 0;
|
|
866
|
+
std::size_t index_ = 0;
|
|
867
|
+
view_type_ parent_;
|
|
868
|
+
|
|
869
|
+
public:
|
|
870
|
+
using iterator_category = std::random_access_iterator_tag;
|
|
871
|
+
|
|
872
|
+
constexpr axis_iterator() noexcept = default;
|
|
873
|
+
|
|
874
|
+
constexpr axis_iterator(view_type_ const &parent, std::size_t index) noexcept
|
|
875
|
+
: data_(parent.byte_data()), stride_(parent.stride_bytes(0)), index_(index), parent_(parent) {}
|
|
876
|
+
|
|
877
|
+
constexpr view_type_ operator*() const noexcept {
|
|
878
|
+
return parent_.slice_leading(static_cast<difference_type>(index_));
|
|
879
|
+
}
|
|
880
|
+
|
|
881
|
+
constexpr axis_iterator &operator++() noexcept {
|
|
882
|
+
++index_;
|
|
883
|
+
return *this;
|
|
884
|
+
}
|
|
885
|
+
constexpr axis_iterator operator++(int) noexcept {
|
|
886
|
+
auto tmp = *this;
|
|
887
|
+
++index_;
|
|
888
|
+
return tmp;
|
|
889
|
+
}
|
|
890
|
+
constexpr axis_iterator &operator--() noexcept {
|
|
891
|
+
--index_;
|
|
892
|
+
return *this;
|
|
893
|
+
}
|
|
894
|
+
constexpr axis_iterator operator--(int) noexcept {
|
|
895
|
+
auto tmp = *this;
|
|
896
|
+
--index_;
|
|
897
|
+
return tmp;
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
constexpr axis_iterator operator+(difference_type n) const noexcept {
|
|
901
|
+
auto copy = *this;
|
|
902
|
+
copy.index_ += n;
|
|
903
|
+
return copy;
|
|
904
|
+
}
|
|
905
|
+
constexpr axis_iterator operator-(difference_type n) const noexcept {
|
|
906
|
+
auto copy = *this;
|
|
907
|
+
copy.index_ -= n;
|
|
908
|
+
return copy;
|
|
909
|
+
}
|
|
910
|
+
constexpr difference_type operator-(axis_iterator const &other) const noexcept {
|
|
911
|
+
return static_cast<difference_type>(index_) - static_cast<difference_type>(other.index_);
|
|
912
|
+
}
|
|
913
|
+
|
|
914
|
+
constexpr bool operator==(axis_iterator const &other) const noexcept { return index_ == other.index_; }
|
|
915
|
+
constexpr bool operator!=(axis_iterator const &other) const noexcept { return index_ != other.index_; }
|
|
916
|
+
constexpr bool operator<(axis_iterator const &other) const noexcept { return index_ < other.index_; }
|
|
917
|
+
};
|
|
918
|
+
|
|
919
|
+
#pragma endregion - Axis Iterator
|
|
920
|
+
|
|
921
|
+
#pragma region - Tensor
|
|
922
|
+
|
|
923
|
+
/**
|
|
924
|
+
* @brief Owning, non-resizable, N-dimensional tensor.
|
|
925
|
+
* @tparam value_type_ Element type.
|
|
926
|
+
* @tparam allocator_type_ Allocator.
|
|
927
|
+
* @tparam max_rank_ Maximum number of dimensions.
|
|
928
|
+
*
|
|
929
|
+
* Fixed-size at construction. Use `try_zeros()` factory for non-throwing construction.
|
|
930
|
+
*/
|
|
931
|
+
template <typename value_type_, typename allocator_type_ = aligned_allocator<value_type_>, std::size_t max_rank_ = 8>
|
|
932
|
+
struct tensor {
|
|
933
|
+
using value_type = value_type_;
|
|
934
|
+
using allocator_type = allocator_type_;
|
|
935
|
+
using alloc_traits = std::allocator_traits<allocator_type_>;
|
|
936
|
+
using size_type = std::size_t;
|
|
937
|
+
using difference_type = std::ptrdiff_t;
|
|
938
|
+
using pointer = value_type_ *;
|
|
939
|
+
|
|
940
|
+
using view_type = tensor_view<value_type_, max_rank_>;
|
|
941
|
+
using span_type = tensor_span<value_type_, max_rank_>;
|
|
942
|
+
|
|
943
|
+
private:
|
|
944
|
+
pointer data_ = nullptr;
|
|
945
|
+
shape_storage_<max_rank_> shape_;
|
|
946
|
+
[[no_unique_address]] allocator_type_ alloc_;
|
|
947
|
+
|
|
948
|
+
public:
|
|
949
|
+
tensor() noexcept = default;
|
|
950
|
+
|
|
951
|
+
explicit tensor(allocator_type_ const &alloc) noexcept : alloc_(alloc) {}
|
|
952
|
+
|
|
953
|
+
~tensor() noexcept {
|
|
954
|
+
if (data_) alloc_traits::deallocate(alloc_, data_, storage_values_for_shape_<value_type_>(shape_));
|
|
955
|
+
}
|
|
956
|
+
|
|
957
|
+
tensor(tensor &&other) noexcept
|
|
958
|
+
: data_(std::exchange(other.data_, nullptr)), shape_(std::exchange(other.shape_, {})),
|
|
959
|
+
alloc_(std::move(other.alloc_)) {}
|
|
960
|
+
|
|
961
|
+
tensor &operator=(tensor &&other) noexcept {
|
|
962
|
+
if (this != &other) {
|
|
963
|
+
if (data_) alloc_traits::deallocate(alloc_, data_, storage_values_for_shape_<value_type_>(shape_));
|
|
964
|
+
if constexpr (alloc_traits::propagate_on_container_move_assignment::value) alloc_ = std::move(other.alloc_);
|
|
965
|
+
data_ = std::exchange(other.data_, nullptr);
|
|
966
|
+
shape_ = std::exchange(other.shape_, {});
|
|
967
|
+
}
|
|
968
|
+
return *this;
|
|
969
|
+
}
|
|
970
|
+
|
|
971
|
+
tensor(tensor const &) = delete;
|
|
972
|
+
tensor &operator=(tensor const &) = delete;
|
|
973
|
+
|
|
974
|
+
/**
|
|
975
|
+
* @brief Factory: allocate a zero-initialized tensor with the given extents.
|
|
976
|
+
* @param extents Extents (one per dimension), e.g. `{3, 4}`.
|
|
977
|
+
* @param alloc Allocator instance.
|
|
978
|
+
* @return Non-empty tensor on success, empty on failure.
|
|
979
|
+
*/
|
|
980
|
+
[[nodiscard]] static tensor try_zeros(std::initializer_list<size_type> extents,
|
|
981
|
+
allocator_type_ alloc = {}) noexcept {
|
|
982
|
+
tensor t(alloc);
|
|
983
|
+
auto rank = extents.size();
|
|
984
|
+
if (rank > max_rank_) return t;
|
|
985
|
+
t.shape_ = make_contiguous_shape_<value_type_, max_rank_>(extents.begin(), rank);
|
|
986
|
+
auto storage_values = storage_values_for_shape_<value_type_>(t.shape_);
|
|
987
|
+
if (storage_values == 0) return t;
|
|
988
|
+
pointer ptr = alloc_traits::allocate(t.alloc_, storage_values);
|
|
989
|
+
if (!ptr) return t;
|
|
990
|
+
if constexpr (is_memset_zero_safe_v<value_type_>)
|
|
991
|
+
std::memset(static_cast<void *>(ptr), 0, storage_values * sizeof(value_type_));
|
|
992
|
+
else
|
|
993
|
+
for (size_type i = 0; i < storage_values; ++i) ptr[i] = value_type_ {};
|
|
994
|
+
t.data_ = ptr;
|
|
995
|
+
return t;
|
|
996
|
+
}
|
|
997
|
+
|
|
998
|
+
/**
|
|
999
|
+
* @brief Factory: allocate a tensor filled with ones.
|
|
1000
|
+
* @param extents Extents (one per dimension), e.g. `{3, 4}`.
|
|
1001
|
+
* @param alloc Allocator instance.
|
|
1002
|
+
* @return Non-empty tensor on success, empty on failure.
|
|
1003
|
+
*/
|
|
1004
|
+
[[nodiscard]] static tensor try_ones(std::initializer_list<size_type> extents,
|
|
1005
|
+
allocator_type_ alloc = {}) noexcept {
|
|
1006
|
+
return try_full(extents, value_type_ {1}, alloc);
|
|
1007
|
+
}
|
|
1008
|
+
|
|
1009
|
+
/**
|
|
1010
|
+
* @brief Factory: allocate a tensor filled with a given value.
|
|
1011
|
+
* @param extents Extents (one per dimension), e.g. `{3, 4}`.
|
|
1012
|
+
* @param val Fill value.
|
|
1013
|
+
* @param alloc Allocator instance.
|
|
1014
|
+
* @return Non-empty tensor on success, empty on failure.
|
|
1015
|
+
*/
|
|
1016
|
+
[[nodiscard]] static tensor try_full(std::initializer_list<size_type> extents, value_type_ val,
|
|
1017
|
+
allocator_type_ alloc = {}) noexcept {
|
|
1018
|
+
tensor t(alloc);
|
|
1019
|
+
auto rank = extents.size();
|
|
1020
|
+
if (rank > max_rank_) return t;
|
|
1021
|
+
t.shape_ = make_contiguous_shape_<value_type_, max_rank_>(extents.begin(), rank);
|
|
1022
|
+
auto storage_values = storage_values_for_shape_<value_type_>(t.shape_);
|
|
1023
|
+
if (storage_values == 0) return t;
|
|
1024
|
+
pointer ptr = alloc_traits::allocate(t.alloc_, storage_values);
|
|
1025
|
+
if (!ptr) return t;
|
|
1026
|
+
for (size_type i = 0; i < storage_values; ++i) ptr[i] = val;
|
|
1027
|
+
t.data_ = ptr;
|
|
1028
|
+
return t;
|
|
1029
|
+
}
|
|
1030
|
+
|
|
1031
|
+
/**
|
|
1032
|
+
* @brief Factory: allocate an uninitialized tensor.
|
|
1033
|
+
* @param extents Extents (one per dimension), e.g. `{3, 4}`.
|
|
1034
|
+
* @param alloc Allocator instance.
|
|
1035
|
+
* @return Non-empty tensor on success, empty on failure.
|
|
1036
|
+
*/
|
|
1037
|
+
[[nodiscard]] static tensor try_empty(std::initializer_list<size_type> extents,
|
|
1038
|
+
allocator_type_ alloc = {}) noexcept {
|
|
1039
|
+
tensor t(alloc);
|
|
1040
|
+
auto rank = extents.size();
|
|
1041
|
+
if (rank > max_rank_) return t;
|
|
1042
|
+
t.shape_ = make_contiguous_shape_<value_type_, max_rank_>(extents.begin(), rank);
|
|
1043
|
+
auto storage_values = storage_values_for_shape_<value_type_>(t.shape_);
|
|
1044
|
+
if (storage_values == 0) return t;
|
|
1045
|
+
pointer ptr = alloc_traits::allocate(t.alloc_, storage_values);
|
|
1046
|
+
if (!ptr) return t;
|
|
1047
|
+
t.data_ = ptr;
|
|
1048
|
+
return t;
|
|
1049
|
+
}
|
|
1050
|
+
|
|
1051
|
+
/** @brief Factory: zero-initialized tensor from pointer + rank. */
|
|
1052
|
+
[[nodiscard]] static tensor try_zeros(size_type const *extents, size_type rank,
|
|
1053
|
+
allocator_type_ alloc = {}) noexcept {
|
|
1054
|
+
tensor t(alloc);
|
|
1055
|
+
if (rank > max_rank_) return t;
|
|
1056
|
+
t.shape_ = make_contiguous_shape_<value_type_, max_rank_>(extents, rank);
|
|
1057
|
+
auto storage_values = storage_values_for_shape_<value_type_>(t.shape_);
|
|
1058
|
+
if (storage_values == 0) return t;
|
|
1059
|
+
pointer ptr = alloc_traits::allocate(t.alloc_, storage_values);
|
|
1060
|
+
if (!ptr) return t;
|
|
1061
|
+
if constexpr (is_memset_zero_safe_v<value_type_>)
|
|
1062
|
+
std::memset(static_cast<void *>(ptr), 0, storage_values * sizeof(value_type_));
|
|
1063
|
+
else
|
|
1064
|
+
for (size_type i = 0; i < storage_values; ++i) ptr[i] = value_type_ {};
|
|
1065
|
+
t.data_ = ptr;
|
|
1066
|
+
return t;
|
|
1067
|
+
}
|
|
1068
|
+
|
|
1069
|
+
/** @brief Factory: uninitialized tensor from pointer + rank. */
|
|
1070
|
+
[[nodiscard]] static tensor try_empty(size_type const *extents, size_type rank,
|
|
1071
|
+
allocator_type_ alloc = {}) noexcept {
|
|
1072
|
+
tensor t(alloc);
|
|
1073
|
+
if (rank > max_rank_) return t;
|
|
1074
|
+
t.shape_ = make_contiguous_shape_<value_type_, max_rank_>(extents, rank);
|
|
1075
|
+
auto storage_values = storage_values_for_shape_<value_type_>(t.shape_);
|
|
1076
|
+
if (storage_values == 0) return t;
|
|
1077
|
+
pointer ptr = alloc_traits::allocate(t.alloc_, storage_values);
|
|
1078
|
+
if (!ptr) return t;
|
|
1079
|
+
t.data_ = ptr;
|
|
1080
|
+
return t;
|
|
1081
|
+
}
|
|
1082
|
+
|
|
1083
|
+
/** @brief Factory: filled tensor from pointer + rank. */
|
|
1084
|
+
[[nodiscard]] static tensor try_full(size_type const *extents, size_type rank, value_type_ val,
|
|
1085
|
+
allocator_type_ alloc = {}) noexcept {
|
|
1086
|
+
tensor t(alloc);
|
|
1087
|
+
if (rank > max_rank_) return t;
|
|
1088
|
+
t.shape_ = make_contiguous_shape_<value_type_, max_rank_>(extents, rank);
|
|
1089
|
+
auto storage_values = storage_values_for_shape_<value_type_>(t.shape_);
|
|
1090
|
+
if (storage_values == 0) return t;
|
|
1091
|
+
pointer ptr = alloc_traits::allocate(t.alloc_, storage_values);
|
|
1092
|
+
if (!ptr) return t;
|
|
1093
|
+
for (size_type i = 0; i < storage_values; ++i) ptr[i] = val;
|
|
1094
|
+
t.data_ = ptr;
|
|
1095
|
+
return t;
|
|
1096
|
+
}
|
|
1097
|
+
|
|
1098
|
+
/**
|
|
1099
|
+
* @brief Factory: create a rank-1 tensor from an initializer list of values.
|
|
1100
|
+
* @param values Values to fill the tensor with.
|
|
1101
|
+
* @param alloc Allocator instance.
|
|
1102
|
+
* @return Non-empty tensor on success, empty on failure.
|
|
1103
|
+
*/
|
|
1104
|
+
[[nodiscard]] static tensor try_from(std::initializer_list<value_type_> values,
|
|
1105
|
+
allocator_type_ alloc = {}) noexcept {
|
|
1106
|
+
tensor t = try_empty({values.size()}, alloc);
|
|
1107
|
+
if (t.empty()) return t;
|
|
1108
|
+
size_type index = 0;
|
|
1109
|
+
for (auto const &value : values) t.data_[index++] = value;
|
|
1110
|
+
return t;
|
|
1111
|
+
}
|
|
1112
|
+
|
|
1113
|
+
/**
|
|
1114
|
+
* @brief Factory: create a rank-2 tensor from a nested initializer list.
|
|
1115
|
+
* @param rows Each inner list is a row. All rows must have the same length.
|
|
1116
|
+
* @param alloc Allocator instance.
|
|
1117
|
+
* @return Non-empty tensor on success, empty on ragged input or allocation failure.
|
|
1118
|
+
*/
|
|
1119
|
+
[[nodiscard]] static tensor try_from(std::initializer_list<std::initializer_list<value_type_>> rows,
|
|
1120
|
+
allocator_type_ alloc = {}) noexcept
|
|
1121
|
+
requires(max_rank_ >= 2)
|
|
1122
|
+
{
|
|
1123
|
+
auto num_rows = rows.size();
|
|
1124
|
+
if (num_rows == 0) return tensor(alloc);
|
|
1125
|
+
auto num_cols = rows.begin()->size();
|
|
1126
|
+
for (auto const &row : rows)
|
|
1127
|
+
if (row.size() != num_cols) return tensor(alloc);
|
|
1128
|
+
tensor t = try_empty({num_rows, num_cols}, alloc);
|
|
1129
|
+
if (t.empty()) return t;
|
|
1130
|
+
size_type index = 0;
|
|
1131
|
+
for (auto const &row : rows)
|
|
1132
|
+
for (auto const &value : row) t.data_[index++] = value;
|
|
1133
|
+
return t;
|
|
1134
|
+
}
|
|
1135
|
+
|
|
1136
|
+
/**
|
|
1137
|
+
* @brief Factory: adopt raw memory.
|
|
1138
|
+
*/
|
|
1139
|
+
[[nodiscard]] static tensor from_raw(pointer ptr, shape_storage_<max_rank_> const &shape,
|
|
1140
|
+
allocator_type_ alloc = {}) noexcept {
|
|
1141
|
+
tensor t(alloc);
|
|
1142
|
+
t.data_ = ptr;
|
|
1143
|
+
t.shape_ = shape;
|
|
1144
|
+
return t;
|
|
1145
|
+
}
|
|
1146
|
+
|
|
1147
|
+
/** @brief Number of dimensions. */
|
|
1148
|
+
constexpr size_type rank() const noexcept { return shape_.rank; }
|
|
1149
|
+
|
|
1150
|
+
/** @brief Extent along dimension i. */
|
|
1151
|
+
constexpr size_type extent(size_type i) const noexcept { return shape_.extents[i]; }
|
|
1152
|
+
|
|
1153
|
+
/** @brief Stride in bytes along dimension i (signed). */
|
|
1154
|
+
constexpr difference_type stride_bytes(size_type i) const noexcept { return shape_.strides[i]; }
|
|
1155
|
+
|
|
1156
|
+
/** @brief Total number of elements. */
|
|
1157
|
+
constexpr size_type numel() const noexcept { return shape_.numel(); }
|
|
1158
|
+
|
|
1159
|
+
/** @brief True if empty. */
|
|
1160
|
+
constexpr bool empty() const noexcept { return data_ == nullptr; }
|
|
1161
|
+
|
|
1162
|
+
/** @brief Typed pointer to data. */
|
|
1163
|
+
pointer data() noexcept { return data_; }
|
|
1164
|
+
value_type const *data() const noexcept { return data_; }
|
|
1165
|
+
|
|
1166
|
+
/** @brief Shape storage. */
|
|
1167
|
+
constexpr shape_storage_<max_rank_> const &shape() const noexcept { return shape_; }
|
|
1168
|
+
|
|
1169
|
+
/** @brief Allocator. */
|
|
1170
|
+
allocator_type get_allocator() const noexcept { return alloc_; }
|
|
1171
|
+
|
|
1172
|
+
/** @brief Create an immutable view. */
|
|
1173
|
+
view_type view() const noexcept { return {reinterpret_cast<char const *>(data_), shape_}; }
|
|
1174
|
+
|
|
1175
|
+
/** @brief Create a mutable span. */
|
|
1176
|
+
span_type span() noexcept { return {reinterpret_cast<char *>(data_), shape_}; }
|
|
1177
|
+
|
|
1178
|
+
/** @brief Range of immutable row views (slices along leading dimension). */
|
|
1179
|
+
struct rows_views_t {
|
|
1180
|
+
view_type parent;
|
|
1181
|
+
axis_iterator<view_type> begin() const noexcept { return {parent, 0}; }
|
|
1182
|
+
axis_iterator<view_type> end() const noexcept { return {parent, parent.extent(0)}; }
|
|
1183
|
+
};
|
|
1184
|
+
|
|
1185
|
+
/** @brief Range of mutable row spans (slices along leading dimension). */
|
|
1186
|
+
struct rows_spans_t {
|
|
1187
|
+
span_type parent;
|
|
1188
|
+
axis_iterator<span_type> begin() noexcept { return {parent, 0}; }
|
|
1189
|
+
axis_iterator<span_type> end() noexcept { return {parent, parent.extent(0)}; }
|
|
1190
|
+
};
|
|
1191
|
+
|
|
1192
|
+
/** @brief Iterate rows as immutable views. */
|
|
1193
|
+
rows_views_t rows_views() const noexcept { return {view()}; }
|
|
1194
|
+
|
|
1195
|
+
/** @brief Iterate rows as mutable spans. */
|
|
1196
|
+
rows_spans_t rows_spans() noexcept { return {span()}; }
|
|
1197
|
+
|
|
1198
|
+
/** @brief Iterate rows as immutable views (convenience alias for rows_views). */
|
|
1199
|
+
typename view_type::rows_views_t rows() const noexcept { return view().rows(); }
|
|
1200
|
+
|
|
1201
|
+
/** @brief Iterate rows as mutable spans (convenience alias for rows_spans). */
|
|
1202
|
+
typename span_type::rows_spans_t rows() noexcept { return span().rows(); }
|
|
1203
|
+
|
|
1204
|
+
/** @brief Reinterpret as a 2D immutable matrix view. Requires rank >= 2. */
|
|
1205
|
+
tensor_view<value_type_, 2> as_matrix_view() const noexcept { return view().as_matrix(); }
|
|
1206
|
+
|
|
1207
|
+
/** @brief Reinterpret as a 2D mutable matrix span. Requires rank >= 2. */
|
|
1208
|
+
tensor_span<value_type_, 2> as_matrix_span() noexcept { return span().as_matrix(); }
|
|
1209
|
+
|
|
1210
|
+
/** @brief Transpose: reverse dimension order (immutable view). */
|
|
1211
|
+
view_type transpose() const noexcept { return view().transpose(); }
|
|
1212
|
+
|
|
1213
|
+
/** @brief Transpose: reverse dimension order (mutable span). */
|
|
1214
|
+
span_type transpose() noexcept { return span().transpose(); }
|
|
1215
|
+
|
|
1216
|
+
/** @brief Reshape (immutable view). Requires contiguous layout and matching element count. */
|
|
1217
|
+
view_type reshape(std::initializer_list<size_type> new_extents) const noexcept {
|
|
1218
|
+
return view().reshape(new_extents);
|
|
1219
|
+
}
|
|
1220
|
+
|
|
1221
|
+
/** @brief Reshape (mutable span). Requires contiguous layout and matching element count. */
|
|
1222
|
+
span_type reshape(std::initializer_list<size_type> new_extents) noexcept { return span().reshape(new_extents); }
|
|
1223
|
+
|
|
1224
|
+
/** @brief Check if contiguous in memory. Always true for freshly-constructed tensors. */
|
|
1225
|
+
constexpr bool is_contiguous() const noexcept { return view().is_contiguous(); }
|
|
1226
|
+
|
|
1227
|
+
/** @brief Slice along leading dimension (immutable view). */
|
|
1228
|
+
template <std::integral index_type_>
|
|
1229
|
+
view_type slice_leading(index_type_ idx) const noexcept {
|
|
1230
|
+
return view().slice_leading(idx);
|
|
1231
|
+
}
|
|
1232
|
+
|
|
1233
|
+
/** @brief Slice along leading dimension (mutable span). */
|
|
1234
|
+
template <std::integral index_type_>
|
|
1235
|
+
span_type slice_leading(index_type_ idx) noexcept {
|
|
1236
|
+
return span().slice_leading(idx);
|
|
1237
|
+
}
|
|
1238
|
+
|
|
1239
|
+
/** @brief Row access (immutable view, alias for slice_leading). */
|
|
1240
|
+
template <std::integral index_type_>
|
|
1241
|
+
view_type row(index_type_ i) const noexcept {
|
|
1242
|
+
return view().slice_leading(i);
|
|
1243
|
+
}
|
|
1244
|
+
|
|
1245
|
+
/** @brief Row access (mutable span, alias for slice_leading). */
|
|
1246
|
+
template <std::integral index_type_>
|
|
1247
|
+
span_type row(index_type_ i) noexcept {
|
|
1248
|
+
return span().slice_leading(i);
|
|
1249
|
+
}
|
|
1250
|
+
|
|
1251
|
+
/** @brief Flat logical scalar access. */
|
|
1252
|
+
template <std::integral index_type_>
|
|
1253
|
+
decltype(auto) operator[](index_type_ idx) noexcept {
|
|
1254
|
+
return span()[idx];
|
|
1255
|
+
}
|
|
1256
|
+
|
|
1257
|
+
/** @brief Const flat logical scalar access. */
|
|
1258
|
+
template <std::integral index_type_>
|
|
1259
|
+
decltype(auto) operator[](index_type_ idx) const noexcept {
|
|
1260
|
+
return view()[idx];
|
|
1261
|
+
}
|
|
1262
|
+
|
|
1263
|
+
/** @brief Exact multi-dimensional scalar lookup. */
|
|
1264
|
+
template <std::integral... index_types_>
|
|
1265
|
+
requires(sizeof...(index_types_) >= 2)
|
|
1266
|
+
decltype(auto) operator[](index_types_... idxs) noexcept {
|
|
1267
|
+
return span()[idxs...];
|
|
1268
|
+
}
|
|
1269
|
+
|
|
1270
|
+
/** @brief Const multidimensional lookup. */
|
|
1271
|
+
template <std::integral... index_types_>
|
|
1272
|
+
requires(sizeof...(index_types_) >= 2)
|
|
1273
|
+
decltype(auto) operator[](index_types_... idxs) const noexcept {
|
|
1274
|
+
return view()[idxs...];
|
|
1275
|
+
}
|
|
1276
|
+
|
|
1277
|
+
/** @brief Trailing `slice` returns the same tensor view/span category. */
|
|
1278
|
+
span_type operator[](tensor_slice_t) noexcept { return span(); }
|
|
1279
|
+
view_type operator[](tensor_slice_t) const noexcept { return view(); }
|
|
1280
|
+
|
|
1281
|
+
/** @brief Prefix leading-axis slicing with a trailing `slice` marker. */
|
|
1282
|
+
template <typename first_type_, typename second_type_, typename... rest_types_>
|
|
1283
|
+
requires(trailing_tensor_slice_args_v<first_type_, second_type_, rest_types_...>)
|
|
1284
|
+
span_type operator[](first_type_ first, second_type_ second, rest_types_... rest) noexcept {
|
|
1285
|
+
return tensor_slice_suffix_(span(), first, second, rest...);
|
|
1286
|
+
}
|
|
1287
|
+
|
|
1288
|
+
/** @brief Const prefix leading-axis slicing with a trailing `slice` marker. */
|
|
1289
|
+
template <typename first_type_, typename second_type_, typename... rest_types_>
|
|
1290
|
+
requires(trailing_tensor_slice_args_v<first_type_, second_type_, rest_types_...>)
|
|
1291
|
+
view_type operator[](first_type_ first, second_type_ second, rest_types_... rest) const noexcept {
|
|
1292
|
+
return tensor_slice_suffix_(view(), first, second, rest...);
|
|
1293
|
+
}
|
|
1294
|
+
|
|
1295
|
+
/** @brief Rank-0 mutable scalar access. */
|
|
1296
|
+
decltype(auto) scalar_ref() noexcept { return span().scalar_ref(); }
|
|
1297
|
+
|
|
1298
|
+
/** @brief Rank-0 const scalar access. */
|
|
1299
|
+
decltype(auto) scalar() const noexcept { return view().scalar(); }
|
|
1300
|
+
|
|
1301
|
+
/** @brief Convert to vector_view (requires rank == 1). */
|
|
1302
|
+
vector_view<value_type> as_vector_view() const noexcept { return view().as_vector(); }
|
|
1303
|
+
|
|
1304
|
+
/** @brief Convert to vector_span (requires rank == 1). */
|
|
1305
|
+
vector_span<value_type> as_vector_span() noexcept { return span().as_vector(); }
|
|
1306
|
+
|
|
1307
|
+
/** @brief Flatten (immutable view). Requires contiguous layout. */
|
|
1308
|
+
view_type flatten() const noexcept { return view().flatten(); }
|
|
1309
|
+
|
|
1310
|
+
/** @brief Flatten (mutable span). Requires contiguous layout. */
|
|
1311
|
+
span_type flatten() noexcept { return span().flatten(); }
|
|
1312
|
+
|
|
1313
|
+
/** @brief Squeeze (immutable view). Removes size-1 dimensions. */
|
|
1314
|
+
view_type squeeze() const noexcept { return view().squeeze(); }
|
|
1315
|
+
|
|
1316
|
+
/** @brief Squeeze (mutable span). Removes size-1 dimensions. */
|
|
1317
|
+
span_type squeeze() noexcept { return span().squeeze(); }
|
|
1318
|
+
};
|
|
1319
|
+
|
|
1320
|
+
/** @brief Non-member swap. */
|
|
1321
|
+
template <typename V, typename A, std::size_t R>
|
|
1322
|
+
void swap(tensor<V, A, R> &a, tensor<V, A, R> &b) noexcept {
|
|
1323
|
+
auto tmp = std::move(a);
|
|
1324
|
+
a = std::move(b);
|
|
1325
|
+
b = std::move(tmp);
|
|
1326
|
+
}
|
|
1327
|
+
|
|
1328
|
+
#pragma endregion - Tensor
|
|
1329
|
+
|
|
1330
|
+
#pragma region - Matrix Aliases
|
|
1331
|
+
|
|
1332
|
+
/** @brief 2D owning matrix (max_rank = 2, smaller shape_storage). */
|
|
1333
|
+
template <typename value_type_, typename allocator_type_ = aligned_allocator<value_type_>>
|
|
1334
|
+
using matrix = tensor<value_type_, allocator_type_, 2>;
|
|
1335
|
+
|
|
1336
|
+
/** @brief 2D immutable view. */
|
|
1337
|
+
template <typename value_type_>
|
|
1338
|
+
using matrix_view = tensor_view<value_type_, 2>;
|
|
1339
|
+
|
|
1340
|
+
/** @brief 2D mutable span. */
|
|
1341
|
+
template <typename value_type_>
|
|
1342
|
+
using matrix_span = tensor_span<value_type_, 2>;
|
|
1343
|
+
|
|
1344
|
+
#pragma endregion - Matrix Aliases
|
|
1345
|
+
|
|
1346
|
+
} // namespace ashvardanian::numkong
|
|
1347
|
+
|
|
1348
|
+
namespace ashvardanian::numkong {
|
|
1349
|
+
|
|
1350
|
+
#pragma region - Enums and Result Types
|
|
1351
|
+
|
|
1352
|
+
/** @brief Controls whether reduction collapses or preserves the reduced axis. */
|
|
1353
|
+
enum keep_dims_t : bool { collapse_dims_k = false, keep_dims_k = true };
|
|
1354
|
+
|
|
1355
|
+
/** @brief Result of moments(): Σxᵢ and Σxᵢ². */
|
|
1356
|
+
template <typename sum_type_, typename sumsq_type_>
|
|
1357
|
+
struct moments_result {
|
|
1358
|
+
sum_type_ sum {};
|
|
1359
|
+
sumsq_type_ sumsq {};
|
|
1360
|
+
};
|
|
1361
|
+
|
|
1362
|
+
/** @brief Result of minmax(): min/max values with their indices. */
|
|
1363
|
+
template <typename minmax_value_type_>
|
|
1364
|
+
struct minmax_result {
|
|
1365
|
+
minmax_value_type_ min_value {};
|
|
1366
|
+
std::size_t min_index = 0;
|
|
1367
|
+
minmax_value_type_ max_value {};
|
|
1368
|
+
std::size_t max_index = 0;
|
|
1369
|
+
};
|
|
1370
|
+
|
|
1371
|
+
#pragma endregion - Enums and Result Types
|
|
1372
|
+
|
|
1373
|
+
#pragma region - Helpers
|
|
1374
|
+
|
|
1375
|
+
/** @brief Compute output shape with one axis removed (or set to 1 if keep_dims). */
|
|
1376
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
1377
|
+
shape_storage_<max_rank_> reduced_shape_(shape_storage_<max_rank_> const &in, std::size_t axis,
|
|
1378
|
+
keep_dims_t keep_dims) noexcept {
|
|
1379
|
+
std::size_t out_extents[max_rank_];
|
|
1380
|
+
std::size_t out_rank = 0;
|
|
1381
|
+
for (std::size_t i = 0; i < in.rank; ++i) {
|
|
1382
|
+
if (i == axis) {
|
|
1383
|
+
if (keep_dims) out_extents[out_rank++] = 1;
|
|
1384
|
+
}
|
|
1385
|
+
else { out_extents[out_rank++] = in.extents[i]; }
|
|
1386
|
+
}
|
|
1387
|
+
return make_contiguous_shape_<value_type_, max_rank_>(out_extents, out_rank);
|
|
1388
|
+
}
|
|
1389
|
+
|
|
1390
|
+
/** @brief Validate that two views have matching shapes. */
|
|
1391
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
1392
|
+
bool shapes_match_(tensor_view<value_type_, max_rank_> a, tensor_view<value_type_, max_rank_> b) noexcept {
|
|
1393
|
+
if (a.rank() != b.rank()) return false;
|
|
1394
|
+
for (std::size_t i = 0; i < a.rank(); ++i)
|
|
1395
|
+
if (a.extent(i) != b.extent(i)) return false;
|
|
1396
|
+
return true;
|
|
1397
|
+
}
|
|
1398
|
+
|
|
1399
|
+
/** @brief Validate shape match between view and span. */
|
|
1400
|
+
template <typename in_type_, typename out_type_, std::size_t max_rank_>
|
|
1401
|
+
bool shapes_match_out_(tensor_view<in_type_, max_rank_> a, tensor_span<out_type_, max_rank_> out) noexcept {
|
|
1402
|
+
if (a.rank() != out.rank()) return false;
|
|
1403
|
+
for (std::size_t i = 0; i < a.rank(); ++i)
|
|
1404
|
+
if (a.extent(i) != out.extent(i)) return false;
|
|
1405
|
+
return true;
|
|
1406
|
+
}
|
|
1407
|
+
|
|
1408
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
1409
|
+
bool tensor_layout_supported_(tensor_view<value_type_, max_rank_> input) noexcept {
|
|
1410
|
+
return packed_tensor_layout_supported_<value_type_>(input.shape());
|
|
1411
|
+
}
|
|
1412
|
+
|
|
1413
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
1414
|
+
bool tensor_layout_supported_(tensor_span<value_type_, max_rank_> input) noexcept {
|
|
1415
|
+
return packed_tensor_layout_supported_<value_type_>(input.shape());
|
|
1416
|
+
}
|
|
1417
|
+
|
|
1418
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
1419
|
+
bool shape_matches_(shape_storage_<max_rank_> const &expected, tensor_span<value_type_, max_rank_> actual) noexcept {
|
|
1420
|
+
if (expected.rank != actual.rank()) return false;
|
|
1421
|
+
for (std::size_t i = 0; i < expected.rank; ++i)
|
|
1422
|
+
if (expected.extents[i] != actual.extent(i)) return false;
|
|
1423
|
+
return true;
|
|
1424
|
+
}
|
|
1425
|
+
|
|
1426
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
1427
|
+
struct normalized_rank1_lane_ {
|
|
1428
|
+
value_type_ const *data = nullptr;
|
|
1429
|
+
std::size_t count = 0;
|
|
1430
|
+
std::size_t stride_bytes = sizeof(value_type_);
|
|
1431
|
+
bool reversed = false;
|
|
1432
|
+
};
|
|
1433
|
+
|
|
1434
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
1435
|
+
bool can_reduce_rank1_with_kernel_(tensor_view<value_type_, max_rank_> input) noexcept {
|
|
1436
|
+
if (input.rank() != 1 || input.byte_data() == nullptr || !tensor_layout_supported_(input)) return false;
|
|
1437
|
+
if constexpr (dimensions_per_value<value_type_>() > 1) return input.is_contiguous();
|
|
1438
|
+
return input.stride_bytes(0) != 0;
|
|
1439
|
+
}
|
|
1440
|
+
|
|
1441
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
1442
|
+
bool can_apply_rank1_data_kernel_(tensor_view<value_type_, max_rank_> input) noexcept {
|
|
1443
|
+
if (input.rank() != 1 || input.byte_data() == nullptr || !tensor_layout_supported_(input)) return false;
|
|
1444
|
+
return input.is_contiguous();
|
|
1445
|
+
}
|
|
1446
|
+
|
|
1447
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
1448
|
+
bool can_apply_rank1_data_kernel_(tensor_span<value_type_, max_rank_> output) noexcept {
|
|
1449
|
+
if (output.rank() != 1 || output.byte_data() == nullptr || !tensor_layout_supported_(output)) return false;
|
|
1450
|
+
return output.is_contiguous();
|
|
1451
|
+
}
|
|
1452
|
+
|
|
1453
|
+
template <typename value_type_, std::size_t max_rank_>
|
|
1454
|
+
normalized_rank1_lane_<value_type_, max_rank_> normalize_rank1_lane_(
|
|
1455
|
+
tensor_view<value_type_, max_rank_> input) noexcept {
|
|
1456
|
+
normalized_rank1_lane_<value_type_, max_rank_> lane;
|
|
1457
|
+
if (input.rank() != 1 || input.byte_data() == nullptr) return lane;
|
|
1458
|
+
lane.count = input.extent(0);
|
|
1459
|
+
if (lane.count == 0) return lane;
|
|
1460
|
+
auto stride = input.stride_bytes(0);
|
|
1461
|
+
if constexpr (dimensions_per_value<value_type_>() > 1) {
|
|
1462
|
+
if (!input.is_contiguous()) return {};
|
|
1463
|
+
lane.data = input.data();
|
|
1464
|
+
lane.stride_bytes = sizeof(value_type_);
|
|
1465
|
+
lane.reversed = false;
|
|
1466
|
+
return lane;
|
|
1467
|
+
}
|
|
1468
|
+
if (stride >= 0) {
|
|
1469
|
+
lane.data = input.data();
|
|
1470
|
+
lane.stride_bytes = static_cast<std::size_t>(stride);
|
|
1471
|
+
lane.reversed = false;
|
|
1472
|
+
}
|
|
1473
|
+
else {
|
|
1474
|
+
lane.data = reinterpret_cast<value_type_ const *>(input.byte_data() + (lane.count - 1) * stride);
|
|
1475
|
+
lane.stride_bytes = static_cast<std::size_t>(-stride);
|
|
1476
|
+
lane.reversed = true;
|
|
1477
|
+
}
|
|
1478
|
+
return lane;
|
|
1479
|
+
}
|
|
1480
|
+
|
|
1481
|
+
template <typename value_type_, std::size_t max_rank_, typename lane_fn_>
|
|
1482
|
+
bool for_each_axis_lane_(tensor_view<value_type_, max_rank_> input, std::size_t axis, lane_fn_ &&lane_fn) noexcept {
|
|
1483
|
+
if (axis >= input.rank() || !tensor_layout_supported_(input) || input.byte_data() == nullptr) return false;
|
|
1484
|
+
|
|
1485
|
+
shape_storage_<max_rank_> lane_shape;
|
|
1486
|
+
lane_shape.rank = 1;
|
|
1487
|
+
lane_shape.extents[0] = input.extent(axis);
|
|
1488
|
+
lane_shape.strides[0] = input.stride_bytes(axis);
|
|
1489
|
+
|
|
1490
|
+
std::size_t remaining_dims[max_rank_] = {};
|
|
1491
|
+
std::size_t remaining_count = 0;
|
|
1492
|
+
for (std::size_t dim = 0; dim < input.rank(); ++dim) {
|
|
1493
|
+
if (dim != axis) remaining_dims[remaining_count++] = dim;
|
|
1494
|
+
}
|
|
1495
|
+
|
|
1496
|
+
if (remaining_count == 0) return lane_fn(tensor_view<value_type_, max_rank_> {input.byte_data(), lane_shape}, 0);
|
|
1497
|
+
|
|
1498
|
+
std::size_t coords[max_rank_] = {};
|
|
1499
|
+
std::size_t total_lanes = 1;
|
|
1500
|
+
for (std::size_t i = 0; i < remaining_count; ++i) total_lanes *= input.extent(remaining_dims[i]);
|
|
1501
|
+
|
|
1502
|
+
for (std::size_t lane_index = 0; lane_index < total_lanes; ++lane_index) {
|
|
1503
|
+
auto offset = std::ptrdiff_t {};
|
|
1504
|
+
for (std::size_t i = 0; i < remaining_count; ++i)
|
|
1505
|
+
offset += static_cast<std::ptrdiff_t>(coords[i]) * input.stride_bytes(remaining_dims[i]);
|
|
1506
|
+
if (!lane_fn(tensor_view<value_type_, max_rank_> {input.byte_data() + offset, lane_shape}, lane_index))
|
|
1507
|
+
return false;
|
|
1508
|
+
|
|
1509
|
+
for (std::size_t i = remaining_count; i > 0; --i) {
|
|
1510
|
+
auto coord_index = i - 1;
|
|
1511
|
+
auto dim = remaining_dims[coord_index];
|
|
1512
|
+
if (++coords[coord_index] < input.extent(dim)) break;
|
|
1513
|
+
coords[coord_index] = 0;
|
|
1514
|
+
}
|
|
1515
|
+
}
|
|
1516
|
+
return true;
|
|
1517
|
+
}
|
|
1518
|
+
|
|
1519
|
+
/** @brief Unary elementwise traversal: validates shapes, recurses on rank≥2, calls leaf on rank-1 slices. */
|
|
1520
|
+
template <typename value_type_, std::size_t max_rank_, typename leaf_fn_>
|
|
1521
|
+
bool elementwise_into_(tensor_view<value_type_, max_rank_> input, tensor_span<value_type_, max_rank_> output,
|
|
1522
|
+
leaf_fn_ &&leaf) noexcept {
|
|
1523
|
+
if (!shapes_match_out_(input, output) || !tensor_layout_supported_(input) || !tensor_layout_supported_(output))
|
|
1524
|
+
return false;
|
|
1525
|
+
if (input.empty()) return true;
|
|
1526
|
+
if (input.rank() >= 2) {
|
|
1527
|
+
for (std::size_t i = 0; i < input.extent(0); ++i) {
|
|
1528
|
+
auto idx = static_cast<std::ptrdiff_t>(i);
|
|
1529
|
+
if (!elementwise_into_<value_type_, max_rank_>(input.slice_leading(idx), output.slice_leading(idx), leaf))
|
|
1530
|
+
return false;
|
|
1531
|
+
}
|
|
1532
|
+
return true;
|
|
1533
|
+
}
|
|
1534
|
+
if (!can_apply_rank1_data_kernel_(input) || !can_apply_rank1_data_kernel_(output)) return false;
|
|
1535
|
+
leaf(input, output);
|
|
1536
|
+
return true;
|
|
1537
|
+
}
|
|
1538
|
+
|
|
1539
|
+
/** @brief Binary elementwise traversal: validates shapes, recurses on rank≥2, calls leaf on rank-1 slices. */
|
|
1540
|
+
template <typename value_type_, std::size_t max_rank_, typename leaf_fn_>
|
|
1541
|
+
bool elementwise_into_(tensor_view<value_type_, max_rank_> lhs, tensor_view<value_type_, max_rank_> rhs,
|
|
1542
|
+
tensor_span<value_type_, max_rank_> output, leaf_fn_ &&leaf) noexcept {
|
|
1543
|
+
if (!shapes_match_(lhs, rhs) || !shapes_match_out_(lhs, output) || !tensor_layout_supported_(lhs) ||
|
|
1544
|
+
!tensor_layout_supported_(rhs) || !tensor_layout_supported_(output))
|
|
1545
|
+
return false;
|
|
1546
|
+
if (lhs.empty()) return true;
|
|
1547
|
+
if (lhs.rank() >= 2) {
|
|
1548
|
+
for (std::size_t i = 0; i < lhs.extent(0); ++i) {
|
|
1549
|
+
auto idx = static_cast<std::ptrdiff_t>(i);
|
|
1550
|
+
if (!elementwise_into_<value_type_, max_rank_>(lhs.slice_leading(idx), rhs.slice_leading(idx),
|
|
1551
|
+
output.slice_leading(idx), leaf))
|
|
1552
|
+
return false;
|
|
1553
|
+
}
|
|
1554
|
+
return true;
|
|
1555
|
+
}
|
|
1556
|
+
if (!can_apply_rank1_data_kernel_(lhs) || !can_apply_rank1_data_kernel_(rhs) ||
|
|
1557
|
+
!can_apply_rank1_data_kernel_(output))
|
|
1558
|
+
return false;
|
|
1559
|
+
leaf(lhs, rhs, output);
|
|
1560
|
+
return true;
|
|
1561
|
+
}
|
|
1562
|
+
|
|
1563
|
+
/** @brief Ternary elementwise traversal: validates shapes, recurses on rank≥2, calls leaf on rank-1 slices. */
|
|
1564
|
+
template <typename value_type_, std::size_t max_rank_, typename leaf_fn_>
|
|
1565
|
+
bool elementwise_into_(tensor_view<value_type_, max_rank_> a, tensor_view<value_type_, max_rank_> b,
|
|
1566
|
+
tensor_view<value_type_, max_rank_> c, tensor_span<value_type_, max_rank_> output,
|
|
1567
|
+
leaf_fn_ &&leaf) noexcept {
|
|
1568
|
+
if (!shapes_match_(a, b) || !shapes_match_(a, c) || !shapes_match_out_(a, output) || !tensor_layout_supported_(a) ||
|
|
1569
|
+
!tensor_layout_supported_(b) || !tensor_layout_supported_(c) || !tensor_layout_supported_(output))
|
|
1570
|
+
return false;
|
|
1571
|
+
if (a.empty()) return true;
|
|
1572
|
+
if (a.rank() >= 2) {
|
|
1573
|
+
for (std::size_t i = 0; i < a.extent(0); ++i) {
|
|
1574
|
+
auto idx = static_cast<std::ptrdiff_t>(i);
|
|
1575
|
+
if (!elementwise_into_<value_type_, max_rank_>(a.slice_leading(idx), b.slice_leading(idx),
|
|
1576
|
+
c.slice_leading(idx), output.slice_leading(idx), leaf))
|
|
1577
|
+
return false;
|
|
1578
|
+
}
|
|
1579
|
+
return true;
|
|
1580
|
+
}
|
|
1581
|
+
if (!can_apply_rank1_data_kernel_(a) || !can_apply_rank1_data_kernel_(b) || !can_apply_rank1_data_kernel_(c) ||
|
|
1582
|
+
!can_apply_rank1_data_kernel_(output))
|
|
1583
|
+
return false;
|
|
1584
|
+
leaf(a, b, c, output);
|
|
1585
|
+
return true;
|
|
1586
|
+
}
|
|
1587
|
+
|
|
1588
|
+
#pragma endregion - Helpers
|
|
1589
|
+
|
|
1590
|
+
} // namespace ashvardanian::numkong
|
|
1591
|
+
|
|
1592
|
+
#endif // NK_TENSOR_HPP
|