gstaichi 0.1.25.dev0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. gstaichi/CHANGELOG.md +9 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/__main__.py +5 -0
  4. gstaichi/_funcs.py +706 -0
  5. gstaichi/_kernels.py +420 -0
  6. gstaichi/_lib/__init__.py +3 -0
  7. gstaichi/_lib/core/__init__.py +0 -0
  8. gstaichi/_lib/core/gstaichi_python.cp312-win_amd64.pyd +0 -0
  9. gstaichi/_lib/core/gstaichi_python.pyi +2937 -0
  10. gstaichi/_lib/core/py.typed +0 -0
  11. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  12. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  13. gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  14. gstaichi/_lib/utils.py +249 -0
  15. gstaichi/_logging.py +131 -0
  16. gstaichi/_main.py +545 -0
  17. gstaichi/_snode/__init__.py +5 -0
  18. gstaichi/_snode/fields_builder.py +187 -0
  19. gstaichi/_snode/snode_tree.py +34 -0
  20. gstaichi/_test_tools/__init__.py +0 -0
  21. gstaichi/_test_tools/load_kernel_string.py +30 -0
  22. gstaichi/_version.py +1 -0
  23. gstaichi/_version_check.py +103 -0
  24. gstaichi/ad/__init__.py +3 -0
  25. gstaichi/ad/_ad.py +530 -0
  26. gstaichi/algorithms/__init__.py +3 -0
  27. gstaichi/algorithms/_algorithms.py +117 -0
  28. gstaichi/assets/.git +1 -0
  29. gstaichi/assets/Go-Regular.ttf +0 -0
  30. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_ndarray.py +352 -0
  35. gstaichi/lang/_ndrange.py +152 -0
  36. gstaichi/lang/_template_mapper.py +199 -0
  37. gstaichi/lang/_texture.py +172 -0
  38. gstaichi/lang/_wrap_inspect.py +189 -0
  39. gstaichi/lang/any_array.py +99 -0
  40. gstaichi/lang/argpack.py +411 -0
  41. gstaichi/lang/ast/__init__.py +5 -0
  42. gstaichi/lang/ast/ast_transformer.py +1318 -0
  43. gstaichi/lang/ast/ast_transformer_utils.py +341 -0
  44. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  45. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  46. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  47. gstaichi/lang/ast/checkers.py +106 -0
  48. gstaichi/lang/ast/symbol_resolver.py +57 -0
  49. gstaichi/lang/ast/transform.py +9 -0
  50. gstaichi/lang/common_ops.py +310 -0
  51. gstaichi/lang/exception.py +80 -0
  52. gstaichi/lang/expr.py +180 -0
  53. gstaichi/lang/field.py +466 -0
  54. gstaichi/lang/impl.py +1241 -0
  55. gstaichi/lang/kernel_arguments.py +157 -0
  56. gstaichi/lang/kernel_impl.py +1382 -0
  57. gstaichi/lang/matrix.py +1881 -0
  58. gstaichi/lang/matrix_ops.py +341 -0
  59. gstaichi/lang/matrix_ops_utils.py +190 -0
  60. gstaichi/lang/mesh.py +687 -0
  61. gstaichi/lang/misc.py +778 -0
  62. gstaichi/lang/ops.py +1494 -0
  63. gstaichi/lang/runtime_ops.py +13 -0
  64. gstaichi/lang/shell.py +35 -0
  65. gstaichi/lang/simt/__init__.py +5 -0
  66. gstaichi/lang/simt/block.py +94 -0
  67. gstaichi/lang/simt/grid.py +7 -0
  68. gstaichi/lang/simt/subgroup.py +191 -0
  69. gstaichi/lang/simt/warp.py +96 -0
  70. gstaichi/lang/snode.py +489 -0
  71. gstaichi/lang/source_builder.py +150 -0
  72. gstaichi/lang/struct.py +855 -0
  73. gstaichi/lang/util.py +381 -0
  74. gstaichi/linalg/__init__.py +8 -0
  75. gstaichi/linalg/matrixfree_cg.py +310 -0
  76. gstaichi/linalg/sparse_cg.py +59 -0
  77. gstaichi/linalg/sparse_matrix.py +303 -0
  78. gstaichi/linalg/sparse_solver.py +123 -0
  79. gstaichi/math/__init__.py +11 -0
  80. gstaichi/math/_complex.py +205 -0
  81. gstaichi/math/mathimpl.py +886 -0
  82. gstaichi/profiler/__init__.py +6 -0
  83. gstaichi/profiler/kernel_metrics.py +260 -0
  84. gstaichi/profiler/kernel_profiler.py +586 -0
  85. gstaichi/profiler/memory_profiler.py +15 -0
  86. gstaichi/profiler/scoped_profiler.py +36 -0
  87. gstaichi/sparse/__init__.py +3 -0
  88. gstaichi/sparse/_sparse_grid.py +77 -0
  89. gstaichi/tools/__init__.py +12 -0
  90. gstaichi/tools/diagnose.py +117 -0
  91. gstaichi/tools/np2ply.py +364 -0
  92. gstaichi/tools/vtk.py +38 -0
  93. gstaichi/types/__init__.py +19 -0
  94. gstaichi/types/annotations.py +47 -0
  95. gstaichi/types/compound_types.py +90 -0
  96. gstaichi/types/enums.py +49 -0
  97. gstaichi/types/ndarray_type.py +147 -0
  98. gstaichi/types/primitive_types.py +206 -0
  99. gstaichi/types/quant.py +88 -0
  100. gstaichi/types/texture_type.py +85 -0
  101. gstaichi/types/utils.py +13 -0
  102. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  103. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  104. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  105. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  106. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  107. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  108. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  109. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  110. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  111. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  112. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  113. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  114. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  115. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  116. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  117. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  118. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  119. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  120. gstaichi-0.1.25.dev0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  121. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
  122. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
  123. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  124. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
  125. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  126. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  127. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-link.lib +0 -0
  128. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  129. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  130. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  131. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  132. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools.lib +0 -0
  133. gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
  134. gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
  135. gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
  136. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  137. gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
  138. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
gstaichi/lang/field.py ADDED
@@ -0,0 +1,466 @@
1
+ # type: ignore
2
+
3
+ import gstaichi.lang
4
+ from gstaichi._lib import core as _ti_core
5
+ from gstaichi._logging import warn
6
+ from gstaichi.lang import impl
7
+ from gstaichi.lang.exception import GsTaichiSyntaxError
8
+ from gstaichi.lang.util import (
9
+ in_python_scope,
10
+ python_scope,
11
+ to_numpy_type,
12
+ to_paddle_type,
13
+ to_pytorch_type,
14
+ )
15
+
16
+
17
+ class Field:
18
+ """GsTaichi field class.
19
+
20
+ A field is constructed by a list of field members.
21
+ For example, a scalar field has 1 field member, while a 3x3 matrix field has 9 field members.
22
+ A field member is a Python Expr wrapping a C++ FieldExpression.
23
+
24
+ Args:
25
+ vars (List[Expr]): Field members.
26
+ """
27
+
28
+ def __init__(self, _vars):
29
+ assert all(_vars)
30
+ self.vars = _vars
31
+ self.host_accessors = None
32
+ self.grad = None
33
+ self.dual = None
34
+
35
+ @property
36
+ def snode(self):
37
+ """Gets representative SNode for info purposes.
38
+
39
+ Returns:
40
+ SNode: Representative SNode (SNode of first field member).
41
+ """
42
+ return self._snode
43
+
44
+ @property
45
+ def _snode(self):
46
+ """Gets representative SNode for info purposes.
47
+
48
+ Returns:
49
+ SNode: Representative SNode (SNode of first field member).
50
+ """
51
+ return gstaichi.lang.snode.SNode(self.vars[0].ptr.snode())
52
+
53
+ @property
54
+ def shape(self):
55
+ """Gets field shape.
56
+
57
+ Returns:
58
+ Tuple[Int]: Field shape.
59
+ """
60
+ return self._snode.shape
61
+
62
+ @property
63
+ def dtype(self):
64
+ """Gets data type of each individual value.
65
+
66
+ Returns:
67
+ DataType: Data type of each individual value.
68
+ """
69
+ return self._snode._dtype
70
+
71
+ @property
72
+ def _name(self):
73
+ """Gets field name.
74
+
75
+ Returns:
76
+ str: Field name.
77
+ """
78
+ return self._snode._name
79
+
80
+ def parent(self, n=1):
81
+ """Gets an ancestor of the representative SNode in the SNode tree.
82
+
83
+ Args:
84
+ n (int): the number of levels going up from the representative SNode.
85
+
86
+ Returns:
87
+ SNode: The n-th parent of the representative SNode.
88
+ """
89
+ return self.snode.parent(n)
90
+
91
+ def _get_field_members(self):
92
+ """Gets field members.
93
+
94
+ Returns:
95
+ List[Expr]: Field members.
96
+ """
97
+ return self.vars
98
+
99
+ def _loop_range(self):
100
+ """Gets SNode of representative field member for loop range info.
101
+
102
+ Returns:
103
+ gstaichi_python.SNode: SNode of representative (first) field member.
104
+ """
105
+ return self.vars[0].ptr.snode()
106
+
107
+ def _set_grad(self, grad):
108
+ """Sets corresponding grad field (reverse mode).
109
+ Args:
110
+ grad (Field): Corresponding grad field.
111
+ """
112
+ self.grad = grad
113
+
114
+ def _set_dual(self, dual):
115
+ """Sets corresponding dual field (forward mode).
116
+
117
+ Args:
118
+ dual (Field): Corresponding dual field.
119
+ """
120
+ self.dual = dual
121
+
122
+ @python_scope
123
+ def fill(self, val):
124
+ """Fills `self` with a specific value.
125
+
126
+ Args:
127
+ val (Union[int, float]): Value to fill.
128
+ """
129
+ raise NotImplementedError()
130
+
131
+ @python_scope
132
+ def to_numpy(self, dtype=None):
133
+ """Converts `self` to a numpy array.
134
+
135
+ Args:
136
+ dtype (DataType, optional): The desired data type of returned numpy array.
137
+
138
+ Returns:
139
+ numpy.ndarray: The result numpy array.
140
+ """
141
+ raise NotImplementedError()
142
+
143
+ @python_scope
144
+ def to_torch(self, device=None):
145
+ """Converts `self` to a torch tensor.
146
+
147
+ Args:
148
+ device (torch.device, optional): The desired device of returned tensor.
149
+
150
+ Returns:
151
+ torch.tensor: The result torch tensor.
152
+ """
153
+ raise NotImplementedError()
154
+
155
+ @python_scope
156
+ def to_paddle(self, place=None):
157
+ """Converts `self` to a paddle tensor.
158
+
159
+ Args:
160
+ place (paddle.CPUPlace()/CUDAPlace(n), optional): The desired place of returned tensor.
161
+
162
+ Returns:
163
+ paddle.Tensor: The result paddle tensor.
164
+ """
165
+ raise NotImplementedError()
166
+
167
+ @python_scope
168
+ def from_numpy(self, arr):
169
+ """Loads all elements from a numpy array.
170
+
171
+ The shape of the numpy array needs to be the same as `self`.
172
+
173
+ Args:
174
+ arr (numpy.ndarray): The source numpy array.
175
+ """
176
+ raise NotImplementedError()
177
+
178
+ @python_scope
179
+ def _from_external_arr(self, arr):
180
+ raise NotImplementedError()
181
+
182
+ @python_scope
183
+ def from_torch(self, arr):
184
+ """Loads all elements from a torch tensor.
185
+
186
+ The shape of the torch tensor needs to be the same as `self`.
187
+
188
+ Args:
189
+ arr (torch.tensor): The source torch tensor.
190
+ """
191
+ self._from_external_arr(arr.contiguous())
192
+
193
+ @python_scope
194
+ def from_paddle(self, arr):
195
+ """Loads all elements from a paddle tensor.
196
+
197
+ The shape of the paddle tensor needs to be the same as `self`.
198
+
199
+ Args:
200
+ arr (paddle.Tensor): The source paddle tensor.
201
+ """
202
+ self.from_numpy(arr)
203
+
204
+ @python_scope
205
+ def copy_from(self, other):
206
+ """Copies all elements from another field.
207
+
208
+ The shape of the other field needs to be the same as `self`.
209
+
210
+ Args:
211
+ other (Field): The source field.
212
+ """
213
+ if not isinstance(other, Field):
214
+ raise TypeError("Cannot copy from a non-field object")
215
+ if self.shape != other.shape:
216
+ raise ValueError(f"ti.field shape {self.shape} does not match" f" the source field shape {other.shape}")
217
+ from gstaichi._kernels import tensor_to_tensor # pylint: disable=C0415
218
+
219
+ tensor_to_tensor(self, other)
220
+
221
+ @python_scope
222
+ def __setitem__(self, key, value):
223
+ """Sets field element in Python scope.
224
+
225
+ Args:
226
+ key (Union[List[int], int, None]): Coordinates of the field element.
227
+ value (element type): Value to set.
228
+ """
229
+ raise NotImplementedError()
230
+
231
+ @python_scope
232
+ def __getitem__(self, key):
233
+ """Gets field element in Python scope.
234
+
235
+ Args:
236
+ key (Union[List[int], int, None]): Coordinates of the field element.
237
+
238
+ Returns:
239
+ element type: Value retrieved.
240
+ """
241
+ raise NotImplementedError()
242
+
243
+ def __str__(self):
244
+ if gstaichi.lang.impl.inside_kernel():
245
+ return self.__repr__() # make pybind11 happy, see Matrix.__str__
246
+ if self._snode.ptr is None:
247
+ return "<Field: Definition of this field is incomplete>"
248
+ return str(self.to_numpy())
249
+
250
+ def _pad_key(self, key):
251
+ if key is None:
252
+ key = ()
253
+ if not isinstance(key, (tuple, list)):
254
+ key = (key,)
255
+
256
+ if len(key) != len(self.shape):
257
+ raise AssertionError("Slicing is not supported on ti.field")
258
+
259
+ return key + ((0,) * (_ti_core.get_max_num_indices() - len(key)))
260
+
261
+ def _initialize_host_accessors(self):
262
+ if self.host_accessors:
263
+ return
264
+ gstaichi.lang.impl.get_runtime().materialize()
265
+ self.host_accessors = [SNodeHostAccessor(e.ptr.snode()) for e in self.vars]
266
+
267
+ def _host_access(self, key):
268
+ return [SNodeHostAccess(e, key) for e in self.host_accessors]
269
+
270
+ def __iter__(self):
271
+ raise NotImplementedError("Struct for is only available in GsTaichi scope.")
272
+
273
+
274
+ class ScalarField(Field):
275
+ """GsTaichi scalar field with SNode implementation.
276
+
277
+ Args:
278
+ var (Expr): Field member.
279
+ """
280
+
281
+ def __init__(self, var):
282
+ super().__init__([var])
283
+
284
+ def fill(self, val):
285
+ """Fills this scalar field with a specified value."""
286
+ if in_python_scope():
287
+ from gstaichi._kernels import fill_field # pylint: disable=C0415
288
+
289
+ fill_field(self, val)
290
+ else:
291
+ from gstaichi._funcs import ( # pylint: disable=C0415
292
+ field_fill_gstaichi_scope, # pylint: disable=C0415
293
+ )
294
+
295
+ field_fill_gstaichi_scope(self, val)
296
+
297
+ @python_scope
298
+ def to_numpy(self, dtype=None):
299
+ """Converts this field to a `numpy.ndarray`."""
300
+ if self.parent()._snode.ptr.type == _ti_core.SNodeType.dynamic:
301
+ warn(
302
+ "You are trying to convert a dynamic snode to a numpy array, be aware that inactive items in the snode will be converted to zeros in the resulting array."
303
+ )
304
+ if dtype is None:
305
+ dtype = to_numpy_type(self.dtype)
306
+ import numpy as np # pylint: disable=C0415
307
+
308
+ arr = np.zeros(shape=self.shape, dtype=dtype)
309
+ from gstaichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
310
+
311
+ tensor_to_ext_arr(self, arr)
312
+ gstaichi.lang.runtime_ops.sync()
313
+ return arr
314
+
315
+ @python_scope
316
+ def to_torch(self, device=None):
317
+ """Converts this field to a `torch.tensor`."""
318
+ import torch # pylint: disable=C0415
319
+
320
+ # pylint: disable=E1101
321
+ arr = torch.zeros(size=self.shape, dtype=to_pytorch_type(self.dtype), device=device)
322
+ from gstaichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
323
+
324
+ tensor_to_ext_arr(self, arr)
325
+ gstaichi.lang.runtime_ops.sync()
326
+ return arr
327
+
328
+ @python_scope
329
+ def to_paddle(self, place=None):
330
+ """Converts this field to a `paddle.Tensor`."""
331
+ import paddle # pylint: disable=C0415
332
+
333
+ # pylint: disable=E1101
334
+ # paddle.empty() doesn't support argument `place``
335
+ arr = paddle.to_tensor(paddle.zeros(self.shape, to_paddle_type(self.dtype)), place=place)
336
+ from gstaichi._kernels import tensor_to_ext_arr # pylint: disable=C0415
337
+
338
+ tensor_to_ext_arr(self, arr)
339
+ gstaichi.lang.runtime_ops.sync()
340
+ return arr
341
+
342
+ @python_scope
343
+ def _from_external_arr(self, arr):
344
+ if len(self.shape) != len(arr.shape):
345
+ raise ValueError(f"ti.field shape {self.shape} does not match" f" the numpy array shape {arr.shape}")
346
+ for i, _ in enumerate(self.shape):
347
+ if self.shape[i] != arr.shape[i]:
348
+ raise ValueError(f"ti.field shape {self.shape} does not match" f" the numpy array shape {arr.shape}")
349
+ from gstaichi._kernels import ext_arr_to_tensor # pylint: disable=C0415
350
+
351
+ ext_arr_to_tensor(arr, self)
352
+ gstaichi.lang.runtime_ops.sync()
353
+
354
+ @python_scope
355
+ def from_numpy(self, arr):
356
+ """Copies the data from a `numpy.ndarray` into this field."""
357
+ if not arr.flags.c_contiguous:
358
+ import numpy as np # pylint: disable=C0415
359
+
360
+ arr = np.ascontiguousarray(arr)
361
+ self._from_external_arr(arr)
362
+
363
+ @python_scope
364
+ def __setitem__(self, key, value):
365
+ self._initialize_host_accessors()
366
+ self.host_accessors[0].setter(value, *self._pad_key(key))
367
+
368
+ @python_scope
369
+ def __getitem__(self, key):
370
+ self._initialize_host_accessors()
371
+ # Check for potential slicing behaviour
372
+ # for instance: x[0, :]
373
+ padded_key = self._pad_key(key)
374
+ import numpy as np # pylint: disable=C0415
375
+
376
+ for key in padded_key:
377
+ if not isinstance(key, (int, np.integer)):
378
+ raise TypeError(
379
+ f"Detected illegal element of type: {type(key)}. "
380
+ f"Please be aware that slicing a ti.field is not supported so far."
381
+ )
382
+ return self.host_accessors[0].getter(*padded_key)
383
+
384
+ def __repr__(self):
385
+ # make interactive shell happy, prevent materialization
386
+ return "<ti.field>"
387
+
388
+
389
+ class SNodeHostAccessor:
390
+ def __init__(self, snode):
391
+ if _ti_core.is_real(snode.data_type()):
392
+ write_func = snode.write_float
393
+ read_func = snode.read_float
394
+ else:
395
+
396
+ def write_func(key, value):
397
+ if value >= 0:
398
+ snode.write_uint(key, value)
399
+ else:
400
+ snode.write_int(key, value)
401
+
402
+ if _ti_core.is_signed(snode.data_type()):
403
+ read_func = snode.read_int
404
+ else:
405
+ read_func = snode.read_uint
406
+
407
+ def getter(*key):
408
+ assert len(key) == _ti_core.get_max_num_indices()
409
+ return read_func(key)
410
+
411
+ def setter(value, *key):
412
+ assert len(key) == _ti_core.get_max_num_indices()
413
+ write_func(key, value)
414
+ # same as above
415
+ if (
416
+ impl.get_runtime().target_tape
417
+ and impl.get_runtime().target_tape.grad_checker
418
+ and not impl.get_runtime().grad_replaced
419
+ ):
420
+ for x in impl.get_runtime().target_tape.grad_checker.to_check:
421
+ assert snode != x.snode.ptr, "Overwritten is prohibitive when doing grad check."
422
+ impl.get_runtime().target_tape.insert(write_func, (key, value))
423
+
424
+ self.getter = getter
425
+ self.setter = setter
426
+
427
+
428
+ class SNodeHostAccess:
429
+ def __init__(self, accessor, key):
430
+ self.accessor = accessor
431
+ self.key = key
432
+
433
+
434
+ class BitpackedFields:
435
+ """GsTaichi bitpacked fields, where fields with quantized types are packed together.
436
+
437
+ Args:
438
+ max_num_bits (int): Maximum number of bits all fields inside can occupy in total. Only 32 or 64 is allowed.
439
+ """
440
+
441
+ def __init__(self, max_num_bits):
442
+ self.fields = []
443
+ self.bit_struct_type_builder = _ti_core.BitStructTypeBuilder(max_num_bits)
444
+
445
+ def place(self, *args, shared_exponent=False):
446
+ """Places a list of fields with quantized types inside.
447
+
448
+ Args:
449
+ *args (List[Field]): A list of fields with quantized types to place.
450
+ shared_exponent (bool): Whether the fields have a shared exponent.
451
+ """
452
+ if shared_exponent:
453
+ self.bit_struct_type_builder.begin_placing_shared_exponent()
454
+ count = 0
455
+ for arg in args:
456
+ assert isinstance(arg, Field)
457
+ for var in arg._get_field_members():
458
+ self.fields.append((var.ptr, self.bit_struct_type_builder.add_member(var.ptr.get_dt())))
459
+ count += 1
460
+ if shared_exponent:
461
+ self.bit_struct_type_builder.end_placing_shared_exponent()
462
+ if count <= 1:
463
+ raise GsTaichiSyntaxError("At least 2 fields need to be placed when shared_exponent=True")
464
+
465
+
466
+ __all__ = ["BitpackedFields", "Field", "ScalarField"]