gstaichi 0.1.18.dev1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  2. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  3. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  4. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  5. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  6. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  7. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  8. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  9. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  10. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  11. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  12. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  13. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  14. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  15. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  16. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  17. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  18. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  19. gstaichi-0.1.18.dev1.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  20. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
  21. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
  22. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
  23. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
  24. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  25. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
  26. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  27. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  28. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-link.lib +0 -0
  29. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  30. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  31. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  32. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  33. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools.lib +0 -0
  34. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  35. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  36. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  37. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  38. gstaichi-0.1.18.dev1.data/data/lib/glfw3.lib +0 -0
  39. gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
  40. gstaichi-0.1.18.dev1.dist-info/RECORD +198 -0
  41. gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
  42. gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
  43. gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
  44. gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
  45. taichi/CHANGELOG.md +15 -0
  46. taichi/__init__.py +44 -0
  47. taichi/__main__.py +5 -0
  48. taichi/_funcs.py +706 -0
  49. taichi/_kernels.py +420 -0
  50. taichi/_lib/__init__.py +3 -0
  51. taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
  52. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
  53. taichi/_lib/c_api/include/taichi/taichi.h +29 -0
  54. taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
  55. taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
  56. taichi/_lib/c_api/include/taichi/taichi_cuda.h +36 -0
  57. taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
  58. taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
  59. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
  60. taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
  61. taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
  62. taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
  63. taichi/_lib/c_api/runtime/slim_libdevice.10.bc +0 -0
  64. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
  65. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
  66. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
  67. taichi/_lib/core/__init__.py +0 -0
  68. taichi/_lib/core/py.typed +0 -0
  69. taichi/_lib/core/taichi_python.cp310-win_amd64.pyd +0 -0
  70. taichi/_lib/core/taichi_python.pyi +3077 -0
  71. taichi/_lib/runtime/runtime_cuda.bc +0 -0
  72. taichi/_lib/runtime/runtime_x64.bc +0 -0
  73. taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  74. taichi/_lib/utils.py +249 -0
  75. taichi/_logging.py +131 -0
  76. taichi/_main.py +552 -0
  77. taichi/_snode/__init__.py +5 -0
  78. taichi/_snode/fields_builder.py +189 -0
  79. taichi/_snode/snode_tree.py +34 -0
  80. taichi/_ti_module/__init__.py +3 -0
  81. taichi/_ti_module/cppgen.py +309 -0
  82. taichi/_ti_module/module.py +145 -0
  83. taichi/_version.py +1 -0
  84. taichi/_version_check.py +100 -0
  85. taichi/ad/__init__.py +3 -0
  86. taichi/ad/_ad.py +530 -0
  87. taichi/algorithms/__init__.py +3 -0
  88. taichi/algorithms/_algorithms.py +117 -0
  89. taichi/aot/__init__.py +12 -0
  90. taichi/aot/_export.py +28 -0
  91. taichi/aot/conventions/__init__.py +3 -0
  92. taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
  93. taichi/aot/conventions/gfxruntime140/dr.py +244 -0
  94. taichi/aot/conventions/gfxruntime140/sr.py +613 -0
  95. taichi/aot/module.py +253 -0
  96. taichi/aot/utils.py +151 -0
  97. taichi/assets/.git +1 -0
  98. taichi/assets/Go-Regular.ttf +0 -0
  99. taichi/assets/static/imgs/ti_gallery.png +0 -0
  100. taichi/examples/minimal.py +28 -0
  101. taichi/experimental.py +16 -0
  102. taichi/graph/__init__.py +3 -0
  103. taichi/graph/_graph.py +292 -0
  104. taichi/lang/__init__.py +50 -0
  105. taichi/lang/_ndarray.py +348 -0
  106. taichi/lang/_ndrange.py +152 -0
  107. taichi/lang/_texture.py +172 -0
  108. taichi/lang/_wrap_inspect.py +189 -0
  109. taichi/lang/any_array.py +99 -0
  110. taichi/lang/argpack.py +411 -0
  111. taichi/lang/ast/__init__.py +5 -0
  112. taichi/lang/ast/ast_transformer.py +1806 -0
  113. taichi/lang/ast/ast_transformer_utils.py +328 -0
  114. taichi/lang/ast/checkers.py +106 -0
  115. taichi/lang/ast/symbol_resolver.py +57 -0
  116. taichi/lang/ast/transform.py +9 -0
  117. taichi/lang/common_ops.py +310 -0
  118. taichi/lang/exception.py +80 -0
  119. taichi/lang/expr.py +180 -0
  120. taichi/lang/field.py +464 -0
  121. taichi/lang/impl.py +1246 -0
  122. taichi/lang/kernel_arguments.py +157 -0
  123. taichi/lang/kernel_impl.py +1415 -0
  124. taichi/lang/matrix.py +1877 -0
  125. taichi/lang/matrix_ops.py +341 -0
  126. taichi/lang/matrix_ops_utils.py +190 -0
  127. taichi/lang/mesh.py +687 -0
  128. taichi/lang/misc.py +807 -0
  129. taichi/lang/ops.py +1489 -0
  130. taichi/lang/runtime_ops.py +13 -0
  131. taichi/lang/shell.py +35 -0
  132. taichi/lang/simt/__init__.py +5 -0
  133. taichi/lang/simt/block.py +94 -0
  134. taichi/lang/simt/grid.py +7 -0
  135. taichi/lang/simt/subgroup.py +191 -0
  136. taichi/lang/simt/warp.py +96 -0
  137. taichi/lang/snode.py +487 -0
  138. taichi/lang/source_builder.py +150 -0
  139. taichi/lang/struct.py +855 -0
  140. taichi/lang/util.py +381 -0
  141. taichi/linalg/__init__.py +8 -0
  142. taichi/linalg/matrixfree_cg.py +310 -0
  143. taichi/linalg/sparse_cg.py +59 -0
  144. taichi/linalg/sparse_matrix.py +303 -0
  145. taichi/linalg/sparse_solver.py +123 -0
  146. taichi/math/__init__.py +11 -0
  147. taichi/math/_complex.py +204 -0
  148. taichi/math/mathimpl.py +886 -0
  149. taichi/profiler/__init__.py +6 -0
  150. taichi/profiler/kernel_metrics.py +260 -0
  151. taichi/profiler/kernel_profiler.py +592 -0
  152. taichi/profiler/memory_profiler.py +15 -0
  153. taichi/profiler/scoped_profiler.py +36 -0
  154. taichi/shaders/Circles_vk.frag +29 -0
  155. taichi/shaders/Circles_vk.vert +45 -0
  156. taichi/shaders/Circles_vk_frag.spv +0 -0
  157. taichi/shaders/Circles_vk_vert.spv +0 -0
  158. taichi/shaders/Lines_vk.frag +9 -0
  159. taichi/shaders/Lines_vk.vert +11 -0
  160. taichi/shaders/Lines_vk_frag.spv +0 -0
  161. taichi/shaders/Lines_vk_vert.spv +0 -0
  162. taichi/shaders/Mesh_vk.frag +71 -0
  163. taichi/shaders/Mesh_vk.vert +68 -0
  164. taichi/shaders/Mesh_vk_frag.spv +0 -0
  165. taichi/shaders/Mesh_vk_vert.spv +0 -0
  166. taichi/shaders/Particles_vk.frag +95 -0
  167. taichi/shaders/Particles_vk.vert +73 -0
  168. taichi/shaders/Particles_vk_frag.spv +0 -0
  169. taichi/shaders/Particles_vk_vert.spv +0 -0
  170. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  171. taichi/shaders/SceneLines_vk.frag +9 -0
  172. taichi/shaders/SceneLines_vk.vert +12 -0
  173. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  174. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  175. taichi/shaders/SetImage_vk.frag +21 -0
  176. taichi/shaders/SetImage_vk.vert +15 -0
  177. taichi/shaders/SetImage_vk_frag.spv +0 -0
  178. taichi/shaders/SetImage_vk_vert.spv +0 -0
  179. taichi/shaders/Triangles_vk.frag +16 -0
  180. taichi/shaders/Triangles_vk.vert +29 -0
  181. taichi/shaders/Triangles_vk_frag.spv +0 -0
  182. taichi/shaders/Triangles_vk_vert.spv +0 -0
  183. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  184. taichi/sparse/__init__.py +3 -0
  185. taichi/sparse/_sparse_grid.py +77 -0
  186. taichi/tools/__init__.py +12 -0
  187. taichi/tools/diagnose.py +124 -0
  188. taichi/tools/np2ply.py +364 -0
  189. taichi/tools/vtk.py +38 -0
  190. taichi/types/__init__.py +19 -0
  191. taichi/types/annotations.py +47 -0
  192. taichi/types/compound_types.py +90 -0
  193. taichi/types/enums.py +49 -0
  194. taichi/types/ndarray_type.py +147 -0
  195. taichi/types/primitive_types.py +203 -0
  196. taichi/types/quant.py +88 -0
  197. taichi/types/texture_type.py +85 -0
  198. taichi/types/utils.py +13 -0
taichi/lang/field.py ADDED
@@ -0,0 +1,464 @@
1
+ # type: ignore
2
+
3
+ import taichi.lang
4
+ from taichi._lib import core as _ti_core
5
+ from taichi._logging import warn
6
+ from taichi.lang import impl
7
+ from taichi.lang.exception import TaichiSyntaxError
8
+ from taichi.lang.util import (
9
+ in_python_scope,
10
+ python_scope,
11
+ to_numpy_type,
12
+ to_paddle_type,
13
+ to_pytorch_type,
14
+ )
15
+
16
+
17
+ class Field:
18
+ """Taichi field class.
19
+
20
+ A field is constructed by a list of field members.
21
+ For example, a scalar field has 1 field member, while a 3x3 matrix field has 9 field members.
22
+ A field member is a Python Expr wrapping a C++ FieldExpression.
23
+
24
+ Args:
25
+ vars (List[Expr]): Field members.
26
+ """
27
+
28
+ def __init__(self, _vars):
29
+ assert all(_vars)
30
+ self.vars = _vars
31
+ self.host_accessors = None
32
+ self.grad = None
33
+ self.dual = None
34
+
35
+ @property
36
+ def snode(self):
37
+ """Gets representative SNode for info purposes.
38
+
39
+ Returns:
40
+ SNode: Representative SNode (SNode of first field member).
41
+ """
42
+ return self._snode
43
+
44
+ @property
45
+ def _snode(self):
46
+ """Gets representative SNode for info purposes.
47
+
48
+ Returns:
49
+ SNode: Representative SNode (SNode of first field member).
50
+ """
51
+ return taichi.lang.snode.SNode(self.vars[0].ptr.snode())
52
+
53
+ @property
54
+ def shape(self):
55
+ """Gets field shape.
56
+
57
+ Returns:
58
+ Tuple[Int]: Field shape.
59
+ """
60
+ return self._snode.shape
61
+
62
+ @property
63
+ def dtype(self):
64
+ """Gets data type of each individual value.
65
+
66
+ Returns:
67
+ DataType: Data type of each individual value.
68
+ """
69
+ return self._snode._dtype
70
+
71
+ @property
72
+ def _name(self):
73
+ """Gets field name.
74
+
75
+ Returns:
76
+ str: Field name.
77
+ """
78
+ return self._snode._name
79
+
80
+ def parent(self, n=1):
81
+ """Gets an ancestor of the representative SNode in the SNode tree.
82
+
83
+ Args:
84
+ n (int): the number of levels going up from the representative SNode.
85
+
86
+ Returns:
87
+ SNode: The n-th parent of the representative SNode.
88
+ """
89
+ return self.snode.parent(n)
90
+
91
+ def _get_field_members(self):
92
+ """Gets field members.
93
+
94
+ Returns:
95
+ List[Expr]: Field members.
96
+ """
97
+ return self.vars
98
+
99
+ def _loop_range(self):
100
+ """Gets SNode of representative field member for loop range info.
101
+
102
+ Returns:
103
+ taichi_python.SNode: SNode of representative (first) field member.
104
+ """
105
+ return self.vars[0].ptr.snode()
106
+
107
+ def _set_grad(self, grad):
108
+ """Sets corresponding grad field (reverse mode).
109
+ Args:
110
+ grad (Field): Corresponding grad field.
111
+ """
112
+ self.grad = grad
113
+
114
+ def _set_dual(self, dual):
115
+ """Sets corresponding dual field (forward mode).
116
+
117
+ Args:
118
+ dual (Field): Corresponding dual field.
119
+ """
120
+ self.dual = dual
121
+
122
+ @python_scope
123
+ def fill(self, val):
124
+ """Fills `self` with a specific value.
125
+
126
+ Args:
127
+ val (Union[int, float]): Value to fill.
128
+ """
129
+ raise NotImplementedError()
130
+
131
+ @python_scope
132
+ def to_numpy(self, dtype=None):
133
+ """Converts `self` to a numpy array.
134
+
135
+ Args:
136
+ dtype (DataType, optional): The desired data type of returned numpy array.
137
+
138
+ Returns:
139
+ numpy.ndarray: The result numpy array.
140
+ """
141
+ raise NotImplementedError()
142
+
143
+ @python_scope
144
+ def to_torch(self, device=None):
145
+ """Converts `self` to a torch tensor.
146
+
147
+ Args:
148
+ device (torch.device, optional): The desired device of returned tensor.
149
+
150
+ Returns:
151
+ torch.tensor: The result torch tensor.
152
+ """
153
+ raise NotImplementedError()
154
+
155
+ @python_scope
156
+ def to_paddle(self, place=None):
157
+ """Converts `self` to a paddle tensor.
158
+
159
+ Args:
160
+ place (paddle.CPUPlace()/CUDAPlace(n), optional): The desired place of returned tensor.
161
+
162
+ Returns:
163
+ paddle.Tensor: The result paddle tensor.
164
+ """
165
+ raise NotImplementedError()
166
+
167
+ @python_scope
168
+ def from_numpy(self, arr):
169
+ """Loads all elements from a numpy array.
170
+
171
+ The shape of the numpy array needs to be the same as `self`.
172
+
173
+ Args:
174
+ arr (numpy.ndarray): The source numpy array.
175
+ """
176
+ raise NotImplementedError()
177
+
178
+ @python_scope
179
+ def _from_external_arr(self, arr):
180
+ raise NotImplementedError()
181
+
182
+ @python_scope
183
+ def from_torch(self, arr):
184
+ """Loads all elements from a torch tensor.
185
+
186
+ The shape of the torch tensor needs to be the same as `self`.
187
+
188
+ Args:
189
+ arr (torch.tensor): The source torch tensor.
190
+ """
191
+ self._from_external_arr(arr.contiguous())
192
+
193
+ @python_scope
194
+ def from_paddle(self, arr):
195
+ """Loads all elements from a paddle tensor.
196
+
197
+ The shape of the paddle tensor needs to be the same as `self`.
198
+
199
+ Args:
200
+ arr (paddle.Tensor): The source paddle tensor.
201
+ """
202
+ self.from_numpy(arr)
203
+
204
+ @python_scope
205
+ def copy_from(self, other):
206
+ """Copies all elements from another field.
207
+
208
+ The shape of the other field needs to be the same as `self`.
209
+
210
+ Args:
211
+ other (Field): The source field.
212
+ """
213
+ if not isinstance(other, Field):
214
+ raise TypeError("Cannot copy from a non-field object")
215
+ if self.shape != other.shape:
216
+ raise ValueError(f"ti.field shape {self.shape} does not match" f" the source field shape {other.shape}")
217
+ from taichi._kernels import tensor_to_tensor # pylint: disable=C0415
218
+
219
+ tensor_to_tensor(self, other)
220
+
221
+ @python_scope
222
+ def __setitem__(self, key, value):
223
+ """Sets field element in Python scope.
224
+
225
+ Args:
226
+ key (Union[List[int], int, None]): Coordinates of the field element.
227
+ value (element type): Value to set.
228
+ """
229
+ raise NotImplementedError()
230
+
231
+ @python_scope
232
+ def __getitem__(self, key):
233
+ """Gets field element in Python scope.
234
+
235
+ Args:
236
+ key (Union[List[int], int, None]): Coordinates of the field element.
237
+
238
+ Returns:
239
+ element type: Value retrieved.
240
+ """
241
+ raise NotImplementedError()
242
+
243
+ def __str__(self):
244
+ if taichi.lang.impl.inside_kernel():
245
+ return self.__repr__() # make pybind11 happy, see Matrix.__str__
246
+ if self._snode.ptr is None:
247
+ return "<Field: Definition of this field is incomplete>"
248
+ return str(self.to_numpy())
249
+
250
+ def _pad_key(self, key):
251
+ if key is None:
252
+ key = ()
253
+ if not isinstance(key, (tuple, list)):
254
+ key = (key,)
255
+
256
+ if len(key) != len(self.shape):
257
+ raise AssertionError("Slicing is not supported on ti.field")
258
+
259
+ return key + ((0,) * (_ti_core.get_max_num_indices() - len(key)))
260
+
261
+ def _initialize_host_accessors(self):
262
+ if self.host_accessors:
263
+ return
264
+ taichi.lang.impl.get_runtime().materialize()
265
+ self.host_accessors = [SNodeHostAccessor(e.ptr.snode()) for e in self.vars]
266
+
267
+ def _host_access(self, key):
268
+ return [SNodeHostAccess(e, key) for e in self.host_accessors]
269
+
270
+ def __iter__(self):
271
+ raise NotImplementedError("Struct for is only available in Taichi scope.")
272
+
273
+
274
+ class ScalarField(Field):
275
+ """Taichi scalar field with SNode implementation.
276
+
277
+ Args:
278
+ var (Expr): Field member.
279
+ """
280
+
281
+ def __init__(self, var):
282
+ super().__init__([var])
283
+
284
+ def fill(self, val):
285
+ """Fills this scalar field with a specified value."""
286
+ if in_python_scope():
287
+ from taichi._kernels import fill_field # pylint: disable=C0415
288
+
289
+ fill_field(self, val)
290
+ else:
291
+ from taichi._funcs import field_fill_taichi_scope # pylint: disable=C0415
292
+
293
+ field_fill_taichi_scope(self, val)
294
+
295
+ @python_scope
296
+ def to_numpy(self, dtype=None):
297
+ """Converts this field to a `numpy.ndarray`."""
298
+ if self.parent()._snode.ptr.type == _ti_core.SNodeType.dynamic:
299
+ warn(
300
+ "You are trying to convert a dynamic snode to a numpy array, be aware that inactive items in the snode will be converted to zeros in the resulting array."
301
+ )
302
+ if dtype is None:
303
+ dtype = to_numpy_type(self.dtype)
304
+ import numpy as np # pylint: disable=C0415
305
+
306
+ arr = np.zeros(shape=self.shape, dtype=dtype)
307
+ from taichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
308
+
309
+ tensor_to_ext_arr(self, arr)
310
+ taichi.lang.runtime_ops.sync()
311
+ return arr
312
+
313
+ @python_scope
314
+ def to_torch(self, device=None):
315
+ """Converts this field to a `torch.tensor`."""
316
+ import torch # pylint: disable=C0415
317
+
318
+ # pylint: disable=E1101
319
+ arr = torch.zeros(size=self.shape, dtype=to_pytorch_type(self.dtype), device=device)
320
+ from taichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
321
+
322
+ tensor_to_ext_arr(self, arr)
323
+ taichi.lang.runtime_ops.sync()
324
+ return arr
325
+
326
+ @python_scope
327
+ def to_paddle(self, place=None):
328
+ """Converts this field to a `paddle.Tensor`."""
329
+ import paddle # pylint: disable=C0415
330
+
331
+ # pylint: disable=E1101
332
+ # paddle.empty() doesn't support argument `place``
333
+ arr = paddle.to_tensor(paddle.zeros(self.shape, to_paddle_type(self.dtype)), place=place)
334
+ from taichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
335
+
336
+ tensor_to_ext_arr(self, arr)
337
+ taichi.lang.runtime_ops.sync()
338
+ return arr
339
+
340
+ @python_scope
341
+ def _from_external_arr(self, arr):
342
+ if len(self.shape) != len(arr.shape):
343
+ raise ValueError(f"ti.field shape {self.shape} does not match" f" the numpy array shape {arr.shape}")
344
+ for i, _ in enumerate(self.shape):
345
+ if self.shape[i] != arr.shape[i]:
346
+ raise ValueError(f"ti.field shape {self.shape} does not match" f" the numpy array shape {arr.shape}")
347
+ from taichi._kernels import ext_arr_to_tensor # pylint: disable=C0415
348
+
349
+ ext_arr_to_tensor(arr, self)
350
+ taichi.lang.runtime_ops.sync()
351
+
352
+ @python_scope
353
+ def from_numpy(self, arr):
354
+ """Copies the data from a `numpy.ndarray` into this field."""
355
+ if not arr.flags.c_contiguous:
356
+ import numpy as np # pylint: disable=C0415
357
+
358
+ arr = np.ascontiguousarray(arr)
359
+ self._from_external_arr(arr)
360
+
361
+ @python_scope
362
+ def __setitem__(self, key, value):
363
+ self._initialize_host_accessors()
364
+ self.host_accessors[0].setter(value, *self._pad_key(key))
365
+
366
+ @python_scope
367
+ def __getitem__(self, key):
368
+ self._initialize_host_accessors()
369
+ # Check for potential slicing behaviour
370
+ # for instance: x[0, :]
371
+ padded_key = self._pad_key(key)
372
+ import numpy as np # pylint: disable=C0415
373
+
374
+ for key in padded_key:
375
+ if not isinstance(key, (int, np.integer)):
376
+ raise TypeError(
377
+ f"Detected illegal element of type: {type(key)}. "
378
+ f"Please be aware that slicing a ti.field is not supported so far."
379
+ )
380
+ return self.host_accessors[0].getter(*padded_key)
381
+
382
+ def __repr__(self):
383
+ # make interactive shell happy, prevent materialization
384
+ return "<ti.field>"
385
+
386
+
387
+ class SNodeHostAccessor:
388
+ def __init__(self, snode):
389
+ if _ti_core.is_real(snode.data_type()):
390
+ write_func = snode.write_float
391
+ read_func = snode.read_float
392
+ else:
393
+
394
+ def write_func(key, value):
395
+ if value >= 0:
396
+ snode.write_uint(key, value)
397
+ else:
398
+ snode.write_int(key, value)
399
+
400
+ if _ti_core.is_signed(snode.data_type()):
401
+ read_func = snode.read_int
402
+ else:
403
+ read_func = snode.read_uint
404
+
405
+ def getter(*key):
406
+ assert len(key) == _ti_core.get_max_num_indices()
407
+ return read_func(key)
408
+
409
+ def setter(value, *key):
410
+ assert len(key) == _ti_core.get_max_num_indices()
411
+ write_func(key, value)
412
+ # same as above
413
+ if (
414
+ impl.get_runtime().target_tape
415
+ and impl.get_runtime().target_tape.grad_checker
416
+ and not impl.get_runtime().grad_replaced
417
+ ):
418
+ for x in impl.get_runtime().target_tape.grad_checker.to_check:
419
+ assert snode != x.snode.ptr, "Overwritten is prohibitive when doing grad check."
420
+ impl.get_runtime().target_tape.insert(write_func, (key, value))
421
+
422
+ self.getter = getter
423
+ self.setter = setter
424
+
425
+
426
+ class SNodeHostAccess:
427
+ def __init__(self, accessor, key):
428
+ self.accessor = accessor
429
+ self.key = key
430
+
431
+
432
+ class BitpackedFields:
433
+ """Taichi bitpacked fields, where fields with quantized types are packed together.
434
+
435
+ Args:
436
+ max_num_bits (int): Maximum number of bits all fields inside can occupy in total. Only 32 or 64 is allowed.
437
+ """
438
+
439
+ def __init__(self, max_num_bits):
440
+ self.fields = []
441
+ self.bit_struct_type_builder = _ti_core.BitStructTypeBuilder(max_num_bits)
442
+
443
+ def place(self, *args, shared_exponent=False):
444
+ """Places a list of fields with quantized types inside.
445
+
446
+ Args:
447
+ *args (List[Field]): A list of fields with quantized types to place.
448
+ shared_exponent (bool): Whether the fields have a shared exponent.
449
+ """
450
+ if shared_exponent:
451
+ self.bit_struct_type_builder.begin_placing_shared_exponent()
452
+ count = 0
453
+ for arg in args:
454
+ assert isinstance(arg, Field)
455
+ for var in arg._get_field_members():
456
+ self.fields.append((var.ptr, self.bit_struct_type_builder.add_member(var.ptr.get_dt())))
457
+ count += 1
458
+ if shared_exponent:
459
+ self.bit_struct_type_builder.end_placing_shared_exponent()
460
+ if count <= 1:
461
+ raise TaichiSyntaxError("At least 2 fields need to be placed when shared_exponent=True")
462
+
463
+
464
+ __all__ = ["BitpackedFields", "Field", "ScalarField"]