gstaichi 2.1.1rc3__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,810 @@
1
+ # type: ignore
2
+
3
+ import numbers
4
+ from types import MethodType
5
+
6
+ import numpy as np
7
+
8
+ from gstaichi._lib import core as _ti_core
9
+ from gstaichi.lang import expr, impl, ops
10
+ from gstaichi.lang.exception import (
11
+ GsTaichiRuntimeTypeError,
12
+ GsTaichiSyntaxError,
13
+ GsTaichiTypeError,
14
+ )
15
+ from gstaichi.lang.expr import Expr
16
+ from gstaichi.lang.field import Field, ScalarField, SNodeHostAccess
17
+ from gstaichi.lang.matrix import Matrix, MatrixType
18
+ from gstaichi.lang.util import cook_dtype, gstaichi_scope, in_python_scope, python_scope
19
+ from gstaichi.types import primitive_types
20
+ from gstaichi.types.compound_types import CompoundType
21
+ from gstaichi.types.enums import Layout
22
+ from gstaichi.types.utils import is_signed
23
+
24
+
25
+ class Struct:
26
+ """The Struct type class.
27
+
28
+ A struct is a dictionary-like data structure that stores members as
29
+ (key, value) pairs. Valid data members of a struct can be scalars,
30
+ matrices or other dictionary-like structures.
31
+
32
+ Args:
33
+ entries (Dict[str, Union[Dict, Expr, Matrix, Struct]]): \
34
+ keys and values for struct members. Entries can optionally
35
+ include a dictionary of functions with the key '__struct_methods'
36
+ which will be attached to the struct for executing on the struct data.
37
+
38
+ Returns:
39
+ An instance of this struct.
40
+
41
+ Example::
42
+ _
43
+ >>> vec3 = ti.types.vector(3, ti.f32)
44
+ >>> a = ti.Struct(v=vec3([0, 0, 0]), t=1.0)
45
+ >>> print(a.items)
46
+ dict_items([('v', [0. 0. 0.]), ('t', 1.0)])
47
+ >>>
48
+ >>> B = ti.Struct(v=vec3([0., 0., 0.]), t=1.0, A=a)
49
+ >>> print(B.items)
50
+ dict_items([('v', [0. 0. 0.]), ('t', 1.0), ('A', {'v': [[0.], [0.], [0.]], 't': 1.0})])
51
+ """
52
+
53
+ _is_gstaichi_class = True
54
+ _instance_count = 0
55
+
56
+ def __init__(self, *args, **kwargs):
57
+ # converts lists to matrices and dicts to structs
58
+ if len(args) == 1 and kwargs == {} and isinstance(args[0], dict):
59
+ self.__entries = args[0]
60
+ elif len(args) == 0:
61
+ self.__entries = kwargs
62
+ else:
63
+ raise GsTaichiSyntaxError(
64
+ "Custom structs need to be initialized using either dictionary or keyword arguments"
65
+ )
66
+ self.__methods = self.__entries.pop("__struct_methods", {})
67
+ matrix_ndim = self.__entries.pop("__matrix_ndim", {})
68
+ self._register_methods()
69
+
70
+ for k, v in self.__entries.items():
71
+ if isinstance(v, (list, tuple)):
72
+ v = Matrix(v)
73
+ if isinstance(v, dict):
74
+ v = Struct(v)
75
+ self.__entries[k] = v if in_python_scope() else impl.expr_init(v)
76
+ self._register_members()
77
+ self.__dtype = None
78
+
79
+ @property
80
+ def keys(self):
81
+ """Returns the list of member names in string format.
82
+
83
+ Example::
84
+
85
+ >>> vec3 = ti.types.vector(3, ti.f32)
86
+ >>> sphere = ti.Struct(center=vec3([0, 0, 0]), radius=1.0)
87
+ >>> a.keys
88
+ ['center', 'radius']
89
+ """
90
+ return list(self.__entries.keys())
91
+
92
+ @property
93
+ def _members(self):
94
+ return list(self.__entries.values())
95
+
96
+ @property
97
+ def entries(self):
98
+ return self.__entries
99
+
100
+ @property
101
+ def methods(self):
102
+ return self.__methods
103
+
104
+ @property
105
+ def items(self):
106
+ """Returns the items in this struct.
107
+
108
+ Example::
109
+
110
+ >>> vec3 = ti.types.vector(3, ti.f32)
111
+ >>> sphere = ti.Struct(center=vec3([0, 0, 0]), radius=1.0)
112
+ >>> sphere.items
113
+ dict_items([('center', 2), ('radius', 1.0)])
114
+ """
115
+ return self.__entries.items()
116
+
117
+ def _register_members(self):
118
+ # https://stackoverflow.com/questions/48448074/adding-a-property-to-an-existing-object-instance
119
+ cls = self.__class__
120
+ new_cls_name = cls.__name__ + str(cls._instance_count)
121
+ cls._instance_count += 1
122
+ properties = {k: property(cls._make_getter(k), cls._make_setter(k)) for k in self.keys}
123
+ self.__class__ = type(new_cls_name, (cls,), properties)
124
+
125
+ def _register_methods(self):
126
+ for name, method in self.__methods.items():
127
+ # use MethodType to pass self (this object) to the method
128
+ setattr(self, name, MethodType(method, self))
129
+
130
+ def __getitem__(self, key):
131
+ ret = self.__entries[key]
132
+ if isinstance(ret, SNodeHostAccess):
133
+ ret = ret.accessor.getter(*ret.key)
134
+ return ret
135
+
136
+ def __setitem__(self, key, value):
137
+ if isinstance(self.__entries[key], SNodeHostAccess):
138
+ self.__entries[key].accessor.setter(value, *self.__entries[key].key)
139
+ else:
140
+ if in_python_scope():
141
+ if isinstance(self.__entries[key], Struct) or isinstance(self.__entries[key], Matrix):
142
+ self.__entries[key]._set_entries(value)
143
+ else:
144
+ if isinstance(value, numbers.Number):
145
+ self.__entries[key] = value
146
+ else:
147
+ raise TypeError("A number is expected when assigning struct members")
148
+ else:
149
+ self.__entries[key] = value
150
+
151
+ def _set_entries(self, value):
152
+ if isinstance(value, dict):
153
+ value = Struct(value)
154
+ for k in self.keys:
155
+ self[k] = value[k]
156
+ self.__dtype = value.__dtype
157
+
158
+ @staticmethod
159
+ def _make_getter(key):
160
+ def getter(self):
161
+ """Get an entry from custom struct by name."""
162
+ return self[key]
163
+
164
+ return getter
165
+
166
+ @staticmethod
167
+ def _make_setter(key):
168
+ @python_scope
169
+ def setter(self, value):
170
+ self[key] = value
171
+
172
+ return setter
173
+
174
+ @gstaichi_scope
175
+ def _assign(self, other):
176
+ if not isinstance(other, (dict, Struct)):
177
+ raise GsTaichiTypeError("Only dict or Struct can be assigned to a Struct")
178
+ if isinstance(other, dict):
179
+ other = Struct(other)
180
+ if self.__entries.keys() != other.__entries.keys():
181
+ raise GsTaichiTypeError(f"Member mismatch between structs {self.keys}, {other.keys}")
182
+ for k, v in self.items:
183
+ v._assign(other.__entries[k])
184
+ self.__dtype = other.__dtype
185
+ return self
186
+
187
+ def __len__(self):
188
+ """Get the number of entries in a custom struct"""
189
+ return len(self.__entries)
190
+
191
+ def __iter__(self):
192
+ return self.__entries.values()
193
+
194
+ def __str__(self):
195
+ """Python scope struct array print support."""
196
+ if impl.inside_kernel():
197
+ item_str = ", ".join([str(k) + "=" + str(v) for k, v in self.items])
198
+ item_str += f", struct_methods={self.__methods}"
199
+ return f"<ti.Struct {item_str}>"
200
+ return str(self.to_dict())
201
+
202
+ def __repr__(self):
203
+ return str(self.to_dict())
204
+
205
+ def to_dict(self, include_methods=False, include_ndim=False):
206
+ """Converts the Struct to a dictionary.
207
+
208
+ Args:
209
+ include_methods (bool): Whether any struct methods should be included
210
+ in the result dictionary under the key '__struct_methods'.
211
+
212
+ Returns:
213
+ Dict: The result dictionary.
214
+ """
215
+ res_dict = {
216
+ k: (
217
+ v.to_dict(include_methods=include_methods, include_ndim=include_ndim)
218
+ if isinstance(v, Struct)
219
+ else v.to_list() if isinstance(v, Matrix) else v
220
+ )
221
+ for k, v in self.__entries.items()
222
+ }
223
+ if include_methods:
224
+ res_dict["__struct_methods"] = self.__methods
225
+ if include_ndim:
226
+ res_dict["__matrix_ndim"] = dict()
227
+ for k, v in self.__entries.items():
228
+ if isinstance(v, Matrix):
229
+ res_dict["__matrix_ndim"][k] = v.ndim
230
+ return res_dict
231
+
232
+ @classmethod
233
+ @python_scope
234
+ def field(
235
+ cls,
236
+ members,
237
+ methods={},
238
+ shape=None,
239
+ name="<Struct>",
240
+ offset=None,
241
+ needs_grad=False,
242
+ needs_dual=False,
243
+ layout=Layout.AOS,
244
+ ):
245
+ """Creates a :class:`~gstaichi.StructField` with each element
246
+ has this struct as its type.
247
+
248
+ Args:
249
+ members (dict): a dict, each item is like `name: type`.
250
+ methods (dict): a dict of methods that should be included with
251
+ the field. Each struct item of the field will have the
252
+ methods as instance functions.
253
+ shape (Tuple[int]): width and height of the field.
254
+ offset (Tuple[int]): offset of the indices of the created field.
255
+ For example if `offset=(-10, -10)` the indices of the field
256
+ will start at `(-10, -10)`, not `(0, 0)`.
257
+ needs_grad (bool): enabling grad field (reverse mode autodiff) or not.
258
+ needs_dual (bool): enabling dual field (forward mode autodiff) or not.
259
+ layout: AOS or SOA.
260
+
261
+ Example:
262
+
263
+ >>> vec3 = ti.types.vector(3, ti.f32)
264
+ >>> sphere = {"center": vec3, "radius": float}
265
+ >>> F = ti.Struct.field(sphere, shape=(3, 3))
266
+ >>> F
267
+ {'center': array([[[0., 0., 0.],
268
+ [0., 0., 0.],
269
+ [0., 0., 0.]],
270
+
271
+ [[0., 0., 0.],
272
+ [0., 0., 0.],
273
+ [0., 0., 0.]],
274
+
275
+ [[0., 0., 0.],
276
+ [0., 0., 0.],
277
+ [0., 0., 0.]]], dtype=float32), 'radius': array([[0., 0., 0.],
278
+ [0., 0., 0.],
279
+ [0., 0., 0.]], dtype=float32)}
280
+ """
281
+
282
+ if shape is None and offset is not None:
283
+ raise GsTaichiSyntaxError("shape cannot be None when offset is being set")
284
+
285
+ field_dict = {}
286
+
287
+ for key, dtype in members.items():
288
+ field_name = name + "." + key
289
+ if isinstance(dtype, CompoundType):
290
+ if isinstance(dtype, StructType):
291
+ field_dict[key] = dtype.field(
292
+ shape=None,
293
+ name=field_name,
294
+ offset=offset,
295
+ needs_grad=needs_grad,
296
+ needs_dual=needs_dual,
297
+ )
298
+ else:
299
+ field_dict[key] = dtype.field(
300
+ shape=None,
301
+ name=field_name,
302
+ offset=offset,
303
+ needs_grad=needs_grad,
304
+ needs_dual=needs_dual,
305
+ ndim=getattr(dtype, "ndim", 2),
306
+ )
307
+ else:
308
+ field_dict[key] = impl.field(
309
+ dtype,
310
+ shape=None,
311
+ name=field_name,
312
+ offset=offset,
313
+ needs_grad=needs_grad,
314
+ needs_dual=needs_dual,
315
+ )
316
+
317
+ if shape is not None:
318
+ if isinstance(shape, numbers.Number):
319
+ shape = (shape,)
320
+ if isinstance(offset, numbers.Number):
321
+ offset = (offset,)
322
+
323
+ if offset is not None and len(shape) != len(offset):
324
+ raise GsTaichiSyntaxError(
325
+ f"The dimensionality of shape and offset must be the same ({len(shape)} != {len(offset)})"
326
+ )
327
+ dim = len(shape)
328
+ if layout == Layout.SOA:
329
+ for e in field_dict.values():
330
+ impl.root.dense(impl.index_nd(dim), shape).place(e, offset=offset)
331
+ if needs_grad:
332
+ for e in field_dict.values():
333
+ impl.root.dense(impl.index_nd(dim), shape).place(e.grad, offset=offset)
334
+ if needs_dual:
335
+ for e in field_dict.values():
336
+ impl.root.dense(impl.index_nd(dim), shape).place(e.dual, offset=offset)
337
+ else:
338
+ impl.root.dense(impl.index_nd(dim), shape).place(*tuple(field_dict.values()), offset=offset)
339
+ if needs_grad:
340
+ grads = tuple(e.grad for e in field_dict.values())
341
+ impl.root.dense(impl.index_nd(dim), shape).place(*grads, offset=offset)
342
+
343
+ if needs_dual:
344
+ duals = tuple(e.dual for e in field_dict.values())
345
+ impl.root.dense(impl.index_nd(dim), shape).place(*duals, offset=offset)
346
+
347
+ return StructField(field_dict, methods, name=name)
348
+
349
+
350
+ class _IntermediateStruct(Struct):
351
+ """Intermediate struct class for compiler internal use only.
352
+
353
+ Args:
354
+ entries (Dict[str, Union[Expr, Matrix, Struct]]): keys and values for struct members.
355
+ Any methods included under the key '__struct_methods' will be applied to each
356
+ struct instance.
357
+ """
358
+
359
+ def __init__(self, entries):
360
+ assert isinstance(entries, dict)
361
+ self._Struct__methods = entries.pop("__struct_methods", {})
362
+ self._register_methods()
363
+ self._Struct__entries = entries
364
+ self._register_members()
365
+
366
+
367
+ class StructField(Field):
368
+ """GsTaichi struct field with SNode implementation.
369
+
370
+ Instead of directly constraining Expr entries, the StructField object
371
+ directly hosts members as `Field` instances to support nested structs.
372
+
373
+ Args:
374
+ field_dict (Dict[str, Field]): Struct field members.
375
+ struct_methods (Dict[str, callable]): Dictionary of functions to apply
376
+ to each struct instance in the field.
377
+ name (string, optional): The custom name of the field.
378
+ """
379
+
380
+ def __init__(self, field_dict, struct_methods, name=None, is_primal=True):
381
+ # will not call Field initializer
382
+ self.field_dict = field_dict
383
+ self.struct_methods = struct_methods
384
+ self.name = name
385
+ self.grad = None
386
+ self.dual = None
387
+ if is_primal:
388
+ grad_field_dict = {}
389
+ for k, v in self.field_dict.items():
390
+ grad_field_dict[k] = v.grad
391
+ self.grad = StructField(grad_field_dict, struct_methods, name + ".grad", is_primal=False)
392
+
393
+ dual_field_dict = {}
394
+ for k, v in self.field_dict.items():
395
+ dual_field_dict[k] = v.dual
396
+ self.dual = StructField(dual_field_dict, struct_methods, name + ".dual", is_primal=False)
397
+ self._register_fields()
398
+
399
+ @property
400
+ def keys(self):
401
+ """Returns the list of names of the field members.
402
+
403
+ Example::
404
+
405
+ >>> f1 = ti.Vector.field(3, ti.f32, shape=(3, 3))
406
+ >>> f2 = ti.field(ti.f32, shape=(3, 3))
407
+ >>> F = ti.StructField({"center": f1, "radius": f2})
408
+ >>> F.keys
409
+ ['center', 'radius']
410
+ """
411
+ return list(self.field_dict.keys())
412
+
413
+ @property
414
+ def _members(self):
415
+ return list(self.field_dict.values())
416
+
417
+ @property
418
+ def _items(self):
419
+ return self.field_dict.items()
420
+
421
+ @staticmethod
422
+ def _make_getter(key):
423
+ def getter(self):
424
+ """Get an entry from custom struct by name."""
425
+ return self.field_dict[key]
426
+
427
+ return getter
428
+
429
+ @staticmethod
430
+ def _make_setter(key):
431
+ @python_scope
432
+ def setter(self, value):
433
+ self.field_dict[key] = value
434
+
435
+ return setter
436
+
437
+ def _register_fields(self):
438
+ for k in self.keys:
439
+ setattr(self, k, self.field_dict[k])
440
+
441
+ def _get_field_members(self):
442
+ """Gets A flattened list of all struct elements.
443
+
444
+ Returns:
445
+ A list of struct elements.
446
+ """
447
+ field_members = []
448
+ for m in self._members:
449
+ assert isinstance(m, Field)
450
+ field_members += m._get_field_members()
451
+ return field_members
452
+
453
+ @property
454
+ def _snode(self):
455
+ """Gets representative SNode for info purposes.
456
+
457
+ Returns:
458
+ SNode: Representative SNode (SNode of first field member).
459
+ """
460
+ return self._members[0]._snode
461
+
462
+ def _loop_range(self):
463
+ """Gets SNode of representative field member for loop range info.
464
+
465
+ Returns:
466
+ gstaichi_python.SNode: SNode of representative (first) field member.
467
+ """
468
+ return self._members[0]._loop_range()
469
+
470
+ @python_scope
471
+ def copy_from(self, other):
472
+ """Copies all elements from another field.
473
+
474
+ The shape of the other field needs to be the same as `self`.
475
+
476
+ Args:
477
+ other (Field): The source field.
478
+ """
479
+ assert isinstance(other, Field)
480
+ assert set(self.keys) == set(other.keys)
481
+ for k in self.keys:
482
+ self.field_dict[k].copy_from(other.get_member_field(k))
483
+
484
+ @python_scope
485
+ def fill(self, val):
486
+ """Fills this struct field with a specified value.
487
+
488
+ Args:
489
+ val (Union[int, float]): Value to fill.
490
+ """
491
+ for v in self._members:
492
+ v.fill(val)
493
+
494
+ def _initialize_host_accessors(self):
495
+ for v in self._members:
496
+ v._initialize_host_accessors()
497
+
498
+ def get_member_field(self, key):
499
+ """Creates a ScalarField using a specific field member.
500
+
501
+ Args:
502
+ key (str): Specified key of the field member.
503
+
504
+ Returns:
505
+ ScalarField: The result ScalarField.
506
+ """
507
+ return self.field_dict[key]
508
+
509
+ @python_scope
510
+ def from_numpy(self, array_dict):
511
+ """Copies the data from a set of `numpy.array` into this field.
512
+
513
+ The argument `array_dict` must be a dictionay-like object, it
514
+ contains all the keys in this field and the copying process
515
+ between corresponding items can be performed.
516
+ """
517
+ for k, v in self._items:
518
+ v.from_numpy(array_dict[k])
519
+
520
+ @python_scope
521
+ def from_torch(self, array_dict):
522
+ """Copies the data from a set of `torch.tensor` into this field.
523
+
524
+ The argument `array_dict` must be a dictionay-like object, it
525
+ contains all the keys in this field and the copying process
526
+ between corresponding items can be performed.
527
+ """
528
+ for k, v in self._items:
529
+ v.from_torch(array_dict[k])
530
+
531
+ @python_scope
532
+ def to_numpy(self):
533
+ """Converts the Struct field instance to a dictionary of NumPy arrays.
534
+
535
+ The dictionary may be nested when converting nested structs.
536
+
537
+ Returns:
538
+ Dict[str, Union[numpy.ndarray, Dict]]: The result NumPy array.
539
+ """
540
+ return {k: v.to_numpy() for k, v in self._items}
541
+
542
+ @python_scope
543
+ def to_torch(self, device=None):
544
+ """Converts the Struct field instance to a dictionary of PyTorch tensors.
545
+
546
+ The dictionary may be nested when converting nested structs.
547
+
548
+ Args:
549
+ device (torch.device, optional): The
550
+ desired device of returned tensor.
551
+
552
+ Returns:
553
+ Dict[str, Union[torch.Tensor, Dict]]: The result
554
+ PyTorch tensor.
555
+ """
556
+ return {k: v.to_torch(device=device) for k, v in self._items}
557
+
558
+ @python_scope
559
+ def __setitem__(self, indices, element):
560
+ self._initialize_host_accessors()
561
+ self[indices]._set_entries(element)
562
+
563
+ @python_scope
564
+ def __getitem__(self, indices):
565
+ self._initialize_host_accessors()
566
+ # scalar fields does not instantiate SNodeHostAccess by default
567
+ entries = {
568
+ k: v._host_access(self._pad_key(indices))[0] if isinstance(v, ScalarField) else v[indices]
569
+ for k, v in self._items
570
+ }
571
+ entries["__struct_methods"] = self.struct_methods
572
+ return Struct(entries)
573
+
574
+
575
+ class StructType(CompoundType):
576
+ def __init__(self, **kwargs):
577
+ self.members = {}
578
+ self.methods = {}
579
+ elements = []
580
+ for k, dtype in kwargs.items():
581
+ if k == "__struct_methods":
582
+ self.methods = dtype
583
+ elif isinstance(dtype, StructType):
584
+ self.members[k] = dtype
585
+ elements.append([dtype.dtype, k])
586
+ elif isinstance(dtype, MatrixType):
587
+ self.members[k] = dtype
588
+ elements.append([dtype.tensor_type, k])
589
+ else:
590
+ dtype = cook_dtype(dtype)
591
+ self.members[k] = dtype
592
+ elements.append([dtype, k])
593
+ self.dtype = _ti_core.get_type_factory_instance().get_struct_type(elements)
594
+
595
+ def __call__(self, *args, **kwargs):
596
+ """Create an instance of this struct type."""
597
+ d = {}
598
+ items = self.members.items()
599
+ # iterate over the members of this struct
600
+ for index, pair in enumerate(items):
601
+ name, dtype = pair # (member name, member type)
602
+ if index < len(args): # set from args
603
+ data = args[index]
604
+ else: # set from kwargs
605
+ data = kwargs.get(name, 0)
606
+
607
+ # If dtype is CompoundType and data is a scalar, it cannot be
608
+ # casted in the self.cast call later. We need an initialization here.
609
+ if isinstance(dtype, CompoundType) and not isinstance(data, (dict, Struct)):
610
+ data = dtype(data)
611
+
612
+ d[name] = data
613
+
614
+ entries = Struct(d)
615
+ entries._Struct__dtype = self.dtype
616
+ struct = self.cast(entries)
617
+ struct._Struct__dtype = self.dtype
618
+ return struct
619
+
620
+ def __instancecheck__(self, instance):
621
+ if not isinstance(instance, Struct):
622
+ return False
623
+ if list(self.members.keys()) != list(instance._Struct__entries.keys()):
624
+ return False
625
+ if (
626
+ hasattr(instance, "_Struct__dtype")
627
+ and instance._Struct__dtype is not None
628
+ and instance._Struct__dtype != self.dtype
629
+ ):
630
+ return False
631
+ for index, (name, dtype) in enumerate(self.members.items()):
632
+ val = instance._members[index]
633
+ if isinstance(dtype, StructType):
634
+ if not isinstance(val, dtype):
635
+ return False
636
+ elif isinstance(dtype, MatrixType):
637
+ if isinstance(val, Expr):
638
+ if not val.is_tensor():
639
+ return False
640
+ if val.get_shape() != dtype.get_shape():
641
+ return False
642
+ elif dtype in primitive_types.integer_types:
643
+ if isinstance(val, Expr):
644
+ if val.is_tensor() or val.is_struct() or val.element_type() not in primitive_types.integer_types:
645
+ return False
646
+ elif not isinstance(val, (int, np.integer)):
647
+ return False
648
+ elif dtype in primitive_types.real_types:
649
+ if isinstance(val, Expr):
650
+ if val.is_tensor() or val.is_struct() or val.element_type() not in primitive_types.real_types:
651
+ return False
652
+ elif not isinstance(val, (float, np.floating)):
653
+ return False
654
+ return True
655
+
656
+ def from_gstaichi_object(self, func_ret, ret_index=()):
657
+ d = {}
658
+ items = self.members.items()
659
+ for index, pair in enumerate(items):
660
+ name, dtype = pair
661
+ if isinstance(dtype, CompoundType):
662
+ d[name] = dtype.from_gstaichi_object(func_ret, ret_index + (index,))
663
+ else:
664
+ d[name] = expr.Expr(
665
+ _ti_core.make_get_element_expr(
666
+ func_ret.ptr,
667
+ ret_index + (index,),
668
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
669
+ )
670
+ )
671
+ d["__struct_methods"] = self.methods
672
+
673
+ struct = Struct(d)
674
+ struct._Struct__dtype = self.dtype
675
+ return struct
676
+
677
+ def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
678
+ d = {}
679
+ items = self.members.items()
680
+ for index, pair in enumerate(items):
681
+ name, dtype = pair
682
+ if isinstance(dtype, CompoundType):
683
+ d[name] = dtype.from_kernel_struct_ret(launch_ctx, ret_index + (index,))
684
+ else:
685
+ if dtype in primitive_types.integer_types:
686
+ if is_signed(cook_dtype(dtype)):
687
+ d[name] = launch_ctx.get_struct_ret_int(ret_index + (index,))
688
+ else:
689
+ d[name] = launch_ctx.get_struct_ret_uint(ret_index + (index,))
690
+ elif dtype in primitive_types.real_types:
691
+ d[name] = launch_ctx.get_struct_ret_float(ret_index + (index,))
692
+ else:
693
+ raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index + (index, )}")
694
+ d["__struct_methods"] = self.methods
695
+
696
+ struct = Struct(d)
697
+ struct._Struct__dtype = self.dtype
698
+ return struct
699
+
700
+ def set_kernel_struct_args(self, struct, launch_ctx, ret_index=()):
701
+ # TODO: move this to class Struct after we add dtype to Struct
702
+ items = self.members.items()
703
+ for index, pair in enumerate(items):
704
+ name, dtype = pair
705
+ if isinstance(dtype, CompoundType):
706
+ dtype.set_kernel_struct_args(struct[name], launch_ctx, ret_index + (index,))
707
+ else:
708
+ if dtype in primitive_types.integer_types:
709
+ if is_signed(cook_dtype(dtype)):
710
+ launch_ctx.set_struct_arg_int(ret_index + (index,), struct[name])
711
+ else:
712
+ launch_ctx.set_struct_arg_uint(ret_index + (index,), struct[name])
713
+ elif dtype in primitive_types.real_types:
714
+ launch_ctx.set_struct_arg_float(ret_index + (index,), struct[name])
715
+ else:
716
+ raise GsTaichiRuntimeTypeError(f"Invalid argument type on index={ret_index + (index, )}")
717
+
718
+ def cast(self, struct):
719
+ # sanity check members
720
+ if self.members.keys() != struct._Struct__entries.keys():
721
+ raise GsTaichiSyntaxError("Incompatible arguments for custom struct members!")
722
+ entries = {}
723
+ for k, dtype in self.members.items():
724
+ if isinstance(dtype, MatrixType):
725
+ entries[k] = dtype(struct._Struct__entries[k])
726
+ elif isinstance(dtype, CompoundType):
727
+ entries[k] = dtype.cast(struct._Struct__entries[k])
728
+ else:
729
+ if in_python_scope():
730
+ v = struct._Struct__entries[k]
731
+ entries[k] = int(v) if dtype in primitive_types.integer_types else float(v)
732
+ else:
733
+ entries[k] = ops.cast(struct._Struct__entries[k], dtype)
734
+ entries["__struct_methods"] = self.methods
735
+ struct = Struct(entries)
736
+ struct._Struct__dtype = self.dtype
737
+ return struct
738
+
739
+ def filled_with_scalar(self, value):
740
+ entries = {}
741
+ for k, dtype in self.members.items():
742
+ if isinstance(dtype, MatrixType):
743
+ entries[k] = dtype(value)
744
+ elif isinstance(dtype, CompoundType):
745
+ entries[k] = dtype.filled_with_scalar(value)
746
+ else:
747
+ entries[k] = value
748
+ entries["__struct_methods"] = self.methods
749
+ struct = Struct(entries)
750
+ struct._Struct__dtype = self.dtype
751
+ return struct
752
+
753
+ def field(self, **kwargs):
754
+ return Struct.field(self.members, self.methods, **kwargs)
755
+
756
+ def __str__(self):
757
+ """Python scope struct type print support."""
758
+ item_str = ", ".join([str(k) + "=" + str(v) for k, v in self.members.items()])
759
+ item_str += f", struct_methods={self.methods}"
760
+ return f"<ti.StructType {item_str}>"
761
+
762
+
763
+ def dataclass(cls):
764
+ """Converts a class with field annotations and methods into a gstaichi struct type.
765
+
766
+ This will return a normal custom struct type, with the functions added to it.
767
+ Struct fields can be generated in the normal way from the struct type.
768
+ Functions in the class can be run on the struct instance.
769
+
770
+ This class decorator inspects the class for annotations and methods and
771
+ 1. Sets the annotations as fields for the struct
772
+ 2. Attaches the methods to the struct type
773
+
774
+ Example::
775
+
776
+ >>> @ti.dataclass
777
+ >>> class Sphere:
778
+ >>> center: vec3
779
+ >>> radius: ti.f32
780
+ >>>
781
+ >>> @ti.func
782
+ >>> def area(self):
783
+ >>> return 4 * 3.14 * self.radius * self.radius
784
+ >>>
785
+ >>> my_spheres = Sphere.field(shape=(n, ))
786
+ >>> my_sphere[2].area()
787
+
788
+ Args:
789
+ cls (Class): the class with annotations and methods to convert to a struct
790
+
791
+ Returns:
792
+ A gstaichi struct with the annotations as fields
793
+ and methods from the class attached.
794
+ """
795
+ # save the annotation fields for the struct
796
+ fields = getattr(cls, "__annotations__", {})
797
+ # raise error if there are default values
798
+ for k in fields.keys():
799
+ if hasattr(cls, k):
800
+ raise GsTaichiSyntaxError("Default value in @dataclass is not supported.")
801
+ # get the class methods to be attached to the struct types
802
+ fields["__struct_methods"] = {
803
+ attribute: getattr(cls, attribute)
804
+ for attribute in dir(cls)
805
+ if callable(getattr(cls, attribute)) and not attribute.startswith("__")
806
+ }
807
+ return StructType(**fields)
808
+
809
+
810
+ __all__ = ["Struct", "StructField", "dataclass"]