gstaichi 0.1.25.dev0__cp312-cp312-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. gstaichi/__init__.py +40 -0
  2. gstaichi/__main__.py +5 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-312-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2939 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +249 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_main.py +545 -0
  15. gstaichi/_snode/__init__.py +5 -0
  16. gstaichi/_snode/fields_builder.py +187 -0
  17. gstaichi/_snode/snode_tree.py +34 -0
  18. gstaichi/_test_tools/__init__.py +0 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_version.py +1 -0
  21. gstaichi/_version_check.py +103 -0
  22. gstaichi/ad/__init__.py +3 -0
  23. gstaichi/ad/_ad.py +530 -0
  24. gstaichi/algorithms/__init__.py +3 -0
  25. gstaichi/algorithms/_algorithms.py +117 -0
  26. gstaichi/assets/.git +1 -0
  27. gstaichi/assets/Go-Regular.ttf +0 -0
  28. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  29. gstaichi/examples/minimal.py +28 -0
  30. gstaichi/experimental.py +16 -0
  31. gstaichi/lang/__init__.py +50 -0
  32. gstaichi/lang/_ndarray.py +352 -0
  33. gstaichi/lang/_ndrange.py +152 -0
  34. gstaichi/lang/_template_mapper.py +199 -0
  35. gstaichi/lang/_texture.py +172 -0
  36. gstaichi/lang/_wrap_inspect.py +189 -0
  37. gstaichi/lang/any_array.py +99 -0
  38. gstaichi/lang/argpack.py +411 -0
  39. gstaichi/lang/ast/__init__.py +5 -0
  40. gstaichi/lang/ast/ast_transformer.py +1318 -0
  41. gstaichi/lang/ast/ast_transformer_utils.py +341 -0
  42. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  43. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  44. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  45. gstaichi/lang/ast/checkers.py +106 -0
  46. gstaichi/lang/ast/symbol_resolver.py +57 -0
  47. gstaichi/lang/ast/transform.py +9 -0
  48. gstaichi/lang/common_ops.py +310 -0
  49. gstaichi/lang/exception.py +80 -0
  50. gstaichi/lang/expr.py +180 -0
  51. gstaichi/lang/field.py +466 -0
  52. gstaichi/lang/impl.py +1241 -0
  53. gstaichi/lang/kernel_arguments.py +157 -0
  54. gstaichi/lang/kernel_impl.py +1382 -0
  55. gstaichi/lang/matrix.py +1881 -0
  56. gstaichi/lang/matrix_ops.py +341 -0
  57. gstaichi/lang/matrix_ops_utils.py +190 -0
  58. gstaichi/lang/mesh.py +687 -0
  59. gstaichi/lang/misc.py +778 -0
  60. gstaichi/lang/ops.py +1494 -0
  61. gstaichi/lang/runtime_ops.py +13 -0
  62. gstaichi/lang/shell.py +35 -0
  63. gstaichi/lang/simt/__init__.py +5 -0
  64. gstaichi/lang/simt/block.py +94 -0
  65. gstaichi/lang/simt/grid.py +7 -0
  66. gstaichi/lang/simt/subgroup.py +191 -0
  67. gstaichi/lang/simt/warp.py +96 -0
  68. gstaichi/lang/snode.py +489 -0
  69. gstaichi/lang/source_builder.py +150 -0
  70. gstaichi/lang/struct.py +855 -0
  71. gstaichi/lang/util.py +381 -0
  72. gstaichi/linalg/__init__.py +8 -0
  73. gstaichi/linalg/matrixfree_cg.py +310 -0
  74. gstaichi/linalg/sparse_cg.py +59 -0
  75. gstaichi/linalg/sparse_matrix.py +303 -0
  76. gstaichi/linalg/sparse_solver.py +123 -0
  77. gstaichi/math/__init__.py +11 -0
  78. gstaichi/math/_complex.py +205 -0
  79. gstaichi/math/mathimpl.py +886 -0
  80. gstaichi/profiler/__init__.py +6 -0
  81. gstaichi/profiler/kernel_metrics.py +260 -0
  82. gstaichi/profiler/kernel_profiler.py +586 -0
  83. gstaichi/profiler/memory_profiler.py +15 -0
  84. gstaichi/profiler/scoped_profiler.py +36 -0
  85. gstaichi/sparse/__init__.py +3 -0
  86. gstaichi/sparse/_sparse_grid.py +77 -0
  87. gstaichi/tools/__init__.py +12 -0
  88. gstaichi/tools/diagnose.py +117 -0
  89. gstaichi/tools/np2ply.py +364 -0
  90. gstaichi/tools/vtk.py +38 -0
  91. gstaichi/types/__init__.py +19 -0
  92. gstaichi/types/annotations.py +47 -0
  93. gstaichi/types/compound_types.py +90 -0
  94. gstaichi/types/enums.py +49 -0
  95. gstaichi/types/ndarray_type.py +147 -0
  96. gstaichi/types/primitive_types.py +206 -0
  97. gstaichi/types/quant.py +88 -0
  98. gstaichi/types/texture_type.py +85 -0
  99. gstaichi/types/utils.py +13 -0
  100. gstaichi-0.1.25.dev0.data/data/include/GLFW/glfw3.h +6389 -0
  101. gstaichi-0.1.25.dev0.data/data/include/GLFW/glfw3native.h +594 -0
  102. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
  103. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
  104. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  105. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
  106. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  107. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  108. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv.h +2568 -0
  109. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv.hpp +2579 -0
  110. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  111. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  112. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  113. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  114. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  115. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  116. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  117. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  118. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  119. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  120. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  121. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  122. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  123. gstaichi-0.1.25.dev0.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  124. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  125. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  126. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  127. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  128. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  129. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  130. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  131. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  132. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  133. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  134. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  135. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  136. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  137. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  138. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  139. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  140. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  141. gstaichi-0.1.25.dev0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  142. gstaichi-0.1.25.dev0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  143. gstaichi-0.1.25.dev0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  144. gstaichi-0.1.25.dev0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  145. gstaichi-0.1.25.dev0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  146. gstaichi-0.1.25.dev0.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  147. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  148. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  149. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  150. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  151. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  152. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  153. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  154. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  155. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  156. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  157. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  158. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  159. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  160. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  161. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  162. gstaichi-0.1.25.dev0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  163. gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
  164. gstaichi-0.1.25.dev0.dist-info/RECORD +168 -0
  165. gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
  166. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  167. gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
  168. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,855 @@
1
+ # type: ignore
2
+
3
+ import numbers
4
+ from types import MethodType
5
+
6
+ import numpy as np
7
+
8
+ from gstaichi._lib import core as _ti_core
9
+ from gstaichi.lang import expr, impl, ops
10
+ from gstaichi.lang.exception import (
11
+ GsTaichiRuntimeTypeError,
12
+ GsTaichiSyntaxError,
13
+ GsTaichiTypeError,
14
+ )
15
+ from gstaichi.lang.expr import Expr
16
+ from gstaichi.lang.field import Field, ScalarField, SNodeHostAccess
17
+ from gstaichi.lang.matrix import Matrix, MatrixType
18
+ from gstaichi.lang.util import cook_dtype, gstaichi_scope, in_python_scope, python_scope
19
+ from gstaichi.types import primitive_types
20
+ from gstaichi.types.compound_types import CompoundType
21
+ from gstaichi.types.enums import Layout
22
+ from gstaichi.types.utils import is_signed
23
+
24
+
25
+ class Struct:
26
+ """The Struct type class.
27
+
28
+ A struct is a dictionary-like data structure that stores members as
29
+ (key, value) pairs. Valid data members of a struct can be scalars,
30
+ matrices or other dictionary-like structures.
31
+
32
+ Args:
33
+ entries (Dict[str, Union[Dict, Expr, Matrix, Struct]]): \
34
+ keys and values for struct members. Entries can optionally
35
+ include a dictionary of functions with the key '__struct_methods'
36
+ which will be attached to the struct for executing on the struct data.
37
+
38
+ Returns:
39
+ An instance of this struct.
40
+
41
+ Example::
42
+ _
43
+ >>> vec3 = ti.types.vector(3, ti.f32)
44
+ >>> a = ti.Struct(v=vec3([0, 0, 0]), t=1.0)
45
+ >>> print(a.items)
46
+ dict_items([('v', [0. 0. 0.]), ('t', 1.0)])
47
+ >>>
48
+ >>> B = ti.Struct(v=vec3([0., 0., 0.]), t=1.0, A=a)
49
+ >>> print(B.items)
50
+ dict_items([('v', [0. 0. 0.]), ('t', 1.0), ('A', {'v': [[0.], [0.], [0.]], 't': 1.0})])
51
+ """
52
+
53
+ _is_gstaichi_class = True
54
+ _instance_count = 0
55
+
56
+ def __init__(self, *args, **kwargs):
57
+ # converts lists to matrices and dicts to structs
58
+ if len(args) == 1 and kwargs == {} and isinstance(args[0], dict):
59
+ self.__entries = args[0]
60
+ elif len(args) == 0:
61
+ self.__entries = kwargs
62
+ else:
63
+ raise GsTaichiSyntaxError(
64
+ "Custom structs need to be initialized using either dictionary or keyword arguments"
65
+ )
66
+ self.__methods = self.__entries.pop("__struct_methods", {})
67
+ matrix_ndim = self.__entries.pop("__matrix_ndim", {})
68
+ self._register_methods()
69
+
70
+ for k, v in self.__entries.items():
71
+ if isinstance(v, (list, tuple)):
72
+ v = Matrix(v)
73
+ if isinstance(v, dict):
74
+ v = Struct(v)
75
+ self.__entries[k] = v if in_python_scope() else impl.expr_init(v)
76
+ self._register_members()
77
+ self.__dtype = None
78
+
79
+ @property
80
+ def keys(self):
81
+ """Returns the list of member names in string format.
82
+
83
+ Example::
84
+
85
+ >>> vec3 = ti.types.vector(3, ti.f32)
86
+ >>> sphere = ti.Struct(center=vec3([0, 0, 0]), radius=1.0)
87
+ >>> a.keys
88
+ ['center', 'radius']
89
+ """
90
+ return list(self.__entries.keys())
91
+
92
+ @property
93
+ def _members(self):
94
+ return list(self.__entries.values())
95
+
96
+ @property
97
+ def entries(self):
98
+ return self.__entries
99
+
100
+ @property
101
+ def methods(self):
102
+ return self.__methods
103
+
104
+ @property
105
+ def items(self):
106
+ """Returns the items in this struct.
107
+
108
+ Example::
109
+
110
+ >>> vec3 = ti.types.vector(3, ti.f32)
111
+ >>> sphere = ti.Struct(center=vec3([0, 0, 0]), radius=1.0)
112
+ >>> sphere.items
113
+ dict_items([('center', 2), ('radius', 1.0)])
114
+ """
115
+ return self.__entries.items()
116
+
117
+ def _register_members(self):
118
+ # https://stackoverflow.com/questions/48448074/adding-a-property-to-an-existing-object-instance
119
+ cls = self.__class__
120
+ new_cls_name = cls.__name__ + str(cls._instance_count)
121
+ cls._instance_count += 1
122
+ properties = {k: property(cls._make_getter(k), cls._make_setter(k)) for k in self.keys}
123
+ self.__class__ = type(new_cls_name, (cls,), properties)
124
+
125
+ def _register_methods(self):
126
+ for name, method in self.__methods.items():
127
+ # use MethodType to pass self (this object) to the method
128
+ setattr(self, name, MethodType(method, self))
129
+
130
+ def __getitem__(self, key):
131
+ ret = self.__entries[key]
132
+ if isinstance(ret, SNodeHostAccess):
133
+ ret = ret.accessor.getter(*ret.key)
134
+ return ret
135
+
136
+ def __setitem__(self, key, value):
137
+ if isinstance(self.__entries[key], SNodeHostAccess):
138
+ self.__entries[key].accessor.setter(value, *self.__entries[key].key)
139
+ else:
140
+ if in_python_scope():
141
+ if isinstance(self.__entries[key], Struct) or isinstance(self.__entries[key], Matrix):
142
+ self.__entries[key]._set_entries(value)
143
+ else:
144
+ if isinstance(value, numbers.Number):
145
+ self.__entries[key] = value
146
+ else:
147
+ raise TypeError("A number is expected when assigning struct members")
148
+ else:
149
+ self.__entries[key] = value
150
+
151
+ def _set_entries(self, value):
152
+ if isinstance(value, dict):
153
+ value = Struct(value)
154
+ for k in self.keys:
155
+ self[k] = value[k]
156
+ self.__dtype = value.__dtype
157
+
158
+ @staticmethod
159
+ def _make_getter(key):
160
+ def getter(self):
161
+ """Get an entry from custom struct by name."""
162
+ return self[key]
163
+
164
+ return getter
165
+
166
+ @staticmethod
167
+ def _make_setter(key):
168
+ @python_scope
169
+ def setter(self, value):
170
+ self[key] = value
171
+
172
+ return setter
173
+
174
+ @gstaichi_scope
175
+ def _assign(self, other):
176
+ if not isinstance(other, (dict, Struct)):
177
+ raise GsTaichiTypeError("Only dict or Struct can be assigned to a Struct")
178
+ if isinstance(other, dict):
179
+ other = Struct(other)
180
+ if self.__entries.keys() != other.__entries.keys():
181
+ raise GsTaichiTypeError(f"Member mismatch between structs {self.keys}, {other.keys}")
182
+ for k, v in self.items:
183
+ v._assign(other.__entries[k])
184
+ self.__dtype = other.__dtype
185
+ return self
186
+
187
+ def __len__(self):
188
+ """Get the number of entries in a custom struct"""
189
+ return len(self.__entries)
190
+
191
+ def __iter__(self):
192
+ return self.__entries.values()
193
+
194
+ def __str__(self):
195
+ """Python scope struct array print support."""
196
+ if impl.inside_kernel():
197
+ item_str = ", ".join([str(k) + "=" + str(v) for k, v in self.items])
198
+ item_str += f", struct_methods={self.__methods}"
199
+ return f"<ti.Struct {item_str}>"
200
+ return str(self.to_dict())
201
+
202
+ def __repr__(self):
203
+ return str(self.to_dict())
204
+
205
+ def to_dict(self, include_methods=False, include_ndim=False):
206
+ """Converts the Struct to a dictionary.
207
+
208
+ Args:
209
+ include_methods (bool): Whether any struct methods should be included
210
+ in the result dictionary under the key '__struct_methods'.
211
+
212
+ Returns:
213
+ Dict: The result dictionary.
214
+ """
215
+ res_dict = {
216
+ k: (
217
+ v.to_dict(include_methods=include_methods, include_ndim=include_ndim)
218
+ if isinstance(v, Struct)
219
+ else v.to_list() if isinstance(v, Matrix) else v
220
+ )
221
+ for k, v in self.__entries.items()
222
+ }
223
+ if include_methods:
224
+ res_dict["__struct_methods"] = self.__methods
225
+ if include_ndim:
226
+ res_dict["__matrix_ndim"] = dict()
227
+ for k, v in self.__entries.items():
228
+ if isinstance(v, Matrix):
229
+ res_dict["__matrix_ndim"][k] = v.ndim
230
+ return res_dict
231
+
232
+ @classmethod
233
+ @python_scope
234
+ def field(
235
+ cls,
236
+ members,
237
+ methods={},
238
+ shape=None,
239
+ name="<Struct>",
240
+ offset=None,
241
+ needs_grad=False,
242
+ needs_dual=False,
243
+ layout=Layout.AOS,
244
+ ):
245
+ """Creates a :class:`~gstaichi.StructField` with each element
246
+ has this struct as its type.
247
+
248
+ Args:
249
+ members (dict): a dict, each item is like `name: type`.
250
+ methods (dict): a dict of methods that should be included with
251
+ the field. Each struct item of the field will have the
252
+ methods as instance functions.
253
+ shape (Tuple[int]): width and height of the field.
254
+ offset (Tuple[int]): offset of the indices of the created field.
255
+ For example if `offset=(-10, -10)` the indices of the field
256
+ will start at `(-10, -10)`, not `(0, 0)`.
257
+ needs_grad (bool): enabling grad field (reverse mode autodiff) or not.
258
+ needs_dual (bool): enabling dual field (forward mode autodiff) or not.
259
+ layout: AOS or SOA.
260
+
261
+ Example:
262
+
263
+ >>> vec3 = ti.types.vector(3, ti.f32)
264
+ >>> sphere = {"center": vec3, "radius": float}
265
+ >>> F = ti.Struct.field(sphere, shape=(3, 3))
266
+ >>> F
267
+ {'center': array([[[0., 0., 0.],
268
+ [0., 0., 0.],
269
+ [0., 0., 0.]],
270
+
271
+ [[0., 0., 0.],
272
+ [0., 0., 0.],
273
+ [0., 0., 0.]],
274
+
275
+ [[0., 0., 0.],
276
+ [0., 0., 0.],
277
+ [0., 0., 0.]]], dtype=float32), 'radius': array([[0., 0., 0.],
278
+ [0., 0., 0.],
279
+ [0., 0., 0.]], dtype=float32)}
280
+ """
281
+
282
+ if shape is None and offset is not None:
283
+ raise GsTaichiSyntaxError("shape cannot be None when offset is being set")
284
+
285
+ field_dict = {}
286
+
287
+ for key, dtype in members.items():
288
+ field_name = name + "." + key
289
+ if isinstance(dtype, CompoundType):
290
+ if isinstance(dtype, StructType):
291
+ field_dict[key] = dtype.field(
292
+ shape=None,
293
+ name=field_name,
294
+ offset=offset,
295
+ needs_grad=needs_grad,
296
+ needs_dual=needs_dual,
297
+ )
298
+ else:
299
+ field_dict[key] = dtype.field(
300
+ shape=None,
301
+ name=field_name,
302
+ offset=offset,
303
+ needs_grad=needs_grad,
304
+ needs_dual=needs_dual,
305
+ ndim=getattr(dtype, "ndim", 2),
306
+ )
307
+ else:
308
+ field_dict[key] = impl.field(
309
+ dtype,
310
+ shape=None,
311
+ name=field_name,
312
+ offset=offset,
313
+ needs_grad=needs_grad,
314
+ needs_dual=needs_dual,
315
+ )
316
+
317
+ if shape is not None:
318
+ if isinstance(shape, numbers.Number):
319
+ shape = (shape,)
320
+ if isinstance(offset, numbers.Number):
321
+ offset = (offset,)
322
+
323
+ if offset is not None and len(shape) != len(offset):
324
+ raise GsTaichiSyntaxError(
325
+ f"The dimensionality of shape and offset must be the same ({len(shape)} != {len(offset)})"
326
+ )
327
+ dim = len(shape)
328
+ if layout == Layout.SOA:
329
+ for e in field_dict.values():
330
+ impl.root.dense(impl.index_nd(dim), shape).place(e, offset=offset)
331
+ if needs_grad:
332
+ for e in field_dict.values():
333
+ impl.root.dense(impl.index_nd(dim), shape).place(e.grad, offset=offset)
334
+ if needs_dual:
335
+ for e in field_dict.values():
336
+ impl.root.dense(impl.index_nd(dim), shape).place(e.dual, offset=offset)
337
+ else:
338
+ impl.root.dense(impl.index_nd(dim), shape).place(*tuple(field_dict.values()), offset=offset)
339
+ if needs_grad:
340
+ grads = tuple(e.grad for e in field_dict.values())
341
+ impl.root.dense(impl.index_nd(dim), shape).place(*grads, offset=offset)
342
+
343
+ if needs_dual:
344
+ duals = tuple(e.dual for e in field_dict.values())
345
+ impl.root.dense(impl.index_nd(dim), shape).place(*duals, offset=offset)
346
+
347
+ return StructField(field_dict, methods, name=name)
348
+
349
+
350
+ class _IntermediateStruct(Struct):
351
+ """Intermediate struct class for compiler internal use only.
352
+
353
+ Args:
354
+ entries (Dict[str, Union[Expr, Matrix, Struct]]): keys and values for struct members.
355
+ Any methods included under the key '__struct_methods' will be applied to each
356
+ struct instance.
357
+ """
358
+
359
+ def __init__(self, entries):
360
+ assert isinstance(entries, dict)
361
+ self._Struct__methods = entries.pop("__struct_methods", {})
362
+ self._register_methods()
363
+ self._Struct__entries = entries
364
+ self._register_members()
365
+
366
+
367
+ class StructField(Field):
368
+ """GsTaichi struct field with SNode implementation.
369
+
370
+ Instead of directly constraining Expr entries, the StructField object
371
+ directly hosts members as `Field` instances to support nested structs.
372
+
373
+ Args:
374
+ field_dict (Dict[str, Field]): Struct field members.
375
+ struct_methods (Dict[str, callable]): Dictionary of functions to apply
376
+ to each struct instance in the field.
377
+ name (string, optional): The custom name of the field.
378
+ """
379
+
380
+ def __init__(self, field_dict, struct_methods, name=None, is_primal=True):
381
+ # will not call Field initializer
382
+ self.field_dict = field_dict
383
+ self.struct_methods = struct_methods
384
+ self.name = name
385
+ self.grad = None
386
+ self.dual = None
387
+ if is_primal:
388
+ grad_field_dict = {}
389
+ for k, v in self.field_dict.items():
390
+ grad_field_dict[k] = v.grad
391
+ self.grad = StructField(grad_field_dict, struct_methods, name + ".grad", is_primal=False)
392
+
393
+ dual_field_dict = {}
394
+ for k, v in self.field_dict.items():
395
+ dual_field_dict[k] = v.dual
396
+ self.dual = StructField(dual_field_dict, struct_methods, name + ".dual", is_primal=False)
397
+ self._register_fields()
398
+
399
+ @property
400
+ def keys(self):
401
+ """Returns the list of names of the field members.
402
+
403
+ Example::
404
+
405
+ >>> f1 = ti.Vector.field(3, ti.f32, shape=(3, 3))
406
+ >>> f2 = ti.field(ti.f32, shape=(3, 3))
407
+ >>> F = ti.StructField({"center": f1, "radius": f2})
408
+ >>> F.keys
409
+ ['center', 'radius']
410
+ """
411
+ return list(self.field_dict.keys())
412
+
413
+ @property
414
+ def _members(self):
415
+ return list(self.field_dict.values())
416
+
417
+ @property
418
+ def _items(self):
419
+ return self.field_dict.items()
420
+
421
+ @staticmethod
422
+ def _make_getter(key):
423
+ def getter(self):
424
+ """Get an entry from custom struct by name."""
425
+ return self.field_dict[key]
426
+
427
+ return getter
428
+
429
+ @staticmethod
430
+ def _make_setter(key):
431
+ @python_scope
432
+ def setter(self, value):
433
+ self.field_dict[key] = value
434
+
435
+ return setter
436
+
437
+ def _register_fields(self):
438
+ for k in self.keys:
439
+ setattr(self, k, self.field_dict[k])
440
+
441
+ def _get_field_members(self):
442
+ """Gets A flattened list of all struct elements.
443
+
444
+ Returns:
445
+ A list of struct elements.
446
+ """
447
+ field_members = []
448
+ for m in self._members:
449
+ assert isinstance(m, Field)
450
+ field_members += m._get_field_members()
451
+ return field_members
452
+
453
+ @property
454
+ def _snode(self):
455
+ """Gets representative SNode for info purposes.
456
+
457
+ Returns:
458
+ SNode: Representative SNode (SNode of first field member).
459
+ """
460
+ return self._members[0]._snode
461
+
462
+ def _loop_range(self):
463
+ """Gets SNode of representative field member for loop range info.
464
+
465
+ Returns:
466
+ gstaichi_python.SNode: SNode of representative (first) field member.
467
+ """
468
+ return self._members[0]._loop_range()
469
+
470
+ @python_scope
471
+ def copy_from(self, other):
472
+ """Copies all elements from another field.
473
+
474
+ The shape of the other field needs to be the same as `self`.
475
+
476
+ Args:
477
+ other (Field): The source field.
478
+ """
479
+ assert isinstance(other, Field)
480
+ assert set(self.keys) == set(other.keys)
481
+ for k in self.keys:
482
+ self.field_dict[k].copy_from(other.get_member_field(k))
483
+
484
+ @python_scope
485
+ def fill(self, val):
486
+ """Fills this struct field with a specified value.
487
+
488
+ Args:
489
+ val (Union[int, float]): Value to fill.
490
+ """
491
+ for v in self._members:
492
+ v.fill(val)
493
+
494
+ def _initialize_host_accessors(self):
495
+ for v in self._members:
496
+ v._initialize_host_accessors()
497
+
498
+ def get_member_field(self, key):
499
+ """Creates a ScalarField using a specific field member.
500
+
501
+ Args:
502
+ key (str): Specified key of the field member.
503
+
504
+ Returns:
505
+ ScalarField: The result ScalarField.
506
+ """
507
+ return self.field_dict[key]
508
+
509
+ @python_scope
510
+ def from_numpy(self, array_dict):
511
+ """Copies the data from a set of `numpy.array` into this field.
512
+
513
+ The argument `array_dict` must be a dictionay-like object, it
514
+ contains all the keys in this field and the copying process
515
+ between corresponding items can be performed.
516
+ """
517
+ for k, v in self._items:
518
+ v.from_numpy(array_dict[k])
519
+
520
+ @python_scope
521
+ def from_torch(self, array_dict):
522
+ """Copies the data from a set of `torch.tensor` into this field.
523
+
524
+ The argument `array_dict` must be a dictionay-like object, it
525
+ contains all the keys in this field and the copying process
526
+ between corresponding items can be performed.
527
+ """
528
+ for k, v in self._items:
529
+ v.from_torch(array_dict[k])
530
+
531
+ @python_scope
532
+ def from_paddle(self, array_dict):
533
+ """Copies the data from a set of `paddle.Tensor` into this field.
534
+
535
+ The argument `array_dict` must be a dictionay-like object, it
536
+ contains all the keys in this field and the copying process
537
+ between corresponding items can be performed.
538
+ """
539
+ for k, v in self._items:
540
+ v.from_paddle(array_dict[k])
541
+
542
+ @python_scope
543
+ def to_numpy(self):
544
+ """Converts the Struct field instance to a dictionary of NumPy arrays.
545
+
546
+ The dictionary may be nested when converting nested structs.
547
+
548
+ Returns:
549
+ Dict[str, Union[numpy.ndarray, Dict]]: The result NumPy array.
550
+ """
551
+ return {k: v.to_numpy() for k, v in self._items}
552
+
553
+ @python_scope
554
+ def to_torch(self, device=None):
555
+ """Converts the Struct field instance to a dictionary of PyTorch tensors.
556
+
557
+ The dictionary may be nested when converting nested structs.
558
+
559
+ Args:
560
+ device (torch.device, optional): The
561
+ desired device of returned tensor.
562
+
563
+ Returns:
564
+ Dict[str, Union[torch.Tensor, Dict]]: The result
565
+ PyTorch tensor.
566
+ """
567
+ return {k: v.to_torch(device=device) for k, v in self._items}
568
+
569
+ @python_scope
570
+ def to_paddle(self, place=None):
571
+ """Converts the Struct field instance to a dictionary of Paddle tensors.
572
+
573
+ The dictionary may be nested when converting nested structs.
574
+
575
+ Args:
576
+ place (paddle.CPUPlace()/CUDAPlace(n), optional): The
577
+ desired place of returned tensor.
578
+
579
+ Returns:
580
+ Dict[str, Union[paddle.Tensor, Dict]]: The result
581
+ Paddle tensor.
582
+ """
583
+ return {k: v.to_paddle(place=place) for k, v in self._items}
584
+
585
+ @python_scope
586
+ def __setitem__(self, indices, element):
587
+ self._initialize_host_accessors()
588
+ self[indices]._set_entries(element)
589
+
590
+ @python_scope
591
+ def __getitem__(self, indices):
592
+ self._initialize_host_accessors()
593
+ # scalar fields does not instantiate SNodeHostAccess by default
594
+ entries = {
595
+ k: v._host_access(self._pad_key(indices))[0] if isinstance(v, ScalarField) else v[indices]
596
+ for k, v in self._items
597
+ }
598
+ entries["__struct_methods"] = self.struct_methods
599
+ return Struct(entries)
600
+
601
+
602
+ class StructType(CompoundType):
603
+ def __init__(self, **kwargs):
604
+ self.members = {}
605
+ self.methods = {}
606
+ elements = []
607
+ for k, dtype in kwargs.items():
608
+ if k == "__struct_methods":
609
+ self.methods = dtype
610
+ elif isinstance(dtype, StructType):
611
+ self.members[k] = dtype
612
+ elements.append([dtype.dtype, k])
613
+ elif isinstance(dtype, MatrixType):
614
+ self.members[k] = dtype
615
+ elements.append([dtype.tensor_type, k])
616
+ else:
617
+ dtype = cook_dtype(dtype)
618
+ self.members[k] = dtype
619
+ elements.append([dtype, k])
620
+ self.dtype = _ti_core.get_type_factory_instance().get_struct_type(elements)
621
+
622
+ def __call__(self, *args, **kwargs):
623
+ """Create an instance of this struct type."""
624
+ d = {}
625
+ items = self.members.items()
626
+ # iterate over the members of this struct
627
+ for index, pair in enumerate(items):
628
+ name, dtype = pair # (member name, member type)
629
+ if index < len(args): # set from args
630
+ data = args[index]
631
+ else: # set from kwargs
632
+ data = kwargs.get(name, 0)
633
+
634
+ # If dtype is CompoundType and data is a scalar, it cannot be
635
+ # casted in the self.cast call later. We need an initialization here.
636
+ if isinstance(dtype, CompoundType) and not isinstance(data, (dict, Struct)):
637
+ data = dtype(data)
638
+
639
+ d[name] = data
640
+
641
+ entries = Struct(d)
642
+ entries._Struct__dtype = self.dtype
643
+ struct = self.cast(entries)
644
+ struct._Struct__dtype = self.dtype
645
+ return struct
646
+
647
+ def __instancecheck__(self, instance):
648
+ if not isinstance(instance, Struct):
649
+ return False
650
+ if list(self.members.keys()) != list(instance._Struct__entries.keys()):
651
+ return False
652
+ if (
653
+ hasattr(instance, "_Struct__dtype")
654
+ and instance._Struct__dtype is not None
655
+ and instance._Struct__dtype != self.dtype
656
+ ):
657
+ return False
658
+ for index, (name, dtype) in enumerate(self.members.items()):
659
+ val = instance._members[index]
660
+ if isinstance(dtype, StructType):
661
+ if not isinstance(val, dtype):
662
+ return False
663
+ elif isinstance(dtype, MatrixType):
664
+ if isinstance(val, Expr):
665
+ if not val.is_tensor():
666
+ return False
667
+ if val.get_shape() != dtype.get_shape():
668
+ return False
669
+ elif dtype in primitive_types.integer_types:
670
+ if isinstance(val, Expr):
671
+ if val.is_tensor() or val.is_struct() or val.element_type() not in primitive_types.integer_types:
672
+ return False
673
+ elif not isinstance(val, (int, np.integer)):
674
+ return False
675
+ elif dtype in primitive_types.real_types:
676
+ if isinstance(val, Expr):
677
+ if val.is_tensor() or val.is_struct() or val.element_type() not in primitive_types.real_types:
678
+ return False
679
+ elif not isinstance(val, (float, np.floating)):
680
+ return False
681
+ return True
682
+
683
+ def from_gstaichi_object(self, func_ret, ret_index=()):
684
+ d = {}
685
+ items = self.members.items()
686
+ for index, pair in enumerate(items):
687
+ name, dtype = pair
688
+ if isinstance(dtype, CompoundType):
689
+ d[name] = dtype.from_gstaichi_object(func_ret, ret_index + (index,))
690
+ else:
691
+ d[name] = expr.Expr(
692
+ _ti_core.make_get_element_expr(
693
+ func_ret.ptr,
694
+ ret_index + (index,),
695
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
696
+ )
697
+ )
698
+ d["__struct_methods"] = self.methods
699
+
700
+ struct = Struct(d)
701
+ struct._Struct__dtype = self.dtype
702
+ return struct
703
+
704
+ def from_kernel_struct_ret(self, launch_ctx, ret_index=()):
705
+ d = {}
706
+ items = self.members.items()
707
+ for index, pair in enumerate(items):
708
+ name, dtype = pair
709
+ if isinstance(dtype, CompoundType):
710
+ d[name] = dtype.from_kernel_struct_ret(launch_ctx, ret_index + (index,))
711
+ else:
712
+ if dtype in primitive_types.integer_types:
713
+ if is_signed(cook_dtype(dtype)):
714
+ d[name] = launch_ctx.get_struct_ret_int(ret_index + (index,))
715
+ else:
716
+ d[name] = launch_ctx.get_struct_ret_uint(ret_index + (index,))
717
+ elif dtype in primitive_types.real_types:
718
+ d[name] = launch_ctx.get_struct_ret_float(ret_index + (index,))
719
+ else:
720
+ raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index + (index, )}")
721
+ d["__struct_methods"] = self.methods
722
+
723
+ struct = Struct(d)
724
+ struct._Struct__dtype = self.dtype
725
+ return struct
726
+
727
+ def set_kernel_struct_args(self, struct, launch_ctx, ret_index=()):
728
+ # TODO: move this to class Struct after we add dtype to Struct
729
+ items = self.members.items()
730
+ for index, pair in enumerate(items):
731
+ name, dtype = pair
732
+ if isinstance(dtype, CompoundType):
733
+ dtype.set_kernel_struct_args(struct[name], launch_ctx, ret_index + (index,))
734
+ else:
735
+ if dtype in primitive_types.integer_types:
736
+ if is_signed(cook_dtype(dtype)):
737
+ launch_ctx.set_struct_arg_int(ret_index + (index,), struct[name])
738
+ else:
739
+ launch_ctx.set_struct_arg_uint(ret_index + (index,), struct[name])
740
+ elif dtype in primitive_types.real_types:
741
+ launch_ctx.set_struct_arg_float(ret_index + (index,), struct[name])
742
+ else:
743
+ raise GsTaichiRuntimeTypeError(f"Invalid argument type on index={ret_index + (index, )}")
744
+
745
+ def set_argpack_struct_args(self, struct, argpack, ret_index=()):
746
+ # TODO: move this to class Struct after we add dtype to Struct
747
+ items = self.members.items()
748
+ for index, pair in enumerate(items):
749
+ name, dtype = pair
750
+ if isinstance(dtype, CompoundType):
751
+ dtype.set_kernel_struct_args(struct[name], argpack, ret_index + (index,))
752
+ else:
753
+ if dtype in primitive_types.integer_types:
754
+ if is_signed(cook_dtype(dtype)):
755
+ argpack.set_arg_int(ret_index + (index,), struct[name])
756
+ else:
757
+ argpack.set_arg_uint(ret_index + (index,), struct[name])
758
+ elif dtype in primitive_types.real_types:
759
+ argpack.set_arg_float(ret_index + (index,), struct[name])
760
+ else:
761
+ raise GsTaichiRuntimeTypeError(f"Invalid argument type on index={ret_index + (index, )}")
762
+
763
+ def cast(self, struct):
764
+ # sanity check members
765
+ if self.members.keys() != struct._Struct__entries.keys():
766
+ raise GsTaichiSyntaxError("Incompatible arguments for custom struct members!")
767
+ entries = {}
768
+ for k, dtype in self.members.items():
769
+ if isinstance(dtype, MatrixType):
770
+ entries[k] = dtype(struct._Struct__entries[k])
771
+ elif isinstance(dtype, CompoundType):
772
+ entries[k] = dtype.cast(struct._Struct__entries[k])
773
+ else:
774
+ if in_python_scope():
775
+ v = struct._Struct__entries[k]
776
+ entries[k] = int(v) if dtype in primitive_types.integer_types else float(v)
777
+ else:
778
+ entries[k] = ops.cast(struct._Struct__entries[k], dtype)
779
+ entries["__struct_methods"] = self.methods
780
+ struct = Struct(entries)
781
+ struct._Struct__dtype = self.dtype
782
+ return struct
783
+
784
+ def filled_with_scalar(self, value):
785
+ entries = {}
786
+ for k, dtype in self.members.items():
787
+ if isinstance(dtype, MatrixType):
788
+ entries[k] = dtype(value)
789
+ elif isinstance(dtype, CompoundType):
790
+ entries[k] = dtype.filled_with_scalar(value)
791
+ else:
792
+ entries[k] = value
793
+ entries["__struct_methods"] = self.methods
794
+ struct = Struct(entries)
795
+ struct._Struct__dtype = self.dtype
796
+ return struct
797
+
798
+ def field(self, **kwargs):
799
+ return Struct.field(self.members, self.methods, **kwargs)
800
+
801
+ def __str__(self):
802
+ """Python scope struct type print support."""
803
+ item_str = ", ".join([str(k) + "=" + str(v) for k, v in self.members.items()])
804
+ item_str += f", struct_methods={self.methods}"
805
+ return f"<ti.StructType {item_str}>"
806
+
807
+
808
+ def dataclass(cls):
809
+ """Converts a class with field annotations and methods into a gstaichi struct type.
810
+
811
+ This will return a normal custom struct type, with the functions added to it.
812
+ Struct fields can be generated in the normal way from the struct type.
813
+ Functions in the class can be run on the struct instance.
814
+
815
+ This class decorator inspects the class for annotations and methods and
816
+ 1. Sets the annotations as fields for the struct
817
+ 2. Attaches the methods to the struct type
818
+
819
+ Example::
820
+
821
+ >>> @ti.dataclass
822
+ >>> class Sphere:
823
+ >>> center: vec3
824
+ >>> radius: ti.f32
825
+ >>>
826
+ >>> @ti.func
827
+ >>> def area(self):
828
+ >>> return 4 * 3.14 * self.radius * self.radius
829
+ >>>
830
+ >>> my_spheres = Sphere.field(shape=(n, ))
831
+ >>> my_sphere[2].area()
832
+
833
+ Args:
834
+ cls (Class): the class with annotations and methods to convert to a struct
835
+
836
+ Returns:
837
+ A gstaichi struct with the annotations as fields
838
+ and methods from the class attached.
839
+ """
840
+ # save the annotation fields for the struct
841
+ fields = getattr(cls, "__annotations__", {})
842
+ # raise error if there are default values
843
+ for k in fields.keys():
844
+ if hasattr(cls, k):
845
+ raise GsTaichiSyntaxError("Default value in @dataclass is not supported.")
846
+ # get the class methods to be attached to the struct types
847
+ fields["__struct_methods"] = {
848
+ attribute: getattr(cls, attribute)
849
+ for attribute in dir(cls)
850
+ if callable(getattr(cls, attribute)) and not attribute.startswith("__")
851
+ }
852
+ return StructType(**fields)
853
+
854
+
855
+ __all__ = ["Struct", "StructField", "dataclass"]