gstaichi 0.0.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (154) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +51 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +5 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cp313-win_amd64.pyd +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2917 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  11. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  12. gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  13. gstaichi/_lib/utils.py +243 -0
  14. gstaichi/_logging.py +131 -0
  15. gstaichi/_snode/__init__.py +5 -0
  16. gstaichi/_snode/fields_builder.py +187 -0
  17. gstaichi/_snode/snode_tree.py +34 -0
  18. gstaichi/_test_tools/__init__.py +18 -0
  19. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  20. gstaichi/_test_tools/load_kernel_string.py +30 -0
  21. gstaichi/_test_tools/textwrap2.py +6 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +122 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +83 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +366 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +7 -0
  52. gstaichi/lang/ast/ast_transformer.py +1351 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +327 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1259 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1386 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +784 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +10 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +21 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  113. gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  114. gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  115. gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  116. gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  117. gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  118. gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  119. gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  120. gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  121. gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  122. gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  123. gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  124. gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  125. gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  126. gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  127. gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  128. gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  129. gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  130. gstaichi-0.0.0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  131. gstaichi-0.0.0.data/data/include/GLFW/glfw3.h +6389 -0
  132. gstaichi-0.0.0.data/data/include/GLFW/glfw3native.h +594 -0
  133. gstaichi-0.0.0.data/data/include/spirv-tools/instrument.hpp +268 -0
  134. gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.h +907 -0
  135. gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  136. gstaichi-0.0.0.data/data/include/spirv-tools/linker.hpp +97 -0
  137. gstaichi-0.0.0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  138. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  139. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-link.lib +0 -0
  140. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  141. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  142. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  143. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  144. gstaichi-0.0.0.data/data/lib/SPIRV-Tools.lib +0 -0
  145. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  146. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  147. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  148. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  149. gstaichi-0.0.0.data/data/lib/glfw3.lib +0 -0
  150. gstaichi-0.0.0.dist-info/METADATA +97 -0
  151. gstaichi-0.0.0.dist-info/RECORD +154 -0
  152. gstaichi-0.0.0.dist-info/WHEEL +5 -0
  153. gstaichi-0.0.0.dist-info/licenses/LICENSE +201 -0
  154. gstaichi-0.0.0.dist-info/top_level.txt +1 -0
gstaichi/lang/mesh.py ADDED
@@ -0,0 +1,687 @@
1
+ # type: ignore
2
+
3
+ import json
4
+
5
+ import numpy as np
6
+
7
+ from gstaichi import lang
8
+ from gstaichi._lib import core as _ti_core
9
+ from gstaichi.lang import impl
10
+ from gstaichi.lang.exception import GsTaichiSyntaxError
11
+ from gstaichi.lang.field import Field, ScalarField
12
+ from gstaichi.lang.matrix import Matrix, MatrixField
13
+ from gstaichi.lang.struct import StructField
14
+ from gstaichi.lang.util import python_scope
15
+ from gstaichi.types import u16, u32
16
+ from gstaichi.types.compound_types import CompoundType
17
+ from gstaichi.types.enums import Layout
18
+
19
+ MeshTopology = _ti_core.MeshTopology
20
+ MeshElementType = _ti_core.MeshElementType
21
+ MeshRelationType = _ti_core.MeshRelationType
22
+ ConvType = _ti_core.ConvType
23
+ element_order = _ti_core.element_order
24
+ from_end_element_order = _ti_core.from_end_element_order
25
+ to_end_element_order = _ti_core.to_end_element_order
26
+ relation_by_orders = _ti_core.relation_by_orders
27
+ inverse_relation = _ti_core.inverse_relation
28
+ element_type_name = _ti_core.element_type_name
29
+
30
+
31
+ class MeshAttrType:
32
+ def __init__(self, name, dtype, reorder, needs_grad):
33
+ self.name = name
34
+ self.dtype = dtype
35
+ self.reorder = reorder
36
+ self.needs_grad = needs_grad
37
+
38
+
39
+ class MeshReorderedScalarFieldProxy(ScalarField):
40
+ def __init__(
41
+ self,
42
+ field: ScalarField,
43
+ mesh_ptr: _ti_core.MeshPtr,
44
+ element_type: MeshElementType,
45
+ g2r_field: ScalarField,
46
+ ):
47
+ self.vars = field.vars
48
+ self.host_accessors = field.host_accessors
49
+ self.grad = field.grad
50
+
51
+ self.mesh_ptr = mesh_ptr
52
+ self.element_type = element_type
53
+ self.g2r_field = g2r_field
54
+
55
+ @python_scope
56
+ def __setitem__(self, key, value):
57
+ self._initialize_host_accessors()
58
+ key = self.g2r_field[key]
59
+ self.host_accessors[0].setter(value, *self._pad_key(key))
60
+
61
+ @python_scope
62
+ def __getitem__(self, key):
63
+ self._initialize_host_accessors()
64
+ key = self.g2r_field[key]
65
+ return self.host_accessors[0].getter(*self._pad_key(key))
66
+
67
+
68
+ class MeshReorderedMatrixFieldProxy(MatrixField):
69
+ def __init__(
70
+ self,
71
+ field: MatrixField,
72
+ mesh_ptr: _ti_core.MeshPtr,
73
+ element_type: MeshElementType,
74
+ g2r_field: ScalarField,
75
+ ):
76
+ self.vars = field.vars
77
+ self.host_accessors = field.host_accessors
78
+ self.grad = field.grad
79
+ self.n = field.n
80
+ self.m = field.m
81
+ self.ndim = field.ndim
82
+ self.ptr = field.ptr
83
+
84
+ self.mesh_ptr = mesh_ptr
85
+ self.element_type = element_type
86
+ self.g2r_field = g2r_field
87
+
88
+ @python_scope
89
+ def __setitem__(self, key, value):
90
+ self._initialize_host_accessors()
91
+ self[key]._set_entries(value)
92
+
93
+ @python_scope
94
+ def __getitem__(self, key):
95
+ self._initialize_host_accessors()
96
+ key = self.g2r_field[key]
97
+ key = self._pad_key(key)
98
+ return Matrix(self._host_access(key))
99
+
100
+
101
+ class MeshElementField:
102
+ def __init__(self, mesh_instance, _type, attr_dict, field_dict, g2r_field):
103
+ self.mesh = mesh_instance
104
+ self._type = _type
105
+ self.attr_dict = attr_dict
106
+ self.field_dict = field_dict
107
+ self.g2r_field = g2r_field
108
+
109
+ self._register_fields()
110
+
111
+ def place(self, members, reorder=False, needs_grad=False, layout=Layout.SOA):
112
+ """Declares mesh attributes for the mesh element in current mesh.
113
+
114
+ Args:
115
+ members (Dict[str, Union[PrimitiveType, MatrixType]]): \
116
+ names and types for element attributes.
117
+ reorder: True if reorders the internal memory for coalesced data access within mesh-for loop.
118
+ needs_grad: True if needs to record grad.
119
+ layout: ti.Layout.AoS/ti.Layout.SoA
120
+
121
+ Example::
122
+ >>> import meshgstaichi_patcher as Patcher
123
+ >>> vec3 = ti.types.vector(3, ti.f32)
124
+ >>> mesh = Patcher.load_mesh("bunny.obj", relations=['FV'])
125
+ >>> mesh.faces.place({'area' : ti.f32}) # declares a mesh attribute `area` for each face element.
126
+ >>> mesh.verts.place({'pos' : vec3}, reorder=True) # declares a mesh attribute `pos` for each vertex element, and reorder it in memory.
127
+ """
128
+
129
+ for key, dtype in members.items():
130
+ if key in {"verts", "edges", "faces", "cells"}:
131
+ raise GsTaichiSyntaxError(
132
+ f"'{key}' cannot use as attribute name. It has been reserved as MeshGsTaichi's keyword."
133
+ )
134
+ if key in self.attr_dict:
135
+ raise GsTaichiSyntaxError(f"'{key}' has already use as attribute name.")
136
+
137
+ # init attr type
138
+ self.attr_dict[key] = MeshAttrType(key, dtype, reorder, needs_grad)
139
+
140
+ # init field
141
+ if isinstance(dtype, CompoundType):
142
+ self.field_dict[key] = dtype.field(shape=None, needs_grad=needs_grad)
143
+ else:
144
+ self.field_dict[key] = impl.field(dtype, shape=None, needs_grad=needs_grad)
145
+
146
+ size = _ti_core.get_num_elements(self.mesh.mesh_ptr, self._type)
147
+ if layout == Layout.SOA:
148
+ for key in members.keys():
149
+ impl.root.dense(impl.axes(0), size).place(self.field_dict[key])
150
+ if self.attr_dict[key].needs_grad:
151
+ impl.root.dense(impl.axes(0), size).place(self.field_dict[key].grad)
152
+ elif len(members) > 0:
153
+ _member_fields = {}
154
+ for key in members.keys():
155
+ _member_fields[key] = self.field_dict[key]
156
+ impl.root.dense(impl.axes(0), size).place(*tuple(_member_fields.values()))
157
+ grads = []
158
+ for key in members.keys():
159
+ if self.attr_dict[key].needs_grad:
160
+ grads.append(self.field_dict[key].grad)
161
+ if len(grads) > 0:
162
+ impl.root.dense(impl.axes(0), size).place(*grads)
163
+
164
+ for key, dtype in members.items():
165
+ # expose interface
166
+ setattr(MeshElementField, key, property(fget=MeshElementField._make_getter(key)))
167
+
168
+ @property
169
+ def keys(self):
170
+ return list(self.field_dict.keys())
171
+
172
+ @property
173
+ def _members(self):
174
+ return list(self.field_dict.values())
175
+
176
+ @property
177
+ def _items(self):
178
+ return self.field_dict.items()
179
+
180
+ @staticmethod
181
+ def _make_getter(key):
182
+ def getter(self):
183
+ if key not in self.getter_dict:
184
+ if self.attr_dict[key].reorder:
185
+ if isinstance(self.field_dict[key], ScalarField):
186
+ self.getter_dict[key] = MeshReorderedScalarFieldProxy(
187
+ self.field_dict[key],
188
+ self.mesh.mesh_ptr,
189
+ self._type,
190
+ self.g2r_field,
191
+ )
192
+ elif isinstance(self.field_dict[key], MatrixField):
193
+ self.getter_dict[key] = MeshReorderedMatrixFieldProxy(
194
+ self.field_dict[key],
195
+ self.mesh.mesh_ptr,
196
+ self._type,
197
+ self.g2r_field,
198
+ )
199
+ else:
200
+ self.getter_dict[key] = self.field_dict[key]
201
+ """Get an entry from custom struct by name."""
202
+ return self.getter_dict[key]
203
+
204
+ return getter
205
+
206
+ def _register_fields(self):
207
+ self.getter_dict = {}
208
+ for k in self.keys:
209
+ setattr(MeshElementField, k, property(fget=MeshElementField._make_getter(k)))
210
+
211
+ def _get_field_members(self):
212
+ field_members = []
213
+ for m in self._members:
214
+ assert isinstance(m, Field)
215
+ field_members += m._get_field_members()
216
+ return field_members
217
+
218
+ def _initialize_host_accessors(self):
219
+ for v in self._members:
220
+ v._initialize_host_accessors()
221
+
222
+ def get_member_field(self, key):
223
+ return self.field_dict[key]
224
+
225
+ @python_scope
226
+ def __len__(self):
227
+ return _ti_core.get_num_elements(self.mesh.mesh_ptr, self._type)
228
+
229
+
230
+ class MeshElement:
231
+ def __init__(self, _type, builder):
232
+ self.builder = builder
233
+ self._type = _type
234
+ self.layout = Layout.SOA
235
+ self.attr_dict = {}
236
+
237
+ def _SOA(self, soa=True): # AOS/SOA
238
+ self.layout = Layout.SOA if soa else Layout.AOS
239
+
240
+ def _AOS(self, aos=True):
241
+ self.layout = Layout.AOS if aos else Layout.SOA
242
+
243
+ SOA = property(fset=_SOA)
244
+ """(Deprecated) Set `True` for SOA (structure of arrays) layout.
245
+ """
246
+ AOS = property(fset=_AOS)
247
+ """(Deprecated) Set `True` for AOS (array of structures) layout.
248
+ """
249
+
250
+ def place(
251
+ self,
252
+ members,
253
+ reorder=False,
254
+ needs_grad=False,
255
+ ):
256
+ """(Deprecated) Declares mesh attributes for the mesh element in current mesh builder.
257
+
258
+ Args:
259
+ members (Dict[str, Union[PrimitiveType, MatrixType]]): \
260
+ names and types for element attributes.
261
+ reorder: True if reorders the internal memory for coalesced data access within mesh-for loop.
262
+ needs_grad: True if needs to record grad.
263
+
264
+ Example::
265
+ >>> vec3 = ti.types.vector(3, ti.f32)
266
+ >>> mesh = ti.TriMesh()
267
+ >>> mesh.faces.place({'area' : ti.f32}) # declares a mesh attribute `area` for each face element.
268
+ >>> mesh.verts.place({'pos' : vec3}, reorder=True) # declares a mesh attribute `pos` for each vertex element, and reorder it in memory.
269
+ """
270
+ for key, dtype in members.items():
271
+ if key in {"verts", "edges", "faces", "cells"}:
272
+ raise GsTaichiSyntaxError(
273
+ f"'{key}' cannot use as attribute name. It has been reserved as MeshGsTaichi's keyword."
274
+ )
275
+ self.attr_dict[key] = MeshAttrType(key, dtype, reorder, needs_grad)
276
+
277
+ def build(self, mesh_instance, size, g2r_field):
278
+ field_dict = {}
279
+
280
+ for key, attr in self.attr_dict.items():
281
+ if isinstance(attr.dtype, CompoundType):
282
+ field_dict[key] = attr.dtype.field(shape=None, needs_grad=attr.needs_grad)
283
+ else:
284
+ field_dict[key] = impl.field(attr.dtype, shape=None, needs_grad=attr.needs_grad)
285
+
286
+ if self.layout == Layout.SOA:
287
+ for key, field in field_dict.items():
288
+ impl.root.dense(impl.axes(0), size).place(field)
289
+ if self.attr_dict[key].needs_grad:
290
+ impl.root.dense(impl.axes(0), size).place(field.grad)
291
+ elif len(field_dict) > 0:
292
+ impl.root.dense(impl.axes(0), size).place(*tuple(field_dict.values()))
293
+ grads = []
294
+ for key, field in field_dict.items():
295
+ if self.attr_dict[key].needs_grad:
296
+ grads.append(field.grad)
297
+ if len(grads) > 0:
298
+ impl.root.dense(impl.axes(0), size).place(*grads)
299
+
300
+ return MeshElementField(mesh_instance, self._type, self.attr_dict, field_dict, g2r_field)
301
+
302
+
303
+ # Define the instance of the Mesh Type, stores the field (type and data) info
304
+ class MeshInstance:
305
+ def __init__(self):
306
+ self.mesh_ptr = _ti_core.create_mesh()
307
+ self.relation_set = set()
308
+ self.verts = MeshElementField(self, MeshElementType.Vertex, {}, {}, {})
309
+ self.edges = MeshElementField(self, MeshElementType.Edge, {}, {}, {})
310
+ self.faces = MeshElementField(self, MeshElementType.Face, {}, {}, {})
311
+ self.cells = MeshElementField(self, MeshElementType.Cell, {}, {}, {})
312
+
313
+ def get_position_as_numpy(self):
314
+ """Get the vertex position of current mesh to numpy array.
315
+
316
+ Returns:
317
+ 3d numpy array: [x, y, z] with float-format.
318
+ """
319
+ if hasattr(self, "_vert_position"):
320
+ return self._vert_position
321
+ raise GsTaichiSyntaxError("Position info is not in the file.")
322
+
323
+ def set_owned_offset(self, element_type: MeshElementType, owned_offset: ScalarField):
324
+ _ti_core.set_owned_offset(self.mesh_ptr, element_type, owned_offset.vars[0].ptr.snode())
325
+
326
+ def set_total_offset(self, element_type: MeshElementType, total_offset: ScalarField):
327
+ _ti_core.set_total_offset(self.mesh_ptr, element_type, total_offset.vars[0].ptr.snode())
328
+
329
+ def set_index_mapping(self, element_type: MeshElementType, conv_type: ConvType, mapping: ScalarField):
330
+ _ti_core.set_index_mapping(self.mesh_ptr, element_type, conv_type, mapping.vars[0].ptr.snode())
331
+
332
+ def set_num_patches(self, num_patches: int):
333
+ _ti_core.set_num_patches(self.mesh_ptr, num_patches)
334
+
335
+ def set_patch_max_element_num(self, element_type: MeshElementType, max_element_num: int):
336
+ _ti_core.set_patch_max_element_num(self.mesh_ptr, element_type, max_element_num)
337
+
338
+ def set_relation_fixed(self, rel_type: MeshRelationType, value: ScalarField):
339
+ self.relation_set.add(rel_type)
340
+ _ti_core.set_relation_fixed(self.mesh_ptr, rel_type, value.vars[0].ptr.snode())
341
+
342
+ def set_relation_dynamic(
343
+ self,
344
+ rel_type: MeshRelationType,
345
+ value: ScalarField,
346
+ patch_offset: ScalarField,
347
+ offset: ScalarField,
348
+ ):
349
+ self.relation_set.add(rel_type)
350
+ _ti_core.set_relation_dynamic(
351
+ self.mesh_ptr,
352
+ rel_type,
353
+ value.vars[0].ptr.snode(),
354
+ patch_offset.vars[0].ptr.snode(),
355
+ offset.vars[0].ptr.snode(),
356
+ )
357
+
358
+ def add_mesh_attribute(self, element_type, snode, reorder_type):
359
+ _ti_core.add_mesh_attribute(self.mesh_ptr, element_type, snode, reorder_type)
360
+
361
+ def get_relation_size(self, from_index, to_element_type):
362
+ return _ti_core.get_relation_size(
363
+ self.mesh_ptr,
364
+ from_index.ptr,
365
+ to_element_type,
366
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
367
+ )
368
+
369
+ def get_relation_access(self, from_index, to_element_type, neighbor_idx_ptr):
370
+ return _ti_core.get_relation_access(
371
+ self.mesh_ptr,
372
+ from_index.ptr,
373
+ to_element_type,
374
+ neighbor_idx_ptr,
375
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
376
+ )
377
+
378
+
379
+ class MeshMetadata:
380
+ def __init__(self, data):
381
+ self.num_patches = data["num_patches"]
382
+
383
+ self.element_fields = {}
384
+ self.relation_fields = {}
385
+ self.num_elements = {}
386
+ self.max_num_per_patch = {}
387
+
388
+ for element in data["elements"]:
389
+ element_type = MeshElementType(element["order"])
390
+ self.num_elements[element_type] = element["num"]
391
+ self.max_num_per_patch[element_type] = element["max_num_per_patch"]
392
+
393
+ element["l2g_mapping"] = np.array(element["l2g_mapping"])
394
+ element["l2r_mapping"] = np.array(element["l2r_mapping"])
395
+ element["g2r_mapping"] = np.array(element["g2r_mapping"])
396
+ self.element_fields[element_type] = {}
397
+ self.element_fields[element_type]["owned"] = impl.field(dtype=u32, shape=self.num_patches + 1)
398
+ self.element_fields[element_type]["total"] = impl.field(dtype=u32, shape=self.num_patches + 1)
399
+ self.element_fields[element_type]["l2g"] = impl.field(dtype=u32, shape=element["l2g_mapping"].shape[0])
400
+ self.element_fields[element_type]["l2r"] = impl.field(dtype=u32, shape=element["l2r_mapping"].shape[0])
401
+ self.element_fields[element_type]["g2r"] = impl.field(dtype=u32, shape=element["g2r_mapping"].shape[0])
402
+
403
+ for relation in data["relations"]:
404
+ from_order = relation["from_order"]
405
+ to_order = relation["to_order"]
406
+ rel_type = MeshRelationType(relation_by_orders(from_order, to_order))
407
+ self.relation_fields[rel_type] = {}
408
+ self.relation_fields[rel_type]["value"] = impl.field(dtype=u16, shape=len(relation["value"]))
409
+ if from_order <= to_order:
410
+ self.relation_fields[rel_type]["offset"] = impl.field(dtype=u16, shape=len(relation["offset"]))
411
+ self.relation_fields[rel_type]["patch_offset"] = impl.field(
412
+ dtype=u32, shape=len(relation["patch_offset"])
413
+ )
414
+ self.relation_fields[rel_type]["from_order"] = from_order
415
+ self.relation_fields[rel_type]["to_order"] = to_order
416
+
417
+ for element in data["elements"]:
418
+ element_type = MeshElementType(element["order"])
419
+ self.element_fields[element_type]["owned"].from_numpy(np.array(element["owned_offsets"]))
420
+ self.element_fields[element_type]["total"].from_numpy(np.array(element["total_offsets"]))
421
+ self.element_fields[element_type]["l2g"].from_numpy(element["l2g_mapping"])
422
+ self.element_fields[element_type]["l2r"].from_numpy(element["l2r_mapping"])
423
+ self.element_fields[element_type]["g2r"].from_numpy(element["g2r_mapping"])
424
+
425
+ for relation in data["relations"]:
426
+ from_order = relation["from_order"]
427
+ to_order = relation["to_order"]
428
+ rel_type = MeshRelationType(relation_by_orders(from_order, to_order))
429
+ self.relation_fields[rel_type]["value"].from_numpy(np.array(relation["value"]))
430
+ if from_order <= to_order:
431
+ self.relation_fields[rel_type]["patch_offset"].from_numpy(np.array(relation["patch_offset"]))
432
+ self.relation_fields[rel_type]["offset"].from_numpy(np.array(relation["offset"]))
433
+
434
+ self.attrs = {}
435
+ self.attrs["x"] = np.array(data["attrs"]["x"]).reshape(-1, 3)
436
+ if "patcher" in data:
437
+ self.patcher = data["patcher"]
438
+ else:
439
+ self.patcher = None
440
+
441
+
442
+ # Define the Mesh Type, stores the field type info
443
+ class MeshBuilder:
444
+ def __init__(self):
445
+ if not lang.misc.is_extension_supported(impl.current_cfg().arch, lang.extension.mesh):
446
+ raise Exception("Backend " + str(impl.current_cfg().arch) + " doesn't support MeshGsTaichi extension")
447
+
448
+ self.verts = MeshElement(MeshElementType.Vertex, self)
449
+ self.edges = MeshElement(MeshElementType.Edge, self)
450
+ self.faces = MeshElement(MeshElementType.Face, self)
451
+ self.cells = MeshElement(MeshElementType.Cell, self)
452
+
453
+ def build(self, metadata: MeshMetadata) -> MeshInstance:
454
+ instance = MeshInstance()
455
+ instance.fields = {}
456
+
457
+ instance.set_num_patches(metadata.num_patches)
458
+
459
+ for element in metadata.element_fields:
460
+ _ti_core.set_num_elements(instance.mesh_ptr, element, metadata.num_elements[element])
461
+ instance.set_patch_max_element_num(element, metadata.max_num_per_patch[element])
462
+
463
+ element_name = element_type_name(element)
464
+ setattr(
465
+ instance,
466
+ element_name,
467
+ getattr(self, element_name).build(
468
+ instance,
469
+ metadata.num_elements[element],
470
+ metadata.element_fields[element]["g2r"],
471
+ ),
472
+ )
473
+ instance.fields[element] = getattr(instance, element_name)
474
+
475
+ instance.set_owned_offset(element, metadata.element_fields[element]["owned"])
476
+ instance.set_total_offset(element, metadata.element_fields[element]["total"])
477
+ instance.set_index_mapping(element, ConvType.l2g, metadata.element_fields[element]["l2g"])
478
+ instance.set_index_mapping(element, ConvType.l2r, metadata.element_fields[element]["l2r"])
479
+ instance.set_index_mapping(element, ConvType.g2r, metadata.element_fields[element]["g2r"])
480
+
481
+ for rel_type in metadata.relation_fields:
482
+ from_order = metadata.relation_fields[rel_type]["from_order"]
483
+ to_order = metadata.relation_fields[rel_type]["to_order"]
484
+ if from_order <= to_order:
485
+ instance.set_relation_dynamic(
486
+ rel_type,
487
+ metadata.relation_fields[rel_type]["value"],
488
+ metadata.relation_fields[rel_type]["patch_offset"],
489
+ metadata.relation_fields[rel_type]["offset"],
490
+ )
491
+ else:
492
+ instance.set_relation_fixed(rel_type, metadata.relation_fields[rel_type]["value"])
493
+
494
+ instance._vert_position = metadata.attrs["x"]
495
+ instance.patcher = metadata.patcher
496
+
497
+ return instance
498
+
499
+
500
+ # Mesh First Class
501
+ class Mesh:
502
+ """The Mesh type class.
503
+
504
+ MeshGsTaichi offers first-class support for triangular/tetrahedral meshes
505
+ and allows efficient computation on these irregular data structures,
506
+ only available for backends supporting `ti.extension.mesh`.
507
+
508
+ See more details in https://github.com/taichi-dev/meshgstaichi
509
+ """
510
+
511
+ def __init__(self):
512
+ pass
513
+
514
+ @staticmethod
515
+ def _create_instance(metadata: MeshMetadata) -> MeshInstance:
516
+ instance = MeshInstance()
517
+ instance.fields = {}
518
+
519
+ instance.set_num_patches(metadata.num_patches)
520
+
521
+ for element in metadata.element_fields:
522
+ _ti_core.set_num_elements(instance.mesh_ptr, element, metadata.num_elements[element])
523
+ instance.set_patch_max_element_num(element, metadata.max_num_per_patch[element])
524
+
525
+ element_name = element_type_name(element)
526
+ setattr(
527
+ instance,
528
+ element_name,
529
+ MeshElementField(instance, element, {}, {}, metadata.element_fields[element]["g2r"]),
530
+ )
531
+ instance.fields[element] = getattr(instance, element_name)
532
+
533
+ instance.set_owned_offset(element, metadata.element_fields[element]["owned"])
534
+ instance.set_total_offset(element, metadata.element_fields[element]["total"])
535
+ instance.set_index_mapping(element, ConvType.l2g, metadata.element_fields[element]["l2g"])
536
+ instance.set_index_mapping(element, ConvType.l2r, metadata.element_fields[element]["l2r"])
537
+ instance.set_index_mapping(element, ConvType.g2r, metadata.element_fields[element]["g2r"])
538
+
539
+ for rel_type in metadata.relation_fields:
540
+ from_order = metadata.relation_fields[rel_type]["from_order"]
541
+ to_order = metadata.relation_fields[rel_type]["to_order"]
542
+ if from_order <= to_order:
543
+ instance.set_relation_dynamic(
544
+ rel_type,
545
+ metadata.relation_fields[rel_type]["value"],
546
+ metadata.relation_fields[rel_type]["patch_offset"],
547
+ metadata.relation_fields[rel_type]["offset"],
548
+ )
549
+ else:
550
+ instance.set_relation_fixed(rel_type, metadata.relation_fields[rel_type]["value"])
551
+
552
+ instance._vert_position = metadata.attrs["x"]
553
+ instance.patcher = metadata.patcher
554
+
555
+ return instance
556
+
557
+ @staticmethod
558
+ def load_meta(filename):
559
+ with open(filename, "r") as fi:
560
+ data = json.loads(fi.read())
561
+ return MeshMetadata(data)
562
+
563
+ @staticmethod
564
+ def generate_meta(data):
565
+ return MeshMetadata(data)
566
+
567
+
568
+ def _TriMesh():
569
+ """(Deprecated) Create a triangle mesh (a set of vert/edge/face elements, attributes, and connectivity) builder.
570
+
571
+ Returns:
572
+ An instance of mesh builder.
573
+ """
574
+ return MeshBuilder()
575
+
576
+
577
+ def _TetMesh():
578
+ """(Deprecated) Create a tetrahedron mesh (a set of vert/edge/face/cell elements, attributes, and connectivity) builder.
579
+
580
+ Returns:
581
+ An instance of mesh builder.
582
+ """
583
+ return MeshBuilder()
584
+
585
+
586
+ class MeshElementFieldProxy:
587
+ def __init__(self, mesh: MeshInstance, element_type: MeshElementType, entry_expr: impl.Expr):
588
+ ast_builder = impl.get_runtime().compiling_callable.ast_builder()
589
+
590
+ self.mesh = mesh
591
+ self.element_type = element_type
592
+ self.entry_expr = entry_expr
593
+
594
+ element_field = self.mesh.fields[self.element_type]
595
+ for key, attr in element_field.field_dict.items():
596
+ global_entry_expr = impl.Expr(
597
+ ast_builder.mesh_index_conversion(
598
+ self.mesh.mesh_ptr,
599
+ element_type,
600
+ entry_expr,
601
+ ConvType.l2r if element_field.attr_dict[key].reorder else ConvType.l2g,
602
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
603
+ )
604
+ ) # transform index space
605
+ global_entry_expr_group = impl.make_expr_group(*tuple([global_entry_expr]))
606
+ if isinstance(attr, MatrixField):
607
+ setattr(
608
+ self,
609
+ key,
610
+ impl.Expr(
611
+ ast_builder.expr_subscript(
612
+ attr.ptr,
613
+ global_entry_expr_group,
614
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
615
+ )
616
+ ),
617
+ )
618
+ elif isinstance(attr, StructField):
619
+ raise RuntimeError("MeshGsTaichi has not support StructField yet")
620
+ else: # isinstance(attr, Field)
621
+ var = attr._get_field_members()[0].ptr
622
+ setattr(
623
+ self,
624
+ key,
625
+ impl.Expr(
626
+ ast_builder.expr_subscript(
627
+ var,
628
+ global_entry_expr_group,
629
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
630
+ )
631
+ ),
632
+ )
633
+
634
+ for element_type in {
635
+ MeshElementType.Vertex,
636
+ MeshElementType.Edge,
637
+ MeshElementType.Face,
638
+ MeshElementType.Cell,
639
+ }:
640
+ setattr(
641
+ self,
642
+ element_type_name(element_type),
643
+ impl.mesh_relation_access(self.mesh, self, element_type),
644
+ )
645
+
646
+ @property
647
+ def ptr(self):
648
+ return self.entry_expr
649
+
650
+ @property
651
+ def id(self): # return the global non-reordered index
652
+ ast_builder = impl.get_runtime().compiling_callable.ast_builder()
653
+ l2g_expr = impl.Expr(
654
+ ast_builder.mesh_index_conversion(
655
+ self.mesh.mesh_ptr,
656
+ self.element_type,
657
+ self.entry_expr,
658
+ ConvType.l2g,
659
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
660
+ )
661
+ )
662
+ return l2g_expr
663
+
664
+
665
+ class MeshRelationAccessProxy:
666
+ def __init__(
667
+ self,
668
+ mesh: MeshInstance,
669
+ from_index: impl.Expr,
670
+ to_element_type: MeshElementType,
671
+ ):
672
+ self.mesh = mesh
673
+ self.from_index = from_index
674
+ self.to_element_type = to_element_type
675
+
676
+ @property
677
+ def size(self):
678
+ return impl.Expr(self.mesh.get_relation_size(self.from_index, self.to_element_type))
679
+
680
+ def subscript(self, *indices):
681
+ assert len(indices) == 1
682
+ entry_expr = self.mesh.get_relation_access(self.from_index, self.to_element_type, impl.Expr(indices[0]).ptr)
683
+ entry_expr.type_check(impl.get_runtime().prog.config())
684
+ return MeshElementFieldProxy(self.mesh, self.to_element_type, entry_expr)
685
+
686
+
687
+ __all__ = ["Mesh", "MeshInstance"]