gstaichi 2.1.1rc3__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/field.py
ADDED
@@ -0,0 +1,428 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import gstaichi.lang
|
4
|
+
from gstaichi._lib import core as _ti_core
|
5
|
+
from gstaichi._logging import warn
|
6
|
+
from gstaichi.lang import impl
|
7
|
+
from gstaichi.lang.exception import GsTaichiSyntaxError
|
8
|
+
from gstaichi.lang.util import (
|
9
|
+
in_python_scope,
|
10
|
+
python_scope,
|
11
|
+
to_numpy_type,
|
12
|
+
to_pytorch_type,
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
class Field:
|
17
|
+
"""GsTaichi field class.
|
18
|
+
|
19
|
+
A field is constructed by a list of field members.
|
20
|
+
For example, a scalar field has 1 field member, while a 3x3 matrix field has 9 field members.
|
21
|
+
A field member is a Python Expr wrapping a C++ FieldExpression.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
vars (List[Expr]): Field members.
|
25
|
+
"""
|
26
|
+
|
27
|
+
def __init__(self, _vars):
|
28
|
+
assert all(_vars)
|
29
|
+
self.vars = _vars
|
30
|
+
self.host_accessors = None
|
31
|
+
self.grad = None
|
32
|
+
self.dual = None
|
33
|
+
|
34
|
+
@property
|
35
|
+
def snode(self):
|
36
|
+
"""Gets representative SNode for info purposes.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
SNode: Representative SNode (SNode of first field member).
|
40
|
+
"""
|
41
|
+
return self._snode
|
42
|
+
|
43
|
+
@property
|
44
|
+
def _snode(self):
|
45
|
+
"""Gets representative SNode for info purposes.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
SNode: Representative SNode (SNode of first field member).
|
49
|
+
"""
|
50
|
+
return gstaichi.lang.snode.SNode(self.vars[0].ptr.snode())
|
51
|
+
|
52
|
+
@property
|
53
|
+
def shape(self):
|
54
|
+
"""Gets field shape.
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
Tuple[Int]: Field shape.
|
58
|
+
"""
|
59
|
+
return self._snode.shape
|
60
|
+
|
61
|
+
@property
|
62
|
+
def dtype(self):
|
63
|
+
"""Gets data type of each individual value.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
DataType: Data type of each individual value.
|
67
|
+
"""
|
68
|
+
return self._snode._dtype
|
69
|
+
|
70
|
+
@property
|
71
|
+
def _name(self):
|
72
|
+
"""Gets field name.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
str: Field name.
|
76
|
+
"""
|
77
|
+
return self._snode._name
|
78
|
+
|
79
|
+
def parent(self, n=1):
|
80
|
+
"""Gets an ancestor of the representative SNode in the SNode tree.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
n (int): the number of levels going up from the representative SNode.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
SNode: The n-th parent of the representative SNode.
|
87
|
+
"""
|
88
|
+
return self.snode.parent(n)
|
89
|
+
|
90
|
+
def _get_field_members(self):
|
91
|
+
"""Gets field members.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
List[Expr]: Field members.
|
95
|
+
"""
|
96
|
+
return self.vars
|
97
|
+
|
98
|
+
def _loop_range(self):
|
99
|
+
"""Gets SNode of representative field member for loop range info.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
gstaichi_python.SNode: SNode of representative (first) field member.
|
103
|
+
"""
|
104
|
+
return self.vars[0].ptr.snode()
|
105
|
+
|
106
|
+
def _set_grad(self, grad):
|
107
|
+
"""Sets corresponding grad field (reverse mode).
|
108
|
+
Args:
|
109
|
+
grad (Field): Corresponding grad field.
|
110
|
+
"""
|
111
|
+
self.grad = grad
|
112
|
+
|
113
|
+
def _set_dual(self, dual):
|
114
|
+
"""Sets corresponding dual field (forward mode).
|
115
|
+
|
116
|
+
Args:
|
117
|
+
dual (Field): Corresponding dual field.
|
118
|
+
"""
|
119
|
+
self.dual = dual
|
120
|
+
|
121
|
+
@python_scope
|
122
|
+
def fill(self, val):
|
123
|
+
"""Fills `self` with a specific value.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
val (Union[int, float]): Value to fill.
|
127
|
+
"""
|
128
|
+
raise NotImplementedError()
|
129
|
+
|
130
|
+
@python_scope
|
131
|
+
def to_numpy(self, dtype=None):
|
132
|
+
"""Converts `self` to a numpy array.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
dtype (DataType, optional): The desired data type of returned numpy array.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
numpy.ndarray: The result numpy array.
|
139
|
+
"""
|
140
|
+
raise NotImplementedError()
|
141
|
+
|
142
|
+
@python_scope
|
143
|
+
def to_torch(self, device=None):
|
144
|
+
"""Converts `self` to a torch tensor.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
device (torch.device, optional): The desired device of returned tensor.
|
148
|
+
|
149
|
+
Returns:
|
150
|
+
torch.tensor: The result torch tensor.
|
151
|
+
"""
|
152
|
+
raise NotImplementedError()
|
153
|
+
|
154
|
+
@python_scope
|
155
|
+
def from_numpy(self, arr):
|
156
|
+
"""Loads all elements from a numpy array.
|
157
|
+
|
158
|
+
The shape of the numpy array needs to be the same as `self`.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
arr (numpy.ndarray): The source numpy array.
|
162
|
+
"""
|
163
|
+
raise NotImplementedError()
|
164
|
+
|
165
|
+
@python_scope
|
166
|
+
def _from_external_arr(self, arr):
|
167
|
+
raise NotImplementedError()
|
168
|
+
|
169
|
+
@python_scope
|
170
|
+
def from_torch(self, arr):
|
171
|
+
"""Loads all elements from a torch tensor.
|
172
|
+
|
173
|
+
The shape of the torch tensor needs to be the same as `self`.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
arr (torch.tensor): The source torch tensor.
|
177
|
+
"""
|
178
|
+
self._from_external_arr(arr.contiguous())
|
179
|
+
|
180
|
+
@python_scope
|
181
|
+
def copy_from(self, other):
|
182
|
+
"""Copies all elements from another field.
|
183
|
+
|
184
|
+
The shape of the other field needs to be the same as `self`.
|
185
|
+
|
186
|
+
Args:
|
187
|
+
other (Field): The source field.
|
188
|
+
"""
|
189
|
+
if not isinstance(other, Field):
|
190
|
+
raise TypeError("Cannot copy from a non-field object")
|
191
|
+
if self.shape != other.shape:
|
192
|
+
raise ValueError(f"ti.field shape {self.shape} does not match" f" the source field shape {other.shape}")
|
193
|
+
from gstaichi._kernels import tensor_to_tensor # pylint: disable=C0415
|
194
|
+
|
195
|
+
tensor_to_tensor(self, other)
|
196
|
+
|
197
|
+
@python_scope
|
198
|
+
def __setitem__(self, key, value):
|
199
|
+
"""Sets field element in Python scope.
|
200
|
+
|
201
|
+
Args:
|
202
|
+
key (Union[List[int], int, None]): Coordinates of the field element.
|
203
|
+
value (element type): Value to set.
|
204
|
+
"""
|
205
|
+
raise NotImplementedError()
|
206
|
+
|
207
|
+
@python_scope
|
208
|
+
def __getitem__(self, key):
|
209
|
+
"""Gets field element in Python scope.
|
210
|
+
|
211
|
+
Args:
|
212
|
+
key (Union[List[int], int, None]): Coordinates of the field element.
|
213
|
+
|
214
|
+
Returns:
|
215
|
+
element type: Value retrieved.
|
216
|
+
"""
|
217
|
+
raise NotImplementedError()
|
218
|
+
|
219
|
+
def __str__(self):
|
220
|
+
if gstaichi.lang.impl.inside_kernel():
|
221
|
+
return self.__repr__() # make pybind11 happy, see Matrix.__str__
|
222
|
+
if self._snode.ptr is None:
|
223
|
+
return "<Field: Definition of this field is incomplete>"
|
224
|
+
return str(self.to_numpy())
|
225
|
+
|
226
|
+
def _pad_key(self, key):
|
227
|
+
if key is None:
|
228
|
+
key = ()
|
229
|
+
if not isinstance(key, (tuple, list)):
|
230
|
+
key = (key,)
|
231
|
+
|
232
|
+
if len(key) != len(self.shape):
|
233
|
+
raise AssertionError("Slicing is not supported on ti.field")
|
234
|
+
|
235
|
+
return key + ((0,) * (_ti_core.get_max_num_indices() - len(key)))
|
236
|
+
|
237
|
+
def _initialize_host_accessors(self):
|
238
|
+
if self.host_accessors:
|
239
|
+
return
|
240
|
+
gstaichi.lang.impl.get_runtime().materialize()
|
241
|
+
self.host_accessors = [SNodeHostAccessor(e.ptr.snode()) for e in self.vars]
|
242
|
+
|
243
|
+
def _host_access(self, key):
|
244
|
+
return [SNodeHostAccess(e, key) for e in self.host_accessors]
|
245
|
+
|
246
|
+
def __iter__(self):
|
247
|
+
raise NotImplementedError("Struct for is only available in GsTaichi scope.")
|
248
|
+
|
249
|
+
|
250
|
+
class ScalarField(Field):
|
251
|
+
"""GsTaichi scalar field with SNode implementation.
|
252
|
+
|
253
|
+
Args:
|
254
|
+
var (Expr): Field member.
|
255
|
+
"""
|
256
|
+
|
257
|
+
def __init__(self, var):
|
258
|
+
super().__init__([var])
|
259
|
+
|
260
|
+
def fill(self, val):
|
261
|
+
"""Fills this scalar field with a specified value."""
|
262
|
+
if in_python_scope():
|
263
|
+
from gstaichi._kernels import fill_field # pylint: disable=C0415
|
264
|
+
|
265
|
+
fill_field(self, val)
|
266
|
+
else:
|
267
|
+
from gstaichi._funcs import ( # pylint: disable=C0415
|
268
|
+
field_fill_gstaichi_scope, # pylint: disable=C0415
|
269
|
+
)
|
270
|
+
|
271
|
+
field_fill_gstaichi_scope(self, val)
|
272
|
+
|
273
|
+
@python_scope
|
274
|
+
def to_numpy(self, dtype=None):
|
275
|
+
"""Converts this field to a `numpy.ndarray`."""
|
276
|
+
if self.parent()._snode.ptr.type == _ti_core.SNodeType.dynamic:
|
277
|
+
warn(
|
278
|
+
"You are trying to convert a dynamic snode to a numpy array, be aware that inactive items in the snode will be converted to zeros in the resulting array."
|
279
|
+
)
|
280
|
+
if dtype is None:
|
281
|
+
dtype = to_numpy_type(self.dtype)
|
282
|
+
import numpy as np # pylint: disable=C0415
|
283
|
+
|
284
|
+
arr = np.zeros(shape=self.shape, dtype=dtype)
|
285
|
+
from gstaichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
|
286
|
+
|
287
|
+
tensor_to_ext_arr(self, arr)
|
288
|
+
gstaichi.lang.runtime_ops.sync()
|
289
|
+
return arr
|
290
|
+
|
291
|
+
@python_scope
|
292
|
+
def to_torch(self, device=None):
|
293
|
+
"""Converts this field to a `torch.tensor`."""
|
294
|
+
import torch # pylint: disable=C0415
|
295
|
+
|
296
|
+
# pylint: disable=E1101
|
297
|
+
arr = torch.zeros(size=self.shape, dtype=to_pytorch_type(self.dtype), device=device)
|
298
|
+
from gstaichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
|
299
|
+
|
300
|
+
tensor_to_ext_arr(self, arr)
|
301
|
+
gstaichi.lang.runtime_ops.sync()
|
302
|
+
return arr
|
303
|
+
|
304
|
+
@python_scope
|
305
|
+
def _from_external_arr(self, arr):
|
306
|
+
if len(self.shape) != len(arr.shape):
|
307
|
+
raise ValueError(f"ti.field shape {self.shape} does not match" f" the numpy array shape {arr.shape}")
|
308
|
+
for i, _ in enumerate(self.shape):
|
309
|
+
if self.shape[i] != arr.shape[i]:
|
310
|
+
raise ValueError(f"ti.field shape {self.shape} does not match" f" the numpy array shape {arr.shape}")
|
311
|
+
from gstaichi._kernels import ext_arr_to_tensor # pylint: disable=C0415
|
312
|
+
|
313
|
+
ext_arr_to_tensor(arr, self)
|
314
|
+
gstaichi.lang.runtime_ops.sync()
|
315
|
+
|
316
|
+
@python_scope
|
317
|
+
def from_numpy(self, arr):
|
318
|
+
"""Copies the data from a `numpy.ndarray` into this field."""
|
319
|
+
if not arr.flags.c_contiguous:
|
320
|
+
import numpy as np # pylint: disable=C0415
|
321
|
+
|
322
|
+
arr = np.ascontiguousarray(arr)
|
323
|
+
self._from_external_arr(arr)
|
324
|
+
|
325
|
+
@python_scope
|
326
|
+
def __setitem__(self, key, value):
|
327
|
+
self._initialize_host_accessors()
|
328
|
+
self.host_accessors[0].setter(value, *self._pad_key(key))
|
329
|
+
|
330
|
+
@python_scope
|
331
|
+
def __getitem__(self, key):
|
332
|
+
self._initialize_host_accessors()
|
333
|
+
# Check for potential slicing behaviour
|
334
|
+
# for instance: x[0, :]
|
335
|
+
padded_key = self._pad_key(key)
|
336
|
+
import numpy as np # pylint: disable=C0415
|
337
|
+
|
338
|
+
for key in padded_key:
|
339
|
+
if not isinstance(key, (int, np.integer)):
|
340
|
+
raise TypeError(
|
341
|
+
f"Detected illegal element of type: {type(key)}. "
|
342
|
+
f"Please be aware that slicing a ti.field is not supported so far."
|
343
|
+
)
|
344
|
+
return self.host_accessors[0].getter(*padded_key)
|
345
|
+
|
346
|
+
def __repr__(self):
|
347
|
+
# make interactive shell happy, prevent materialization
|
348
|
+
return "<ti.field>"
|
349
|
+
|
350
|
+
|
351
|
+
class SNodeHostAccessor:
|
352
|
+
def __init__(self, snode):
|
353
|
+
if _ti_core.is_real(snode.data_type()):
|
354
|
+
write_func = snode.write_float
|
355
|
+
read_func = snode.read_float
|
356
|
+
else:
|
357
|
+
|
358
|
+
def write_func(key, value):
|
359
|
+
if value >= 0:
|
360
|
+
snode.write_uint(key, value)
|
361
|
+
else:
|
362
|
+
snode.write_int(key, value)
|
363
|
+
|
364
|
+
if _ti_core.is_signed(snode.data_type()):
|
365
|
+
read_func = snode.read_int
|
366
|
+
else:
|
367
|
+
read_func = snode.read_uint
|
368
|
+
|
369
|
+
def getter(*key):
|
370
|
+
assert len(key) == _ti_core.get_max_num_indices()
|
371
|
+
return read_func(key)
|
372
|
+
|
373
|
+
def setter(value, *key):
|
374
|
+
assert len(key) == _ti_core.get_max_num_indices()
|
375
|
+
write_func(key, value)
|
376
|
+
# same as above
|
377
|
+
if (
|
378
|
+
impl.get_runtime().target_tape
|
379
|
+
and impl.get_runtime().target_tape.grad_checker
|
380
|
+
and not impl.get_runtime().grad_replaced
|
381
|
+
):
|
382
|
+
for x in impl.get_runtime().target_tape.grad_checker.to_check:
|
383
|
+
assert snode != x.snode.ptr, "Overwritten is prohibitive when doing grad check."
|
384
|
+
impl.get_runtime().target_tape.insert(write_func, (key, value))
|
385
|
+
|
386
|
+
self.getter = getter
|
387
|
+
self.setter = setter
|
388
|
+
|
389
|
+
|
390
|
+
class SNodeHostAccess:
|
391
|
+
def __init__(self, accessor, key):
|
392
|
+
self.accessor = accessor
|
393
|
+
self.key = key
|
394
|
+
|
395
|
+
|
396
|
+
class BitpackedFields:
|
397
|
+
"""GsTaichi bitpacked fields, where fields with quantized types are packed together.
|
398
|
+
|
399
|
+
Args:
|
400
|
+
max_num_bits (int): Maximum number of bits all fields inside can occupy in total. Only 32 or 64 is allowed.
|
401
|
+
"""
|
402
|
+
|
403
|
+
def __init__(self, max_num_bits):
|
404
|
+
self.fields = []
|
405
|
+
self.bit_struct_type_builder = _ti_core.BitStructTypeBuilder(max_num_bits)
|
406
|
+
|
407
|
+
def place(self, *args, shared_exponent=False):
|
408
|
+
"""Places a list of fields with quantized types inside.
|
409
|
+
|
410
|
+
Args:
|
411
|
+
*args (List[Field]): A list of fields with quantized types to place.
|
412
|
+
shared_exponent (bool): Whether the fields have a shared exponent.
|
413
|
+
"""
|
414
|
+
if shared_exponent:
|
415
|
+
self.bit_struct_type_builder.begin_placing_shared_exponent()
|
416
|
+
count = 0
|
417
|
+
for arg in args:
|
418
|
+
assert isinstance(arg, Field)
|
419
|
+
for var in arg._get_field_members():
|
420
|
+
self.fields.append((var.ptr, self.bit_struct_type_builder.add_member(var.ptr.get_dt())))
|
421
|
+
count += 1
|
422
|
+
if shared_exponent:
|
423
|
+
self.bit_struct_type_builder.end_placing_shared_exponent()
|
424
|
+
if count <= 1:
|
425
|
+
raise GsTaichiSyntaxError("At least 2 fields need to be placed when shared_exponent=True")
|
426
|
+
|
427
|
+
|
428
|
+
__all__ = ["BitpackedFields", "Field", "ScalarField"]
|