gstaichi 0.1.21.dev0__cp310-cp310-win_amd64.whl → 0.1.25.dev0__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +9 -0
- {taichi → gstaichi}/__init__.py +9 -13
- {taichi → gstaichi}/_funcs.py +8 -8
- {taichi → gstaichi}/_kernels.py +19 -19
- gstaichi/_lib/__init__.py +3 -0
- taichi/_lib/core/taichi_python.cp310-win_amd64.pyd → gstaichi/_lib/core/gstaichi_python.cp310-win_amd64.pyd +0 -0
- taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -522
- {taichi → gstaichi}/_lib/runtime/runtime_cuda.bc +0 -0
- {taichi → gstaichi}/_lib/runtime/runtime_x64.bc +0 -0
- {taichi → gstaichi}/_lib/utils.py +15 -15
- {taichi → gstaichi}/_logging.py +1 -1
- {taichi → gstaichi}/_main.py +24 -31
- gstaichi/_snode/__init__.py +5 -0
- {taichi → gstaichi}/_snode/fields_builder.py +27 -29
- {taichi → gstaichi}/_snode/snode_tree.py +5 -5
- gstaichi/_test_tools/__init__.py +0 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_version.py +1 -0
- {taichi → gstaichi}/_version_check.py +8 -5
- gstaichi/ad/__init__.py +3 -0
- {taichi → gstaichi}/ad/_ad.py +26 -26
- {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
- {taichi → gstaichi}/examples/minimal.py +1 -1
- {taichi → gstaichi}/experimental.py +1 -1
- gstaichi/lang/__init__.py +50 -0
- {taichi → gstaichi}/lang/_ndarray.py +30 -26
- {taichi → gstaichi}/lang/_ndrange.py +8 -8
- gstaichi/lang/_template_mapper.py +199 -0
- {taichi → gstaichi}/lang/_texture.py +19 -19
- {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
- {taichi → gstaichi}/lang/any_array.py +13 -13
- {taichi → gstaichi}/lang/argpack.py +29 -29
- gstaichi/lang/ast/__init__.py +5 -0
- {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
- {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
- {taichi → gstaichi}/lang/ast/checkers.py +5 -5
- gstaichi/lang/ast/transform.py +9 -0
- {taichi → gstaichi}/lang/common_ops.py +12 -12
- gstaichi/lang/exception.py +80 -0
- {taichi → gstaichi}/lang/expr.py +22 -22
- {taichi → gstaichi}/lang/field.py +29 -27
- {taichi → gstaichi}/lang/impl.py +116 -121
- {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
- {taichi → gstaichi}/lang/kernel_impl.py +330 -363
- {taichi → gstaichi}/lang/matrix.py +119 -115
- {taichi → gstaichi}/lang/matrix_ops.py +6 -6
- {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
- {taichi → gstaichi}/lang/mesh.py +22 -22
- {taichi → gstaichi}/lang/misc.py +39 -68
- {taichi → gstaichi}/lang/ops.py +146 -141
- {taichi → gstaichi}/lang/runtime_ops.py +2 -2
- {taichi → gstaichi}/lang/shell.py +3 -3
- {taichi → gstaichi}/lang/simt/__init__.py +1 -1
- {taichi → gstaichi}/lang/simt/block.py +7 -7
- {taichi → gstaichi}/lang/simt/grid.py +1 -1
- {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
- {taichi → gstaichi}/lang/simt/warp.py +1 -1
- {taichi → gstaichi}/lang/snode.py +46 -44
- {taichi → gstaichi}/lang/source_builder.py +13 -13
- {taichi → gstaichi}/lang/struct.py +33 -33
- {taichi → gstaichi}/lang/util.py +24 -24
- gstaichi/linalg/__init__.py +8 -0
- {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
- {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
- {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
- {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
- {taichi → gstaichi}/math/__init__.py +1 -1
- {taichi → gstaichi}/math/_complex.py +21 -20
- {taichi → gstaichi}/math/mathimpl.py +56 -56
- gstaichi/profiler/__init__.py +6 -0
- {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
- {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
- {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
- {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
- {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
- {taichi → gstaichi}/tools/__init__.py +4 -4
- {taichi → gstaichi}/tools/diagnose.py +10 -17
- gstaichi/types/__init__.py +19 -0
- {taichi → gstaichi}/types/annotations.py +1 -1
- {taichi → gstaichi}/types/compound_types.py +8 -8
- {taichi → gstaichi}/types/enums.py +1 -1
- {taichi → gstaichi}/types/ndarray_type.py +7 -7
- {taichi → gstaichi}/types/primitive_types.py +17 -14
- {taichi → gstaichi}/types/quant.py +9 -9
- {taichi → gstaichi}/types/texture_type.py +5 -5
- {taichi → gstaichi}/types/utils.py +1 -1
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/bin/SPIRV-Tools-shared.dll +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-diff.lib +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-link.lib +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-lint.lib +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-opt.lib +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-reduce.lib +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-shared.lib +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools.lib +0 -0
- {gstaichi-0.1.21.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/METADATA +13 -16
- gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
- gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
- gstaichi-0.1.21.dev0.data/data/include/GLFW/glfw3.h +0 -6389
- gstaichi-0.1.21.dev0.data/data/include/GLFW/glfw3native.h +0 -594
- gstaichi-0.1.21.dev0.data/data/lib/cmake/glfw3/glfw3Config.cmake +0 -3
- gstaichi-0.1.21.dev0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -65
- gstaichi-0.1.21.dev0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -19
- gstaichi-0.1.21.dev0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -107
- gstaichi-0.1.21.dev0.data/data/lib/glfw3.lib +0 -0
- gstaichi-0.1.21.dev0.dist-info/RECORD +0 -198
- gstaichi-0.1.21.dev0.dist-info/entry_points.txt +0 -2
- gstaichi-0.1.21.dev0.dist-info/top_level.txt +0 -1
- taichi/CHANGELOG.md +0 -17
- taichi/_lib/__init__.py +0 -3
- taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
- taichi/_lib/c_api/include/taichi/taichi.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
- taichi/_lib/c_api/include/taichi/taichi_cuda.h +0 -36
- taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
- taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
- taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
- taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
- taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
- taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- taichi/_snode/__init__.py +0 -5
- taichi/_ti_module/__init__.py +0 -3
- taichi/_ti_module/cppgen.py +0 -309
- taichi/_ti_module/module.py +0 -145
- taichi/_version.py +0 -1
- taichi/ad/__init__.py +0 -3
- taichi/aot/__init__.py +0 -12
- taichi/aot/_export.py +0 -28
- taichi/aot/conventions/__init__.py +0 -3
- taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
- taichi/aot/conventions/gfxruntime140/dr.py +0 -244
- taichi/aot/conventions/gfxruntime140/sr.py +0 -613
- taichi/aot/module.py +0 -253
- taichi/aot/utils.py +0 -151
- taichi/graph/__init__.py +0 -3
- taichi/graph/_graph.py +0 -292
- taichi/lang/__init__.py +0 -50
- taichi/lang/ast/__init__.py +0 -5
- taichi/lang/ast/transform.py +0 -9
- taichi/lang/exception.py +0 -80
- taichi/linalg/__init__.py +0 -8
- taichi/profiler/__init__.py +0 -6
- taichi/shaders/Circles_vk.frag +0 -29
- taichi/shaders/Circles_vk.vert +0 -45
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +0 -9
- taichi/shaders/Lines_vk.vert +0 -11
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +0 -71
- taichi/shaders/Mesh_vk.vert +0 -68
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +0 -95
- taichi/shaders/Particles_vk.vert +0 -73
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +0 -9
- taichi/shaders/SceneLines_vk.vert +0 -12
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +0 -21
- taichi/shaders/SetImage_vk.vert +0 -15
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +0 -16
- taichi/shaders/Triangles_vk.vert +0 -29
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/types/__init__.py +0 -19
- {taichi → gstaichi}/__main__.py +0 -0
- {taichi → gstaichi}/_lib/core/__init__.py +0 -0
- {taichi → gstaichi}/_lib/core/py.typed +0 -0
- {taichi/_lib/c_api → gstaichi/_lib}/runtime/slim_libdevice.10.bc +0 -0
- {taichi → gstaichi}/algorithms/__init__.py +0 -0
- {taichi → gstaichi}/assets/.git +0 -0
- {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
- {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
- {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
- {taichi → gstaichi}/sparse/__init__.py +0 -0
- {taichi → gstaichi}/tools/np2ply.py +0 -0
- {taichi → gstaichi}/tools/vtk.py +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/instrument.hpp +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.h +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/linker.hpp +0 -0
- {gstaichi-0.1.21.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
- {gstaichi-0.1.21.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/WHEEL +0 -0
- {gstaichi-0.1.21.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/licenses/LICENSE +0 -0
@@ -1,1401 +0,0 @@
|
|
1
|
-
// C++ wrapper of Taichi C-API
|
2
|
-
#pragma once
|
3
|
-
#include <iostream>
|
4
|
-
#include <algorithm>
|
5
|
-
#include <cassert>
|
6
|
-
#include <cstddef>
|
7
|
-
#include <cstring>
|
8
|
-
#include <cassert>
|
9
|
-
#include <list>
|
10
|
-
#include <vector>
|
11
|
-
#include <map>
|
12
|
-
#include <string>
|
13
|
-
#include <utility>
|
14
|
-
#include <taichi/taichi.h>
|
15
|
-
|
16
|
-
namespace ti {
|
17
|
-
|
18
|
-
struct Version {
|
19
|
-
uint32_t version;
|
20
|
-
|
21
|
-
explicit Version(uint32_t version) : version(version) {
|
22
|
-
}
|
23
|
-
Version(uint32_t major, uint32_t minor, uint32_t patch)
|
24
|
-
: version((major * 1000 + minor) * 1000 + patch) {
|
25
|
-
}
|
26
|
-
Version(const Version &) = default;
|
27
|
-
Version(Version &&) = default;
|
28
|
-
Version &operator=(const Version &) = default;
|
29
|
-
Version &operator=(Version &&) = default;
|
30
|
-
|
31
|
-
inline uint32_t major() const {
|
32
|
-
return version / 1000000;
|
33
|
-
}
|
34
|
-
inline uint32_t minor() const {
|
35
|
-
return (version / 1000) % 1000;
|
36
|
-
}
|
37
|
-
inline uint32_t patch() const {
|
38
|
-
return version % 1000;
|
39
|
-
}
|
40
|
-
};
|
41
|
-
inline Version get_version() {
|
42
|
-
return Version(ti_get_version());
|
43
|
-
}
|
44
|
-
|
45
|
-
inline std::vector<TiArch> get_available_archs() {
|
46
|
-
uint32_t narch = 0;
|
47
|
-
ti_get_available_archs(&narch, nullptr);
|
48
|
-
std::vector<TiArch> archs(narch);
|
49
|
-
ti_get_available_archs(&narch, archs.data());
|
50
|
-
return archs;
|
51
|
-
}
|
52
|
-
inline std::vector<TiArch> get_available_archs(
|
53
|
-
const std::vector<TiArch> &expect_archs) {
|
54
|
-
std::vector<TiArch> actual_archs = get_available_archs();
|
55
|
-
std::vector<TiArch> out_archs;
|
56
|
-
for (TiArch arch : actual_archs) {
|
57
|
-
auto it = std::find(expect_archs.begin(), expect_archs.end(), arch);
|
58
|
-
if (it != expect_archs.end()) {
|
59
|
-
out_archs.emplace_back(arch);
|
60
|
-
}
|
61
|
-
}
|
62
|
-
return out_archs;
|
63
|
-
}
|
64
|
-
inline bool is_arch_available(TiArch arch) {
|
65
|
-
std::vector<TiArch> archs = get_available_archs();
|
66
|
-
for (size_t i = 0; i < archs.size(); ++i) {
|
67
|
-
if (archs.at(i) == arch) {
|
68
|
-
return true;
|
69
|
-
}
|
70
|
-
}
|
71
|
-
return false;
|
72
|
-
}
|
73
|
-
|
74
|
-
struct Error {
|
75
|
-
TiError error;
|
76
|
-
std::string message;
|
77
|
-
|
78
|
-
inline operator TiError() const { // NOLINT
|
79
|
-
return error;
|
80
|
-
}
|
81
|
-
};
|
82
|
-
|
83
|
-
inline Error get_last_error() {
|
84
|
-
uint64_t message_size = 0;
|
85
|
-
ti_get_last_error(&message_size, nullptr);
|
86
|
-
std::string message(message_size, '\0');
|
87
|
-
TiError error = ti_get_last_error(&message_size, (char *)message.data());
|
88
|
-
message.resize(message.size() - 1);
|
89
|
-
return Error{error, message};
|
90
|
-
}
|
91
|
-
inline void check_last_error() {
|
92
|
-
Error error = get_last_error();
|
93
|
-
if (error != TI_ERROR_SUCCESS) {
|
94
|
-
#ifdef TI_WITH_EXCEPTIONS
|
95
|
-
throw std::runtime_error(error.message);
|
96
|
-
#else
|
97
|
-
assert(false);
|
98
|
-
#endif // TI_WITH_EXCEPTIONS
|
99
|
-
}
|
100
|
-
}
|
101
|
-
inline void set_last_error(TiError error) {
|
102
|
-
ti_set_last_error(error, nullptr);
|
103
|
-
}
|
104
|
-
inline void set_last_error(TiError error, const std::string &message) {
|
105
|
-
ti_set_last_error(error, message.c_str());
|
106
|
-
}
|
107
|
-
inline void set_last_error(const Error &error) {
|
108
|
-
set_last_error(error.error, error.message);
|
109
|
-
}
|
110
|
-
|
111
|
-
namespace detail {
|
112
|
-
|
113
|
-
// Template type to data type enum.
|
114
|
-
template <typename T>
|
115
|
-
struct templ2dtype {};
|
116
|
-
template <>
|
117
|
-
struct templ2dtype<int8_t> {
|
118
|
-
static const TiDataType value = TI_DATA_TYPE_I8;
|
119
|
-
};
|
120
|
-
template <>
|
121
|
-
struct templ2dtype<int16_t> {
|
122
|
-
static const TiDataType value = TI_DATA_TYPE_I16;
|
123
|
-
};
|
124
|
-
template <>
|
125
|
-
struct templ2dtype<int32_t> {
|
126
|
-
static const TiDataType value = TI_DATA_TYPE_I32;
|
127
|
-
};
|
128
|
-
template <>
|
129
|
-
struct templ2dtype<uint8_t> {
|
130
|
-
static const TiDataType value = TI_DATA_TYPE_U8;
|
131
|
-
};
|
132
|
-
template <>
|
133
|
-
struct templ2dtype<uint16_t> {
|
134
|
-
static const TiDataType value = TI_DATA_TYPE_U16;
|
135
|
-
};
|
136
|
-
template <>
|
137
|
-
struct templ2dtype<uint32_t> {
|
138
|
-
static const TiDataType value = TI_DATA_TYPE_U32;
|
139
|
-
};
|
140
|
-
template <>
|
141
|
-
struct templ2dtype<float> {
|
142
|
-
static const TiDataType value = TI_DATA_TYPE_F32;
|
143
|
-
};
|
144
|
-
template <>
|
145
|
-
struct templ2dtype<double> {
|
146
|
-
static const TiDataType value = TI_DATA_TYPE_F64;
|
147
|
-
};
|
148
|
-
|
149
|
-
template <typename T, typename U>
|
150
|
-
T exchange(T &storage, U &&value) {
|
151
|
-
T out = std::move(storage);
|
152
|
-
storage = (T)std::move(value);
|
153
|
-
return out;
|
154
|
-
}
|
155
|
-
|
156
|
-
template <typename THandle>
|
157
|
-
THandle move_handle(THandle &handle) {
|
158
|
-
THandle out = std::move(handle);
|
159
|
-
handle = TI_NULL_HANDLE;
|
160
|
-
return out;
|
161
|
-
}
|
162
|
-
|
163
|
-
} // namespace detail
|
164
|
-
|
165
|
-
class MemorySlice {
|
166
|
-
TiRuntime runtime_{TI_NULL_HANDLE};
|
167
|
-
TiMemorySlice slice_{};
|
168
|
-
|
169
|
-
public:
|
170
|
-
MemorySlice() = default;
|
171
|
-
MemorySlice(TiRuntime runtime, const TiMemorySlice &slice)
|
172
|
-
: runtime_(runtime), slice_(slice) {
|
173
|
-
}
|
174
|
-
MemorySlice(const MemorySlice &) = default;
|
175
|
-
MemorySlice(MemorySlice &&) = default;
|
176
|
-
MemorySlice &operator=(const MemorySlice &) = default;
|
177
|
-
MemorySlice &operator=(MemorySlice &&) = default;
|
178
|
-
|
179
|
-
inline void copy_to(const MemorySlice &dst) const {
|
180
|
-
if (runtime_ != dst.runtime_) {
|
181
|
-
ti_set_last_error(
|
182
|
-
TI_ERROR_INVALID_ARGUMENT,
|
183
|
-
"cannot copy device memory between different runtime instances");
|
184
|
-
return;
|
185
|
-
}
|
186
|
-
if (slice_.size != dst.slice_.size) {
|
187
|
-
ti_set_last_error(
|
188
|
-
TI_ERROR_INVALID_ARGUMENT,
|
189
|
-
"copy source and destination slice must have the same size");
|
190
|
-
return;
|
191
|
-
}
|
192
|
-
ti_copy_memory_device_to_device(runtime_, &dst.slice_, &slice_);
|
193
|
-
}
|
194
|
-
|
195
|
-
inline TiMemory memory() const {
|
196
|
-
return slice_.memory;
|
197
|
-
}
|
198
|
-
inline uint64_t offset() const {
|
199
|
-
return slice_.offset;
|
200
|
-
}
|
201
|
-
inline uint64_t size() const {
|
202
|
-
return slice_.size;
|
203
|
-
}
|
204
|
-
inline const TiMemorySlice &slice() const {
|
205
|
-
return slice_;
|
206
|
-
}
|
207
|
-
inline operator const TiMemorySlice &() const { // NOLINT
|
208
|
-
return slice_;
|
209
|
-
}
|
210
|
-
};
|
211
|
-
|
212
|
-
class Memory {
|
213
|
-
protected:
|
214
|
-
TiRuntime runtime_{TI_NULL_HANDLE};
|
215
|
-
TiMemory memory_{TI_NULL_HANDLE};
|
216
|
-
size_t size_{0};
|
217
|
-
bool should_destroy_{false};
|
218
|
-
|
219
|
-
public:
|
220
|
-
constexpr bool is_valid() const {
|
221
|
-
return runtime_ != nullptr;
|
222
|
-
}
|
223
|
-
inline void destroy() {
|
224
|
-
if (should_destroy_) {
|
225
|
-
ti_free_memory(runtime_, memory_);
|
226
|
-
memory_ = TI_NULL_HANDLE;
|
227
|
-
should_destroy_ = false;
|
228
|
-
}
|
229
|
-
}
|
230
|
-
|
231
|
-
Memory() {
|
232
|
-
}
|
233
|
-
Memory(const Memory &) = delete;
|
234
|
-
Memory(Memory &&b)
|
235
|
-
: runtime_(detail::move_handle(b.runtime_)),
|
236
|
-
memory_(detail::move_handle(b.memory_)),
|
237
|
-
size_(detail::exchange(b.size_, 0)),
|
238
|
-
should_destroy_(detail::exchange(b.should_destroy_, false)) {
|
239
|
-
}
|
240
|
-
Memory(TiRuntime runtime, TiMemory memory, size_t size, bool should_destroy)
|
241
|
-
: runtime_(runtime),
|
242
|
-
memory_(memory),
|
243
|
-
size_(size),
|
244
|
-
should_destroy_(should_destroy) {
|
245
|
-
}
|
246
|
-
~Memory() {
|
247
|
-
destroy();
|
248
|
-
}
|
249
|
-
|
250
|
-
Memory &operator=(const Memory &) = delete;
|
251
|
-
Memory &operator=(Memory &&b) {
|
252
|
-
destroy();
|
253
|
-
runtime_ = detail::move_handle(b.runtime_);
|
254
|
-
memory_ = detail::move_handle(b.memory_);
|
255
|
-
size_ = detail::exchange(b.size_, 0);
|
256
|
-
should_destroy_ = detail::exchange(b.should_destroy_, false);
|
257
|
-
return *this;
|
258
|
-
}
|
259
|
-
|
260
|
-
inline Memory borrow() const {
|
261
|
-
return Memory(runtime_, memory_, size_, false);
|
262
|
-
}
|
263
|
-
|
264
|
-
void *map() const {
|
265
|
-
return ti_map_memory(runtime_, memory_);
|
266
|
-
}
|
267
|
-
void unmap() const {
|
268
|
-
ti_unmap_memory(runtime_, memory_);
|
269
|
-
}
|
270
|
-
|
271
|
-
inline void read(void *dst, size_t size) const {
|
272
|
-
void *src = map();
|
273
|
-
if (src != nullptr) {
|
274
|
-
std::memcpy(dst, src, size);
|
275
|
-
}
|
276
|
-
unmap();
|
277
|
-
}
|
278
|
-
inline void write(const void *src, size_t size) const {
|
279
|
-
void *dst = map();
|
280
|
-
if (dst != nullptr) {
|
281
|
-
std::memcpy(dst, src, size);
|
282
|
-
}
|
283
|
-
unmap();
|
284
|
-
}
|
285
|
-
|
286
|
-
inline void copy_to(const ti::Memory &dst) const {
|
287
|
-
slice().copy_to(dst.slice());
|
288
|
-
}
|
289
|
-
|
290
|
-
inline MemorySlice slice(size_t offset, size_t size) const {
|
291
|
-
if (offset + size > size_) {
|
292
|
-
ti_set_last_error(TI_ERROR_ARGUMENT_OUT_OF_RANGE, "size");
|
293
|
-
return {};
|
294
|
-
}
|
295
|
-
TiMemorySlice slice{};
|
296
|
-
slice.memory = memory_;
|
297
|
-
slice.offset = offset;
|
298
|
-
slice.size = size;
|
299
|
-
return MemorySlice(runtime_, slice);
|
300
|
-
}
|
301
|
-
inline MemorySlice slice() const {
|
302
|
-
return slice(0, size_);
|
303
|
-
}
|
304
|
-
|
305
|
-
constexpr size_t size() const {
|
306
|
-
return size_;
|
307
|
-
}
|
308
|
-
constexpr TiMemory memory() const {
|
309
|
-
return memory_;
|
310
|
-
}
|
311
|
-
constexpr operator TiMemory() const { // NOLINT
|
312
|
-
return memory_;
|
313
|
-
}
|
314
|
-
};
|
315
|
-
|
316
|
-
template <typename T>
|
317
|
-
class NdArray {
|
318
|
-
protected:
|
319
|
-
Memory memory_{};
|
320
|
-
TiNdArray ndarray_{};
|
321
|
-
size_t elem_count_{};
|
322
|
-
size_t scalar_count_{};
|
323
|
-
|
324
|
-
public:
|
325
|
-
constexpr bool is_valid() const {
|
326
|
-
return memory_.is_valid();
|
327
|
-
}
|
328
|
-
inline void destroy() {
|
329
|
-
memory_.destroy();
|
330
|
-
}
|
331
|
-
|
332
|
-
NdArray() : elem_count_(1), scalar_count_(1) {
|
333
|
-
}
|
334
|
-
NdArray(const NdArray<T> &) = delete;
|
335
|
-
NdArray(NdArray<T> &&b)
|
336
|
-
: memory_(std::move(b.memory_)),
|
337
|
-
ndarray_(detail::exchange(b.ndarray_, TiNdArray{})),
|
338
|
-
elem_count_(detail::exchange(b.elem_count_, 1)),
|
339
|
-
scalar_count_(detail::exchange(b.scalar_count_, 1)) {
|
340
|
-
}
|
341
|
-
NdArray(Memory &&memory, const TiNdArray &ndarray)
|
342
|
-
: memory_(std::move(memory)),
|
343
|
-
ndarray_(ndarray),
|
344
|
-
elem_count_(1),
|
345
|
-
scalar_count_(1) {
|
346
|
-
if (ndarray.memory != memory_) {
|
347
|
-
ti_set_last_error(TI_ERROR_INVALID_ARGUMENT, "ndarray.memory != memory");
|
348
|
-
}
|
349
|
-
for (uint32_t i = 0; i < ndarray_.shape.dim_count; ++i) {
|
350
|
-
elem_count_ *= ndarray_.shape.dims[i];
|
351
|
-
}
|
352
|
-
scalar_count_ *= elem_count_;
|
353
|
-
for (uint32_t i = 0; i < ndarray_.elem_shape.dim_count; ++i) {
|
354
|
-
scalar_count_ *= ndarray_.elem_shape.dims[i];
|
355
|
-
}
|
356
|
-
}
|
357
|
-
~NdArray() {
|
358
|
-
destroy();
|
359
|
-
}
|
360
|
-
|
361
|
-
NdArray<T> &operator=(const NdArray<T> &) = delete;
|
362
|
-
NdArray<T> &operator=(NdArray<T> &&b) {
|
363
|
-
destroy();
|
364
|
-
memory_ = std::move(b.memory_);
|
365
|
-
ndarray_ = detail::exchange(b.ndarray_, TiNdArray{});
|
366
|
-
elem_count_ = detail::exchange(b.elem_count_, 1);
|
367
|
-
scalar_count_ = detail::exchange(b.scalar_count_, 1);
|
368
|
-
return *this;
|
369
|
-
}
|
370
|
-
|
371
|
-
inline NdArray<T> borrow() const {
|
372
|
-
return NdArray<T>(memory_.borrow(), ndarray_);
|
373
|
-
}
|
374
|
-
|
375
|
-
inline void *map() const {
|
376
|
-
return memory_.map();
|
377
|
-
}
|
378
|
-
inline void unmap() const {
|
379
|
-
return memory_.unmap();
|
380
|
-
}
|
381
|
-
|
382
|
-
inline size_t scalar_count() const {
|
383
|
-
return scalar_count_;
|
384
|
-
}
|
385
|
-
inline size_t elem_count() const {
|
386
|
-
return elem_count_;
|
387
|
-
}
|
388
|
-
|
389
|
-
inline void read(T *dst, size_t count) const {
|
390
|
-
if (count > scalar_count_) {
|
391
|
-
ti_set_last_error(
|
392
|
-
TI_ERROR_ARGUMENT_OUT_OF_RANGE,
|
393
|
-
"ndarray read ouf of range; please ensure you specified the correct "
|
394
|
-
"number of elements (rather than size-in-bytes) to be read");
|
395
|
-
return;
|
396
|
-
}
|
397
|
-
memory_.read(dst, count * sizeof(T));
|
398
|
-
}
|
399
|
-
inline void read(std::vector<T> &dst) const {
|
400
|
-
read(dst.data(), dst.size());
|
401
|
-
}
|
402
|
-
template <typename U>
|
403
|
-
inline void read(std::vector<U> &dst) const {
|
404
|
-
static_assert(sizeof(U) % sizeof(T) == 0,
|
405
|
-
"sizeof(U) must be a multiple of sizeof(T)");
|
406
|
-
read((T *)dst.data(), dst.size() * (sizeof(U) / sizeof(T)));
|
407
|
-
}
|
408
|
-
inline void write(const T *src, size_t count) const {
|
409
|
-
if (count > scalar_count_) {
|
410
|
-
ti_set_last_error(
|
411
|
-
TI_ERROR_ARGUMENT_OUT_OF_RANGE,
|
412
|
-
"ndarray write ouf of range; please ensure you specified the correct "
|
413
|
-
"number of elements (rather than size-in-bytes) to be written");
|
414
|
-
return;
|
415
|
-
}
|
416
|
-
memory_.write(src, count * sizeof(T));
|
417
|
-
}
|
418
|
-
inline void write(const std::vector<T> &src) const {
|
419
|
-
write(src.data(), src.size());
|
420
|
-
}
|
421
|
-
template <typename U>
|
422
|
-
inline void write(const std::vector<U> &src) const {
|
423
|
-
static_assert(sizeof(U) % sizeof(T) == 0,
|
424
|
-
"sizeof(U) must be a multiple of sizeof(T)");
|
425
|
-
write((const T *)src.data(), src.size() * (sizeof(U) / sizeof(T)));
|
426
|
-
}
|
427
|
-
|
428
|
-
template <typename U>
|
429
|
-
inline void copy_to(const ti::NdArray<U> &dst) const {
|
430
|
-
memory().copy_to(dst.memory());
|
431
|
-
}
|
432
|
-
|
433
|
-
inline MemorySlice slice() const {
|
434
|
-
return memory_.slice();
|
435
|
-
}
|
436
|
-
|
437
|
-
constexpr TiDataType elem_type() const {
|
438
|
-
return ndarray_.elem_type;
|
439
|
-
}
|
440
|
-
constexpr const TiNdShape &shape() const {
|
441
|
-
return ndarray_.shape;
|
442
|
-
}
|
443
|
-
constexpr const TiNdShape &elem_shape() const {
|
444
|
-
return ndarray_.elem_shape;
|
445
|
-
}
|
446
|
-
constexpr const Memory &memory() const {
|
447
|
-
return memory_;
|
448
|
-
}
|
449
|
-
constexpr const TiNdArray &ndarray() const {
|
450
|
-
return ndarray_;
|
451
|
-
}
|
452
|
-
constexpr operator TiNdArray() const { // NOLINT
|
453
|
-
return ndarray_;
|
454
|
-
}
|
455
|
-
};
|
456
|
-
|
457
|
-
class ImageSlice {
|
458
|
-
TiRuntime runtime_{TI_NULL_HANDLE};
|
459
|
-
TiImageSlice slice_{};
|
460
|
-
|
461
|
-
public:
|
462
|
-
ImageSlice() = default;
|
463
|
-
ImageSlice(TiRuntime runtime, const TiImageSlice &slice)
|
464
|
-
: runtime_(runtime), slice_(slice) {
|
465
|
-
}
|
466
|
-
ImageSlice(const ImageSlice &) = default;
|
467
|
-
ImageSlice(ImageSlice &&) = default;
|
468
|
-
ImageSlice &operator=(const ImageSlice &) = default;
|
469
|
-
ImageSlice &operator=(ImageSlice &&) = default;
|
470
|
-
|
471
|
-
inline void copy_to(const ImageSlice &dst) const {
|
472
|
-
if (runtime_ != dst.runtime_) {
|
473
|
-
ti_set_last_error(
|
474
|
-
TI_ERROR_INVALID_ARGUMENT,
|
475
|
-
"cannot copy device memory between different runtime instances");
|
476
|
-
return;
|
477
|
-
}
|
478
|
-
if (slice_.extent.width != dst.slice_.extent.width ||
|
479
|
-
slice_.extent.height != dst.slice_.extent.height ||
|
480
|
-
slice_.extent.depth != dst.slice_.extent.depth ||
|
481
|
-
slice_.extent.array_layer_count !=
|
482
|
-
dst.slice_.extent.array_layer_count) {
|
483
|
-
ti_set_last_error(
|
484
|
-
TI_ERROR_INVALID_ARGUMENT,
|
485
|
-
"copy source and destination slice must have the same size");
|
486
|
-
return;
|
487
|
-
}
|
488
|
-
ti_copy_image_device_to_device(runtime_, &dst.slice_, &slice_);
|
489
|
-
}
|
490
|
-
|
491
|
-
inline TiImage image() const {
|
492
|
-
return slice_.image;
|
493
|
-
}
|
494
|
-
inline const TiImageOffset &offset() const {
|
495
|
-
return slice_.offset;
|
496
|
-
}
|
497
|
-
inline const TiImageExtent &extent() const {
|
498
|
-
return slice_.extent;
|
499
|
-
}
|
500
|
-
inline const uint32_t &mip_level() const {
|
501
|
-
return slice_.mip_level;
|
502
|
-
}
|
503
|
-
inline const TiImageSlice &slice() const {
|
504
|
-
return slice_;
|
505
|
-
}
|
506
|
-
inline operator TiImageSlice() const { // NOLINT
|
507
|
-
return slice_;
|
508
|
-
}
|
509
|
-
};
|
510
|
-
|
511
|
-
class Image {
|
512
|
-
protected:
|
513
|
-
TiRuntime runtime_{TI_NULL_HANDLE};
|
514
|
-
TiImage image_{TI_NULL_HANDLE};
|
515
|
-
TiImageDimension dimension_{TI_IMAGE_DIMENSION_MAX_ENUM};
|
516
|
-
TiImageExtent extent_{0, 0, 0};
|
517
|
-
uint32_t mip_level_count_;
|
518
|
-
TiFormat format_{TI_FORMAT_UNKNOWN};
|
519
|
-
bool should_destroy_{false};
|
520
|
-
|
521
|
-
public:
|
522
|
-
constexpr bool is_valid() const {
|
523
|
-
return image_ != nullptr;
|
524
|
-
}
|
525
|
-
inline void destroy() {
|
526
|
-
if (should_destroy_) {
|
527
|
-
ti_free_image(runtime_, image_);
|
528
|
-
image_ = TI_NULL_HANDLE;
|
529
|
-
should_destroy_ = false;
|
530
|
-
}
|
531
|
-
}
|
532
|
-
|
533
|
-
Image() {
|
534
|
-
}
|
535
|
-
Image(const Image &b) = delete;
|
536
|
-
Image(Image &&b)
|
537
|
-
: runtime_(detail::move_handle(b.runtime_)),
|
538
|
-
image_(detail::move_handle(b.image_)),
|
539
|
-
dimension_(detail::exchange(b.dimension_, TI_IMAGE_DIMENSION_MAX_ENUM)),
|
540
|
-
extent_(detail::exchange(b.extent_, TiImageExtent{0, 0, 0})),
|
541
|
-
mip_level_count_(detail::exchange(b.mip_level_count_, 0)),
|
542
|
-
format_(detail::exchange(b.format_, TI_FORMAT_UNKNOWN)),
|
543
|
-
should_destroy_(detail::exchange(b.should_destroy_, false)) {
|
544
|
-
}
|
545
|
-
Image(TiRuntime runtime,
|
546
|
-
TiImage image,
|
547
|
-
TiImageDimension dimension,
|
548
|
-
const TiImageExtent &extent,
|
549
|
-
uint32_t mip_level_count,
|
550
|
-
TiFormat format,
|
551
|
-
bool should_destroy)
|
552
|
-
: runtime_(runtime),
|
553
|
-
image_(image),
|
554
|
-
dimension_(dimension),
|
555
|
-
extent_(extent),
|
556
|
-
mip_level_count_(mip_level_count),
|
557
|
-
format_(format),
|
558
|
-
should_destroy_(should_destroy) {
|
559
|
-
}
|
560
|
-
~Image() {
|
561
|
-
destroy();
|
562
|
-
}
|
563
|
-
|
564
|
-
Image &operator=(const Image &) = delete;
|
565
|
-
Image &operator=(Image &&b) {
|
566
|
-
destroy();
|
567
|
-
runtime_ = detail::move_handle(b.runtime_);
|
568
|
-
image_ = detail::move_handle(b.image_);
|
569
|
-
dimension_ = detail::exchange(b.dimension_, TI_IMAGE_DIMENSION_MAX_ENUM);
|
570
|
-
extent_ = detail::exchange(b.extent_, TiImageExtent{0, 0, 0});
|
571
|
-
mip_level_count_ = detail::exchange(b.mip_level_count_, 0);
|
572
|
-
format_ = detail::exchange(b.format_, TI_FORMAT_UNKNOWN);
|
573
|
-
should_destroy_ = detail::exchange(b.should_destroy_, false);
|
574
|
-
return *this;
|
575
|
-
}
|
576
|
-
|
577
|
-
inline Image borrow() const {
|
578
|
-
return Image(runtime_, image_, dimension_, extent_, mip_level_count_,
|
579
|
-
format_, false);
|
580
|
-
}
|
581
|
-
|
582
|
-
inline void copy_to(const Image &dst) const {
|
583
|
-
slice().copy_to(dst.slice());
|
584
|
-
}
|
585
|
-
|
586
|
-
inline void transition_to(TiImageLayout layout) const {
|
587
|
-
ti_transition_image(runtime_, image_, layout);
|
588
|
-
}
|
589
|
-
|
590
|
-
inline ImageSlice slice(const TiImageOffset &offset,
|
591
|
-
const TiImageExtent &extent,
|
592
|
-
uint32_t mip_level) const {
|
593
|
-
if (offset.x + extent.width > extent_.width ||
|
594
|
-
offset.y + extent.height > extent_.height ||
|
595
|
-
offset.z + extent.depth > extent_.depth ||
|
596
|
-
offset.array_layer_offset + extent.array_layer_count >
|
597
|
-
extent_.array_layer_count) {
|
598
|
-
ti_set_last_error(TI_ERROR_ARGUMENT_OUT_OF_RANGE, "extent");
|
599
|
-
return {};
|
600
|
-
}
|
601
|
-
TiImageSlice slice{};
|
602
|
-
slice.image = image_;
|
603
|
-
slice.extent = extent;
|
604
|
-
slice.offset = offset;
|
605
|
-
slice.mip_level = mip_level;
|
606
|
-
return ImageSlice(runtime_, slice);
|
607
|
-
}
|
608
|
-
inline ImageSlice slice() const {
|
609
|
-
return slice(TiImageOffset{}, extent_, 0);
|
610
|
-
}
|
611
|
-
|
612
|
-
constexpr TiImageDimension dimension() const {
|
613
|
-
return dimension_;
|
614
|
-
}
|
615
|
-
constexpr const TiImageExtent &extent() const {
|
616
|
-
return extent_;
|
617
|
-
}
|
618
|
-
constexpr uint32_t mip_level_count() const {
|
619
|
-
return mip_level_count_;
|
620
|
-
}
|
621
|
-
constexpr TiFormat format() const {
|
622
|
-
return format_;
|
623
|
-
}
|
624
|
-
constexpr TiImage image() const {
|
625
|
-
return image_;
|
626
|
-
}
|
627
|
-
constexpr operator TiImage() const { // NOLINT
|
628
|
-
return image_;
|
629
|
-
}
|
630
|
-
};
|
631
|
-
|
632
|
-
class Texture {
|
633
|
-
protected:
|
634
|
-
Image image_{};
|
635
|
-
TiTexture texture_{};
|
636
|
-
|
637
|
-
public:
|
638
|
-
constexpr bool is_valid() const {
|
639
|
-
return image_.is_valid();
|
640
|
-
}
|
641
|
-
inline void destroy() {
|
642
|
-
image_.destroy();
|
643
|
-
}
|
644
|
-
|
645
|
-
Texture() {
|
646
|
-
}
|
647
|
-
Texture(const Texture &b) = delete;
|
648
|
-
Texture(Texture &&b)
|
649
|
-
: image_(std::move(b.image_)), texture_(std::move(b.texture_)) {
|
650
|
-
}
|
651
|
-
Texture(Image &&image, const TiTexture &texture)
|
652
|
-
: image_(std::move(image)), texture_(texture) {
|
653
|
-
if (texture.image != image_.image()) {
|
654
|
-
ti_set_last_error(TI_ERROR_INVALID_ARGUMENT, "texture.image != image");
|
655
|
-
}
|
656
|
-
}
|
657
|
-
~Texture() {
|
658
|
-
destroy();
|
659
|
-
}
|
660
|
-
|
661
|
-
Texture &operator=(const Texture &) = delete;
|
662
|
-
Texture &operator=(Texture &&b) {
|
663
|
-
destroy();
|
664
|
-
image_ = std::move(b.image_);
|
665
|
-
texture_ = std::move(b.texture_);
|
666
|
-
return *this;
|
667
|
-
}
|
668
|
-
|
669
|
-
inline Texture borrow() const {
|
670
|
-
return Texture(image_.borrow(), texture_);
|
671
|
-
}
|
672
|
-
|
673
|
-
inline void copy_to(const Texture &dst) const {
|
674
|
-
slice().copy_to(dst.slice());
|
675
|
-
}
|
676
|
-
|
677
|
-
inline ImageSlice slice() const {
|
678
|
-
return image_.slice();
|
679
|
-
}
|
680
|
-
|
681
|
-
constexpr const Image &image() const {
|
682
|
-
return image_;
|
683
|
-
}
|
684
|
-
constexpr TiTexture texture() const {
|
685
|
-
return texture_;
|
686
|
-
}
|
687
|
-
constexpr operator TiTexture() const { // NOLINT
|
688
|
-
return texture_;
|
689
|
-
}
|
690
|
-
};
|
691
|
-
|
692
|
-
template <typename T>
|
693
|
-
struct DataTypeToEnum {
|
694
|
-
static constexpr TiDataType value = TI_DATA_TYPE_UNKNOWN;
|
695
|
-
};
|
696
|
-
#define DEFINE_DATA_TYPE_ENUM(type, enumv) \
|
697
|
-
template <> \
|
698
|
-
struct DataTypeToEnum<type> { \
|
699
|
-
static constexpr TiDataType value = TI_DATA_TYPE_##enumv; \
|
700
|
-
};
|
701
|
-
|
702
|
-
DEFINE_DATA_TYPE_ENUM(int32_t, I32);
|
703
|
-
DEFINE_DATA_TYPE_ENUM(float, F32);
|
704
|
-
DEFINE_DATA_TYPE_ENUM(uint16_t, U16);
|
705
|
-
DEFINE_DATA_TYPE_ENUM(int16_t, I16);
|
706
|
-
DEFINE_DATA_TYPE_ENUM(uint8_t, U8);
|
707
|
-
DEFINE_DATA_TYPE_ENUM(int8_t, I8);
|
708
|
-
DEFINE_DATA_TYPE_ENUM(uint64_t, U64);
|
709
|
-
DEFINE_DATA_TYPE_ENUM(int64_t, I64);
|
710
|
-
#undef DEFINE_DATA_TYPE_ENUM
|
711
|
-
|
712
|
-
class ArgumentEntry {
|
713
|
-
friend class ComputeGraph;
|
714
|
-
TiArgument *arg_;
|
715
|
-
|
716
|
-
public:
|
717
|
-
ArgumentEntry() = delete;
|
718
|
-
ArgumentEntry(const ArgumentEntry &) = delete;
|
719
|
-
ArgumentEntry(ArgumentEntry &&b) : arg_(b.arg_) {
|
720
|
-
}
|
721
|
-
explicit ArgumentEntry(TiArgument *arg) : arg_(arg) {
|
722
|
-
}
|
723
|
-
|
724
|
-
inline void set_f16(float value) {
|
725
|
-
arg_->type = TI_ARGUMENT_TYPE_SCALAR;
|
726
|
-
arg_->value.scalar.type = TI_DATA_TYPE_F16;
|
727
|
-
std::memcpy(&arg_->value.scalar.value.x32, &value, sizeof(value));
|
728
|
-
}
|
729
|
-
inline void set_u16(uint16_t value) {
|
730
|
-
arg_->type = TI_ARGUMENT_TYPE_SCALAR;
|
731
|
-
arg_->value.scalar.type = TI_DATA_TYPE_U16;
|
732
|
-
std::memcpy(&arg_->value.scalar.value.x16, &value, sizeof(value));
|
733
|
-
}
|
734
|
-
inline void set_i16(int16_t value) {
|
735
|
-
arg_->type = TI_ARGUMENT_TYPE_SCALAR;
|
736
|
-
arg_->value.scalar.type = TI_DATA_TYPE_I16;
|
737
|
-
std::memcpy(&arg_->value.scalar.value.x16, &value, sizeof(value));
|
738
|
-
}
|
739
|
-
|
740
|
-
inline ArgumentEntry &operator=(const TiArgument &b) {
|
741
|
-
*arg_ = b;
|
742
|
-
return *this;
|
743
|
-
}
|
744
|
-
inline ArgumentEntry &operator=(int32_t i32) {
|
745
|
-
arg_->type = TI_ARGUMENT_TYPE_I32;
|
746
|
-
arg_->value.i32 = i32;
|
747
|
-
return *this;
|
748
|
-
}
|
749
|
-
inline ArgumentEntry &operator=(float f32) {
|
750
|
-
arg_->type = TI_ARGUMENT_TYPE_F32;
|
751
|
-
arg_->value.f32 = f32;
|
752
|
-
return *this;
|
753
|
-
}
|
754
|
-
inline ArgumentEntry &operator=(uint16_t u16) {
|
755
|
-
this->set_u16(u16);
|
756
|
-
return *this;
|
757
|
-
}
|
758
|
-
inline ArgumentEntry &operator=(int16_t i16) {
|
759
|
-
this->set_i16(i16);
|
760
|
-
return *this;
|
761
|
-
}
|
762
|
-
inline ArgumentEntry &operator=(const TiNdArray &ndarray) {
|
763
|
-
arg_->type = TI_ARGUMENT_TYPE_NDARRAY;
|
764
|
-
arg_->value.ndarray = ndarray;
|
765
|
-
return *this;
|
766
|
-
}
|
767
|
-
inline ArgumentEntry &operator=(const TiTexture &texture) {
|
768
|
-
arg_->type = TI_ARGUMENT_TYPE_TEXTURE;
|
769
|
-
arg_->value.texture = texture;
|
770
|
-
return *this;
|
771
|
-
}
|
772
|
-
template <typename T>
|
773
|
-
inline ArgumentEntry &operator=(const std::vector<T> &matrix) {
|
774
|
-
arg_->type = TI_ARGUMENT_TYPE_TENSOR;
|
775
|
-
std::memcpy(arg_->value.tensor.contents.data.x8, matrix.data(),
|
776
|
-
matrix.size() * sizeof(T));
|
777
|
-
arg_->value.tensor.contents.length = matrix.size();
|
778
|
-
arg_->value.tensor.type = DataTypeToEnum<T>::value;
|
779
|
-
return *this;
|
780
|
-
}
|
781
|
-
template <typename T>
|
782
|
-
inline ArgumentEntry &operator=(const std::vector<std::vector<T>> &matrix) {
|
783
|
-
arg_->type = TI_ARGUMENT_TYPE_TENSOR;
|
784
|
-
uint32_t size = 0, bias = 0;
|
785
|
-
for (const auto &row : matrix) {
|
786
|
-
std::memcpy((arg_->value.tensor.contents.data.x8 + bias), row.data(),
|
787
|
-
row.size() * sizeof(T));
|
788
|
-
size += row.size();
|
789
|
-
bias += row.size() * sizeof(T);
|
790
|
-
}
|
791
|
-
arg_->value.tensor.contents.length = size;
|
792
|
-
arg_->value.tensor.type = DataTypeToEnum<T>::value;
|
793
|
-
return *this;
|
794
|
-
}
|
795
|
-
};
|
796
|
-
|
797
|
-
class ComputeGraph {
|
798
|
-
protected:
|
799
|
-
TiRuntime runtime_{TI_NULL_HANDLE};
|
800
|
-
TiComputeGraph compute_graph_{TI_NULL_HANDLE};
|
801
|
-
std::list<std::string> arg_names_{}; // For stable addresses.
|
802
|
-
std::vector<TiNamedArgument> args_{};
|
803
|
-
|
804
|
-
public:
|
805
|
-
constexpr bool is_valid() const {
|
806
|
-
return compute_graph_ != nullptr;
|
807
|
-
}
|
808
|
-
|
809
|
-
ComputeGraph() {
|
810
|
-
}
|
811
|
-
ComputeGraph(const ComputeGraph &) = delete;
|
812
|
-
ComputeGraph(ComputeGraph &&b)
|
813
|
-
: runtime_(detail::move_handle(b.runtime_)),
|
814
|
-
compute_graph_(detail::move_handle(b.compute_graph_)),
|
815
|
-
arg_names_(std::move(b.arg_names_)),
|
816
|
-
args_(std::move(b.args_)) {
|
817
|
-
}
|
818
|
-
ComputeGraph(TiRuntime runtime, TiComputeGraph compute_graph)
|
819
|
-
: runtime_(runtime), compute_graph_(compute_graph) {
|
820
|
-
}
|
821
|
-
~ComputeGraph() {
|
822
|
-
}
|
823
|
-
|
824
|
-
ComputeGraph &operator=(const ComputeGraph &) = delete;
|
825
|
-
ComputeGraph &operator=(ComputeGraph &&b) {
|
826
|
-
runtime_ = detail::move_handle(b.runtime_);
|
827
|
-
compute_graph_ = detail::move_handle(b.compute_graph_);
|
828
|
-
arg_names_ = std::move(b.arg_names_);
|
829
|
-
args_ = std::move(b.args_);
|
830
|
-
return *this;
|
831
|
-
}
|
832
|
-
|
833
|
-
inline ArgumentEntry at(const char *name) {
|
834
|
-
size_t i = 0;
|
835
|
-
auto it = arg_names_.begin();
|
836
|
-
for (; it != arg_names_.end(); ++it) {
|
837
|
-
if (*it == name) {
|
838
|
-
break;
|
839
|
-
}
|
840
|
-
++i;
|
841
|
-
}
|
842
|
-
|
843
|
-
TiArgument *out;
|
844
|
-
if (it != arg_names_.end()) {
|
845
|
-
out = &args_.at(i).argument;
|
846
|
-
} else {
|
847
|
-
arg_names_.emplace_back(name);
|
848
|
-
args_.emplace_back();
|
849
|
-
args_.back().name = arg_names_.back().c_str();
|
850
|
-
out = &args_.back().argument;
|
851
|
-
}
|
852
|
-
|
853
|
-
return ArgumentEntry(out);
|
854
|
-
};
|
855
|
-
inline ArgumentEntry at(const std::string &name) {
|
856
|
-
return at(name.c_str());
|
857
|
-
}
|
858
|
-
inline ArgumentEntry operator[](const char *name) {
|
859
|
-
return at(name);
|
860
|
-
}
|
861
|
-
inline ArgumentEntry operator[](const std::string &name) {
|
862
|
-
return at(name);
|
863
|
-
}
|
864
|
-
|
865
|
-
void launch(uint32_t argument_count, const TiNamedArgument *arguments) const {
|
866
|
-
ti_launch_compute_graph(runtime_, compute_graph_, argument_count,
|
867
|
-
arguments);
|
868
|
-
}
|
869
|
-
void launch() const {
|
870
|
-
launch(args_.size(), args_.data());
|
871
|
-
}
|
872
|
-
void launch(const std::vector<TiNamedArgument> &arguments) const {
|
873
|
-
launch(arguments.size(), arguments.data());
|
874
|
-
}
|
875
|
-
|
876
|
-
constexpr TiComputeGraph compute_graph() const {
|
877
|
-
return compute_graph_;
|
878
|
-
}
|
879
|
-
constexpr operator TiComputeGraph() const { // NOLINT
|
880
|
-
return compute_graph_;
|
881
|
-
}
|
882
|
-
};
|
883
|
-
|
884
|
-
class Kernel {
|
885
|
-
protected:
|
886
|
-
TiRuntime runtime_{TI_NULL_HANDLE};
|
887
|
-
TiKernel kernel_{TI_NULL_HANDLE};
|
888
|
-
std::vector<TiArgument> args_{};
|
889
|
-
|
890
|
-
public:
|
891
|
-
constexpr bool is_valid() const {
|
892
|
-
return kernel_ != nullptr;
|
893
|
-
}
|
894
|
-
|
895
|
-
Kernel() {
|
896
|
-
}
|
897
|
-
Kernel(const Kernel &) = delete;
|
898
|
-
Kernel(Kernel &&b)
|
899
|
-
: runtime_(detail::move_handle(b.runtime_)),
|
900
|
-
kernel_(detail::move_handle(b.kernel_)),
|
901
|
-
args_(std::move(b.args_)) {
|
902
|
-
}
|
903
|
-
Kernel(TiRuntime runtime, TiKernel kernel)
|
904
|
-
: runtime_(runtime), kernel_(kernel) {
|
905
|
-
}
|
906
|
-
|
907
|
-
Kernel &operator=(const Kernel &) = delete;
|
908
|
-
Kernel &operator=(Kernel &&b) {
|
909
|
-
runtime_ = detail::move_handle(b.runtime_);
|
910
|
-
kernel_ = detail::move_handle(b.kernel_);
|
911
|
-
args_ = std::move(b.args_);
|
912
|
-
return *this;
|
913
|
-
}
|
914
|
-
|
915
|
-
ArgumentEntry at(uint32_t i) {
|
916
|
-
if (i < args_.size()) {
|
917
|
-
return ArgumentEntry(&args_.at(i));
|
918
|
-
} else {
|
919
|
-
args_.resize(i + 1);
|
920
|
-
return ArgumentEntry(&args_.at(i));
|
921
|
-
}
|
922
|
-
}
|
923
|
-
ArgumentEntry operator[](uint32_t i) {
|
924
|
-
return at(i);
|
925
|
-
}
|
926
|
-
|
927
|
-
template <typename T>
|
928
|
-
void push_arg(const std::vector<T> &v) {
|
929
|
-
int idx = args_.size();
|
930
|
-
args_.resize(idx + 1);
|
931
|
-
args_[idx].type = TI_ARGUMENT_TYPE_TENSOR;
|
932
|
-
std::memcpy(args_[idx].value.tensor.contents.data.x32, v.data(),
|
933
|
-
v.size() * sizeof(T));
|
934
|
-
args_[idx].value.tensor.contents.length = v.size();
|
935
|
-
args_[idx].value.tensor.type = DataTypeToEnum<T>::value;
|
936
|
-
}
|
937
|
-
|
938
|
-
template <typename T>
|
939
|
-
void push_arg(const T &arg) {
|
940
|
-
int idx = args_.size();
|
941
|
-
args_.resize(idx + 1);
|
942
|
-
at(idx) = arg;
|
943
|
-
}
|
944
|
-
|
945
|
-
void clear_args() {
|
946
|
-
args_.clear();
|
947
|
-
}
|
948
|
-
|
949
|
-
void launch(uint32_t argument_count, const TiArgument *arguments) const {
|
950
|
-
ti_launch_kernel(runtime_, kernel_, argument_count, arguments);
|
951
|
-
}
|
952
|
-
void launch() const {
|
953
|
-
launch(args_.size(), args_.data());
|
954
|
-
}
|
955
|
-
void launch(const std::vector<TiArgument> &arguments) const {
|
956
|
-
launch(arguments.size(), arguments.data());
|
957
|
-
}
|
958
|
-
|
959
|
-
constexpr TiKernel kernel() const {
|
960
|
-
return kernel_;
|
961
|
-
}
|
962
|
-
constexpr operator TiKernel() const { // NOLINT
|
963
|
-
return kernel_;
|
964
|
-
}
|
965
|
-
};
|
966
|
-
|
967
|
-
class AotModule {
|
968
|
-
protected:
|
969
|
-
TiRuntime runtime_{TI_NULL_HANDLE};
|
970
|
-
TiAotModule aot_module_{TI_NULL_HANDLE};
|
971
|
-
bool should_destroy_{false};
|
972
|
-
|
973
|
-
public:
|
974
|
-
constexpr bool is_valid() const {
|
975
|
-
return aot_module_ != nullptr;
|
976
|
-
}
|
977
|
-
inline void destroy() {
|
978
|
-
if (should_destroy_) {
|
979
|
-
ti_destroy_aot_module(aot_module_);
|
980
|
-
aot_module_ = TI_NULL_HANDLE;
|
981
|
-
should_destroy_ = false;
|
982
|
-
}
|
983
|
-
}
|
984
|
-
|
985
|
-
AotModule() {
|
986
|
-
}
|
987
|
-
AotModule(const AotModule &) = delete;
|
988
|
-
AotModule(AotModule &&b)
|
989
|
-
: runtime_(detail::move_handle(b.runtime_)),
|
990
|
-
aot_module_(detail::move_handle(b.aot_module_)),
|
991
|
-
should_destroy_(detail::exchange(b.should_destroy_, false)) {
|
992
|
-
}
|
993
|
-
AotModule(TiRuntime runtime, TiAotModule aot_module, bool should_destroy)
|
994
|
-
: runtime_(runtime),
|
995
|
-
aot_module_(aot_module),
|
996
|
-
should_destroy_(should_destroy) {
|
997
|
-
}
|
998
|
-
~AotModule() {
|
999
|
-
destroy();
|
1000
|
-
}
|
1001
|
-
|
1002
|
-
AotModule &operator=(const AotModule &) = delete;
|
1003
|
-
AotModule &operator=(AotModule &&b) {
|
1004
|
-
runtime_ = detail::move_handle(b.runtime_);
|
1005
|
-
aot_module_ = detail::move_handle(b.aot_module_);
|
1006
|
-
should_destroy_ = detail::exchange(b.should_destroy_, false);
|
1007
|
-
return *this;
|
1008
|
-
}
|
1009
|
-
|
1010
|
-
inline AotModule borrow() const {
|
1011
|
-
return AotModule(runtime_, aot_module_, false);
|
1012
|
-
}
|
1013
|
-
|
1014
|
-
Kernel get_kernel(const char *name) const {
|
1015
|
-
TiKernel kernel_ = ti_get_aot_module_kernel(aot_module_, name);
|
1016
|
-
return Kernel(runtime_, kernel_);
|
1017
|
-
}
|
1018
|
-
ComputeGraph get_compute_graph(const char *name) const {
|
1019
|
-
TiComputeGraph compute_graph_ =
|
1020
|
-
ti_get_aot_module_compute_graph(aot_module_, name);
|
1021
|
-
return ComputeGraph(runtime_, compute_graph_);
|
1022
|
-
}
|
1023
|
-
|
1024
|
-
constexpr TiAotModule aot_module() const {
|
1025
|
-
return aot_module_;
|
1026
|
-
}
|
1027
|
-
constexpr operator TiAotModule() const { // NOLINT
|
1028
|
-
return aot_module_;
|
1029
|
-
}
|
1030
|
-
};
|
1031
|
-
|
1032
|
-
class CapabilityLevelConfigBuilder;
|
1033
|
-
class CapabilityLevelConfig {
|
1034
|
-
public:
|
1035
|
-
std::vector<TiCapabilityLevelInfo> cap_level_infos;
|
1036
|
-
|
1037
|
-
CapabilityLevelConfig() {
|
1038
|
-
}
|
1039
|
-
explicit CapabilityLevelConfig(
|
1040
|
-
std::vector<TiCapabilityLevelInfo> &&capabilities)
|
1041
|
-
: cap_level_infos(std::move(capabilities)) {
|
1042
|
-
}
|
1043
|
-
|
1044
|
-
static CapabilityLevelConfigBuilder builder();
|
1045
|
-
|
1046
|
-
uint32_t get(TiCapability capability) const {
|
1047
|
-
for (size_t i = 0; i < cap_level_infos.size(); ++i) {
|
1048
|
-
const TiCapabilityLevelInfo &cap_level_info = cap_level_infos.at(i);
|
1049
|
-
if (cap_level_info.capability == capability) {
|
1050
|
-
return cap_level_info.level;
|
1051
|
-
}
|
1052
|
-
}
|
1053
|
-
return 0;
|
1054
|
-
}
|
1055
|
-
|
1056
|
-
void set(TiCapability capability, uint32_t level) {
|
1057
|
-
std::vector<TiCapabilityLevelInfo>::iterator it = cap_level_infos.begin();
|
1058
|
-
for (; it != cap_level_infos.end(); ++it) {
|
1059
|
-
if (it->capability == capability) {
|
1060
|
-
it->level = level;
|
1061
|
-
return;
|
1062
|
-
}
|
1063
|
-
}
|
1064
|
-
TiCapabilityLevelInfo cap_level_info{};
|
1065
|
-
cap_level_info.capability = capability;
|
1066
|
-
cap_level_info.level = level;
|
1067
|
-
cap_level_infos.emplace_back(std::move(cap_level_info));
|
1068
|
-
}
|
1069
|
-
};
|
1070
|
-
|
1071
|
-
class CapabilityLevelConfigBuilder {
|
1072
|
-
typedef CapabilityLevelConfigBuilder Self;
|
1073
|
-
std::map<TiCapability, uint32_t> cap_level_infos_;
|
1074
|
-
|
1075
|
-
public:
|
1076
|
-
CapabilityLevelConfigBuilder() {
|
1077
|
-
}
|
1078
|
-
CapabilityLevelConfigBuilder(const Self &) = delete;
|
1079
|
-
Self &operator=(const Self &) = delete;
|
1080
|
-
|
1081
|
-
Self &spirv_version(uint32_t major, uint32_t minor) {
|
1082
|
-
if (major == 1) {
|
1083
|
-
if (minor == 3) {
|
1084
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_VERSION] = 0x10300;
|
1085
|
-
} else if (minor == 4) {
|
1086
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_VERSION] = 0x10400;
|
1087
|
-
} else if (minor == 5) {
|
1088
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_VERSION] = 0x10500;
|
1089
|
-
} else {
|
1090
|
-
ti_set_last_error(TI_ERROR_ARGUMENT_OUT_OF_RANGE, "minor");
|
1091
|
-
}
|
1092
|
-
} else {
|
1093
|
-
ti_set_last_error(TI_ERROR_ARGUMENT_OUT_OF_RANGE, "major");
|
1094
|
-
}
|
1095
|
-
return *this;
|
1096
|
-
}
|
1097
|
-
Self &spirv_has_int8(bool value = true) {
|
1098
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_INT8] = value ? TI_TRUE : TI_FALSE;
|
1099
|
-
return *this;
|
1100
|
-
}
|
1101
|
-
Self &spirv_has_int16(bool value = true) {
|
1102
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_INT16] =
|
1103
|
-
value ? TI_TRUE : TI_FALSE;
|
1104
|
-
return *this;
|
1105
|
-
}
|
1106
|
-
Self &spirv_has_int64(bool value = true) {
|
1107
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_INT64] =
|
1108
|
-
value ? TI_TRUE : TI_FALSE;
|
1109
|
-
return *this;
|
1110
|
-
}
|
1111
|
-
Self &spirv_has_float16(bool value = true) {
|
1112
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_FLOAT16] =
|
1113
|
-
value ? TI_TRUE : TI_FALSE;
|
1114
|
-
return *this;
|
1115
|
-
}
|
1116
|
-
Self &spirv_has_float64(bool value = true) {
|
1117
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_FLOAT64] =
|
1118
|
-
value ? TI_TRUE : TI_FALSE;
|
1119
|
-
return *this;
|
1120
|
-
}
|
1121
|
-
Self &spirv_has_atomic_int64(bool value = true) {
|
1122
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_ATOMIC_INT64] =
|
1123
|
-
value ? TI_TRUE : TI_FALSE;
|
1124
|
-
return *this;
|
1125
|
-
}
|
1126
|
-
Self &spirv_has_atomic_float16(bool value = true) {
|
1127
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT16] =
|
1128
|
-
value ? TI_TRUE : TI_FALSE;
|
1129
|
-
return *this;
|
1130
|
-
}
|
1131
|
-
Self &spirv_has_atomic_float16_add(bool value = true) {
|
1132
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT16_ADD] =
|
1133
|
-
value ? TI_TRUE : TI_FALSE;
|
1134
|
-
return *this;
|
1135
|
-
}
|
1136
|
-
Self &spirv_has_atomic_float16_minmax(bool value = true) {
|
1137
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT16_MINMAX] =
|
1138
|
-
value ? TI_TRUE : TI_FALSE;
|
1139
|
-
return *this;
|
1140
|
-
}
|
1141
|
-
Self &spirv_has_atomic_float64(bool value = true) {
|
1142
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64] =
|
1143
|
-
value ? TI_TRUE : TI_FALSE;
|
1144
|
-
return *this;
|
1145
|
-
}
|
1146
|
-
Self &spirv_has_atomic_float64_add(bool value = true) {
|
1147
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64_ADD] =
|
1148
|
-
value ? TI_TRUE : TI_FALSE;
|
1149
|
-
return *this;
|
1150
|
-
}
|
1151
|
-
Self &spirv_has_variable_ptr(bool value = true) {
|
1152
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_VARIABLE_PTR] =
|
1153
|
-
value ? TI_TRUE : TI_FALSE;
|
1154
|
-
return *this;
|
1155
|
-
}
|
1156
|
-
Self &spirv_has_physical_storage_buffer(bool value = true) {
|
1157
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_PHYSICAL_STORAGE_BUFFER] =
|
1158
|
-
value ? TI_TRUE : TI_FALSE;
|
1159
|
-
return *this;
|
1160
|
-
}
|
1161
|
-
Self &spirv_has_subgroup_basic(bool value = true) {
|
1162
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_SUBGROUP_BASIC] =
|
1163
|
-
value ? TI_TRUE : TI_FALSE;
|
1164
|
-
return *this;
|
1165
|
-
}
|
1166
|
-
Self &spirv_has_subgroup_vote(bool value = true) {
|
1167
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_SUBGROUP_VOTE] =
|
1168
|
-
value ? TI_TRUE : TI_FALSE;
|
1169
|
-
return *this;
|
1170
|
-
}
|
1171
|
-
Self &spirv_has_subgroup_arithmetic(bool value = true) {
|
1172
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_SUBGROUP_ARITHMETIC] =
|
1173
|
-
value ? TI_TRUE : TI_FALSE;
|
1174
|
-
return *this;
|
1175
|
-
}
|
1176
|
-
Self &spirv_has_subgroup_ballot(bool value = true) {
|
1177
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_SUBGROUP_BALLOT] =
|
1178
|
-
value ? TI_TRUE : TI_FALSE;
|
1179
|
-
return *this;
|
1180
|
-
}
|
1181
|
-
Self &spirv_has_non_semantic_info(bool value = true) {
|
1182
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_NON_SEMANTIC_INFO] =
|
1183
|
-
value ? TI_TRUE : TI_FALSE;
|
1184
|
-
return *this;
|
1185
|
-
}
|
1186
|
-
Self &spirv_has_no_integer_wrap_decoration(bool value = true) {
|
1187
|
-
cap_level_infos_[TI_CAPABILITY_SPIRV_HAS_NO_INTEGER_WRAP_DECORATION] =
|
1188
|
-
value ? TI_TRUE : TI_FALSE;
|
1189
|
-
return *this;
|
1190
|
-
}
|
1191
|
-
|
1192
|
-
CapabilityLevelConfig build() {
|
1193
|
-
std::vector<TiCapabilityLevelInfo> out{};
|
1194
|
-
for (const auto &pair : cap_level_infos_) {
|
1195
|
-
TiCapabilityLevelInfo cap_level_info{};
|
1196
|
-
cap_level_info.capability = pair.first;
|
1197
|
-
cap_level_info.level = pair.second;
|
1198
|
-
out.emplace_back(std::move(cap_level_info));
|
1199
|
-
}
|
1200
|
-
return CapabilityLevelConfig{std::move(out)};
|
1201
|
-
}
|
1202
|
-
};
|
1203
|
-
|
1204
|
-
inline CapabilityLevelConfigBuilder CapabilityLevelConfig::builder() {
|
1205
|
-
return {};
|
1206
|
-
}
|
1207
|
-
|
1208
|
-
class Runtime {
|
1209
|
-
protected:
|
1210
|
-
TiArch arch_{TI_ARCH_MAX_ENUM};
|
1211
|
-
TiRuntime runtime_{TI_NULL_HANDLE};
|
1212
|
-
bool should_destroy_{false};
|
1213
|
-
|
1214
|
-
public:
|
1215
|
-
constexpr bool is_valid() const {
|
1216
|
-
return runtime_ != nullptr;
|
1217
|
-
}
|
1218
|
-
inline void destroy() {
|
1219
|
-
if (should_destroy_) {
|
1220
|
-
ti_destroy_runtime(runtime_);
|
1221
|
-
runtime_ = TI_NULL_HANDLE;
|
1222
|
-
should_destroy_ = false;
|
1223
|
-
}
|
1224
|
-
}
|
1225
|
-
|
1226
|
-
Runtime() {
|
1227
|
-
}
|
1228
|
-
Runtime(const Runtime &) = delete;
|
1229
|
-
Runtime(Runtime &&b)
|
1230
|
-
: arch_(detail::exchange(b.arch_, TI_ARCH_MAX_ENUM)),
|
1231
|
-
runtime_(detail::move_handle(b.runtime_)),
|
1232
|
-
should_destroy_(detail::exchange(b.should_destroy_, false)) {
|
1233
|
-
}
|
1234
|
-
explicit Runtime(TiArch arch, uint32_t device_index = 0)
|
1235
|
-
: arch_(arch),
|
1236
|
-
runtime_(ti_create_runtime(arch, device_index)),
|
1237
|
-
should_destroy_(true) {
|
1238
|
-
}
|
1239
|
-
Runtime(TiArch arch, TiRuntime runtime, bool should_destroy)
|
1240
|
-
: arch_(arch), runtime_(runtime), should_destroy_(should_destroy) {
|
1241
|
-
}
|
1242
|
-
~Runtime() {
|
1243
|
-
destroy();
|
1244
|
-
}
|
1245
|
-
|
1246
|
-
Runtime &operator=(const Runtime &) = delete;
|
1247
|
-
Runtime &operator=(Runtime &&b) {
|
1248
|
-
arch_ = detail::exchange(b.arch_, TI_ARCH_MAX_ENUM);
|
1249
|
-
runtime_ = detail::move_handle(b.runtime_);
|
1250
|
-
should_destroy_ = detail::exchange(b.should_destroy_, false);
|
1251
|
-
return *this;
|
1252
|
-
}
|
1253
|
-
|
1254
|
-
inline Runtime borrow() const {
|
1255
|
-
return Runtime(arch_, runtime_, false);
|
1256
|
-
}
|
1257
|
-
|
1258
|
-
void set_capabilities_ext(
|
1259
|
-
const std::vector<TiCapabilityLevelInfo> &capabilities) const {
|
1260
|
-
ti_set_runtime_capabilities_ext(runtime_, (uint32_t)capabilities.size(),
|
1261
|
-
capabilities.data());
|
1262
|
-
}
|
1263
|
-
void set_capabilities_ext(const CapabilityLevelConfig &capabilities) const {
|
1264
|
-
set_capabilities_ext(capabilities.cap_level_infos);
|
1265
|
-
}
|
1266
|
-
CapabilityLevelConfig get_capabilities() const {
|
1267
|
-
uint32_t n = 0;
|
1268
|
-
ti_get_runtime_capabilities(runtime_, &n, nullptr);
|
1269
|
-
std::vector<TiCapabilityLevelInfo> devcaps(n);
|
1270
|
-
ti_get_runtime_capabilities(runtime_, &n, devcaps.data());
|
1271
|
-
return CapabilityLevelConfig{std::move(devcaps)};
|
1272
|
-
}
|
1273
|
-
|
1274
|
-
Memory allocate_memory(const TiMemoryAllocateInfo &allocate_info) const {
|
1275
|
-
TiMemory memory = ti_allocate_memory(runtime_, &allocate_info);
|
1276
|
-
return Memory(runtime_, memory, allocate_info.size, true);
|
1277
|
-
}
|
1278
|
-
Memory allocate_memory(size_t size, bool host_access = false) const {
|
1279
|
-
TiMemoryAllocateInfo allocate_info{};
|
1280
|
-
allocate_info.size = size;
|
1281
|
-
allocate_info.host_read = host_access;
|
1282
|
-
allocate_info.host_write = host_access;
|
1283
|
-
allocate_info.usage = TI_MEMORY_USAGE_STORAGE_BIT;
|
1284
|
-
return allocate_memory(allocate_info);
|
1285
|
-
}
|
1286
|
-
template <typename T>
|
1287
|
-
NdArray<T> allocate_ndarray(const std::vector<uint32_t> &shape = {},
|
1288
|
-
const std::vector<uint32_t> &elem_shape = {},
|
1289
|
-
bool host_access = false) const {
|
1290
|
-
auto dtype = detail::templ2dtype<T>::value;
|
1291
|
-
return allocate_ndarray<T>(dtype, shape, elem_shape, host_access);
|
1292
|
-
}
|
1293
|
-
|
1294
|
-
template <typename T>
|
1295
|
-
NdArray<T> allocate_ndarray(TiDataType dtype,
|
1296
|
-
const std::vector<uint32_t> &shape = {},
|
1297
|
-
const std::vector<uint32_t> &elem_shape = {},
|
1298
|
-
bool host_access = false) const {
|
1299
|
-
size_t size = sizeof(T);
|
1300
|
-
TiNdArray ndarray{};
|
1301
|
-
for (size_t i = 0; i < shape.size(); ++i) {
|
1302
|
-
uint32_t x = shape.at(i);
|
1303
|
-
size *= x;
|
1304
|
-
ndarray.shape.dims[i] = x;
|
1305
|
-
}
|
1306
|
-
ndarray.shape.dim_count = shape.size();
|
1307
|
-
for (size_t i = 0; i < elem_shape.size(); ++i) {
|
1308
|
-
uint32_t x = elem_shape.at(i);
|
1309
|
-
size *= x;
|
1310
|
-
ndarray.elem_shape.dims[i] = x;
|
1311
|
-
}
|
1312
|
-
ndarray.elem_shape.dim_count = elem_shape.size();
|
1313
|
-
ndarray.elem_type = dtype;
|
1314
|
-
|
1315
|
-
ti::Memory memory = allocate_memory(size, host_access);
|
1316
|
-
ndarray.memory = memory.memory();
|
1317
|
-
return NdArray<T>(std::move(memory), ndarray);
|
1318
|
-
}
|
1319
|
-
|
1320
|
-
Image allocate_image(const TiImageAllocateInfo &allocate_info) const {
|
1321
|
-
TiImage image = ti_allocate_image(runtime_, &allocate_info);
|
1322
|
-
return Image(runtime_, image, allocate_info.dimension, allocate_info.extent,
|
1323
|
-
allocate_info.mip_level_count, allocate_info.format, true);
|
1324
|
-
}
|
1325
|
-
Texture allocate_texture2d(uint32_t width,
|
1326
|
-
uint32_t height,
|
1327
|
-
TiFormat format,
|
1328
|
-
TiSampler sampler) const {
|
1329
|
-
TiImageExtent extent{};
|
1330
|
-
extent.width = width;
|
1331
|
-
extent.height = height;
|
1332
|
-
extent.depth = 1;
|
1333
|
-
extent.array_layer_count = 1;
|
1334
|
-
|
1335
|
-
TiImageAllocateInfo allocate_info{};
|
1336
|
-
allocate_info.dimension = TI_IMAGE_DIMENSION_2D;
|
1337
|
-
allocate_info.extent = extent;
|
1338
|
-
allocate_info.mip_level_count = 1;
|
1339
|
-
allocate_info.format = format;
|
1340
|
-
allocate_info.usage =
|
1341
|
-
TI_IMAGE_USAGE_STORAGE_BIT | TI_IMAGE_USAGE_SAMPLED_BIT;
|
1342
|
-
|
1343
|
-
Image image = allocate_image(allocate_info);
|
1344
|
-
TiTexture texture{};
|
1345
|
-
texture.image = image.image();
|
1346
|
-
texture.dimension = TI_IMAGE_DIMENSION_2D;
|
1347
|
-
texture.extent = extent;
|
1348
|
-
texture.format = format;
|
1349
|
-
texture.sampler = sampler;
|
1350
|
-
return Texture(std::move(image), texture);
|
1351
|
-
}
|
1352
|
-
|
1353
|
-
AotModule load_aot_module(const char *path) const {
|
1354
|
-
TiAotModule aot_module_ = ti_load_aot_module(runtime_, path);
|
1355
|
-
return AotModule(runtime_, aot_module_, true);
|
1356
|
-
}
|
1357
|
-
AotModule load_aot_module(const std::string &path) const {
|
1358
|
-
return load_aot_module(path.c_str());
|
1359
|
-
}
|
1360
|
-
|
1361
|
-
AotModule create_aot_module(const void *tcm, size_t size) const {
|
1362
|
-
TiAotModule aot_module = ti_create_aot_module(runtime_, tcm, size);
|
1363
|
-
return AotModule(runtime_, aot_module, true);
|
1364
|
-
}
|
1365
|
-
AotModule create_aot_module(const std::vector<uint8_t> &tcm) const {
|
1366
|
-
return create_aot_module(tcm.data(), tcm.size());
|
1367
|
-
}
|
1368
|
-
|
1369
|
-
void copy_memory_device_to_device(const MemorySlice &dst_memory,
|
1370
|
-
const MemorySlice &src_memory) const {
|
1371
|
-
ti_copy_memory_device_to_device(runtime_, &dst_memory.slice(),
|
1372
|
-
&src_memory.slice());
|
1373
|
-
}
|
1374
|
-
void copy_image_device_to_device(const ImageSlice &dst_image,
|
1375
|
-
const ImageSlice &src_image) const {
|
1376
|
-
ti_copy_image_device_to_device(runtime_, &dst_image.slice(),
|
1377
|
-
&src_image.slice());
|
1378
|
-
}
|
1379
|
-
void transition_image(TiImage image, TiImageLayout layout) const {
|
1380
|
-
ti_transition_image(runtime_, image, layout);
|
1381
|
-
}
|
1382
|
-
|
1383
|
-
void flush() const {
|
1384
|
-
ti_flush(runtime_);
|
1385
|
-
}
|
1386
|
-
void wait() const {
|
1387
|
-
ti_wait(runtime_);
|
1388
|
-
}
|
1389
|
-
|
1390
|
-
constexpr TiArch arch() const {
|
1391
|
-
return arch_;
|
1392
|
-
}
|
1393
|
-
constexpr TiRuntime runtime() const {
|
1394
|
-
return runtime_;
|
1395
|
-
}
|
1396
|
-
constexpr operator TiRuntime() const { // NOLINT
|
1397
|
-
return runtime_;
|
1398
|
-
}
|
1399
|
-
};
|
1400
|
-
|
1401
|
-
} // namespace ti
|