gstaichi 0.1.18.dev1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/bin/SPIRV-Tools-shared.dll +0 -0
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-diff.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-link.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-lint.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-opt.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-shared.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-0.1.18.dev1.data/data/lib/glfw3.lib +0 -0
- gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
- gstaichi-0.1.18.dev1.dist-info/RECORD +198 -0
- gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
- gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
- taichi/CHANGELOG.md +15 -0
- taichi/__init__.py +44 -0
- taichi/__main__.py +5 -0
- taichi/_funcs.py +706 -0
- taichi/_kernels.py +420 -0
- taichi/_lib/__init__.py +3 -0
- taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
- taichi/_lib/c_api/include/taichi/taichi.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_cuda.h +36 -0
- taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
- taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
- taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
- taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
- taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
- taichi/_lib/c_api/runtime/slim_libdevice.10.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
- taichi/_lib/core/__init__.py +0 -0
- taichi/_lib/core/py.typed +0 -0
- taichi/_lib/core/taichi_python.cp310-win_amd64.pyd +0 -0
- taichi/_lib/core/taichi_python.pyi +3077 -0
- taichi/_lib/runtime/runtime_cuda.bc +0 -0
- taichi/_lib/runtime/runtime_x64.bc +0 -0
- taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- taichi/_lib/utils.py +249 -0
- taichi/_logging.py +131 -0
- taichi/_main.py +552 -0
- taichi/_snode/__init__.py +5 -0
- taichi/_snode/fields_builder.py +189 -0
- taichi/_snode/snode_tree.py +34 -0
- taichi/_ti_module/__init__.py +3 -0
- taichi/_ti_module/cppgen.py +309 -0
- taichi/_ti_module/module.py +145 -0
- taichi/_version.py +1 -0
- taichi/_version_check.py +100 -0
- taichi/ad/__init__.py +3 -0
- taichi/ad/_ad.py +530 -0
- taichi/algorithms/__init__.py +3 -0
- taichi/algorithms/_algorithms.py +117 -0
- taichi/aot/__init__.py +12 -0
- taichi/aot/_export.py +28 -0
- taichi/aot/conventions/__init__.py +3 -0
- taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
- taichi/aot/conventions/gfxruntime140/dr.py +244 -0
- taichi/aot/conventions/gfxruntime140/sr.py +613 -0
- taichi/aot/module.py +253 -0
- taichi/aot/utils.py +151 -0
- taichi/assets/.git +1 -0
- taichi/assets/Go-Regular.ttf +0 -0
- taichi/assets/static/imgs/ti_gallery.png +0 -0
- taichi/examples/minimal.py +28 -0
- taichi/experimental.py +16 -0
- taichi/graph/__init__.py +3 -0
- taichi/graph/_graph.py +292 -0
- taichi/lang/__init__.py +50 -0
- taichi/lang/_ndarray.py +348 -0
- taichi/lang/_ndrange.py +152 -0
- taichi/lang/_texture.py +172 -0
- taichi/lang/_wrap_inspect.py +189 -0
- taichi/lang/any_array.py +99 -0
- taichi/lang/argpack.py +411 -0
- taichi/lang/ast/__init__.py +5 -0
- taichi/lang/ast/ast_transformer.py +1806 -0
- taichi/lang/ast/ast_transformer_utils.py +328 -0
- taichi/lang/ast/checkers.py +106 -0
- taichi/lang/ast/symbol_resolver.py +57 -0
- taichi/lang/ast/transform.py +9 -0
- taichi/lang/common_ops.py +310 -0
- taichi/lang/exception.py +80 -0
- taichi/lang/expr.py +180 -0
- taichi/lang/field.py +464 -0
- taichi/lang/impl.py +1246 -0
- taichi/lang/kernel_arguments.py +157 -0
- taichi/lang/kernel_impl.py +1415 -0
- taichi/lang/matrix.py +1877 -0
- taichi/lang/matrix_ops.py +341 -0
- taichi/lang/matrix_ops_utils.py +190 -0
- taichi/lang/mesh.py +687 -0
- taichi/lang/misc.py +807 -0
- taichi/lang/ops.py +1489 -0
- taichi/lang/runtime_ops.py +13 -0
- taichi/lang/shell.py +35 -0
- taichi/lang/simt/__init__.py +5 -0
- taichi/lang/simt/block.py +94 -0
- taichi/lang/simt/grid.py +7 -0
- taichi/lang/simt/subgroup.py +191 -0
- taichi/lang/simt/warp.py +96 -0
- taichi/lang/snode.py +487 -0
- taichi/lang/source_builder.py +150 -0
- taichi/lang/struct.py +855 -0
- taichi/lang/util.py +381 -0
- taichi/linalg/__init__.py +8 -0
- taichi/linalg/matrixfree_cg.py +310 -0
- taichi/linalg/sparse_cg.py +59 -0
- taichi/linalg/sparse_matrix.py +303 -0
- taichi/linalg/sparse_solver.py +123 -0
- taichi/math/__init__.py +11 -0
- taichi/math/_complex.py +204 -0
- taichi/math/mathimpl.py +886 -0
- taichi/profiler/__init__.py +6 -0
- taichi/profiler/kernel_metrics.py +260 -0
- taichi/profiler/kernel_profiler.py +592 -0
- taichi/profiler/memory_profiler.py +15 -0
- taichi/profiler/scoped_profiler.py +36 -0
- taichi/shaders/Circles_vk.frag +29 -0
- taichi/shaders/Circles_vk.vert +45 -0
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +9 -0
- taichi/shaders/Lines_vk.vert +11 -0
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +71 -0
- taichi/shaders/Mesh_vk.vert +68 -0
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +95 -0
- taichi/shaders/Particles_vk.vert +73 -0
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +9 -0
- taichi/shaders/SceneLines_vk.vert +12 -0
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +21 -0
- taichi/shaders/SetImage_vk.vert +15 -0
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +16 -0
- taichi/shaders/Triangles_vk.vert +29 -0
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/sparse/__init__.py +3 -0
- taichi/sparse/_sparse_grid.py +77 -0
- taichi/tools/__init__.py +12 -0
- taichi/tools/diagnose.py +124 -0
- taichi/tools/np2ply.py +364 -0
- taichi/tools/vtk.py +38 -0
- taichi/types/__init__.py +19 -0
- taichi/types/annotations.py +47 -0
- taichi/types/compound_types.py +90 -0
- taichi/types/enums.py +49 -0
- taichi/types/ndarray_type.py +147 -0
- taichi/types/primitive_types.py +203 -0
- taichi/types/quant.py +88 -0
- taichi/types/texture_type.py +85 -0
- taichi/types/utils.py +13 -0
taichi/graph/_graph.py
ADDED
@@ -0,0 +1,292 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import warnings
|
4
|
+
from typing import Any, Dict, List
|
5
|
+
|
6
|
+
from taichi._lib import core as _ti_core
|
7
|
+
from taichi.aot.utils import produce_injected_args
|
8
|
+
from taichi.lang import impl, kernel_impl
|
9
|
+
from taichi.lang._ndarray import Ndarray
|
10
|
+
from taichi.lang._texture import Texture
|
11
|
+
from taichi.lang.exception import TaichiRuntimeError
|
12
|
+
from taichi.lang.matrix import Matrix, MatrixType
|
13
|
+
from taichi.types import enums
|
14
|
+
from taichi.types.texture_type import FORMAT2TY_CH, TY_CH2FORMAT
|
15
|
+
|
16
|
+
ArgKind = _ti_core.ArgKind
|
17
|
+
|
18
|
+
|
19
|
+
def gen_cpp_kernel(kernel_fn, args):
|
20
|
+
kernel = kernel_fn._primal
|
21
|
+
assert isinstance(kernel, kernel_impl.Kernel)
|
22
|
+
injected_args = produce_injected_args(kernel, symbolic_args=args)
|
23
|
+
key = kernel.ensure_compiled(*injected_args)
|
24
|
+
return kernel.compiled_kernels[key]
|
25
|
+
|
26
|
+
|
27
|
+
def flatten_args(args):
|
28
|
+
unzipped_args = []
|
29
|
+
# Tuple for matrix args
|
30
|
+
# FIXME remove this when native Matrix type is ready
|
31
|
+
for arg in args:
|
32
|
+
if isinstance(arg, list):
|
33
|
+
for sublist in arg:
|
34
|
+
unzipped_args.extend(sublist)
|
35
|
+
else:
|
36
|
+
unzipped_args.append(arg)
|
37
|
+
return unzipped_args
|
38
|
+
|
39
|
+
|
40
|
+
class Sequential:
|
41
|
+
def __init__(self, seq):
|
42
|
+
self.seq_ = seq
|
43
|
+
|
44
|
+
def dispatch(self, kernel_fn, *args):
|
45
|
+
kernel_cpp = gen_cpp_kernel(kernel_fn, args)
|
46
|
+
unzipped_args = flatten_args(args)
|
47
|
+
self.seq_.dispatch(kernel_cpp, unzipped_args)
|
48
|
+
|
49
|
+
|
50
|
+
class GraphBuilder:
|
51
|
+
def __init__(self):
|
52
|
+
self._graph_builder = _ti_core.GraphBuilder()
|
53
|
+
|
54
|
+
def dispatch(self, kernel_fn, *args):
|
55
|
+
kernel_cpp = gen_cpp_kernel(kernel_fn, args)
|
56
|
+
unzipped_args = flatten_args(args)
|
57
|
+
self._graph_builder.dispatch(kernel_cpp, unzipped_args)
|
58
|
+
|
59
|
+
def create_sequential(self):
|
60
|
+
return Sequential(self._graph_builder.create_sequential())
|
61
|
+
|
62
|
+
def append(self, node):
|
63
|
+
# TODO: support appending dispatch node as well.
|
64
|
+
assert isinstance(node, Sequential)
|
65
|
+
self._graph_builder.seq().append(node.seq_)
|
66
|
+
|
67
|
+
def compile(self):
|
68
|
+
return Graph(self._graph_builder.compile())
|
69
|
+
|
70
|
+
|
71
|
+
class Graph:
|
72
|
+
def __init__(self, compiled_graph) -> None:
|
73
|
+
self._compiled_graph = compiled_graph
|
74
|
+
|
75
|
+
def run(self, args):
|
76
|
+
# Support native python numerical types (int, float), Ndarray.
|
77
|
+
# Taichi Matrix types are flattened into (int, float) arrays.
|
78
|
+
# TODO diminish the flatten behavior when Matrix becomes a Taichi native type.
|
79
|
+
flattened = {}
|
80
|
+
for k, v in args.items():
|
81
|
+
if isinstance(v, Ndarray):
|
82
|
+
flattened[k] = v.arr
|
83
|
+
elif isinstance(v, Texture):
|
84
|
+
flattened[k] = v.tex
|
85
|
+
elif isinstance(v, Matrix):
|
86
|
+
flattened[k] = v.entries
|
87
|
+
elif isinstance(v, (int, float)):
|
88
|
+
flattened[k] = v
|
89
|
+
else:
|
90
|
+
raise TaichiRuntimeError(
|
91
|
+
f"Only python int, float, ti.Matrix and ti.Ndarray are supported as runtime arguments but got {type(v)}"
|
92
|
+
)
|
93
|
+
self._compiled_graph.jit_run(impl.get_runtime().prog.config(), flattened)
|
94
|
+
|
95
|
+
|
96
|
+
def _deprecate_arg_args(kwargs: Dict[str, Any]):
|
97
|
+
if "field_dim" in kwargs:
|
98
|
+
warnings.warn(
|
99
|
+
"The field_dim argument for ndarray will be deprecated in v1.6.0, use ndim instead.",
|
100
|
+
DeprecationWarning,
|
101
|
+
)
|
102
|
+
if "ndim" in kwargs:
|
103
|
+
raise TaichiRuntimeError(
|
104
|
+
"field_dim is deprecated, please do not specify field_dim and ndim at the same time."
|
105
|
+
)
|
106
|
+
kwargs["ndim"] = kwargs["field_dim"]
|
107
|
+
del kwargs["field_dim"]
|
108
|
+
tag = kwargs["tag"]
|
109
|
+
|
110
|
+
if tag == ArgKind.SCALAR:
|
111
|
+
if "element_shape" in kwargs:
|
112
|
+
raise TaichiRuntimeError(
|
113
|
+
"The element_shape argument for scalar is deprecated in v1.6.0, and is removed in v1.7.0. "
|
114
|
+
"Please remove them."
|
115
|
+
)
|
116
|
+
|
117
|
+
if tag == ArgKind.NDARRAY:
|
118
|
+
if "element_shape" not in kwargs:
|
119
|
+
if "dtype" in kwargs:
|
120
|
+
dtype = kwargs["dtype"]
|
121
|
+
if isinstance(dtype, MatrixType):
|
122
|
+
kwargs["dtype"] = dtype.dtype
|
123
|
+
kwargs["element_shape"] = dtype.get_shape()
|
124
|
+
else:
|
125
|
+
kwargs["element_shape"] = ()
|
126
|
+
else:
|
127
|
+
raise TaichiRuntimeError(
|
128
|
+
"The element_shape argument for ndarray is deprecated in v1.6.0, and it is removed in v1.7.0. "
|
129
|
+
"Please use vector or matrix data type instead."
|
130
|
+
)
|
131
|
+
|
132
|
+
if tag == ArgKind.RWTEXTURE or tag == ArgKind.TEXTURE:
|
133
|
+
if "dtype" in kwargs:
|
134
|
+
warnings.warn(
|
135
|
+
"The dtype argument for texture will be deprecated in v1.6.0, use format instead.",
|
136
|
+
DeprecationWarning,
|
137
|
+
)
|
138
|
+
del kwargs["dtype"]
|
139
|
+
|
140
|
+
if "shape" in kwargs:
|
141
|
+
raise TaichiRuntimeError(
|
142
|
+
"The shape argument for texture is deprecated in v1.6.0, and it is removed in v1.7.0. "
|
143
|
+
"Please use ndim instead. (Note that you no longer need the exact texture size.)"
|
144
|
+
)
|
145
|
+
|
146
|
+
if "channel_format" in kwargs or "num_channels" in kwargs:
|
147
|
+
if "fmt" in kwargs:
|
148
|
+
raise TaichiRuntimeError(
|
149
|
+
"channel_format and num_channels are deprecated, please do not specify channel_format/num_channels and fmt at the same time."
|
150
|
+
)
|
151
|
+
if tag == ArgKind.RWTEXTURE:
|
152
|
+
fmt = TY_CH2FORMAT[(kwargs["channel_format"], kwargs["num_channels"])]
|
153
|
+
kwargs["fmt"] = fmt
|
154
|
+
raise TaichiRuntimeError(
|
155
|
+
"The channel_format and num_channels arguments for texture are deprecated in v1.6.0, "
|
156
|
+
"and they are removed in v1.7.0. Please use fmt instead."
|
157
|
+
)
|
158
|
+
else:
|
159
|
+
raise TaichiRuntimeError(
|
160
|
+
"The channel_format and num_channels arguments are no longer required for non-RW textures "
|
161
|
+
"since v1.6.0, and they are removed in v1.7.0. Please remove them."
|
162
|
+
)
|
163
|
+
|
164
|
+
|
165
|
+
def _check_args(kwargs: Dict[str, Any], allowed_kwargs: List[str]):
|
166
|
+
for k, v in kwargs.items():
|
167
|
+
if k not in allowed_kwargs:
|
168
|
+
raise TaichiRuntimeError(
|
169
|
+
f"Invalid argument: {k}, you can only create a graph argument with: {allowed_kwargs}"
|
170
|
+
)
|
171
|
+
if k == "tag":
|
172
|
+
if not isinstance(v, ArgKind):
|
173
|
+
raise TaichiRuntimeError(f"tag must be a ArgKind variant, but found {type(v)}.")
|
174
|
+
if k == "name":
|
175
|
+
if not isinstance(v, str):
|
176
|
+
raise TaichiRuntimeError(f"name must be a string, but found {type(v)}.")
|
177
|
+
|
178
|
+
|
179
|
+
def _make_arg_scalar(kwargs: Dict[str, Any]):
|
180
|
+
allowed_kwargs = [
|
181
|
+
"tag",
|
182
|
+
"name",
|
183
|
+
"dtype",
|
184
|
+
]
|
185
|
+
_check_args(kwargs, allowed_kwargs)
|
186
|
+
name = kwargs["name"]
|
187
|
+
dtype = kwargs["dtype"]
|
188
|
+
if isinstance(dtype, MatrixType):
|
189
|
+
raise TaichiRuntimeError(f"Tag ArgKind.SCALAR must specify a scalar type, but found {type(dtype)}.")
|
190
|
+
return _ti_core.Arg(ArgKind.SCALAR, name, dtype, 0, [])
|
191
|
+
|
192
|
+
|
193
|
+
def _make_arg_ndarray(kwargs: Dict[str, Any]):
|
194
|
+
allowed_kwargs = [
|
195
|
+
"tag",
|
196
|
+
"name",
|
197
|
+
"dtype",
|
198
|
+
"ndim",
|
199
|
+
"element_shape",
|
200
|
+
]
|
201
|
+
_check_args(kwargs, allowed_kwargs)
|
202
|
+
name = kwargs["name"]
|
203
|
+
ndim = kwargs["ndim"]
|
204
|
+
dtype = kwargs["dtype"]
|
205
|
+
element_shape = kwargs["element_shape"]
|
206
|
+
if isinstance(dtype, MatrixType):
|
207
|
+
raise TaichiRuntimeError(f"Tag ArgKind.NDARRAY must specify a scalar type, but found {dtype}.")
|
208
|
+
return _ti_core.Arg(ArgKind.NDARRAY, name, dtype, ndim, element_shape)
|
209
|
+
|
210
|
+
|
211
|
+
def _make_arg_matrix(kwargs: Dict[str, Any]):
|
212
|
+
allowed_kwargs = [
|
213
|
+
"tag",
|
214
|
+
"name",
|
215
|
+
"dtype",
|
216
|
+
]
|
217
|
+
_check_args(kwargs, allowed_kwargs)
|
218
|
+
name = kwargs["name"]
|
219
|
+
dtype = kwargs["dtype"]
|
220
|
+
if not isinstance(dtype, MatrixType):
|
221
|
+
raise TaichiRuntimeError(f"Tag ArgKind.MATRIX must specify matrix type, but got {dtype}.")
|
222
|
+
return _ti_core.Arg(ArgKind.MATRIX, f"{name}", dtype.dtype, 0, [dtype.n, dtype.m])
|
223
|
+
|
224
|
+
|
225
|
+
def _make_arg_texture(kwargs: Dict[str, Any]):
|
226
|
+
allowed_kwargs = [
|
227
|
+
"tag",
|
228
|
+
"name",
|
229
|
+
"ndim",
|
230
|
+
]
|
231
|
+
_check_args(kwargs, allowed_kwargs)
|
232
|
+
name = kwargs["name"]
|
233
|
+
ndim = kwargs["ndim"]
|
234
|
+
return _ti_core.Arg(ArgKind.TEXTURE, name, impl.f32, 4, [2] * ndim)
|
235
|
+
|
236
|
+
|
237
|
+
def _make_arg_rwtexture(kwargs: Dict[str, Any]):
|
238
|
+
allowed_kwargs = [
|
239
|
+
"tag",
|
240
|
+
"name",
|
241
|
+
"ndim",
|
242
|
+
"fmt",
|
243
|
+
]
|
244
|
+
_check_args(kwargs, allowed_kwargs)
|
245
|
+
name = kwargs["name"]
|
246
|
+
ndim = kwargs["ndim"]
|
247
|
+
fmt = kwargs["fmt"]
|
248
|
+
if fmt == enums.Format.unknown:
|
249
|
+
raise TaichiRuntimeError(f"Tag ArgKind.RWTEXTURE must specify a valid color format, but found {fmt}.")
|
250
|
+
channel_format, num_channels = FORMAT2TY_CH[fmt]
|
251
|
+
return _ti_core.Arg(ArgKind.RWTEXTURE, name, channel_format, num_channels, [2] * ndim)
|
252
|
+
|
253
|
+
|
254
|
+
def _make_arg(kwargs: Dict[str, Any]):
|
255
|
+
assert "tag" in kwargs
|
256
|
+
_deprecate_arg_args(kwargs)
|
257
|
+
proc = {
|
258
|
+
ArgKind.SCALAR: _make_arg_scalar,
|
259
|
+
ArgKind.NDARRAY: _make_arg_ndarray,
|
260
|
+
ArgKind.MATRIX: _make_arg_matrix,
|
261
|
+
ArgKind.TEXTURE: _make_arg_texture,
|
262
|
+
ArgKind.RWTEXTURE: _make_arg_rwtexture,
|
263
|
+
}
|
264
|
+
tag = kwargs["tag"]
|
265
|
+
return proc[tag](kwargs)
|
266
|
+
|
267
|
+
|
268
|
+
def _kwarg_rewriter(args, kwargs):
|
269
|
+
for i, arg in enumerate(args):
|
270
|
+
rewrite_map = {
|
271
|
+
0: "tag",
|
272
|
+
1: "name",
|
273
|
+
2: "dtype",
|
274
|
+
3: "ndim",
|
275
|
+
4: "field_dim",
|
276
|
+
5: "element_shape",
|
277
|
+
6: "channel_format",
|
278
|
+
7: "shape",
|
279
|
+
8: "num_channels",
|
280
|
+
}
|
281
|
+
if i in rewrite_map:
|
282
|
+
kwargs[rewrite_map[i]] = arg
|
283
|
+
else:
|
284
|
+
raise TaichiRuntimeError(f"Unexpected {i}th positional argument")
|
285
|
+
|
286
|
+
|
287
|
+
def Arg(*args, **kwargs):
|
288
|
+
_kwarg_rewriter(args, kwargs)
|
289
|
+
return _make_arg(kwargs)
|
290
|
+
|
291
|
+
|
292
|
+
__all__ = ["GraphBuilder", "Graph", "Arg", "ArgKind"]
|
taichi/lang/__init__.py
ADDED
@@ -0,0 +1,50 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from taichi.lang import impl, simt
|
4
|
+
from taichi.lang._ndarray import *
|
5
|
+
from taichi.lang._ndrange import ndrange
|
6
|
+
from taichi.lang._texture import Texture
|
7
|
+
from taichi.lang.argpack import *
|
8
|
+
from taichi.lang.exception import *
|
9
|
+
from taichi.lang.field import *
|
10
|
+
from taichi.lang.impl import *
|
11
|
+
from taichi.lang.kernel_impl import *
|
12
|
+
from taichi.lang.matrix import *
|
13
|
+
from taichi.lang.mesh import *
|
14
|
+
from taichi.lang.misc import * # pylint: disable=W0622
|
15
|
+
from taichi.lang.ops import * # pylint: disable=W0622
|
16
|
+
from taichi.lang.runtime_ops import *
|
17
|
+
from taichi.lang.snode import *
|
18
|
+
from taichi.lang.source_builder import *
|
19
|
+
from taichi.lang.struct import *
|
20
|
+
from taichi.types.enums import DeviceCapability, Format, Layout
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
s
|
24
|
+
for s in dir()
|
25
|
+
if not s.startswith("_")
|
26
|
+
and s
|
27
|
+
not in [
|
28
|
+
"any_array",
|
29
|
+
"ast",
|
30
|
+
"common_ops",
|
31
|
+
"enums",
|
32
|
+
"exception",
|
33
|
+
"expr",
|
34
|
+
"impl",
|
35
|
+
"inspect",
|
36
|
+
"kernel_arguments",
|
37
|
+
"kernel_impl",
|
38
|
+
"matrix",
|
39
|
+
"mesh",
|
40
|
+
"misc",
|
41
|
+
"ops",
|
42
|
+
"platform",
|
43
|
+
"runtime_ops",
|
44
|
+
"shell",
|
45
|
+
"snode",
|
46
|
+
"source_builder",
|
47
|
+
"struct",
|
48
|
+
"util",
|
49
|
+
]
|
50
|
+
]
|
taichi/lang/_ndarray.py
ADDED
@@ -0,0 +1,348 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from taichi._lib import core as _ti_core
|
6
|
+
from taichi.lang import impl
|
7
|
+
from taichi.lang.exception import TaichiIndexError
|
8
|
+
from taichi.lang.util import cook_dtype, get_traceback, python_scope, to_numpy_type
|
9
|
+
from taichi.types import primitive_types
|
10
|
+
from taichi.types.enums import Layout
|
11
|
+
from taichi.types.ndarray_type import NdarrayTypeMetadata
|
12
|
+
from taichi.types.utils import is_real, is_signed
|
13
|
+
|
14
|
+
|
15
|
+
class Ndarray:
|
16
|
+
"""Taichi ndarray class.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
dtype (DataType): Data type of each value.
|
20
|
+
shape (Tuple[int]): Shape of the Ndarray.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(self):
|
24
|
+
self.host_accessor = None
|
25
|
+
self.shape = None
|
26
|
+
self.element_type = None
|
27
|
+
self.dtype = None
|
28
|
+
self.arr = None
|
29
|
+
self.layout = Layout.AOS
|
30
|
+
self.grad = None
|
31
|
+
|
32
|
+
def get_type(self):
|
33
|
+
return NdarrayTypeMetadata(self.element_type, self.shape, self.grad is not None)
|
34
|
+
|
35
|
+
@property
|
36
|
+
def element_shape(self):
|
37
|
+
"""Gets ndarray element shape.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
Tuple[Int]: Ndarray element shape.
|
41
|
+
"""
|
42
|
+
raise NotImplementedError()
|
43
|
+
|
44
|
+
@python_scope
|
45
|
+
def __setitem__(self, key, value):
|
46
|
+
"""Sets ndarray element in Python scope.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
key (Union[List[int], int, None]): Coordinates of the ndarray element.
|
50
|
+
value (element type): Value to set.
|
51
|
+
"""
|
52
|
+
raise NotImplementedError()
|
53
|
+
|
54
|
+
@python_scope
|
55
|
+
def __getitem__(self, key):
|
56
|
+
"""Gets ndarray element in Python scope.
|
57
|
+
|
58
|
+
Args:
|
59
|
+
key (Union[List[int], int, None]): Coordinates of the ndarray element.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
element type: Value retrieved.
|
63
|
+
"""
|
64
|
+
raise NotImplementedError()
|
65
|
+
|
66
|
+
@python_scope
|
67
|
+
def fill(self, val):
|
68
|
+
"""Fills ndarray with a specific scalar value.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
val (Union[int, float]): Value to fill.
|
72
|
+
"""
|
73
|
+
if impl.current_cfg().arch != _ti_core.Arch.cuda and impl.current_cfg().arch != _ti_core.Arch.x64:
|
74
|
+
self._fill_by_kernel(val)
|
75
|
+
elif _ti_core.is_tensor(self.element_type):
|
76
|
+
self._fill_by_kernel(val)
|
77
|
+
elif self.dtype == primitive_types.f32:
|
78
|
+
impl.get_runtime().prog.fill_float(self.arr, val)
|
79
|
+
elif self.dtype == primitive_types.i32:
|
80
|
+
impl.get_runtime().prog.fill_int(self.arr, val)
|
81
|
+
elif self.dtype == primitive_types.u32:
|
82
|
+
impl.get_runtime().prog.fill_uint(self.arr, val)
|
83
|
+
else:
|
84
|
+
self._fill_by_kernel(val)
|
85
|
+
|
86
|
+
@python_scope
|
87
|
+
def _ndarray_to_numpy(self):
|
88
|
+
"""Converts ndarray to a numpy array.
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
numpy.ndarray: The result numpy array.
|
92
|
+
"""
|
93
|
+
arr = np.zeros(shape=self.arr.total_shape(), dtype=to_numpy_type(self.dtype))
|
94
|
+
from taichi._kernels import ndarray_to_ext_arr # pylint: disable=C0415
|
95
|
+
|
96
|
+
ndarray_to_ext_arr(self, arr)
|
97
|
+
impl.get_runtime().sync()
|
98
|
+
return arr
|
99
|
+
|
100
|
+
@python_scope
|
101
|
+
def _ndarray_matrix_to_numpy(self, as_vector):
|
102
|
+
"""Converts matrix ndarray to a numpy array.
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
numpy.ndarray: The result numpy array.
|
106
|
+
"""
|
107
|
+
arr = np.zeros(shape=self.arr.total_shape(), dtype=to_numpy_type(self.dtype))
|
108
|
+
from taichi._kernels import ndarray_matrix_to_ext_arr # pylint: disable=C0415
|
109
|
+
|
110
|
+
layout_is_aos = 1
|
111
|
+
ndarray_matrix_to_ext_arr(self, arr, layout_is_aos, as_vector)
|
112
|
+
impl.get_runtime().sync()
|
113
|
+
return arr
|
114
|
+
|
115
|
+
@python_scope
|
116
|
+
def _ndarray_from_numpy(self, arr):
|
117
|
+
"""Loads all values from a numpy array.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
arr (numpy.ndarray): The source numpy array.
|
121
|
+
"""
|
122
|
+
if not isinstance(arr, np.ndarray):
|
123
|
+
raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
|
124
|
+
if tuple(self.arr.total_shape()) != tuple(arr.shape):
|
125
|
+
raise ValueError(f"Mismatch shape: {tuple(self.arr.shape)} expected, but {tuple(arr.shape)} provided")
|
126
|
+
if not arr.flags.c_contiguous:
|
127
|
+
arr = np.ascontiguousarray(arr)
|
128
|
+
|
129
|
+
from taichi._kernels import ext_arr_to_ndarray # pylint: disable=C0415
|
130
|
+
|
131
|
+
ext_arr_to_ndarray(arr, self)
|
132
|
+
impl.get_runtime().sync()
|
133
|
+
|
134
|
+
@python_scope
|
135
|
+
def _ndarray_matrix_from_numpy(self, arr, as_vector):
|
136
|
+
"""Loads all values from a numpy array.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
arr (numpy.ndarray): The source numpy array.
|
140
|
+
"""
|
141
|
+
if not isinstance(arr, np.ndarray):
|
142
|
+
raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
|
143
|
+
if tuple(self.arr.total_shape()) != tuple(arr.shape):
|
144
|
+
raise ValueError(
|
145
|
+
f"Mismatch shape: {tuple(self.arr.total_shape())} expected, but {tuple(arr.shape)} provided"
|
146
|
+
)
|
147
|
+
if not arr.flags.c_contiguous:
|
148
|
+
arr = np.ascontiguousarray(arr)
|
149
|
+
|
150
|
+
from taichi._kernels import ext_arr_to_ndarray_matrix # pylint: disable=C0415
|
151
|
+
|
152
|
+
layout_is_aos = 1
|
153
|
+
ext_arr_to_ndarray_matrix(arr, self, layout_is_aos, as_vector)
|
154
|
+
impl.get_runtime().sync()
|
155
|
+
|
156
|
+
@python_scope
|
157
|
+
def _get_element_size(self):
|
158
|
+
"""Returns the size of one element in bytes.
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
Size in bytes.
|
162
|
+
"""
|
163
|
+
return self.arr.element_size()
|
164
|
+
|
165
|
+
@python_scope
|
166
|
+
def _get_nelement(self):
|
167
|
+
"""Returns the total number of elements.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
Total number of elements.
|
171
|
+
"""
|
172
|
+
return self.arr.nelement()
|
173
|
+
|
174
|
+
@python_scope
|
175
|
+
def copy_from(self, other):
|
176
|
+
"""Copies all elements from another ndarray.
|
177
|
+
|
178
|
+
The shape of the other ndarray needs to be the same as `self`.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
other (Ndarray): The source ndarray.
|
182
|
+
"""
|
183
|
+
assert isinstance(other, Ndarray)
|
184
|
+
assert tuple(self.arr.shape) == tuple(other.arr.shape)
|
185
|
+
from taichi._kernels import ndarray_to_ndarray # pylint: disable=C0415
|
186
|
+
|
187
|
+
ndarray_to_ndarray(self, other)
|
188
|
+
impl.get_runtime().sync()
|
189
|
+
|
190
|
+
def _set_grad(self, grad):
|
191
|
+
"""Sets the gradient ndarray.
|
192
|
+
|
193
|
+
Args:
|
194
|
+
grad (Ndarray): The gradient ndarray.
|
195
|
+
"""
|
196
|
+
self.grad = grad
|
197
|
+
|
198
|
+
def __deepcopy__(self, memo=None):
|
199
|
+
"""Copies all elements to a new ndarray.
|
200
|
+
|
201
|
+
Returns:
|
202
|
+
Ndarray: The result ndarray.
|
203
|
+
"""
|
204
|
+
raise NotImplementedError()
|
205
|
+
|
206
|
+
def _fill_by_kernel(self, val):
|
207
|
+
"""Fills ndarray with a specific scalar value using a ti.kernel.
|
208
|
+
|
209
|
+
Args:
|
210
|
+
val (Union[int, float]): Value to fill.
|
211
|
+
"""
|
212
|
+
raise NotImplementedError()
|
213
|
+
|
214
|
+
@python_scope
|
215
|
+
def _pad_key(self, key):
|
216
|
+
if key is None:
|
217
|
+
key = ()
|
218
|
+
if not isinstance(key, (tuple, list)):
|
219
|
+
key = (key,)
|
220
|
+
if len(key) != len(self.arr.total_shape()):
|
221
|
+
raise TaichiIndexError(f"{len(self.arr.total_shape())}d ndarray indexed with {len(key)}d indices: {key}")
|
222
|
+
return key
|
223
|
+
|
224
|
+
@python_scope
|
225
|
+
def _initialize_host_accessor(self):
|
226
|
+
if self.host_accessor:
|
227
|
+
return
|
228
|
+
impl.get_runtime().materialize()
|
229
|
+
self.host_accessor = NdarrayHostAccessor(self.arr)
|
230
|
+
|
231
|
+
|
232
|
+
class ScalarNdarray(Ndarray):
|
233
|
+
"""Taichi ndarray with scalar elements.
|
234
|
+
|
235
|
+
Args:
|
236
|
+
dtype (DataType): Data type of each value.
|
237
|
+
shape (Tuple[int]): Shape of the ndarray.
|
238
|
+
"""
|
239
|
+
|
240
|
+
def __init__(self, dtype, arr_shape):
|
241
|
+
super().__init__()
|
242
|
+
self.dtype = cook_dtype(dtype)
|
243
|
+
self.arr = impl.get_runtime().prog.create_ndarray(
|
244
|
+
self.dtype, arr_shape, layout=Layout.NULL, zero_fill=True, dbg_info=_ti_core.DebugInfo(get_traceback())
|
245
|
+
)
|
246
|
+
self.shape = tuple(self.arr.shape)
|
247
|
+
self.element_type = dtype
|
248
|
+
|
249
|
+
def __del__(self):
|
250
|
+
if (
|
251
|
+
impl is not None
|
252
|
+
and impl.get_runtime is not None
|
253
|
+
and impl.get_runtime() is not None
|
254
|
+
and impl.get_runtime().prog is not None
|
255
|
+
):
|
256
|
+
impl.get_runtime().prog.delete_ndarray(self.arr)
|
257
|
+
|
258
|
+
@property
|
259
|
+
def element_shape(self):
|
260
|
+
return ()
|
261
|
+
|
262
|
+
@python_scope
|
263
|
+
def __setitem__(self, key, value):
|
264
|
+
self._initialize_host_accessor()
|
265
|
+
self.host_accessor.setter(value, *self._pad_key(key))
|
266
|
+
|
267
|
+
@python_scope
|
268
|
+
def __getitem__(self, key):
|
269
|
+
self._initialize_host_accessor()
|
270
|
+
return self.host_accessor.getter(*self._pad_key(key))
|
271
|
+
|
272
|
+
@python_scope
|
273
|
+
def to_numpy(self):
|
274
|
+
return self._ndarray_to_numpy()
|
275
|
+
|
276
|
+
@python_scope
|
277
|
+
def from_numpy(self, arr):
|
278
|
+
self._ndarray_from_numpy(arr)
|
279
|
+
|
280
|
+
def __deepcopy__(self, memo=None):
|
281
|
+
ret_arr = ScalarNdarray(self.dtype, self.shape)
|
282
|
+
ret_arr.copy_from(self)
|
283
|
+
return ret_arr
|
284
|
+
|
285
|
+
def _fill_by_kernel(self, val):
|
286
|
+
from taichi._kernels import fill_ndarray # pylint: disable=C0415
|
287
|
+
|
288
|
+
fill_ndarray(self, val)
|
289
|
+
|
290
|
+
def __repr__(self):
|
291
|
+
return "<ti.ndarray>"
|
292
|
+
|
293
|
+
|
294
|
+
class NdarrayHostAccessor:
|
295
|
+
def __init__(self, ndarray):
|
296
|
+
dtype = ndarray.element_data_type()
|
297
|
+
if is_real(dtype):
|
298
|
+
|
299
|
+
def getter(*key):
|
300
|
+
return ndarray.read_float(key)
|
301
|
+
|
302
|
+
def setter(value, *key):
|
303
|
+
ndarray.write_float(key, value)
|
304
|
+
|
305
|
+
else:
|
306
|
+
if is_signed(dtype):
|
307
|
+
|
308
|
+
def getter(*key):
|
309
|
+
return ndarray.read_int(key)
|
310
|
+
|
311
|
+
else:
|
312
|
+
|
313
|
+
def getter(*key):
|
314
|
+
return ndarray.read_uint(key)
|
315
|
+
|
316
|
+
def setter(value, *key):
|
317
|
+
ndarray.write_int(key, value)
|
318
|
+
|
319
|
+
self.getter = getter
|
320
|
+
self.setter = setter
|
321
|
+
|
322
|
+
|
323
|
+
class NdarrayHostAccess:
|
324
|
+
"""Class for accessing VectorNdarray/MatrixNdarray in Python scope.
|
325
|
+
Args:
|
326
|
+
arr (Union[VectorNdarray, MatrixNdarray]): See above.
|
327
|
+
indices_first (Tuple[Int]): Indices of first-level access (coordinates in the field).
|
328
|
+
indices_second (Tuple[Int]): Indices of second-level access (indices in the vector/matrix).
|
329
|
+
"""
|
330
|
+
|
331
|
+
def __init__(self, arr, indices_first, indices_second):
|
332
|
+
self.ndarr = arr
|
333
|
+
self.arr = arr.arr
|
334
|
+
self.indices = indices_first + indices_second
|
335
|
+
|
336
|
+
def getter():
|
337
|
+
self.ndarr._initialize_host_accessor()
|
338
|
+
return self.ndarr.host_accessor.getter(*self.ndarr._pad_key(self.indices))
|
339
|
+
|
340
|
+
def setter(value):
|
341
|
+
self.ndarr._initialize_host_accessor()
|
342
|
+
self.ndarr.host_accessor.setter(value, *self.ndarr._pad_key(self.indices))
|
343
|
+
|
344
|
+
self.getter = getter
|
345
|
+
self.setter = setter
|
346
|
+
|
347
|
+
|
348
|
+
__all__ = ["Ndarray", "ScalarNdarray"]
|