gstaichi 2.1.1rc3__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/struct.py
ADDED
@@ -0,0 +1,810 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import numbers
|
4
|
+
from types import MethodType
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
from gstaichi._lib import core as _ti_core
|
9
|
+
from gstaichi.lang import expr, impl, ops
|
10
|
+
from gstaichi.lang.exception import (
|
11
|
+
GsTaichiRuntimeTypeError,
|
12
|
+
GsTaichiSyntaxError,
|
13
|
+
GsTaichiTypeError,
|
14
|
+
)
|
15
|
+
from gstaichi.lang.expr import Expr
|
16
|
+
from gstaichi.lang.field import Field, ScalarField, SNodeHostAccess
|
17
|
+
from gstaichi.lang.matrix import Matrix, MatrixType
|
18
|
+
from gstaichi.lang.util import cook_dtype, gstaichi_scope, in_python_scope, python_scope
|
19
|
+
from gstaichi.types import primitive_types
|
20
|
+
from gstaichi.types.compound_types import CompoundType
|
21
|
+
from gstaichi.types.enums import Layout
|
22
|
+
from gstaichi.types.utils import is_signed
|
23
|
+
|
24
|
+
|
25
|
+
class Struct:
|
26
|
+
"""The Struct type class.
|
27
|
+
|
28
|
+
A struct is a dictionary-like data structure that stores members as
|
29
|
+
(key, value) pairs. Valid data members of a struct can be scalars,
|
30
|
+
matrices or other dictionary-like structures.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
entries (Dict[str, Union[Dict, Expr, Matrix, Struct]]): \
|
34
|
+
keys and values for struct members. Entries can optionally
|
35
|
+
include a dictionary of functions with the key '__struct_methods'
|
36
|
+
which will be attached to the struct for executing on the struct data.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
An instance of this struct.
|
40
|
+
|
41
|
+
Example::
|
42
|
+
_
|
43
|
+
>>> vec3 = ti.types.vector(3, ti.f32)
|
44
|
+
>>> a = ti.Struct(v=vec3([0, 0, 0]), t=1.0)
|
45
|
+
>>> print(a.items)
|
46
|
+
dict_items([('v', [0. 0. 0.]), ('t', 1.0)])
|
47
|
+
>>>
|
48
|
+
>>> B = ti.Struct(v=vec3([0., 0., 0.]), t=1.0, A=a)
|
49
|
+
>>> print(B.items)
|
50
|
+
dict_items([('v', [0. 0. 0.]), ('t', 1.0), ('A', {'v': [[0.], [0.], [0.]], 't': 1.0})])
|
51
|
+
"""
|
52
|
+
|
53
|
+
_is_gstaichi_class = True
|
54
|
+
_instance_count = 0
|
55
|
+
|
56
|
+
def __init__(self, *args, **kwargs):
|
57
|
+
# converts lists to matrices and dicts to structs
|
58
|
+
if len(args) == 1 and kwargs == {} and isinstance(args[0], dict):
|
59
|
+
self.__entries = args[0]
|
60
|
+
elif len(args) == 0:
|
61
|
+
self.__entries = kwargs
|
62
|
+
else:
|
63
|
+
raise GsTaichiSyntaxError(
|
64
|
+
"Custom structs need to be initialized using either dictionary or keyword arguments"
|
65
|
+
)
|
66
|
+
self.__methods = self.__entries.pop("__struct_methods", {})
|
67
|
+
matrix_ndim = self.__entries.pop("__matrix_ndim", {})
|
68
|
+
self._register_methods()
|
69
|
+
|
70
|
+
for k, v in self.__entries.items():
|
71
|
+
if isinstance(v, (list, tuple)):
|
72
|
+
v = Matrix(v)
|
73
|
+
if isinstance(v, dict):
|
74
|
+
v = Struct(v)
|
75
|
+
self.__entries[k] = v if in_python_scope() else impl.expr_init(v)
|
76
|
+
self._register_members()
|
77
|
+
self.__dtype = None
|
78
|
+
|
79
|
+
@property
|
80
|
+
def keys(self):
|
81
|
+
"""Returns the list of member names in string format.
|
82
|
+
|
83
|
+
Example::
|
84
|
+
|
85
|
+
>>> vec3 = ti.types.vector(3, ti.f32)
|
86
|
+
>>> sphere = ti.Struct(center=vec3([0, 0, 0]), radius=1.0)
|
87
|
+
>>> a.keys
|
88
|
+
['center', 'radius']
|
89
|
+
"""
|
90
|
+
return list(self.__entries.keys())
|
91
|
+
|
92
|
+
@property
|
93
|
+
def _members(self):
|
94
|
+
return list(self.__entries.values())
|
95
|
+
|
96
|
+
@property
|
97
|
+
def entries(self):
|
98
|
+
return self.__entries
|
99
|
+
|
100
|
+
@property
|
101
|
+
def methods(self):
|
102
|
+
return self.__methods
|
103
|
+
|
104
|
+
@property
|
105
|
+
def items(self):
|
106
|
+
"""Returns the items in this struct.
|
107
|
+
|
108
|
+
Example::
|
109
|
+
|
110
|
+
>>> vec3 = ti.types.vector(3, ti.f32)
|
111
|
+
>>> sphere = ti.Struct(center=vec3([0, 0, 0]), radius=1.0)
|
112
|
+
>>> sphere.items
|
113
|
+
dict_items([('center', 2), ('radius', 1.0)])
|
114
|
+
"""
|
115
|
+
return self.__entries.items()
|
116
|
+
|
117
|
+
def _register_members(self):
|
118
|
+
# https://stackoverflow.com/questions/48448074/adding-a-property-to-an-existing-object-instance
|
119
|
+
cls = self.__class__
|
120
|
+
new_cls_name = cls.__name__ + str(cls._instance_count)
|
121
|
+
cls._instance_count += 1
|
122
|
+
properties = {k: property(cls._make_getter(k), cls._make_setter(k)) for k in self.keys}
|
123
|
+
self.__class__ = type(new_cls_name, (cls,), properties)
|
124
|
+
|
125
|
+
def _register_methods(self):
|
126
|
+
for name, method in self.__methods.items():
|
127
|
+
# use MethodType to pass self (this object) to the method
|
128
|
+
setattr(self, name, MethodType(method, self))
|
129
|
+
|
130
|
+
def __getitem__(self, key):
|
131
|
+
ret = self.__entries[key]
|
132
|
+
if isinstance(ret, SNodeHostAccess):
|
133
|
+
ret = ret.accessor.getter(*ret.key)
|
134
|
+
return ret
|
135
|
+
|
136
|
+
def __setitem__(self, key, value):
|
137
|
+
if isinstance(self.__entries[key], SNodeHostAccess):
|
138
|
+
self.__entries[key].accessor.setter(value, *self.__entries[key].key)
|
139
|
+
else:
|
140
|
+
if in_python_scope():
|
141
|
+
if isinstance(self.__entries[key], Struct) or isinstance(self.__entries[key], Matrix):
|
142
|
+
self.__entries[key]._set_entries(value)
|
143
|
+
else:
|
144
|
+
if isinstance(value, numbers.Number):
|
145
|
+
self.__entries[key] = value
|
146
|
+
else:
|
147
|
+
raise TypeError("A number is expected when assigning struct members")
|
148
|
+
else:
|
149
|
+
self.__entries[key] = value
|
150
|
+
|
151
|
+
def _set_entries(self, value):
|
152
|
+
if isinstance(value, dict):
|
153
|
+
value = Struct(value)
|
154
|
+
for k in self.keys:
|
155
|
+
self[k] = value[k]
|
156
|
+
self.__dtype = value.__dtype
|
157
|
+
|
158
|
+
@staticmethod
|
159
|
+
def _make_getter(key):
|
160
|
+
def getter(self):
|
161
|
+
"""Get an entry from custom struct by name."""
|
162
|
+
return self[key]
|
163
|
+
|
164
|
+
return getter
|
165
|
+
|
166
|
+
@staticmethod
|
167
|
+
def _make_setter(key):
|
168
|
+
@python_scope
|
169
|
+
def setter(self, value):
|
170
|
+
self[key] = value
|
171
|
+
|
172
|
+
return setter
|
173
|
+
|
174
|
+
@gstaichi_scope
|
175
|
+
def _assign(self, other):
|
176
|
+
if not isinstance(other, (dict, Struct)):
|
177
|
+
raise GsTaichiTypeError("Only dict or Struct can be assigned to a Struct")
|
178
|
+
if isinstance(other, dict):
|
179
|
+
other = Struct(other)
|
180
|
+
if self.__entries.keys() != other.__entries.keys():
|
181
|
+
raise GsTaichiTypeError(f"Member mismatch between structs {self.keys}, {other.keys}")
|
182
|
+
for k, v in self.items:
|
183
|
+
v._assign(other.__entries[k])
|
184
|
+
self.__dtype = other.__dtype
|
185
|
+
return self
|
186
|
+
|
187
|
+
def __len__(self):
|
188
|
+
"""Get the number of entries in a custom struct"""
|
189
|
+
return len(self.__entries)
|
190
|
+
|
191
|
+
def __iter__(self):
|
192
|
+
return self.__entries.values()
|
193
|
+
|
194
|
+
def __str__(self):
|
195
|
+
"""Python scope struct array print support."""
|
196
|
+
if impl.inside_kernel():
|
197
|
+
item_str = ", ".join([str(k) + "=" + str(v) for k, v in self.items])
|
198
|
+
item_str += f", struct_methods={self.__methods}"
|
199
|
+
return f"<ti.Struct {item_str}>"
|
200
|
+
return str(self.to_dict())
|
201
|
+
|
202
|
+
def __repr__(self):
|
203
|
+
return str(self.to_dict())
|
204
|
+
|
205
|
+
def to_dict(self, include_methods=False, include_ndim=False):
|
206
|
+
"""Converts the Struct to a dictionary.
|
207
|
+
|
208
|
+
Args:
|
209
|
+
include_methods (bool): Whether any struct methods should be included
|
210
|
+
in the result dictionary under the key '__struct_methods'.
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
Dict: The result dictionary.
|
214
|
+
"""
|
215
|
+
res_dict = {
|
216
|
+
k: (
|
217
|
+
v.to_dict(include_methods=include_methods, include_ndim=include_ndim)
|
218
|
+
if isinstance(v, Struct)
|
219
|
+
else v.to_list() if isinstance(v, Matrix) else v
|
220
|
+
)
|
221
|
+
for k, v in self.__entries.items()
|
222
|
+
}
|
223
|
+
if include_methods:
|
224
|
+
res_dict["__struct_methods"] = self.__methods
|
225
|
+
if include_ndim:
|
226
|
+
res_dict["__matrix_ndim"] = dict()
|
227
|
+
for k, v in self.__entries.items():
|
228
|
+
if isinstance(v, Matrix):
|
229
|
+
res_dict["__matrix_ndim"][k] = v.ndim
|
230
|
+
return res_dict
|
231
|
+
|
232
|
+
@classmethod
|
233
|
+
@python_scope
|
234
|
+
def field(
|
235
|
+
cls,
|
236
|
+
members,
|
237
|
+
methods={},
|
238
|
+
shape=None,
|
239
|
+
name="<Struct>",
|
240
|
+
offset=None,
|
241
|
+
needs_grad=False,
|
242
|
+
needs_dual=False,
|
243
|
+
layout=Layout.AOS,
|
244
|
+
):
|
245
|
+
"""Creates a :class:`~gstaichi.StructField` with each element
|
246
|
+
has this struct as its type.
|
247
|
+
|
248
|
+
Args:
|
249
|
+
members (dict): a dict, each item is like `name: type`.
|
250
|
+
methods (dict): a dict of methods that should be included with
|
251
|
+
the field. Each struct item of the field will have the
|
252
|
+
methods as instance functions.
|
253
|
+
shape (Tuple[int]): width and height of the field.
|
254
|
+
offset (Tuple[int]): offset of the indices of the created field.
|
255
|
+
For example if `offset=(-10, -10)` the indices of the field
|
256
|
+
will start at `(-10, -10)`, not `(0, 0)`.
|
257
|
+
needs_grad (bool): enabling grad field (reverse mode autodiff) or not.
|
258
|
+
needs_dual (bool): enabling dual field (forward mode autodiff) or not.
|
259
|
+
layout: AOS or SOA.
|
260
|
+
|
261
|
+
Example:
|
262
|
+
|
263
|
+
>>> vec3 = ti.types.vector(3, ti.f32)
|
264
|
+
>>> sphere = {"center": vec3, "radius": float}
|
265
|
+
>>> F = ti.Struct.field(sphere, shape=(3, 3))
|
266
|
+
>>> F
|
267
|
+
{'center': array([[[0., 0., 0.],
|
268
|
+
[0., 0., 0.],
|
269
|
+
[0., 0., 0.]],
|
270
|
+
|
271
|
+
[[0., 0., 0.],
|
272
|
+
[0., 0., 0.],
|
273
|
+
[0., 0., 0.]],
|
274
|
+
|
275
|
+
[[0., 0., 0.],
|
276
|
+
[0., 0., 0.],
|
277
|
+
[0., 0., 0.]]], dtype=float32), 'radius': array([[0., 0., 0.],
|
278
|
+
[0., 0., 0.],
|
279
|
+
[0., 0., 0.]], dtype=float32)}
|
280
|
+
"""
|
281
|
+
|
282
|
+
if shape is None and offset is not None:
|
283
|
+
raise GsTaichiSyntaxError("shape cannot be None when offset is being set")
|
284
|
+
|
285
|
+
field_dict = {}
|
286
|
+
|
287
|
+
for key, dtype in members.items():
|
288
|
+
field_name = name + "." + key
|
289
|
+
if isinstance(dtype, CompoundType):
|
290
|
+
if isinstance(dtype, StructType):
|
291
|
+
field_dict[key] = dtype.field(
|
292
|
+
shape=None,
|
293
|
+
name=field_name,
|
294
|
+
offset=offset,
|
295
|
+
needs_grad=needs_grad,
|
296
|
+
needs_dual=needs_dual,
|
297
|
+
)
|
298
|
+
else:
|
299
|
+
field_dict[key] = dtype.field(
|
300
|
+
shape=None,
|
301
|
+
name=field_name,
|
302
|
+
offset=offset,
|
303
|
+
needs_grad=needs_grad,
|
304
|
+
needs_dual=needs_dual,
|
305
|
+
ndim=getattr(dtype, "ndim", 2),
|
306
|
+
)
|
307
|
+
else:
|
308
|
+
field_dict[key] = impl.field(
|
309
|
+
dtype,
|
310
|
+
shape=None,
|
311
|
+
name=field_name,
|
312
|
+
offset=offset,
|
313
|
+
needs_grad=needs_grad,
|
314
|
+
needs_dual=needs_dual,
|
315
|
+
)
|
316
|
+
|
317
|
+
if shape is not None:
|
318
|
+
if isinstance(shape, numbers.Number):
|
319
|
+
shape = (shape,)
|
320
|
+
if isinstance(offset, numbers.Number):
|
321
|
+
offset = (offset,)
|
322
|
+
|
323
|
+
if offset is not None and len(shape) != len(offset):
|
324
|
+
raise GsTaichiSyntaxError(
|
325
|
+
f"The dimensionality of shape and offset must be the same ({len(shape)} != {len(offset)})"
|
326
|
+
)
|
327
|
+
dim = len(shape)
|
328
|
+
if layout == Layout.SOA:
|
329
|
+
for e in field_dict.values():
|
330
|
+
impl.root.dense(impl.index_nd(dim), shape).place(e, offset=offset)
|
331
|
+
if needs_grad:
|
332
|
+
for e in field_dict.values():
|
333
|
+
impl.root.dense(impl.index_nd(dim), shape).place(e.grad, offset=offset)
|
334
|
+
if needs_dual:
|
335
|
+
for e in field_dict.values():
|
336
|
+
impl.root.dense(impl.index_nd(dim), shape).place(e.dual, offset=offset)
|
337
|
+
else:
|
338
|
+
impl.root.dense(impl.index_nd(dim), shape).place(*tuple(field_dict.values()), offset=offset)
|
339
|
+
if needs_grad:
|
340
|
+
grads = tuple(e.grad for e in field_dict.values())
|
341
|
+
impl.root.dense(impl.index_nd(dim), shape).place(*grads, offset=offset)
|
342
|
+
|
343
|
+
if needs_dual:
|
344
|
+
duals = tuple(e.dual for e in field_dict.values())
|
345
|
+
impl.root.dense(impl.index_nd(dim), shape).place(*duals, offset=offset)
|
346
|
+
|
347
|
+
return StructField(field_dict, methods, name=name)
|
348
|
+
|
349
|
+
|
350
|
+
class _IntermediateStruct(Struct):
|
351
|
+
"""Intermediate struct class for compiler internal use only.
|
352
|
+
|
353
|
+
Args:
|
354
|
+
entries (Dict[str, Union[Expr, Matrix, Struct]]): keys and values for struct members.
|
355
|
+
Any methods included under the key '__struct_methods' will be applied to each
|
356
|
+
struct instance.
|
357
|
+
"""
|
358
|
+
|
359
|
+
def __init__(self, entries):
|
360
|
+
assert isinstance(entries, dict)
|
361
|
+
self._Struct__methods = entries.pop("__struct_methods", {})
|
362
|
+
self._register_methods()
|
363
|
+
self._Struct__entries = entries
|
364
|
+
self._register_members()
|
365
|
+
|
366
|
+
|
367
|
+
class StructField(Field):
|
368
|
+
"""GsTaichi struct field with SNode implementation.
|
369
|
+
|
370
|
+
Instead of directly constraining Expr entries, the StructField object
|
371
|
+
directly hosts members as `Field` instances to support nested structs.
|
372
|
+
|
373
|
+
Args:
|
374
|
+
field_dict (Dict[str, Field]): Struct field members.
|
375
|
+
struct_methods (Dict[str, callable]): Dictionary of functions to apply
|
376
|
+
to each struct instance in the field.
|
377
|
+
name (string, optional): The custom name of the field.
|
378
|
+
"""
|
379
|
+
|
380
|
+
def __init__(self, field_dict, struct_methods, name=None, is_primal=True):
|
381
|
+
# will not call Field initializer
|
382
|
+
self.field_dict = field_dict
|
383
|
+
self.struct_methods = struct_methods
|
384
|
+
self.name = name
|
385
|
+
self.grad = None
|
386
|
+
self.dual = None
|
387
|
+
if is_primal:
|
388
|
+
grad_field_dict = {}
|
389
|
+
for k, v in self.field_dict.items():
|
390
|
+
grad_field_dict[k] = v.grad
|
391
|
+
self.grad = StructField(grad_field_dict, struct_methods, name + ".grad", is_primal=False)
|
392
|
+
|
393
|
+
dual_field_dict = {}
|
394
|
+
for k, v in self.field_dict.items():
|
395
|
+
dual_field_dict[k] = v.dual
|
396
|
+
self.dual = StructField(dual_field_dict, struct_methods, name + ".dual", is_primal=False)
|
397
|
+
self._register_fields()
|
398
|
+
|
399
|
+
@property
|
400
|
+
def keys(self):
|
401
|
+
"""Returns the list of names of the field members.
|
402
|
+
|
403
|
+
Example::
|
404
|
+
|
405
|
+
>>> f1 = ti.Vector.field(3, ti.f32, shape=(3, 3))
|
406
|
+
>>> f2 = ti.field(ti.f32, shape=(3, 3))
|
407
|
+
>>> F = ti.StructField({"center": f1, "radius": f2})
|
408
|
+
>>> F.keys
|
409
|
+
['center', 'radius']
|
410
|
+
"""
|
411
|
+
return list(self.field_dict.keys())
|
412
|
+
|
413
|
+
@property
|
414
|
+
def _members(self):
|
415
|
+
return list(self.field_dict.values())
|
416
|
+
|
417
|
+
@property
|
418
|
+
def _items(self):
|
419
|
+
return self.field_dict.items()
|
420
|
+
|
421
|
+
@staticmethod
|
422
|
+
def _make_getter(key):
|
423
|
+
def getter(self):
|
424
|
+
"""Get an entry from custom struct by name."""
|
425
|
+
return self.field_dict[key]
|
426
|
+
|
427
|
+
return getter
|
428
|
+
|
429
|
+
@staticmethod
|
430
|
+
def _make_setter(key):
|
431
|
+
@python_scope
|
432
|
+
def setter(self, value):
|
433
|
+
self.field_dict[key] = value
|
434
|
+
|
435
|
+
return setter
|
436
|
+
|
437
|
+
def _register_fields(self):
|
438
|
+
for k in self.keys:
|
439
|
+
setattr(self, k, self.field_dict[k])
|
440
|
+
|
441
|
+
def _get_field_members(self):
|
442
|
+
"""Gets A flattened list of all struct elements.
|
443
|
+
|
444
|
+
Returns:
|
445
|
+
A list of struct elements.
|
446
|
+
"""
|
447
|
+
field_members = []
|
448
|
+
for m in self._members:
|
449
|
+
assert isinstance(m, Field)
|
450
|
+
field_members += m._get_field_members()
|
451
|
+
return field_members
|
452
|
+
|
453
|
+
@property
|
454
|
+
def _snode(self):
|
455
|
+
"""Gets representative SNode for info purposes.
|
456
|
+
|
457
|
+
Returns:
|
458
|
+
SNode: Representative SNode (SNode of first field member).
|
459
|
+
"""
|
460
|
+
return self._members[0]._snode
|
461
|
+
|
462
|
+
def _loop_range(self):
|
463
|
+
"""Gets SNode of representative field member for loop range info.
|
464
|
+
|
465
|
+
Returns:
|
466
|
+
gstaichi_python.SNode: SNode of representative (first) field member.
|
467
|
+
"""
|
468
|
+
return self._members[0]._loop_range()
|
469
|
+
|
470
|
+
@python_scope
|
471
|
+
def copy_from(self, other):
|
472
|
+
"""Copies all elements from another field.
|
473
|
+
|
474
|
+
The shape of the other field needs to be the same as `self`.
|
475
|
+
|
476
|
+
Args:
|
477
|
+
other (Field): The source field.
|
478
|
+
"""
|
479
|
+
assert isinstance(other, Field)
|
480
|
+
assert set(self.keys) == set(other.keys)
|
481
|
+
for k in self.keys:
|
482
|
+
self.field_dict[k].copy_from(other.get_member_field(k))
|
483
|
+
|
484
|
+
@python_scope
|
485
|
+
def fill(self, val):
|
486
|
+
"""Fills this struct field with a specified value.
|
487
|
+
|
488
|
+
Args:
|
489
|
+
val (Union[int, float]): Value to fill.
|
490
|
+
"""
|
491
|
+
for v in self._members:
|
492
|
+
v.fill(val)
|
493
|
+
|
494
|
+
def _initialize_host_accessors(self):
|
495
|
+
for v in self._members:
|
496
|
+
v._initialize_host_accessors()
|
497
|
+
|
498
|
+
def get_member_field(self, key):
|
499
|
+
"""Creates a ScalarField using a specific field member.
|
500
|
+
|
501
|
+
Args:
|
502
|
+
key (str): Specified key of the field member.
|
503
|
+
|
504
|
+
Returns:
|
505
|
+
ScalarField: The result ScalarField.
|
506
|
+
"""
|
507
|
+
return self.field_dict[key]
|
508
|
+
|
509
|
+
@python_scope
|
510
|
+
def from_numpy(self, array_dict):
|
511
|
+
"""Copies the data from a set of `numpy.array` into this field.
|
512
|
+
|
513
|
+
The argument `array_dict` must be a dictionay-like object, it
|
514
|
+
contains all the keys in this field and the copying process
|
515
|
+
between corresponding items can be performed.
|
516
|
+
"""
|
517
|
+
for k, v in self._items:
|
518
|
+
v.from_numpy(array_dict[k])
|
519
|
+
|
520
|
+
@python_scope
|
521
|
+
def from_torch(self, array_dict):
|
522
|
+
"""Copies the data from a set of `torch.tensor` into this field.
|
523
|
+
|
524
|
+
The argument `array_dict` must be a dictionay-like object, it
|
525
|
+
contains all the keys in this field and the copying process
|
526
|
+
between corresponding items can be performed.
|
527
|
+
"""
|
528
|
+
for k, v in self._items:
|
529
|
+
v.from_torch(array_dict[k])
|
530
|
+
|
531
|
+
@python_scope
|
532
|
+
def to_numpy(self):
|
533
|
+
"""Converts the Struct field instance to a dictionary of NumPy arrays.
|
534
|
+
|
535
|
+
The dictionary may be nested when converting nested structs.
|
536
|
+
|
537
|
+
Returns:
|
538
|
+
Dict[str, Union[numpy.ndarray, Dict]]: The result NumPy array.
|
539
|
+
"""
|
540
|
+
return {k: v.to_numpy() for k, v in self._items}
|
541
|
+
|
542
|
+
@python_scope
|
543
|
+
def to_torch(self, device=None):
|
544
|
+
"""Converts the Struct field instance to a dictionary of PyTorch tensors.
|
545
|
+
|
546
|
+
The dictionary may be nested when converting nested structs.
|
547
|
+
|
548
|
+
Args:
|
549
|
+
device (torch.device, optional): The
|
550
|
+
desired device of returned tensor.
|
551
|
+
|
552
|
+
Returns:
|
553
|
+
Dict[str, Union[torch.Tensor, Dict]]: The result
|
554
|
+
PyTorch tensor.
|
555
|
+
"""
|
556
|
+
return {k: v.to_torch(device=device) for k, v in self._items}
|
557
|
+
|
558
|
+
@python_scope
|
559
|
+
def __setitem__(self, indices, element):
|
560
|
+
self._initialize_host_accessors()
|
561
|
+
self[indices]._set_entries(element)
|
562
|
+
|
563
|
+
@python_scope
|
564
|
+
def __getitem__(self, indices):
|
565
|
+
self._initialize_host_accessors()
|
566
|
+
# scalar fields does not instantiate SNodeHostAccess by default
|
567
|
+
entries = {
|
568
|
+
k: v._host_access(self._pad_key(indices))[0] if isinstance(v, ScalarField) else v[indices]
|
569
|
+
for k, v in self._items
|
570
|
+
}
|
571
|
+
entries["__struct_methods"] = self.struct_methods
|
572
|
+
return Struct(entries)
|
573
|
+
|
574
|
+
|
575
|
+
class StructType(CompoundType):
|
576
|
+
def __init__(self, **kwargs):
|
577
|
+
self.members = {}
|
578
|
+
self.methods = {}
|
579
|
+
elements = []
|
580
|
+
for k, dtype in kwargs.items():
|
581
|
+
if k == "__struct_methods":
|
582
|
+
self.methods = dtype
|
583
|
+
elif isinstance(dtype, StructType):
|
584
|
+
self.members[k] = dtype
|
585
|
+
elements.append([dtype.dtype, k])
|
586
|
+
elif isinstance(dtype, MatrixType):
|
587
|
+
self.members[k] = dtype
|
588
|
+
elements.append([dtype.tensor_type, k])
|
589
|
+
else:
|
590
|
+
dtype = cook_dtype(dtype)
|
591
|
+
self.members[k] = dtype
|
592
|
+
elements.append([dtype, k])
|
593
|
+
self.dtype = _ti_core.get_type_factory_instance().get_struct_type(elements)
|
594
|
+
|
595
|
+
def __call__(self, *args, **kwargs):
|
596
|
+
"""Create an instance of this struct type."""
|
597
|
+
d = {}
|
598
|
+
items = self.members.items()
|
599
|
+
# iterate over the members of this struct
|
600
|
+
for index, pair in enumerate(items):
|
601
|
+
name, dtype = pair # (member name, member type)
|
602
|
+
if index < len(args): # set from args
|
603
|
+
data = args[index]
|
604
|
+
else: # set from kwargs
|
605
|
+
data = kwargs.get(name, 0)
|
606
|
+
|
607
|
+
# If dtype is CompoundType and data is a scalar, it cannot be
|
608
|
+
# casted in the self.cast call later. We need an initialization here.
|
609
|
+
if isinstance(dtype, CompoundType) and not isinstance(data, (dict, Struct)):
|
610
|
+
data = dtype(data)
|
611
|
+
|
612
|
+
d[name] = data
|
613
|
+
|
614
|
+
entries = Struct(d)
|
615
|
+
entries._Struct__dtype = self.dtype
|
616
|
+
struct = self.cast(entries)
|
617
|
+
struct._Struct__dtype = self.dtype
|
618
|
+
return struct
|
619
|
+
|
620
|
+
def __instancecheck__(self, instance):
|
621
|
+
if not isinstance(instance, Struct):
|
622
|
+
return False
|
623
|
+
if list(self.members.keys()) != list(instance._Struct__entries.keys()):
|
624
|
+
return False
|
625
|
+
if (
|
626
|
+
hasattr(instance, "_Struct__dtype")
|
627
|
+
and instance._Struct__dtype is not None
|
628
|
+
and instance._Struct__dtype != self.dtype
|
629
|
+
):
|
630
|
+
return False
|
631
|
+
for index, (name, dtype) in enumerate(self.members.items()):
|
632
|
+
val = instance._members[index]
|
633
|
+
if isinstance(dtype, StructType):
|
634
|
+
if not isinstance(val, dtype):
|
635
|
+
return False
|
636
|
+
elif isinstance(dtype, MatrixType):
|
637
|
+
if isinstance(val, Expr):
|
638
|
+
if not val.is_tensor():
|
639
|
+
return False
|
640
|
+
if val.get_shape() != dtype.get_shape():
|
641
|
+
return False
|
642
|
+
elif dtype in primitive_types.integer_types:
|
643
|
+
if isinstance(val, Expr):
|
644
|
+
if val.is_tensor() or val.is_struct() or val.element_type() not in primitive_types.integer_types:
|
645
|
+
return False
|
646
|
+
elif not isinstance(val, (int, np.integer)):
|
647
|
+
return False
|
648
|
+
elif dtype in primitive_types.real_types:
|
649
|
+
if isinstance(val, Expr):
|
650
|
+
if val.is_tensor() or val.is_struct() or val.element_type() not in primitive_types.real_types:
|
651
|
+
return False
|
652
|
+
elif not isinstance(val, (float, np.floating)):
|
653
|
+
return False
|
654
|
+
return True
|
655
|
+
|
656
|
+
def from_gstaichi_object(self, func_ret, ret_index=()):
|
657
|
+
d = {}
|
658
|
+
items = self.members.items()
|
659
|
+
for index, pair in enumerate(items):
|
660
|
+
name, dtype = pair
|
661
|
+
if isinstance(dtype, CompoundType):
|
662
|
+
d[name] = dtype.from_gstaichi_object(func_ret, ret_index + (index,))
|
663
|
+
else:
|
664
|
+
d[name] = expr.Expr(
|
665
|
+
_ti_core.make_get_element_expr(
|
666
|
+
func_ret.ptr,
|
667
|
+
ret_index + (index,),
|
668
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
669
|
+
)
|
670
|
+
)
|
671
|
+
d["__struct_methods"] = self.methods
|
672
|
+
|
673
|
+
struct = Struct(d)
|
674
|
+
struct._Struct__dtype = self.dtype
|
675
|
+
return struct
|
676
|
+
|
677
|
+
def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
|
678
|
+
d = {}
|
679
|
+
items = self.members.items()
|
680
|
+
for index, pair in enumerate(items):
|
681
|
+
name, dtype = pair
|
682
|
+
if isinstance(dtype, CompoundType):
|
683
|
+
d[name] = dtype.from_kernel_struct_ret(launch_ctx, ret_index + (index,))
|
684
|
+
else:
|
685
|
+
if dtype in primitive_types.integer_types:
|
686
|
+
if is_signed(cook_dtype(dtype)):
|
687
|
+
d[name] = launch_ctx.get_struct_ret_int(ret_index + (index,))
|
688
|
+
else:
|
689
|
+
d[name] = launch_ctx.get_struct_ret_uint(ret_index + (index,))
|
690
|
+
elif dtype in primitive_types.real_types:
|
691
|
+
d[name] = launch_ctx.get_struct_ret_float(ret_index + (index,))
|
692
|
+
else:
|
693
|
+
raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index + (index, )}")
|
694
|
+
d["__struct_methods"] = self.methods
|
695
|
+
|
696
|
+
struct = Struct(d)
|
697
|
+
struct._Struct__dtype = self.dtype
|
698
|
+
return struct
|
699
|
+
|
700
|
+
def set_kernel_struct_args(self, struct, launch_ctx, ret_index=()):
|
701
|
+
# TODO: move this to class Struct after we add dtype to Struct
|
702
|
+
items = self.members.items()
|
703
|
+
for index, pair in enumerate(items):
|
704
|
+
name, dtype = pair
|
705
|
+
if isinstance(dtype, CompoundType):
|
706
|
+
dtype.set_kernel_struct_args(struct[name], launch_ctx, ret_index + (index,))
|
707
|
+
else:
|
708
|
+
if dtype in primitive_types.integer_types:
|
709
|
+
if is_signed(cook_dtype(dtype)):
|
710
|
+
launch_ctx.set_struct_arg_int(ret_index + (index,), struct[name])
|
711
|
+
else:
|
712
|
+
launch_ctx.set_struct_arg_uint(ret_index + (index,), struct[name])
|
713
|
+
elif dtype in primitive_types.real_types:
|
714
|
+
launch_ctx.set_struct_arg_float(ret_index + (index,), struct[name])
|
715
|
+
else:
|
716
|
+
raise GsTaichiRuntimeTypeError(f"Invalid argument type on index={ret_index + (index, )}")
|
717
|
+
|
718
|
+
def cast(self, struct):
|
719
|
+
# sanity check members
|
720
|
+
if self.members.keys() != struct._Struct__entries.keys():
|
721
|
+
raise GsTaichiSyntaxError("Incompatible arguments for custom struct members!")
|
722
|
+
entries = {}
|
723
|
+
for k, dtype in self.members.items():
|
724
|
+
if isinstance(dtype, MatrixType):
|
725
|
+
entries[k] = dtype(struct._Struct__entries[k])
|
726
|
+
elif isinstance(dtype, CompoundType):
|
727
|
+
entries[k] = dtype.cast(struct._Struct__entries[k])
|
728
|
+
else:
|
729
|
+
if in_python_scope():
|
730
|
+
v = struct._Struct__entries[k]
|
731
|
+
entries[k] = int(v) if dtype in primitive_types.integer_types else float(v)
|
732
|
+
else:
|
733
|
+
entries[k] = ops.cast(struct._Struct__entries[k], dtype)
|
734
|
+
entries["__struct_methods"] = self.methods
|
735
|
+
struct = Struct(entries)
|
736
|
+
struct._Struct__dtype = self.dtype
|
737
|
+
return struct
|
738
|
+
|
739
|
+
def filled_with_scalar(self, value):
|
740
|
+
entries = {}
|
741
|
+
for k, dtype in self.members.items():
|
742
|
+
if isinstance(dtype, MatrixType):
|
743
|
+
entries[k] = dtype(value)
|
744
|
+
elif isinstance(dtype, CompoundType):
|
745
|
+
entries[k] = dtype.filled_with_scalar(value)
|
746
|
+
else:
|
747
|
+
entries[k] = value
|
748
|
+
entries["__struct_methods"] = self.methods
|
749
|
+
struct = Struct(entries)
|
750
|
+
struct._Struct__dtype = self.dtype
|
751
|
+
return struct
|
752
|
+
|
753
|
+
def field(self, **kwargs):
|
754
|
+
return Struct.field(self.members, self.methods, **kwargs)
|
755
|
+
|
756
|
+
def __str__(self):
|
757
|
+
"""Python scope struct type print support."""
|
758
|
+
item_str = ", ".join([str(k) + "=" + str(v) for k, v in self.members.items()])
|
759
|
+
item_str += f", struct_methods={self.methods}"
|
760
|
+
return f"<ti.StructType {item_str}>"
|
761
|
+
|
762
|
+
|
763
|
+
def dataclass(cls):
|
764
|
+
"""Converts a class with field annotations and methods into a gstaichi struct type.
|
765
|
+
|
766
|
+
This will return a normal custom struct type, with the functions added to it.
|
767
|
+
Struct fields can be generated in the normal way from the struct type.
|
768
|
+
Functions in the class can be run on the struct instance.
|
769
|
+
|
770
|
+
This class decorator inspects the class for annotations and methods and
|
771
|
+
1. Sets the annotations as fields for the struct
|
772
|
+
2. Attaches the methods to the struct type
|
773
|
+
|
774
|
+
Example::
|
775
|
+
|
776
|
+
>>> @ti.dataclass
|
777
|
+
>>> class Sphere:
|
778
|
+
>>> center: vec3
|
779
|
+
>>> radius: ti.f32
|
780
|
+
>>>
|
781
|
+
>>> @ti.func
|
782
|
+
>>> def area(self):
|
783
|
+
>>> return 4 * 3.14 * self.radius * self.radius
|
784
|
+
>>>
|
785
|
+
>>> my_spheres = Sphere.field(shape=(n, ))
|
786
|
+
>>> my_sphere[2].area()
|
787
|
+
|
788
|
+
Args:
|
789
|
+
cls (Class): the class with annotations and methods to convert to a struct
|
790
|
+
|
791
|
+
Returns:
|
792
|
+
A gstaichi struct with the annotations as fields
|
793
|
+
and methods from the class attached.
|
794
|
+
"""
|
795
|
+
# save the annotation fields for the struct
|
796
|
+
fields = getattr(cls, "__annotations__", {})
|
797
|
+
# raise error if there are default values
|
798
|
+
for k in fields.keys():
|
799
|
+
if hasattr(cls, k):
|
800
|
+
raise GsTaichiSyntaxError("Default value in @dataclass is not supported.")
|
801
|
+
# get the class methods to be attached to the struct types
|
802
|
+
fields["__struct_methods"] = {
|
803
|
+
attribute: getattr(cls, attribute)
|
804
|
+
for attribute in dir(cls)
|
805
|
+
if callable(getattr(cls, attribute)) and not attribute.startswith("__")
|
806
|
+
}
|
807
|
+
return StructType(**fields)
|
808
|
+
|
809
|
+
|
810
|
+
__all__ = ["Struct", "StructField", "dataclass"]
|