gstaichi 2.1.1rc3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/field.py ADDED
@@ -0,0 +1,428 @@
1
+ # type: ignore
2
+
3
+ import gstaichi.lang
4
+ from gstaichi._lib import core as _ti_core
5
+ from gstaichi._logging import warn
6
+ from gstaichi.lang import impl
7
+ from gstaichi.lang.exception import GsTaichiSyntaxError
8
+ from gstaichi.lang.util import (
9
+ in_python_scope,
10
+ python_scope,
11
+ to_numpy_type,
12
+ to_pytorch_type,
13
+ )
14
+
15
+
16
+ class Field:
17
+ """GsTaichi field class.
18
+
19
+ A field is constructed by a list of field members.
20
+ For example, a scalar field has 1 field member, while a 3x3 matrix field has 9 field members.
21
+ A field member is a Python Expr wrapping a C++ FieldExpression.
22
+
23
+ Args:
24
+ vars (List[Expr]): Field members.
25
+ """
26
+
27
+ def __init__(self, _vars):
28
+ assert all(_vars)
29
+ self.vars = _vars
30
+ self.host_accessors = None
31
+ self.grad = None
32
+ self.dual = None
33
+
34
+ @property
35
+ def snode(self):
36
+ """Gets representative SNode for info purposes.
37
+
38
+ Returns:
39
+ SNode: Representative SNode (SNode of first field member).
40
+ """
41
+ return self._snode
42
+
43
+ @property
44
+ def _snode(self):
45
+ """Gets representative SNode for info purposes.
46
+
47
+ Returns:
48
+ SNode: Representative SNode (SNode of first field member).
49
+ """
50
+ return gstaichi.lang.snode.SNode(self.vars[0].ptr.snode())
51
+
52
+ @property
53
+ def shape(self):
54
+ """Gets field shape.
55
+
56
+ Returns:
57
+ Tuple[Int]: Field shape.
58
+ """
59
+ return self._snode.shape
60
+
61
+ @property
62
+ def dtype(self):
63
+ """Gets data type of each individual value.
64
+
65
+ Returns:
66
+ DataType: Data type of each individual value.
67
+ """
68
+ return self._snode._dtype
69
+
70
+ @property
71
+ def _name(self):
72
+ """Gets field name.
73
+
74
+ Returns:
75
+ str: Field name.
76
+ """
77
+ return self._snode._name
78
+
79
+ def parent(self, n=1):
80
+ """Gets an ancestor of the representative SNode in the SNode tree.
81
+
82
+ Args:
83
+ n (int): the number of levels going up from the representative SNode.
84
+
85
+ Returns:
86
+ SNode: The n-th parent of the representative SNode.
87
+ """
88
+ return self.snode.parent(n)
89
+
90
+ def _get_field_members(self):
91
+ """Gets field members.
92
+
93
+ Returns:
94
+ List[Expr]: Field members.
95
+ """
96
+ return self.vars
97
+
98
+ def _loop_range(self):
99
+ """Gets SNode of representative field member for loop range info.
100
+
101
+ Returns:
102
+ gstaichi_python.SNode: SNode of representative (first) field member.
103
+ """
104
+ return self.vars[0].ptr.snode()
105
+
106
+ def _set_grad(self, grad):
107
+ """Sets corresponding grad field (reverse mode).
108
+ Args:
109
+ grad (Field): Corresponding grad field.
110
+ """
111
+ self.grad = grad
112
+
113
+ def _set_dual(self, dual):
114
+ """Sets corresponding dual field (forward mode).
115
+
116
+ Args:
117
+ dual (Field): Corresponding dual field.
118
+ """
119
+ self.dual = dual
120
+
121
+ @python_scope
122
+ def fill(self, val):
123
+ """Fills `self` with a specific value.
124
+
125
+ Args:
126
+ val (Union[int, float]): Value to fill.
127
+ """
128
+ raise NotImplementedError()
129
+
130
+ @python_scope
131
+ def to_numpy(self, dtype=None):
132
+ """Converts `self` to a numpy array.
133
+
134
+ Args:
135
+ dtype (DataType, optional): The desired data type of returned numpy array.
136
+
137
+ Returns:
138
+ numpy.ndarray: The result numpy array.
139
+ """
140
+ raise NotImplementedError()
141
+
142
+ @python_scope
143
+ def to_torch(self, device=None):
144
+ """Converts `self` to a torch tensor.
145
+
146
+ Args:
147
+ device (torch.device, optional): The desired device of returned tensor.
148
+
149
+ Returns:
150
+ torch.tensor: The result torch tensor.
151
+ """
152
+ raise NotImplementedError()
153
+
154
+ @python_scope
155
+ def from_numpy(self, arr):
156
+ """Loads all elements from a numpy array.
157
+
158
+ The shape of the numpy array needs to be the same as `self`.
159
+
160
+ Args:
161
+ arr (numpy.ndarray): The source numpy array.
162
+ """
163
+ raise NotImplementedError()
164
+
165
+ @python_scope
166
+ def _from_external_arr(self, arr):
167
+ raise NotImplementedError()
168
+
169
+ @python_scope
170
+ def from_torch(self, arr):
171
+ """Loads all elements from a torch tensor.
172
+
173
+ The shape of the torch tensor needs to be the same as `self`.
174
+
175
+ Args:
176
+ arr (torch.tensor): The source torch tensor.
177
+ """
178
+ self._from_external_arr(arr.contiguous())
179
+
180
+ @python_scope
181
+ def copy_from(self, other):
182
+ """Copies all elements from another field.
183
+
184
+ The shape of the other field needs to be the same as `self`.
185
+
186
+ Args:
187
+ other (Field): The source field.
188
+ """
189
+ if not isinstance(other, Field):
190
+ raise TypeError("Cannot copy from a non-field object")
191
+ if self.shape != other.shape:
192
+ raise ValueError(f"ti.field shape {self.shape} does not match" f" the source field shape {other.shape}")
193
+ from gstaichi._kernels import tensor_to_tensor # pylint: disable=C0415
194
+
195
+ tensor_to_tensor(self, other)
196
+
197
+ @python_scope
198
+ def __setitem__(self, key, value):
199
+ """Sets field element in Python scope.
200
+
201
+ Args:
202
+ key (Union[List[int], int, None]): Coordinates of the field element.
203
+ value (element type): Value to set.
204
+ """
205
+ raise NotImplementedError()
206
+
207
+ @python_scope
208
+ def __getitem__(self, key):
209
+ """Gets field element in Python scope.
210
+
211
+ Args:
212
+ key (Union[List[int], int, None]): Coordinates of the field element.
213
+
214
+ Returns:
215
+ element type: Value retrieved.
216
+ """
217
+ raise NotImplementedError()
218
+
219
+ def __str__(self):
220
+ if gstaichi.lang.impl.inside_kernel():
221
+ return self.__repr__() # make pybind11 happy, see Matrix.__str__
222
+ if self._snode.ptr is None:
223
+ return "<Field: Definition of this field is incomplete>"
224
+ return str(self.to_numpy())
225
+
226
+ def _pad_key(self, key):
227
+ if key is None:
228
+ key = ()
229
+ if not isinstance(key, (tuple, list)):
230
+ key = (key,)
231
+
232
+ if len(key) != len(self.shape):
233
+ raise AssertionError("Slicing is not supported on ti.field")
234
+
235
+ return key + ((0,) * (_ti_core.get_max_num_indices() - len(key)))
236
+
237
+ def _initialize_host_accessors(self):
238
+ if self.host_accessors:
239
+ return
240
+ gstaichi.lang.impl.get_runtime().materialize()
241
+ self.host_accessors = [SNodeHostAccessor(e.ptr.snode()) for e in self.vars]
242
+
243
+ def _host_access(self, key):
244
+ return [SNodeHostAccess(e, key) for e in self.host_accessors]
245
+
246
+ def __iter__(self):
247
+ raise NotImplementedError("Struct for is only available in GsTaichi scope.")
248
+
249
+
250
+ class ScalarField(Field):
251
+ """GsTaichi scalar field with SNode implementation.
252
+
253
+ Args:
254
+ var (Expr): Field member.
255
+ """
256
+
257
+ def __init__(self, var):
258
+ super().__init__([var])
259
+
260
+ def fill(self, val):
261
+ """Fills this scalar field with a specified value."""
262
+ if in_python_scope():
263
+ from gstaichi._kernels import fill_field # pylint: disable=C0415
264
+
265
+ fill_field(self, val)
266
+ else:
267
+ from gstaichi._funcs import ( # pylint: disable=C0415
268
+ field_fill_gstaichi_scope, # pylint: disable=C0415
269
+ )
270
+
271
+ field_fill_gstaichi_scope(self, val)
272
+
273
+ @python_scope
274
+ def to_numpy(self, dtype=None):
275
+ """Converts this field to a `numpy.ndarray`."""
276
+ if self.parent()._snode.ptr.type == _ti_core.SNodeType.dynamic:
277
+ warn(
278
+ "You are trying to convert a dynamic snode to a numpy array, be aware that inactive items in the snode will be converted to zeros in the resulting array."
279
+ )
280
+ if dtype is None:
281
+ dtype = to_numpy_type(self.dtype)
282
+ import numpy as np # pylint: disable=C0415
283
+
284
+ arr = np.zeros(shape=self.shape, dtype=dtype)
285
+ from gstaichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
286
+
287
+ tensor_to_ext_arr(self, arr)
288
+ gstaichi.lang.runtime_ops.sync()
289
+ return arr
290
+
291
+ @python_scope
292
+ def to_torch(self, device=None):
293
+ """Converts this field to a `torch.tensor`."""
294
+ import torch # pylint: disable=C0415
295
+
296
+ # pylint: disable=E1101
297
+ arr = torch.zeros(size=self.shape, dtype=to_pytorch_type(self.dtype), device=device)
298
+ from gstaichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
299
+
300
+ tensor_to_ext_arr(self, arr)
301
+ gstaichi.lang.runtime_ops.sync()
302
+ return arr
303
+
304
+ @python_scope
305
+ def _from_external_arr(self, arr):
306
+ if len(self.shape) != len(arr.shape):
307
+ raise ValueError(f"ti.field shape {self.shape} does not match" f" the numpy array shape {arr.shape}")
308
+ for i, _ in enumerate(self.shape):
309
+ if self.shape[i] != arr.shape[i]:
310
+ raise ValueError(f"ti.field shape {self.shape} does not match" f" the numpy array shape {arr.shape}")
311
+ from gstaichi._kernels import ext_arr_to_tensor # pylint: disable=C0415
312
+
313
+ ext_arr_to_tensor(arr, self)
314
+ gstaichi.lang.runtime_ops.sync()
315
+
316
+ @python_scope
317
+ def from_numpy(self, arr):
318
+ """Copies the data from a `numpy.ndarray` into this field."""
319
+ if not arr.flags.c_contiguous:
320
+ import numpy as np # pylint: disable=C0415
321
+
322
+ arr = np.ascontiguousarray(arr)
323
+ self._from_external_arr(arr)
324
+
325
+ @python_scope
326
+ def __setitem__(self, key, value):
327
+ self._initialize_host_accessors()
328
+ self.host_accessors[0].setter(value, *self._pad_key(key))
329
+
330
+ @python_scope
331
+ def __getitem__(self, key):
332
+ self._initialize_host_accessors()
333
+ # Check for potential slicing behaviour
334
+ # for instance: x[0, :]
335
+ padded_key = self._pad_key(key)
336
+ import numpy as np # pylint: disable=C0415
337
+
338
+ for key in padded_key:
339
+ if not isinstance(key, (int, np.integer)):
340
+ raise TypeError(
341
+ f"Detected illegal element of type: {type(key)}. "
342
+ f"Please be aware that slicing a ti.field is not supported so far."
343
+ )
344
+ return self.host_accessors[0].getter(*padded_key)
345
+
346
+ def __repr__(self):
347
+ # make interactive shell happy, prevent materialization
348
+ return "<ti.field>"
349
+
350
+
351
+ class SNodeHostAccessor:
352
+ def __init__(self, snode):
353
+ if _ti_core.is_real(snode.data_type()):
354
+ write_func = snode.write_float
355
+ read_func = snode.read_float
356
+ else:
357
+
358
+ def write_func(key, value):
359
+ if value >= 0:
360
+ snode.write_uint(key, value)
361
+ else:
362
+ snode.write_int(key, value)
363
+
364
+ if _ti_core.is_signed(snode.data_type()):
365
+ read_func = snode.read_int
366
+ else:
367
+ read_func = snode.read_uint
368
+
369
+ def getter(*key):
370
+ assert len(key) == _ti_core.get_max_num_indices()
371
+ return read_func(key)
372
+
373
+ def setter(value, *key):
374
+ assert len(key) == _ti_core.get_max_num_indices()
375
+ write_func(key, value)
376
+ # same as above
377
+ if (
378
+ impl.get_runtime().target_tape
379
+ and impl.get_runtime().target_tape.grad_checker
380
+ and not impl.get_runtime().grad_replaced
381
+ ):
382
+ for x in impl.get_runtime().target_tape.grad_checker.to_check:
383
+ assert snode != x.snode.ptr, "Overwritten is prohibitive when doing grad check."
384
+ impl.get_runtime().target_tape.insert(write_func, (key, value))
385
+
386
+ self.getter = getter
387
+ self.setter = setter
388
+
389
+
390
+ class SNodeHostAccess:
391
+ def __init__(self, accessor, key):
392
+ self.accessor = accessor
393
+ self.key = key
394
+
395
+
396
+ class BitpackedFields:
397
+ """GsTaichi bitpacked fields, where fields with quantized types are packed together.
398
+
399
+ Args:
400
+ max_num_bits (int): Maximum number of bits all fields inside can occupy in total. Only 32 or 64 is allowed.
401
+ """
402
+
403
+ def __init__(self, max_num_bits):
404
+ self.fields = []
405
+ self.bit_struct_type_builder = _ti_core.BitStructTypeBuilder(max_num_bits)
406
+
407
+ def place(self, *args, shared_exponent=False):
408
+ """Places a list of fields with quantized types inside.
409
+
410
+ Args:
411
+ *args (List[Field]): A list of fields with quantized types to place.
412
+ shared_exponent (bool): Whether the fields have a shared exponent.
413
+ """
414
+ if shared_exponent:
415
+ self.bit_struct_type_builder.begin_placing_shared_exponent()
416
+ count = 0
417
+ for arg in args:
418
+ assert isinstance(arg, Field)
419
+ for var in arg._get_field_members():
420
+ self.fields.append((var.ptr, self.bit_struct_type_builder.add_member(var.ptr.get_dt())))
421
+ count += 1
422
+ if shared_exponent:
423
+ self.bit_struct_type_builder.end_placing_shared_exponent()
424
+ if count <= 1:
425
+ raise GsTaichiSyntaxError("At least 2 fields need to be placed when shared_exponent=True")
426
+
427
+
428
+ __all__ = ["BitpackedFields", "Field", "ScalarField"]