gstaichi 0.1.25.dev0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +9 -0
- gstaichi/__init__.py +40 -0
- gstaichi/__main__.py +5 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cp311-win_amd64.pyd +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2937 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
- gstaichi/_lib/runtime/runtime_x64.bc +0 -0
- gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- gstaichi/_lib/utils.py +249 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_main.py +545 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +0 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +103 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +199 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +189 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/argpack.py +411 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1318 -0
- gstaichi/lang/ast/ast_transformer_utils.py +341 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +466 -0
- gstaichi/lang/impl.py +1241 -0
- gstaichi/lang/kernel_arguments.py +157 -0
- gstaichi/lang/kernel_impl.py +1382 -0
- gstaichi/lang/matrix.py +1881 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +778 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +855 -0
- gstaichi/lang/util.py +381 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +47 -0
- gstaichi/types/compound_types.py +90 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +147 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +13 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-link.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools.lib +0 -0
- gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
- gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
- gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
- gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
gstaichi/lang/mesh.py
ADDED
@@ -0,0 +1,687 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import json
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from gstaichi import lang
|
8
|
+
from gstaichi._lib import core as _ti_core
|
9
|
+
from gstaichi.lang import impl
|
10
|
+
from gstaichi.lang.exception import GsTaichiSyntaxError
|
11
|
+
from gstaichi.lang.field import Field, ScalarField
|
12
|
+
from gstaichi.lang.matrix import Matrix, MatrixField
|
13
|
+
from gstaichi.lang.struct import StructField
|
14
|
+
from gstaichi.lang.util import python_scope
|
15
|
+
from gstaichi.types import u16, u32
|
16
|
+
from gstaichi.types.compound_types import CompoundType
|
17
|
+
from gstaichi.types.enums import Layout
|
18
|
+
|
19
|
+
MeshTopology = _ti_core.MeshTopology
|
20
|
+
MeshElementType = _ti_core.MeshElementType
|
21
|
+
MeshRelationType = _ti_core.MeshRelationType
|
22
|
+
ConvType = _ti_core.ConvType
|
23
|
+
element_order = _ti_core.element_order
|
24
|
+
from_end_element_order = _ti_core.from_end_element_order
|
25
|
+
to_end_element_order = _ti_core.to_end_element_order
|
26
|
+
relation_by_orders = _ti_core.relation_by_orders
|
27
|
+
inverse_relation = _ti_core.inverse_relation
|
28
|
+
element_type_name = _ti_core.element_type_name
|
29
|
+
|
30
|
+
|
31
|
+
class MeshAttrType:
|
32
|
+
def __init__(self, name, dtype, reorder, needs_grad):
|
33
|
+
self.name = name
|
34
|
+
self.dtype = dtype
|
35
|
+
self.reorder = reorder
|
36
|
+
self.needs_grad = needs_grad
|
37
|
+
|
38
|
+
|
39
|
+
class MeshReorderedScalarFieldProxy(ScalarField):
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
field: ScalarField,
|
43
|
+
mesh_ptr: _ti_core.MeshPtr,
|
44
|
+
element_type: MeshElementType,
|
45
|
+
g2r_field: ScalarField,
|
46
|
+
):
|
47
|
+
self.vars = field.vars
|
48
|
+
self.host_accessors = field.host_accessors
|
49
|
+
self.grad = field.grad
|
50
|
+
|
51
|
+
self.mesh_ptr = mesh_ptr
|
52
|
+
self.element_type = element_type
|
53
|
+
self.g2r_field = g2r_field
|
54
|
+
|
55
|
+
@python_scope
|
56
|
+
def __setitem__(self, key, value):
|
57
|
+
self._initialize_host_accessors()
|
58
|
+
key = self.g2r_field[key]
|
59
|
+
self.host_accessors[0].setter(value, *self._pad_key(key))
|
60
|
+
|
61
|
+
@python_scope
|
62
|
+
def __getitem__(self, key):
|
63
|
+
self._initialize_host_accessors()
|
64
|
+
key = self.g2r_field[key]
|
65
|
+
return self.host_accessors[0].getter(*self._pad_key(key))
|
66
|
+
|
67
|
+
|
68
|
+
class MeshReorderedMatrixFieldProxy(MatrixField):
|
69
|
+
def __init__(
|
70
|
+
self,
|
71
|
+
field: MatrixField,
|
72
|
+
mesh_ptr: _ti_core.MeshPtr,
|
73
|
+
element_type: MeshElementType,
|
74
|
+
g2r_field: ScalarField,
|
75
|
+
):
|
76
|
+
self.vars = field.vars
|
77
|
+
self.host_accessors = field.host_accessors
|
78
|
+
self.grad = field.grad
|
79
|
+
self.n = field.n
|
80
|
+
self.m = field.m
|
81
|
+
self.ndim = field.ndim
|
82
|
+
self.ptr = field.ptr
|
83
|
+
|
84
|
+
self.mesh_ptr = mesh_ptr
|
85
|
+
self.element_type = element_type
|
86
|
+
self.g2r_field = g2r_field
|
87
|
+
|
88
|
+
@python_scope
|
89
|
+
def __setitem__(self, key, value):
|
90
|
+
self._initialize_host_accessors()
|
91
|
+
self[key]._set_entries(value)
|
92
|
+
|
93
|
+
@python_scope
|
94
|
+
def __getitem__(self, key):
|
95
|
+
self._initialize_host_accessors()
|
96
|
+
key = self.g2r_field[key]
|
97
|
+
key = self._pad_key(key)
|
98
|
+
return Matrix(self._host_access(key))
|
99
|
+
|
100
|
+
|
101
|
+
class MeshElementField:
|
102
|
+
def __init__(self, mesh_instance, _type, attr_dict, field_dict, g2r_field):
|
103
|
+
self.mesh = mesh_instance
|
104
|
+
self._type = _type
|
105
|
+
self.attr_dict = attr_dict
|
106
|
+
self.field_dict = field_dict
|
107
|
+
self.g2r_field = g2r_field
|
108
|
+
|
109
|
+
self._register_fields()
|
110
|
+
|
111
|
+
def place(self, members, reorder=False, needs_grad=False, layout=Layout.SOA):
|
112
|
+
"""Declares mesh attributes for the mesh element in current mesh.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
members (Dict[str, Union[PrimitiveType, MatrixType]]): \
|
116
|
+
names and types for element attributes.
|
117
|
+
reorder: True if reorders the internal memory for coalesced data access within mesh-for loop.
|
118
|
+
needs_grad: True if needs to record grad.
|
119
|
+
layout: ti.Layout.AoS/ti.Layout.SoA
|
120
|
+
|
121
|
+
Example::
|
122
|
+
>>> import meshgstaichi_patcher as Patcher
|
123
|
+
>>> vec3 = ti.types.vector(3, ti.f32)
|
124
|
+
>>> mesh = Patcher.load_mesh("bunny.obj", relations=['FV'])
|
125
|
+
>>> mesh.faces.place({'area' : ti.f32}) # declares a mesh attribute `area` for each face element.
|
126
|
+
>>> mesh.verts.place({'pos' : vec3}, reorder=True) # declares a mesh attribute `pos` for each vertex element, and reorder it in memory.
|
127
|
+
"""
|
128
|
+
|
129
|
+
for key, dtype in members.items():
|
130
|
+
if key in {"verts", "edges", "faces", "cells"}:
|
131
|
+
raise GsTaichiSyntaxError(
|
132
|
+
f"'{key}' cannot use as attribute name. It has been reserved as MeshGsTaichi's keyword."
|
133
|
+
)
|
134
|
+
if key in self.attr_dict:
|
135
|
+
raise GsTaichiSyntaxError(f"'{key}' has already use as attribute name.")
|
136
|
+
|
137
|
+
# init attr type
|
138
|
+
self.attr_dict[key] = MeshAttrType(key, dtype, reorder, needs_grad)
|
139
|
+
|
140
|
+
# init field
|
141
|
+
if isinstance(dtype, CompoundType):
|
142
|
+
self.field_dict[key] = dtype.field(shape=None, needs_grad=needs_grad)
|
143
|
+
else:
|
144
|
+
self.field_dict[key] = impl.field(dtype, shape=None, needs_grad=needs_grad)
|
145
|
+
|
146
|
+
size = _ti_core.get_num_elements(self.mesh.mesh_ptr, self._type)
|
147
|
+
if layout == Layout.SOA:
|
148
|
+
for key in members.keys():
|
149
|
+
impl.root.dense(impl.axes(0), size).place(self.field_dict[key])
|
150
|
+
if self.attr_dict[key].needs_grad:
|
151
|
+
impl.root.dense(impl.axes(0), size).place(self.field_dict[key].grad)
|
152
|
+
elif len(members) > 0:
|
153
|
+
_member_fields = {}
|
154
|
+
for key in members.keys():
|
155
|
+
_member_fields[key] = self.field_dict[key]
|
156
|
+
impl.root.dense(impl.axes(0), size).place(*tuple(_member_fields.values()))
|
157
|
+
grads = []
|
158
|
+
for key in members.keys():
|
159
|
+
if self.attr_dict[key].needs_grad:
|
160
|
+
grads.append(self.field_dict[key].grad)
|
161
|
+
if len(grads) > 0:
|
162
|
+
impl.root.dense(impl.axes(0), size).place(*grads)
|
163
|
+
|
164
|
+
for key, dtype in members.items():
|
165
|
+
# expose interface
|
166
|
+
setattr(MeshElementField, key, property(fget=MeshElementField._make_getter(key)))
|
167
|
+
|
168
|
+
@property
|
169
|
+
def keys(self):
|
170
|
+
return list(self.field_dict.keys())
|
171
|
+
|
172
|
+
@property
|
173
|
+
def _members(self):
|
174
|
+
return list(self.field_dict.values())
|
175
|
+
|
176
|
+
@property
|
177
|
+
def _items(self):
|
178
|
+
return self.field_dict.items()
|
179
|
+
|
180
|
+
@staticmethod
|
181
|
+
def _make_getter(key):
|
182
|
+
def getter(self):
|
183
|
+
if key not in self.getter_dict:
|
184
|
+
if self.attr_dict[key].reorder:
|
185
|
+
if isinstance(self.field_dict[key], ScalarField):
|
186
|
+
self.getter_dict[key] = MeshReorderedScalarFieldProxy(
|
187
|
+
self.field_dict[key],
|
188
|
+
self.mesh.mesh_ptr,
|
189
|
+
self._type,
|
190
|
+
self.g2r_field,
|
191
|
+
)
|
192
|
+
elif isinstance(self.field_dict[key], MatrixField):
|
193
|
+
self.getter_dict[key] = MeshReorderedMatrixFieldProxy(
|
194
|
+
self.field_dict[key],
|
195
|
+
self.mesh.mesh_ptr,
|
196
|
+
self._type,
|
197
|
+
self.g2r_field,
|
198
|
+
)
|
199
|
+
else:
|
200
|
+
self.getter_dict[key] = self.field_dict[key]
|
201
|
+
"""Get an entry from custom struct by name."""
|
202
|
+
return self.getter_dict[key]
|
203
|
+
|
204
|
+
return getter
|
205
|
+
|
206
|
+
def _register_fields(self):
|
207
|
+
self.getter_dict = {}
|
208
|
+
for k in self.keys:
|
209
|
+
setattr(MeshElementField, k, property(fget=MeshElementField._make_getter(k)))
|
210
|
+
|
211
|
+
def _get_field_members(self):
|
212
|
+
field_members = []
|
213
|
+
for m in self._members:
|
214
|
+
assert isinstance(m, Field)
|
215
|
+
field_members += m._get_field_members()
|
216
|
+
return field_members
|
217
|
+
|
218
|
+
def _initialize_host_accessors(self):
|
219
|
+
for v in self._members:
|
220
|
+
v._initialize_host_accessors()
|
221
|
+
|
222
|
+
def get_member_field(self, key):
|
223
|
+
return self.field_dict[key]
|
224
|
+
|
225
|
+
@python_scope
|
226
|
+
def __len__(self):
|
227
|
+
return _ti_core.get_num_elements(self.mesh.mesh_ptr, self._type)
|
228
|
+
|
229
|
+
|
230
|
+
class MeshElement:
|
231
|
+
def __init__(self, _type, builder):
|
232
|
+
self.builder = builder
|
233
|
+
self._type = _type
|
234
|
+
self.layout = Layout.SOA
|
235
|
+
self.attr_dict = {}
|
236
|
+
|
237
|
+
def _SOA(self, soa=True): # AOS/SOA
|
238
|
+
self.layout = Layout.SOA if soa else Layout.AOS
|
239
|
+
|
240
|
+
def _AOS(self, aos=True):
|
241
|
+
self.layout = Layout.AOS if aos else Layout.SOA
|
242
|
+
|
243
|
+
SOA = property(fset=_SOA)
|
244
|
+
"""(Deprecated) Set `True` for SOA (structure of arrays) layout.
|
245
|
+
"""
|
246
|
+
AOS = property(fset=_AOS)
|
247
|
+
"""(Deprecated) Set `True` for AOS (array of structures) layout.
|
248
|
+
"""
|
249
|
+
|
250
|
+
def place(
|
251
|
+
self,
|
252
|
+
members,
|
253
|
+
reorder=False,
|
254
|
+
needs_grad=False,
|
255
|
+
):
|
256
|
+
"""(Deprecated) Declares mesh attributes for the mesh element in current mesh builder.
|
257
|
+
|
258
|
+
Args:
|
259
|
+
members (Dict[str, Union[PrimitiveType, MatrixType]]): \
|
260
|
+
names and types for element attributes.
|
261
|
+
reorder: True if reorders the internal memory for coalesced data access within mesh-for loop.
|
262
|
+
needs_grad: True if needs to record grad.
|
263
|
+
|
264
|
+
Example::
|
265
|
+
>>> vec3 = ti.types.vector(3, ti.f32)
|
266
|
+
>>> mesh = ti.TriMesh()
|
267
|
+
>>> mesh.faces.place({'area' : ti.f32}) # declares a mesh attribute `area` for each face element.
|
268
|
+
>>> mesh.verts.place({'pos' : vec3}, reorder=True) # declares a mesh attribute `pos` for each vertex element, and reorder it in memory.
|
269
|
+
"""
|
270
|
+
for key, dtype in members.items():
|
271
|
+
if key in {"verts", "edges", "faces", "cells"}:
|
272
|
+
raise GsTaichiSyntaxError(
|
273
|
+
f"'{key}' cannot use as attribute name. It has been reserved as MeshGsTaichi's keyword."
|
274
|
+
)
|
275
|
+
self.attr_dict[key] = MeshAttrType(key, dtype, reorder, needs_grad)
|
276
|
+
|
277
|
+
def build(self, mesh_instance, size, g2r_field):
|
278
|
+
field_dict = {}
|
279
|
+
|
280
|
+
for key, attr in self.attr_dict.items():
|
281
|
+
if isinstance(attr.dtype, CompoundType):
|
282
|
+
field_dict[key] = attr.dtype.field(shape=None, needs_grad=attr.needs_grad)
|
283
|
+
else:
|
284
|
+
field_dict[key] = impl.field(attr.dtype, shape=None, needs_grad=attr.needs_grad)
|
285
|
+
|
286
|
+
if self.layout == Layout.SOA:
|
287
|
+
for key, field in field_dict.items():
|
288
|
+
impl.root.dense(impl.axes(0), size).place(field)
|
289
|
+
if self.attr_dict[key].needs_grad:
|
290
|
+
impl.root.dense(impl.axes(0), size).place(field.grad)
|
291
|
+
elif len(field_dict) > 0:
|
292
|
+
impl.root.dense(impl.axes(0), size).place(*tuple(field_dict.values()))
|
293
|
+
grads = []
|
294
|
+
for key, field in field_dict.items():
|
295
|
+
if self.attr_dict[key].needs_grad:
|
296
|
+
grads.append(field.grad)
|
297
|
+
if len(grads) > 0:
|
298
|
+
impl.root.dense(impl.axes(0), size).place(*grads)
|
299
|
+
|
300
|
+
return MeshElementField(mesh_instance, self._type, self.attr_dict, field_dict, g2r_field)
|
301
|
+
|
302
|
+
|
303
|
+
# Define the instance of the Mesh Type, stores the field (type and data) info
|
304
|
+
class MeshInstance:
|
305
|
+
def __init__(self):
|
306
|
+
self.mesh_ptr = _ti_core.create_mesh()
|
307
|
+
self.relation_set = set()
|
308
|
+
self.verts = MeshElementField(self, MeshElementType.Vertex, {}, {}, {})
|
309
|
+
self.edges = MeshElementField(self, MeshElementType.Edge, {}, {}, {})
|
310
|
+
self.faces = MeshElementField(self, MeshElementType.Face, {}, {}, {})
|
311
|
+
self.cells = MeshElementField(self, MeshElementType.Cell, {}, {}, {})
|
312
|
+
|
313
|
+
def get_position_as_numpy(self):
|
314
|
+
"""Get the vertex position of current mesh to numpy array.
|
315
|
+
|
316
|
+
Returns:
|
317
|
+
3d numpy array: [x, y, z] with float-format.
|
318
|
+
"""
|
319
|
+
if hasattr(self, "_vert_position"):
|
320
|
+
return self._vert_position
|
321
|
+
raise GsTaichiSyntaxError("Position info is not in the file.")
|
322
|
+
|
323
|
+
def set_owned_offset(self, element_type: MeshElementType, owned_offset: ScalarField):
|
324
|
+
_ti_core.set_owned_offset(self.mesh_ptr, element_type, owned_offset.vars[0].ptr.snode())
|
325
|
+
|
326
|
+
def set_total_offset(self, element_type: MeshElementType, total_offset: ScalarField):
|
327
|
+
_ti_core.set_total_offset(self.mesh_ptr, element_type, total_offset.vars[0].ptr.snode())
|
328
|
+
|
329
|
+
def set_index_mapping(self, element_type: MeshElementType, conv_type: ConvType, mapping: ScalarField):
|
330
|
+
_ti_core.set_index_mapping(self.mesh_ptr, element_type, conv_type, mapping.vars[0].ptr.snode())
|
331
|
+
|
332
|
+
def set_num_patches(self, num_patches: int):
|
333
|
+
_ti_core.set_num_patches(self.mesh_ptr, num_patches)
|
334
|
+
|
335
|
+
def set_patch_max_element_num(self, element_type: MeshElementType, max_element_num: int):
|
336
|
+
_ti_core.set_patch_max_element_num(self.mesh_ptr, element_type, max_element_num)
|
337
|
+
|
338
|
+
def set_relation_fixed(self, rel_type: MeshRelationType, value: ScalarField):
|
339
|
+
self.relation_set.add(rel_type)
|
340
|
+
_ti_core.set_relation_fixed(self.mesh_ptr, rel_type, value.vars[0].ptr.snode())
|
341
|
+
|
342
|
+
def set_relation_dynamic(
|
343
|
+
self,
|
344
|
+
rel_type: MeshRelationType,
|
345
|
+
value: ScalarField,
|
346
|
+
patch_offset: ScalarField,
|
347
|
+
offset: ScalarField,
|
348
|
+
):
|
349
|
+
self.relation_set.add(rel_type)
|
350
|
+
_ti_core.set_relation_dynamic(
|
351
|
+
self.mesh_ptr,
|
352
|
+
rel_type,
|
353
|
+
value.vars[0].ptr.snode(),
|
354
|
+
patch_offset.vars[0].ptr.snode(),
|
355
|
+
offset.vars[0].ptr.snode(),
|
356
|
+
)
|
357
|
+
|
358
|
+
def add_mesh_attribute(self, element_type, snode, reorder_type):
|
359
|
+
_ti_core.add_mesh_attribute(self.mesh_ptr, element_type, snode, reorder_type)
|
360
|
+
|
361
|
+
def get_relation_size(self, from_index, to_element_type):
|
362
|
+
return _ti_core.get_relation_size(
|
363
|
+
self.mesh_ptr,
|
364
|
+
from_index.ptr,
|
365
|
+
to_element_type,
|
366
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
367
|
+
)
|
368
|
+
|
369
|
+
def get_relation_access(self, from_index, to_element_type, neighbor_idx_ptr):
|
370
|
+
return _ti_core.get_relation_access(
|
371
|
+
self.mesh_ptr,
|
372
|
+
from_index.ptr,
|
373
|
+
to_element_type,
|
374
|
+
neighbor_idx_ptr,
|
375
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
376
|
+
)
|
377
|
+
|
378
|
+
|
379
|
+
class MeshMetadata:
|
380
|
+
def __init__(self, data):
|
381
|
+
self.num_patches = data["num_patches"]
|
382
|
+
|
383
|
+
self.element_fields = {}
|
384
|
+
self.relation_fields = {}
|
385
|
+
self.num_elements = {}
|
386
|
+
self.max_num_per_patch = {}
|
387
|
+
|
388
|
+
for element in data["elements"]:
|
389
|
+
element_type = MeshElementType(element["order"])
|
390
|
+
self.num_elements[element_type] = element["num"]
|
391
|
+
self.max_num_per_patch[element_type] = element["max_num_per_patch"]
|
392
|
+
|
393
|
+
element["l2g_mapping"] = np.array(element["l2g_mapping"])
|
394
|
+
element["l2r_mapping"] = np.array(element["l2r_mapping"])
|
395
|
+
element["g2r_mapping"] = np.array(element["g2r_mapping"])
|
396
|
+
self.element_fields[element_type] = {}
|
397
|
+
self.element_fields[element_type]["owned"] = impl.field(dtype=u32, shape=self.num_patches + 1)
|
398
|
+
self.element_fields[element_type]["total"] = impl.field(dtype=u32, shape=self.num_patches + 1)
|
399
|
+
self.element_fields[element_type]["l2g"] = impl.field(dtype=u32, shape=element["l2g_mapping"].shape[0])
|
400
|
+
self.element_fields[element_type]["l2r"] = impl.field(dtype=u32, shape=element["l2r_mapping"].shape[0])
|
401
|
+
self.element_fields[element_type]["g2r"] = impl.field(dtype=u32, shape=element["g2r_mapping"].shape[0])
|
402
|
+
|
403
|
+
for relation in data["relations"]:
|
404
|
+
from_order = relation["from_order"]
|
405
|
+
to_order = relation["to_order"]
|
406
|
+
rel_type = MeshRelationType(relation_by_orders(from_order, to_order))
|
407
|
+
self.relation_fields[rel_type] = {}
|
408
|
+
self.relation_fields[rel_type]["value"] = impl.field(dtype=u16, shape=len(relation["value"]))
|
409
|
+
if from_order <= to_order:
|
410
|
+
self.relation_fields[rel_type]["offset"] = impl.field(dtype=u16, shape=len(relation["offset"]))
|
411
|
+
self.relation_fields[rel_type]["patch_offset"] = impl.field(
|
412
|
+
dtype=u32, shape=len(relation["patch_offset"])
|
413
|
+
)
|
414
|
+
self.relation_fields[rel_type]["from_order"] = from_order
|
415
|
+
self.relation_fields[rel_type]["to_order"] = to_order
|
416
|
+
|
417
|
+
for element in data["elements"]:
|
418
|
+
element_type = MeshElementType(element["order"])
|
419
|
+
self.element_fields[element_type]["owned"].from_numpy(np.array(element["owned_offsets"]))
|
420
|
+
self.element_fields[element_type]["total"].from_numpy(np.array(element["total_offsets"]))
|
421
|
+
self.element_fields[element_type]["l2g"].from_numpy(element["l2g_mapping"])
|
422
|
+
self.element_fields[element_type]["l2r"].from_numpy(element["l2r_mapping"])
|
423
|
+
self.element_fields[element_type]["g2r"].from_numpy(element["g2r_mapping"])
|
424
|
+
|
425
|
+
for relation in data["relations"]:
|
426
|
+
from_order = relation["from_order"]
|
427
|
+
to_order = relation["to_order"]
|
428
|
+
rel_type = MeshRelationType(relation_by_orders(from_order, to_order))
|
429
|
+
self.relation_fields[rel_type]["value"].from_numpy(np.array(relation["value"]))
|
430
|
+
if from_order <= to_order:
|
431
|
+
self.relation_fields[rel_type]["patch_offset"].from_numpy(np.array(relation["patch_offset"]))
|
432
|
+
self.relation_fields[rel_type]["offset"].from_numpy(np.array(relation["offset"]))
|
433
|
+
|
434
|
+
self.attrs = {}
|
435
|
+
self.attrs["x"] = np.array(data["attrs"]["x"]).reshape(-1, 3)
|
436
|
+
if "patcher" in data:
|
437
|
+
self.patcher = data["patcher"]
|
438
|
+
else:
|
439
|
+
self.patcher = None
|
440
|
+
|
441
|
+
|
442
|
+
# Define the Mesh Type, stores the field type info
|
443
|
+
class MeshBuilder:
|
444
|
+
def __init__(self):
|
445
|
+
if not lang.misc.is_extension_supported(impl.current_cfg().arch, lang.extension.mesh):
|
446
|
+
raise Exception("Backend " + str(impl.current_cfg().arch) + " doesn't support MeshGsTaichi extension")
|
447
|
+
|
448
|
+
self.verts = MeshElement(MeshElementType.Vertex, self)
|
449
|
+
self.edges = MeshElement(MeshElementType.Edge, self)
|
450
|
+
self.faces = MeshElement(MeshElementType.Face, self)
|
451
|
+
self.cells = MeshElement(MeshElementType.Cell, self)
|
452
|
+
|
453
|
+
def build(self, metadata: MeshMetadata) -> MeshInstance:
|
454
|
+
instance = MeshInstance()
|
455
|
+
instance.fields = {}
|
456
|
+
|
457
|
+
instance.set_num_patches(metadata.num_patches)
|
458
|
+
|
459
|
+
for element in metadata.element_fields:
|
460
|
+
_ti_core.set_num_elements(instance.mesh_ptr, element, metadata.num_elements[element])
|
461
|
+
instance.set_patch_max_element_num(element, metadata.max_num_per_patch[element])
|
462
|
+
|
463
|
+
element_name = element_type_name(element)
|
464
|
+
setattr(
|
465
|
+
instance,
|
466
|
+
element_name,
|
467
|
+
getattr(self, element_name).build(
|
468
|
+
instance,
|
469
|
+
metadata.num_elements[element],
|
470
|
+
metadata.element_fields[element]["g2r"],
|
471
|
+
),
|
472
|
+
)
|
473
|
+
instance.fields[element] = getattr(instance, element_name)
|
474
|
+
|
475
|
+
instance.set_owned_offset(element, metadata.element_fields[element]["owned"])
|
476
|
+
instance.set_total_offset(element, metadata.element_fields[element]["total"])
|
477
|
+
instance.set_index_mapping(element, ConvType.l2g, metadata.element_fields[element]["l2g"])
|
478
|
+
instance.set_index_mapping(element, ConvType.l2r, metadata.element_fields[element]["l2r"])
|
479
|
+
instance.set_index_mapping(element, ConvType.g2r, metadata.element_fields[element]["g2r"])
|
480
|
+
|
481
|
+
for rel_type in metadata.relation_fields:
|
482
|
+
from_order = metadata.relation_fields[rel_type]["from_order"]
|
483
|
+
to_order = metadata.relation_fields[rel_type]["to_order"]
|
484
|
+
if from_order <= to_order:
|
485
|
+
instance.set_relation_dynamic(
|
486
|
+
rel_type,
|
487
|
+
metadata.relation_fields[rel_type]["value"],
|
488
|
+
metadata.relation_fields[rel_type]["patch_offset"],
|
489
|
+
metadata.relation_fields[rel_type]["offset"],
|
490
|
+
)
|
491
|
+
else:
|
492
|
+
instance.set_relation_fixed(rel_type, metadata.relation_fields[rel_type]["value"])
|
493
|
+
|
494
|
+
instance._vert_position = metadata.attrs["x"]
|
495
|
+
instance.patcher = metadata.patcher
|
496
|
+
|
497
|
+
return instance
|
498
|
+
|
499
|
+
|
500
|
+
# Mesh First Class
|
501
|
+
class Mesh:
|
502
|
+
"""The Mesh type class.
|
503
|
+
|
504
|
+
MeshGsTaichi offers first-class support for triangular/tetrahedral meshes
|
505
|
+
and allows efficient computation on these irregular data structures,
|
506
|
+
only available for backends supporting `ti.extension.mesh`.
|
507
|
+
|
508
|
+
See more details in https://github.com/taichi-dev/meshgstaichi
|
509
|
+
"""
|
510
|
+
|
511
|
+
def __init__(self):
|
512
|
+
pass
|
513
|
+
|
514
|
+
@staticmethod
|
515
|
+
def _create_instance(metadata: MeshMetadata) -> MeshInstance:
|
516
|
+
instance = MeshInstance()
|
517
|
+
instance.fields = {}
|
518
|
+
|
519
|
+
instance.set_num_patches(metadata.num_patches)
|
520
|
+
|
521
|
+
for element in metadata.element_fields:
|
522
|
+
_ti_core.set_num_elements(instance.mesh_ptr, element, metadata.num_elements[element])
|
523
|
+
instance.set_patch_max_element_num(element, metadata.max_num_per_patch[element])
|
524
|
+
|
525
|
+
element_name = element_type_name(element)
|
526
|
+
setattr(
|
527
|
+
instance,
|
528
|
+
element_name,
|
529
|
+
MeshElementField(instance, element, {}, {}, metadata.element_fields[element]["g2r"]),
|
530
|
+
)
|
531
|
+
instance.fields[element] = getattr(instance, element_name)
|
532
|
+
|
533
|
+
instance.set_owned_offset(element, metadata.element_fields[element]["owned"])
|
534
|
+
instance.set_total_offset(element, metadata.element_fields[element]["total"])
|
535
|
+
instance.set_index_mapping(element, ConvType.l2g, metadata.element_fields[element]["l2g"])
|
536
|
+
instance.set_index_mapping(element, ConvType.l2r, metadata.element_fields[element]["l2r"])
|
537
|
+
instance.set_index_mapping(element, ConvType.g2r, metadata.element_fields[element]["g2r"])
|
538
|
+
|
539
|
+
for rel_type in metadata.relation_fields:
|
540
|
+
from_order = metadata.relation_fields[rel_type]["from_order"]
|
541
|
+
to_order = metadata.relation_fields[rel_type]["to_order"]
|
542
|
+
if from_order <= to_order:
|
543
|
+
instance.set_relation_dynamic(
|
544
|
+
rel_type,
|
545
|
+
metadata.relation_fields[rel_type]["value"],
|
546
|
+
metadata.relation_fields[rel_type]["patch_offset"],
|
547
|
+
metadata.relation_fields[rel_type]["offset"],
|
548
|
+
)
|
549
|
+
else:
|
550
|
+
instance.set_relation_fixed(rel_type, metadata.relation_fields[rel_type]["value"])
|
551
|
+
|
552
|
+
instance._vert_position = metadata.attrs["x"]
|
553
|
+
instance.patcher = metadata.patcher
|
554
|
+
|
555
|
+
return instance
|
556
|
+
|
557
|
+
@staticmethod
|
558
|
+
def load_meta(filename):
|
559
|
+
with open(filename, "r") as fi:
|
560
|
+
data = json.loads(fi.read())
|
561
|
+
return MeshMetadata(data)
|
562
|
+
|
563
|
+
@staticmethod
|
564
|
+
def generate_meta(data):
|
565
|
+
return MeshMetadata(data)
|
566
|
+
|
567
|
+
|
568
|
+
def _TriMesh():
|
569
|
+
"""(Deprecated) Create a triangle mesh (a set of vert/edge/face elements, attributes, and connectivity) builder.
|
570
|
+
|
571
|
+
Returns:
|
572
|
+
An instance of mesh builder.
|
573
|
+
"""
|
574
|
+
return MeshBuilder()
|
575
|
+
|
576
|
+
|
577
|
+
def _TetMesh():
|
578
|
+
"""(Deprecated) Create a tetrahedron mesh (a set of vert/edge/face/cell elements, attributes, and connectivity) builder.
|
579
|
+
|
580
|
+
Returns:
|
581
|
+
An instance of mesh builder.
|
582
|
+
"""
|
583
|
+
return MeshBuilder()
|
584
|
+
|
585
|
+
|
586
|
+
class MeshElementFieldProxy:
|
587
|
+
def __init__(self, mesh: MeshInstance, element_type: MeshElementType, entry_expr: impl.Expr):
|
588
|
+
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
|
589
|
+
|
590
|
+
self.mesh = mesh
|
591
|
+
self.element_type = element_type
|
592
|
+
self.entry_expr = entry_expr
|
593
|
+
|
594
|
+
element_field = self.mesh.fields[self.element_type]
|
595
|
+
for key, attr in element_field.field_dict.items():
|
596
|
+
global_entry_expr = impl.Expr(
|
597
|
+
ast_builder.mesh_index_conversion(
|
598
|
+
self.mesh.mesh_ptr,
|
599
|
+
element_type,
|
600
|
+
entry_expr,
|
601
|
+
ConvType.l2r if element_field.attr_dict[key].reorder else ConvType.l2g,
|
602
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
603
|
+
)
|
604
|
+
) # transform index space
|
605
|
+
global_entry_expr_group = impl.make_expr_group(*tuple([global_entry_expr]))
|
606
|
+
if isinstance(attr, MatrixField):
|
607
|
+
setattr(
|
608
|
+
self,
|
609
|
+
key,
|
610
|
+
impl.Expr(
|
611
|
+
ast_builder.expr_subscript(
|
612
|
+
attr.ptr,
|
613
|
+
global_entry_expr_group,
|
614
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
615
|
+
)
|
616
|
+
),
|
617
|
+
)
|
618
|
+
elif isinstance(attr, StructField):
|
619
|
+
raise RuntimeError("MeshGsTaichi has not support StructField yet")
|
620
|
+
else: # isinstance(attr, Field)
|
621
|
+
var = attr._get_field_members()[0].ptr
|
622
|
+
setattr(
|
623
|
+
self,
|
624
|
+
key,
|
625
|
+
impl.Expr(
|
626
|
+
ast_builder.expr_subscript(
|
627
|
+
var,
|
628
|
+
global_entry_expr_group,
|
629
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
630
|
+
)
|
631
|
+
),
|
632
|
+
)
|
633
|
+
|
634
|
+
for element_type in {
|
635
|
+
MeshElementType.Vertex,
|
636
|
+
MeshElementType.Edge,
|
637
|
+
MeshElementType.Face,
|
638
|
+
MeshElementType.Cell,
|
639
|
+
}:
|
640
|
+
setattr(
|
641
|
+
self,
|
642
|
+
element_type_name(element_type),
|
643
|
+
impl.mesh_relation_access(self.mesh, self, element_type),
|
644
|
+
)
|
645
|
+
|
646
|
+
@property
|
647
|
+
def ptr(self):
|
648
|
+
return self.entry_expr
|
649
|
+
|
650
|
+
@property
|
651
|
+
def id(self): # return the global non-reordered index
|
652
|
+
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
|
653
|
+
l2g_expr = impl.Expr(
|
654
|
+
ast_builder.mesh_index_conversion(
|
655
|
+
self.mesh.mesh_ptr,
|
656
|
+
self.element_type,
|
657
|
+
self.entry_expr,
|
658
|
+
ConvType.l2g,
|
659
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
660
|
+
)
|
661
|
+
)
|
662
|
+
return l2g_expr
|
663
|
+
|
664
|
+
|
665
|
+
class MeshRelationAccessProxy:
|
666
|
+
def __init__(
|
667
|
+
self,
|
668
|
+
mesh: MeshInstance,
|
669
|
+
from_index: impl.Expr,
|
670
|
+
to_element_type: MeshElementType,
|
671
|
+
):
|
672
|
+
self.mesh = mesh
|
673
|
+
self.from_index = from_index
|
674
|
+
self.to_element_type = to_element_type
|
675
|
+
|
676
|
+
@property
|
677
|
+
def size(self):
|
678
|
+
return impl.Expr(self.mesh.get_relation_size(self.from_index, self.to_element_type))
|
679
|
+
|
680
|
+
def subscript(self, *indices):
|
681
|
+
assert len(indices) == 1
|
682
|
+
entry_expr = self.mesh.get_relation_access(self.from_index, self.to_element_type, impl.Expr(indices[0]).ptr)
|
683
|
+
entry_expr.type_check(impl.get_runtime().prog.config())
|
684
|
+
return MeshElementFieldProxy(self.mesh, self.to_element_type, entry_expr)
|
685
|
+
|
686
|
+
|
687
|
+
__all__ = ["Mesh", "MeshInstance"]
|