gstaichi 0.0.0__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +51 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +5 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2917 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +122 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +83 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +366 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +7 -0
- gstaichi/lang/ast/ast_transformer.py +1351 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +327 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1259 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1386 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +784 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +10 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +21 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-0.0.0.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-0.0.0.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-0.0.0.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-0.0.0.dist-info/METADATA +97 -0
- gstaichi-0.0.0.dist-info/RECORD +178 -0
- gstaichi-0.0.0.dist-info/WHEEL +5 -0
- gstaichi-0.0.0.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,152 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import collections.abc
|
4
|
+
from typing import Iterable
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
from gstaichi.lang import ops
|
9
|
+
from gstaichi.lang.exception import GsTaichiSyntaxError, GsTaichiTypeError
|
10
|
+
from gstaichi.lang.expr import Expr
|
11
|
+
from gstaichi.lang.matrix import Matrix
|
12
|
+
from gstaichi.types.utils import is_integral
|
13
|
+
|
14
|
+
|
15
|
+
class _Ndrange:
|
16
|
+
def __init__(self, *args):
|
17
|
+
args = list(args)
|
18
|
+
for i, arg in enumerate(args):
|
19
|
+
if not isinstance(arg, collections.abc.Sequence):
|
20
|
+
args[i] = (0, arg)
|
21
|
+
if len(args[i]) != 2:
|
22
|
+
raise GsTaichiSyntaxError(
|
23
|
+
"Every argument of ndrange should be a scalar or a tuple/list like (begin, end)"
|
24
|
+
)
|
25
|
+
args[i] = (args[i][0], ops.max(args[i][0], args[i][1]))
|
26
|
+
for arg in args:
|
27
|
+
for bound in arg:
|
28
|
+
if not isinstance(bound, (int, np.integer)) and not (
|
29
|
+
isinstance(bound, Expr) and is_integral(bound.ptr.get_rvalue_type())
|
30
|
+
):
|
31
|
+
raise GsTaichiTypeError(
|
32
|
+
"Every argument of ndrange should be an integer scalar or a tuple/list of (int, int)"
|
33
|
+
)
|
34
|
+
self.bounds = args
|
35
|
+
|
36
|
+
self.dimensions = [None] * len(args)
|
37
|
+
for i, bound in enumerate(self.bounds):
|
38
|
+
self.dimensions[i] = bound[1] - bound[0]
|
39
|
+
|
40
|
+
self.acc_dimensions = self.dimensions.copy()
|
41
|
+
for i in reversed(range(len(self.bounds) - 1)):
|
42
|
+
self.acc_dimensions[i] = self.acc_dimensions[i] * self.acc_dimensions[i + 1]
|
43
|
+
if len(self.acc_dimensions) == 0: # for the empty case, e.g. ti.ndrange()
|
44
|
+
self.acc_dimensions = [1]
|
45
|
+
|
46
|
+
def __iter__(self):
|
47
|
+
def gen(d, prefix):
|
48
|
+
if d == len(self.bounds):
|
49
|
+
yield prefix
|
50
|
+
else:
|
51
|
+
for t in range(self.bounds[d][0], self.bounds[d][1]):
|
52
|
+
yield from gen(d + 1, prefix + (t,))
|
53
|
+
|
54
|
+
yield from gen(0, ())
|
55
|
+
|
56
|
+
def grouped(self):
|
57
|
+
return GroupedNDRange(self)
|
58
|
+
|
59
|
+
|
60
|
+
def ndrange(*args) -> Iterable:
|
61
|
+
"""Return an immutable iterator object for looping over multi-dimensional indices.
|
62
|
+
|
63
|
+
This returned set of multi-dimensional indices is the direct product (in the set-theory sense)
|
64
|
+
of n groups of integers, where n equals the number of arguments in the input list, and looks like
|
65
|
+
|
66
|
+
range(x1, y1) x range(x2, y2) x ... x range(xn, yn)
|
67
|
+
|
68
|
+
The k-th argument corresponds to the k-th `range()` factor in the above product, and each
|
69
|
+
argument must be an integer or a pair of two integers. An integer argument n will be interpreted
|
70
|
+
as `range(0, n)`, and a pair of two integers (start, end) will be interpreted as `range(start, end)`.
|
71
|
+
|
72
|
+
You can loop over these multi-dimensonal indices in different ways, see the examples below.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
entries: (int, tuple): Must be either an integer, or a tuple/list of two integers.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
An immutable iterator object.
|
79
|
+
|
80
|
+
Example::
|
81
|
+
|
82
|
+
You can loop over 1-D integers in range [start, end), as in native Python
|
83
|
+
|
84
|
+
>>> @ti.kernel
|
85
|
+
>>> def loop_1d():
|
86
|
+
>>> start = 2
|
87
|
+
>>> end = 5
|
88
|
+
>>> for i in ti.ndrange((start, end)):
|
89
|
+
>>> print(i) # will print 2 3 4
|
90
|
+
|
91
|
+
Note the braces around `(start, end)` in the above code. If without them,
|
92
|
+
the parameter `2` will be interpreted as `range(0, 2)`, `5` will be
|
93
|
+
interpreted as `range(0, 5)`, and you will get a set of 2-D indices which
|
94
|
+
contains 2x5=10 elements, and need two indices i, j to loop over them:
|
95
|
+
|
96
|
+
>>> @ti.kernel
|
97
|
+
>>> def loop_2d():
|
98
|
+
>>> for i, j in ti.ndrange(2, 5):
|
99
|
+
>>> print(i, j)
|
100
|
+
0 0
|
101
|
+
...
|
102
|
+
0 4
|
103
|
+
...
|
104
|
+
1 4
|
105
|
+
|
106
|
+
But you do can use a single index i to loop over these 2-D indices, in this case
|
107
|
+
the indices are returned as a 1-D array `(0, 1, ..., 9)`:
|
108
|
+
|
109
|
+
>>> @ti.kernel
|
110
|
+
>>> def loop_2d_as_1d():
|
111
|
+
>>> for i in ti.ndrange(2, 5):
|
112
|
+
>>> print(i)
|
113
|
+
will print 0 1 2 3 4 5 6 7 8 9
|
114
|
+
|
115
|
+
In general, you can use any `1 <= k <= n` iterators to loop over a set of n-D
|
116
|
+
indices. For `k=n` all the indices are n-dimensional, and they are returned in
|
117
|
+
lexical order, but for `k<n` iterators the last n-k+1 dimensions will be collapsed into
|
118
|
+
a 1-D array of consecutive integers `(0, 1, 2, ...)` whose length equals the
|
119
|
+
total number of indices in the last n-k+1 dimensions:
|
120
|
+
|
121
|
+
>>> @ti.kernel
|
122
|
+
>>> def loop_3d_as_2d():
|
123
|
+
>>> # use two iterators to loop over a set of 3-D indices
|
124
|
+
>>> # the last two dimensions for 4, 5 will collapse into
|
125
|
+
>>> # the array [0, 1, 2, ..., 19]
|
126
|
+
>>> for i, j in ti.ndrange(3, 4, 5):
|
127
|
+
>>> print(i, j)
|
128
|
+
will print 0 0, 0 1, ..., 0 19, ..., 2 19.
|
129
|
+
|
130
|
+
A typical usage of `ndrange` is when you want to loop over a tensor and process
|
131
|
+
its entries in parallel. You should avoid writing nested `for` loops here since
|
132
|
+
only top level `for` loops are paralleled in gstaichi, instead you can use `ndrange`
|
133
|
+
to hold all entries in one top level loop:
|
134
|
+
|
135
|
+
>>> @ti.kernel
|
136
|
+
>>> def loop_tensor():
|
137
|
+
>>> for row, col, channel in ti.ndrange(image_height, image_width, channels):
|
138
|
+
>>> image[row, col, channel] = ...
|
139
|
+
"""
|
140
|
+
return _Ndrange(*args)
|
141
|
+
|
142
|
+
|
143
|
+
class GroupedNDRange:
|
144
|
+
def __init__(self, r):
|
145
|
+
self.r = r
|
146
|
+
|
147
|
+
def __iter__(self):
|
148
|
+
for ind in self.r:
|
149
|
+
yield Matrix(list(ind))
|
150
|
+
|
151
|
+
|
152
|
+
__all__ = ["ndrange"]
|
@@ -0,0 +1,195 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import weakref
|
3
|
+
from typing import Any, Callable, Union
|
4
|
+
|
5
|
+
import gstaichi.lang
|
6
|
+
import gstaichi.lang._ndarray
|
7
|
+
import gstaichi.lang._texture
|
8
|
+
import gstaichi.lang.expr
|
9
|
+
import gstaichi.lang.snode
|
10
|
+
from gstaichi._lib import core as _ti_core
|
11
|
+
from gstaichi.lang import _dataclass_util
|
12
|
+
from gstaichi.lang.any_array import AnyArray
|
13
|
+
from gstaichi.lang.exception import (
|
14
|
+
GsTaichiRuntimeTypeError,
|
15
|
+
)
|
16
|
+
from gstaichi.lang.kernel_arguments import ArgMetadata
|
17
|
+
from gstaichi.lang.matrix import MatrixType
|
18
|
+
from gstaichi.lang.util import is_ti_template, to_gstaichi_type
|
19
|
+
from gstaichi.types import (
|
20
|
+
ndarray_type,
|
21
|
+
sparse_matrix_builder,
|
22
|
+
template,
|
23
|
+
texture_type,
|
24
|
+
)
|
25
|
+
from gstaichi.types.enums import AutodiffMode
|
26
|
+
|
27
|
+
CompiledKernelKeyType = tuple[Callable, int, AutodiffMode]
|
28
|
+
|
29
|
+
|
30
|
+
AnnotationType = Union[
|
31
|
+
template,
|
32
|
+
"texture_type.TextureType",
|
33
|
+
"texture_type.RWTextureType",
|
34
|
+
ndarray_type.NdarrayType,
|
35
|
+
sparse_matrix_builder,
|
36
|
+
Any,
|
37
|
+
]
|
38
|
+
|
39
|
+
|
40
|
+
class TemplateMapper:
|
41
|
+
"""
|
42
|
+
This should probably be renamed to sometihng like FeatureMapper, or
|
43
|
+
FeatureExtractor, since:
|
44
|
+
- it's not specific to templates
|
45
|
+
- it extracts what are later called 'features', for example for ndarray this includes:
|
46
|
+
- element type
|
47
|
+
- number dimensions
|
48
|
+
- needs grad (or not)
|
49
|
+
- these are returned as a heterogeneous tuple, whose contents depends on the type
|
50
|
+
"""
|
51
|
+
|
52
|
+
def __init__(self, arguments: list[ArgMetadata], template_slot_locations: list[int]) -> None:
|
53
|
+
self.arguments: list[ArgMetadata] = arguments
|
54
|
+
self.num_args: int = len(arguments)
|
55
|
+
self.template_slot_locations: list[int] = template_slot_locations
|
56
|
+
self.mapping: dict[tuple[Any, ...], int] = {}
|
57
|
+
|
58
|
+
@staticmethod
|
59
|
+
def extract_arg(arg: Any, annotation: AnnotationType, arg_name: str) -> Any:
|
60
|
+
if is_ti_template(annotation):
|
61
|
+
if isinstance(arg, gstaichi.lang.snode.SNode):
|
62
|
+
return arg.ptr
|
63
|
+
if isinstance(arg, gstaichi.lang.expr.Expr):
|
64
|
+
return arg.ptr.get_underlying_ptr_address()
|
65
|
+
if isinstance(arg, _ti_core.ExprCxx):
|
66
|
+
return arg.get_underlying_ptr_address()
|
67
|
+
if isinstance(arg, tuple):
|
68
|
+
return tuple(TemplateMapper.extract_arg(item, annotation, arg_name) for item in arg)
|
69
|
+
if isinstance(arg, gstaichi.lang._ndarray.Ndarray):
|
70
|
+
raise GsTaichiRuntimeTypeError(
|
71
|
+
"Ndarray shouldn't be passed in via `ti.template()`, please annotate your kernel using `ti.types.ndarray(...)` instead"
|
72
|
+
)
|
73
|
+
|
74
|
+
if isinstance(arg, (list, tuple, dict, set)) or hasattr(arg, "_data_oriented"):
|
75
|
+
# [Composite arguments] Return weak reference to the object
|
76
|
+
# GsTaichi kernel will cache the extracted arguments, thus we can't simply return the original argument.
|
77
|
+
# Instead, a weak reference to the original value is returned to avoid memory leak.
|
78
|
+
|
79
|
+
# TODO(zhanlue): replacing "tuple(args)" with "hash of argument values"
|
80
|
+
# This can resolve the following issues:
|
81
|
+
# 1. Invalid weak-ref will leave a dead(dangling) entry in both caches: "self.mapping" and "self.compiled_functions"
|
82
|
+
# 2. Different argument instances with same type and same value, will get templatized into seperate kernels.
|
83
|
+
return weakref.ref(arg)
|
84
|
+
|
85
|
+
# [Primitive arguments] Return the value
|
86
|
+
return arg
|
87
|
+
if dataclasses.is_dataclass(annotation):
|
88
|
+
_res_l = []
|
89
|
+
for field in dataclasses.fields(annotation):
|
90
|
+
field_value = getattr(arg, field.name)
|
91
|
+
child_name = _dataclass_util.create_flat_name(arg_name, field.name)
|
92
|
+
field_extracted = TemplateMapper.extract_arg(field_value, field.type, child_name)
|
93
|
+
_res_l.append(field_extracted)
|
94
|
+
return tuple(_res_l)
|
95
|
+
if isinstance(annotation, texture_type.TextureType):
|
96
|
+
if not isinstance(arg, gstaichi.lang._texture.Texture):
|
97
|
+
raise GsTaichiRuntimeTypeError(f"Argument {arg_name} must be a texture, got {type(arg)}")
|
98
|
+
if arg.num_dims != annotation.num_dimensions:
|
99
|
+
raise GsTaichiRuntimeTypeError(
|
100
|
+
f"TextureType dimension mismatch for argument {arg_name}: expected {annotation.num_dimensions}, got {arg.num_dims}"
|
101
|
+
)
|
102
|
+
return (arg.num_dims,)
|
103
|
+
if isinstance(annotation, texture_type.RWTextureType):
|
104
|
+
if not isinstance(arg, gstaichi.lang._texture.Texture):
|
105
|
+
raise GsTaichiRuntimeTypeError(f"Argument {arg_name} must be a texture, got {type(arg)}")
|
106
|
+
if arg.num_dims != annotation.num_dimensions:
|
107
|
+
raise GsTaichiRuntimeTypeError(
|
108
|
+
f"RWTextureType dimension mismatch for argument {arg_name}: expected {annotation.num_dimensions}, got {arg.num_dims}"
|
109
|
+
)
|
110
|
+
if arg.fmt != annotation.fmt:
|
111
|
+
raise GsTaichiRuntimeTypeError(
|
112
|
+
f"RWTextureType format mismatch for argument {arg_name}: expected {annotation.fmt}, got {arg.fmt}"
|
113
|
+
)
|
114
|
+
# (penguinliong) '0' is the assumed LOD level. We currently don't
|
115
|
+
# support mip-mapping.
|
116
|
+
return arg.num_dims, arg.fmt, 0
|
117
|
+
if isinstance(annotation, ndarray_type.NdarrayType):
|
118
|
+
if isinstance(arg, gstaichi.lang._ndarray.Ndarray):
|
119
|
+
annotation.check_matched(arg.get_type(), arg_name)
|
120
|
+
needs_grad = (arg.grad is not None) if annotation.needs_grad is None else annotation.needs_grad
|
121
|
+
assert arg.shape is not None
|
122
|
+
return arg.element_type, len(arg.shape), needs_grad, annotation.boundary
|
123
|
+
if isinstance(arg, AnyArray):
|
124
|
+
ty = arg.get_type()
|
125
|
+
annotation.check_matched(arg.get_type(), arg_name)
|
126
|
+
return ty.element_type, len(arg.shape), ty.needs_grad, annotation.boundary
|
127
|
+
# external arrays
|
128
|
+
shape = getattr(arg, "shape", None)
|
129
|
+
if shape is None:
|
130
|
+
raise GsTaichiRuntimeTypeError(f"Invalid type for argument {arg_name}, got {arg}")
|
131
|
+
shape = tuple(shape)
|
132
|
+
element_shape: tuple[int, ...] = ()
|
133
|
+
dtype = to_gstaichi_type(arg.dtype)
|
134
|
+
if isinstance(annotation.dtype, MatrixType):
|
135
|
+
if annotation.ndim is not None:
|
136
|
+
if len(shape) != annotation.dtype.ndim + annotation.ndim:
|
137
|
+
raise ValueError(
|
138
|
+
f"Invalid value for argument {arg_name} - required array has ndim={annotation.ndim} element_dim={annotation.dtype.ndim}, "
|
139
|
+
f"array with {len(shape)} dimensions is provided"
|
140
|
+
)
|
141
|
+
else:
|
142
|
+
if len(shape) < annotation.dtype.ndim:
|
143
|
+
raise ValueError(
|
144
|
+
f"Invalid value for argument {arg_name} - required element_dim={annotation.dtype.ndim}, "
|
145
|
+
f"array with {len(shape)} dimensions is provided"
|
146
|
+
)
|
147
|
+
element_shape = shape[-annotation.dtype.ndim :]
|
148
|
+
anno_element_shape = annotation.dtype.get_shape()
|
149
|
+
if None not in anno_element_shape and element_shape != anno_element_shape:
|
150
|
+
raise ValueError(
|
151
|
+
f"Invalid value for argument {arg_name} - required element_shape={anno_element_shape}, "
|
152
|
+
f"array with element shape of {element_shape} is provided"
|
153
|
+
)
|
154
|
+
elif annotation.dtype is not None:
|
155
|
+
# User specified scalar dtype
|
156
|
+
if annotation.dtype != dtype:
|
157
|
+
raise ValueError(
|
158
|
+
f"Invalid value for argument {arg_name} - required array has dtype={annotation.dtype.to_string()}, "
|
159
|
+
f"array with dtype={dtype.to_string()} is provided"
|
160
|
+
)
|
161
|
+
|
162
|
+
if annotation.ndim is not None and len(shape) != annotation.ndim:
|
163
|
+
raise ValueError(
|
164
|
+
f"Invalid value for argument {arg_name} - required array has ndim={annotation.ndim}, "
|
165
|
+
f"array with {len(shape)} dimensions is provided"
|
166
|
+
)
|
167
|
+
needs_grad = (
|
168
|
+
getattr(arg, "requires_grad", False) if annotation.needs_grad is None else annotation.needs_grad
|
169
|
+
)
|
170
|
+
element_type = (
|
171
|
+
_ti_core.get_type_factory_instance().get_tensor_type(element_shape, dtype)
|
172
|
+
if len(element_shape) != 0
|
173
|
+
else arg.dtype
|
174
|
+
)
|
175
|
+
return element_type, len(shape) - len(element_shape), needs_grad, annotation.boundary
|
176
|
+
if isinstance(annotation, sparse_matrix_builder):
|
177
|
+
return arg.dtype
|
178
|
+
# Use '#' as a placeholder because other kinds of arguments are not involved in template instantiation
|
179
|
+
return "#"
|
180
|
+
|
181
|
+
def extract(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
|
182
|
+
extracted: list[Any] = []
|
183
|
+
for arg, kernel_arg in zip(args, self.arguments):
|
184
|
+
extracted.append(self.extract_arg(arg, kernel_arg.annotation, kernel_arg.name))
|
185
|
+
return tuple(extracted)
|
186
|
+
|
187
|
+
def lookup(self, args: tuple[Any, ...]) -> tuple[int, tuple[Any, ...]]:
|
188
|
+
if len(args) != self.num_args:
|
189
|
+
raise TypeError(f"{self.num_args} argument(s) needed but {len(args)} provided.")
|
190
|
+
|
191
|
+
key = self.extract(args)
|
192
|
+
if key not in self.mapping:
|
193
|
+
count = len(self.mapping)
|
194
|
+
self.mapping[key] = count
|
195
|
+
return self.mapping[key], key
|
@@ -0,0 +1,172 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from gstaichi._lib import core as _ti_core
|
6
|
+
from gstaichi.lang import impl
|
7
|
+
from gstaichi.lang.expr import Expr, make_expr_group
|
8
|
+
from gstaichi.lang.matrix import Matrix
|
9
|
+
from gstaichi.lang.util import gstaichi_scope
|
10
|
+
from gstaichi.types import vector
|
11
|
+
from gstaichi.types.primitive_types import f32
|
12
|
+
|
13
|
+
|
14
|
+
def _get_entries(mat):
|
15
|
+
if isinstance(mat, Matrix):
|
16
|
+
return mat.entries
|
17
|
+
return [mat]
|
18
|
+
|
19
|
+
|
20
|
+
class TextureSampler:
|
21
|
+
def __init__(self, ptr_expr, num_dims) -> None:
|
22
|
+
self.ptr_expr = ptr_expr
|
23
|
+
self.num_dims = num_dims
|
24
|
+
|
25
|
+
@gstaichi_scope
|
26
|
+
def sample_lod(self, uv, lod):
|
27
|
+
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
|
28
|
+
dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
29
|
+
args_group = make_expr_group(*_get_entries(uv), lod)
|
30
|
+
v = ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kSampleLod, self.ptr_expr, args_group, dbg_info)
|
31
|
+
r = impl.call_internal("composite_extract_0", v, with_runtime_context=False)
|
32
|
+
g = impl.call_internal("composite_extract_1", v, with_runtime_context=False)
|
33
|
+
b = impl.call_internal("composite_extract_2", v, with_runtime_context=False)
|
34
|
+
a = impl.call_internal("composite_extract_3", v, with_runtime_context=False)
|
35
|
+
return vector(4, f32)([r, g, b, a])
|
36
|
+
|
37
|
+
@gstaichi_scope
|
38
|
+
def fetch(self, index, lod):
|
39
|
+
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
|
40
|
+
dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
41
|
+
args_group = make_expr_group(*_get_entries(index), lod)
|
42
|
+
v = ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kFetchTexel, self.ptr_expr, args_group, dbg_info)
|
43
|
+
r = impl.call_internal("composite_extract_0", v, with_runtime_context=False)
|
44
|
+
g = impl.call_internal("composite_extract_1", v, with_runtime_context=False)
|
45
|
+
b = impl.call_internal("composite_extract_2", v, with_runtime_context=False)
|
46
|
+
a = impl.call_internal("composite_extract_3", v, with_runtime_context=False)
|
47
|
+
return vector(4, f32)([r, g, b, a])
|
48
|
+
|
49
|
+
|
50
|
+
class RWTextureAccessor:
|
51
|
+
def __init__(self, ptr_expr, num_dims) -> None:
|
52
|
+
# gstaichi_python.TexturePtrExpression.
|
53
|
+
self.ptr_expr = ptr_expr
|
54
|
+
self.num_dims = num_dims
|
55
|
+
|
56
|
+
@gstaichi_scope
|
57
|
+
def load(self, index):
|
58
|
+
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
|
59
|
+
dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
60
|
+
args_group = make_expr_group(*_get_entries(index))
|
61
|
+
v = ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kLoad, self.ptr_expr, args_group, dbg_info)
|
62
|
+
r = impl.call_internal("composite_extract_0", v, with_runtime_context=False)
|
63
|
+
g = impl.call_internal("composite_extract_1", v, with_runtime_context=False)
|
64
|
+
b = impl.call_internal("composite_extract_2", v, with_runtime_context=False)
|
65
|
+
a = impl.call_internal("composite_extract_3", v, with_runtime_context=False)
|
66
|
+
return vector(4, f32)([r, g, b, a])
|
67
|
+
|
68
|
+
@gstaichi_scope
|
69
|
+
def store(self, index, value):
|
70
|
+
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
|
71
|
+
dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
72
|
+
args_group = make_expr_group(*_get_entries(index), *_get_entries(value))
|
73
|
+
impl.expr_init(
|
74
|
+
ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kStore, self.ptr_expr, args_group, dbg_info)
|
75
|
+
)
|
76
|
+
|
77
|
+
@property
|
78
|
+
@gstaichi_scope
|
79
|
+
def shape(self):
|
80
|
+
"""A list containing sizes for each dimension. Note that element shape will be excluded.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
List[Int]: The result list.
|
84
|
+
"""
|
85
|
+
dim = _ti_core.get_external_tensor_dim(self.ptr_expr)
|
86
|
+
dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
87
|
+
ret = [Expr(_ti_core.get_external_tensor_shape_along_axis(self.ptr_expr, i, dbg_info)) for i in range(dim)]
|
88
|
+
return ret
|
89
|
+
|
90
|
+
@gstaichi_scope
|
91
|
+
def _loop_range(self):
|
92
|
+
"""Gets the corresponding gstaichi_python.Expr to serve as loop range.
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
gstaichi_python.Expr: See above.
|
96
|
+
"""
|
97
|
+
return self.ptr_expr
|
98
|
+
|
99
|
+
|
100
|
+
class Texture:
|
101
|
+
"""GsTaichi Texture class.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
fmt (ti.Format): Color format of the texture.
|
105
|
+
shape (Tuple[int]): Shape of the Texture.
|
106
|
+
"""
|
107
|
+
|
108
|
+
def __init__(self, fmt, arr_shape):
|
109
|
+
self.tex = impl.get_runtime().prog.create_texture(fmt, arr_shape)
|
110
|
+
self.fmt = fmt
|
111
|
+
self.num_dims = len(arr_shape)
|
112
|
+
self.shape = arr_shape
|
113
|
+
|
114
|
+
def from_ndarray(self, ndarray):
|
115
|
+
"""Loads an ndarray to texture.
|
116
|
+
|
117
|
+
Args:
|
118
|
+
ndarray (ti.Ndarray): Source ndarray to load from.
|
119
|
+
"""
|
120
|
+
self.tex.from_ndarray(ndarray.arr)
|
121
|
+
|
122
|
+
def from_field(self, field):
|
123
|
+
"""Loads a field to texture.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
field (ti.Field): Source field to load from.
|
127
|
+
"""
|
128
|
+
self.tex.from_snode(field.snode.ptr)
|
129
|
+
|
130
|
+
def _device_allocation_ptr(self):
|
131
|
+
return self.tex.device_allocation_ptr()
|
132
|
+
|
133
|
+
def from_image(self, image):
|
134
|
+
"""Loads a PIL image to texture. This method is only allowed a 2D texture with `ti.Format.rgba8`.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
image (PIL.Image.Image): Source PIL image to load from.
|
138
|
+
|
139
|
+
"""
|
140
|
+
from PIL import Image # pylint: disable=import-outside-toplevel
|
141
|
+
|
142
|
+
assert isinstance(image, Image.Image)
|
143
|
+
if image.mode != "RGB":
|
144
|
+
image = image.convert("RGB")
|
145
|
+
assert image.size == tuple(self.shape)
|
146
|
+
|
147
|
+
assert self.num_dims == 2
|
148
|
+
# Don't use transpose method since its enums are too new
|
149
|
+
image = image.rotate(90, expand=True)
|
150
|
+
arr = np.asarray(image)
|
151
|
+
from gstaichi._kernels import ( # pylint: disable=import-outside-toplevel
|
152
|
+
load_texture_from_numpy,
|
153
|
+
)
|
154
|
+
|
155
|
+
load_texture_from_numpy(self, arr)
|
156
|
+
|
157
|
+
def to_image(self):
|
158
|
+
"""Saves a texture to a PIL image in RGB mode. This method is only allowed a 2D texture with `ti.Format.rgba8`.
|
159
|
+
|
160
|
+
Returns:
|
161
|
+
img (PIL.Image.Image): a PIL image in RGB mode, with the same size as source texture.
|
162
|
+
"""
|
163
|
+
assert self.num_dims == 2
|
164
|
+
from PIL import Image # pylint: disable=import-outside-toplevel
|
165
|
+
|
166
|
+
res = np.zeros(self.shape + (3,), np.uint8)
|
167
|
+
from gstaichi._kernels import ( # pylint: disable=import-outside-toplevel
|
168
|
+
save_texture_to_numpy,
|
169
|
+
)
|
170
|
+
|
171
|
+
save_texture_to_numpy(self, res)
|
172
|
+
return Image.fromarray(res).rotate(270, expand=True)
|