gstaichi 0.1.25.dev0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. gstaichi/CHANGELOG.md +9 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/__main__.py +5 -0
  4. gstaichi/_funcs.py +706 -0
  5. gstaichi/_kernels.py +420 -0
  6. gstaichi/_lib/__init__.py +3 -0
  7. gstaichi/_lib/core/__init__.py +0 -0
  8. gstaichi/_lib/core/gstaichi_python.cp313-win_amd64.pyd +0 -0
  9. gstaichi/_lib/core/gstaichi_python.pyi +2937 -0
  10. gstaichi/_lib/core/py.typed +0 -0
  11. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  12. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  13. gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  14. gstaichi/_lib/utils.py +249 -0
  15. gstaichi/_logging.py +131 -0
  16. gstaichi/_main.py +545 -0
  17. gstaichi/_snode/__init__.py +5 -0
  18. gstaichi/_snode/fields_builder.py +187 -0
  19. gstaichi/_snode/snode_tree.py +34 -0
  20. gstaichi/_test_tools/__init__.py +0 -0
  21. gstaichi/_test_tools/load_kernel_string.py +30 -0
  22. gstaichi/_version.py +1 -0
  23. gstaichi/_version_check.py +103 -0
  24. gstaichi/ad/__init__.py +3 -0
  25. gstaichi/ad/_ad.py +530 -0
  26. gstaichi/algorithms/__init__.py +3 -0
  27. gstaichi/algorithms/_algorithms.py +117 -0
  28. gstaichi/assets/.git +1 -0
  29. gstaichi/assets/Go-Regular.ttf +0 -0
  30. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_ndarray.py +352 -0
  35. gstaichi/lang/_ndrange.py +152 -0
  36. gstaichi/lang/_template_mapper.py +199 -0
  37. gstaichi/lang/_texture.py +172 -0
  38. gstaichi/lang/_wrap_inspect.py +189 -0
  39. gstaichi/lang/any_array.py +99 -0
  40. gstaichi/lang/argpack.py +411 -0
  41. gstaichi/lang/ast/__init__.py +5 -0
  42. gstaichi/lang/ast/ast_transformer.py +1318 -0
  43. gstaichi/lang/ast/ast_transformer_utils.py +341 -0
  44. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  45. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  46. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  47. gstaichi/lang/ast/checkers.py +106 -0
  48. gstaichi/lang/ast/symbol_resolver.py +57 -0
  49. gstaichi/lang/ast/transform.py +9 -0
  50. gstaichi/lang/common_ops.py +310 -0
  51. gstaichi/lang/exception.py +80 -0
  52. gstaichi/lang/expr.py +180 -0
  53. gstaichi/lang/field.py +466 -0
  54. gstaichi/lang/impl.py +1241 -0
  55. gstaichi/lang/kernel_arguments.py +157 -0
  56. gstaichi/lang/kernel_impl.py +1382 -0
  57. gstaichi/lang/matrix.py +1881 -0
  58. gstaichi/lang/matrix_ops.py +341 -0
  59. gstaichi/lang/matrix_ops_utils.py +190 -0
  60. gstaichi/lang/mesh.py +687 -0
  61. gstaichi/lang/misc.py +778 -0
  62. gstaichi/lang/ops.py +1494 -0
  63. gstaichi/lang/runtime_ops.py +13 -0
  64. gstaichi/lang/shell.py +35 -0
  65. gstaichi/lang/simt/__init__.py +5 -0
  66. gstaichi/lang/simt/block.py +94 -0
  67. gstaichi/lang/simt/grid.py +7 -0
  68. gstaichi/lang/simt/subgroup.py +191 -0
  69. gstaichi/lang/simt/warp.py +96 -0
  70. gstaichi/lang/snode.py +489 -0
  71. gstaichi/lang/source_builder.py +150 -0
  72. gstaichi/lang/struct.py +855 -0
  73. gstaichi/lang/util.py +381 -0
  74. gstaichi/linalg/__init__.py +8 -0
  75. gstaichi/linalg/matrixfree_cg.py +310 -0
  76. gstaichi/linalg/sparse_cg.py +59 -0
  77. gstaichi/linalg/sparse_matrix.py +303 -0
  78. gstaichi/linalg/sparse_solver.py +123 -0
  79. gstaichi/math/__init__.py +11 -0
  80. gstaichi/math/_complex.py +205 -0
  81. gstaichi/math/mathimpl.py +886 -0
  82. gstaichi/profiler/__init__.py +6 -0
  83. gstaichi/profiler/kernel_metrics.py +260 -0
  84. gstaichi/profiler/kernel_profiler.py +586 -0
  85. gstaichi/profiler/memory_profiler.py +15 -0
  86. gstaichi/profiler/scoped_profiler.py +36 -0
  87. gstaichi/sparse/__init__.py +3 -0
  88. gstaichi/sparse/_sparse_grid.py +77 -0
  89. gstaichi/tools/__init__.py +12 -0
  90. gstaichi/tools/diagnose.py +117 -0
  91. gstaichi/tools/np2ply.py +364 -0
  92. gstaichi/tools/vtk.py +38 -0
  93. gstaichi/types/__init__.py +19 -0
  94. gstaichi/types/annotations.py +47 -0
  95. gstaichi/types/compound_types.py +90 -0
  96. gstaichi/types/enums.py +49 -0
  97. gstaichi/types/ndarray_type.py +147 -0
  98. gstaichi/types/primitive_types.py +206 -0
  99. gstaichi/types/quant.py +88 -0
  100. gstaichi/types/texture_type.py +85 -0
  101. gstaichi/types/utils.py +13 -0
  102. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  103. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  104. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  105. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  106. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  107. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  108. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  109. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  110. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  111. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  112. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  113. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  114. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  115. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  116. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  117. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  118. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  119. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  120. gstaichi-0.1.25.dev0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  121. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
  122. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
  123. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  124. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
  125. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  126. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  127. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-link.lib +0 -0
  128. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  129. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  130. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  131. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  132. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools.lib +0 -0
  133. gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
  134. gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
  135. gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
  136. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  137. gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
  138. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,411 @@
1
+ # type: ignore
2
+
3
+ import numpy as np
4
+
5
+ import gstaichi.lang
6
+ from gstaichi._lib import core as _ti_core
7
+ from gstaichi.lang import impl, ops
8
+ from gstaichi.lang.exception import (
9
+ GsTaichiRuntimeTypeError,
10
+ GsTaichiSyntaxError,
11
+ )
12
+ from gstaichi.lang.matrix import Matrix, MatrixType
13
+ from gstaichi.lang.struct import Struct, StructType
14
+ from gstaichi.lang.util import cook_dtype, in_python_scope, python_scope
15
+ from gstaichi.types import (
16
+ ndarray_type,
17
+ primitive_types,
18
+ sparse_matrix_builder,
19
+ texture_type,
20
+ )
21
+ from gstaichi.types.compound_types import CompoundType
22
+ from gstaichi.types.utils import is_signed
23
+
24
+
25
+ class ArgPack:
26
+ """ The `ArgPack` Type Class.
27
+
28
+ The `ArgPack` operates as a dictionary-like data pack, storing members as (key, value) pairs. Members stored can
29
+ range from scalars and matrices to other dictionary-like structures. Distinguished from structs, `ArgPack` can
30
+ accommodate buffer types such as `NdarrayType` and `TextureType` from GsTaichi. However, unlike `ti.Struct` which
31
+ serves as a data container, `ArgPack` functions as a reference container. It's important to note that `ArgPack`
32
+ cannot be nested within other types except for another `ArgPack`, and can only be utilized as kernel parameters.
33
+
34
+ Args:
35
+ annotations (Dict[str, Union[Dict, Matrix, Struct]]): \
36
+ The keys and types for `ArgPack` members.
37
+ dtype (ArgPackType): \
38
+ The ArgPackType class of this ArgPack object.
39
+ entries (Dict[str, Union[Dict, Matrix, Struct]]): \
40
+ The keys and corresponding values for `ArgPack` members.
41
+
42
+ Returns:
43
+ An instance of this `ArgPack`.
44
+
45
+ Example::
46
+
47
+ >>> vec3 = ti.types.vector(3, ti.f32)
48
+ >>> pack_type = ti.ArgPackType(v=vec3, t=ti.f32)
49
+ >>> a = pack_type(v=vec3([0, 0, 0]), t=1.0)
50
+ >>> print(a.items)
51
+ dict_items([('v', [0. 0. 0.]), ('t', 1.0)])
52
+ """
53
+
54
+ _instance_count = 0
55
+
56
+ def __init__(self, annotations, dtype, *args, **kwargs):
57
+ # converts dicts to argument packs
58
+ if len(args) == 1 and kwargs == {} and isinstance(args[0], dict):
59
+ self.__entries = args[0]
60
+ elif len(args) == 0:
61
+ self.__entries = kwargs
62
+ else:
63
+ raise GsTaichiSyntaxError(
64
+ "Custom argument packs need to be initialized using either dictionary or keyword arguments"
65
+ )
66
+ if annotations.keys() != self.__entries.keys():
67
+ raise GsTaichiSyntaxError("ArgPack annotations keys not equals to entries keys.")
68
+ self.__annotations = annotations
69
+ for k, v in self.__entries.items():
70
+ self.__entries[k] = v if in_python_scope() else impl.expr_init(v)
71
+ self._register_members()
72
+ self.__dtype = dtype
73
+ self.__argpack = impl.get_runtime().prog.create_argpack(self.__dtype)
74
+ for i, (k, v) in enumerate(self.__entries.items()):
75
+ self._write_to_device(self.__annotations[k], type(v), v, self._calc_element_true_index(i))
76
+
77
+ def __del__(self):
78
+ if impl is not None and impl.get_runtime() is not None and impl.get_runtime().prog is not None:
79
+ impl.get_runtime().prog.delete_argpack(self.__argpack)
80
+
81
+ @property
82
+ def keys(self):
83
+ """Returns the list of member names in string format.
84
+
85
+ Example::
86
+
87
+ >>> vec3 = ti.types.vector(3, ti.f32)
88
+ >>> sphere_pack = ti.ArgPackType(center=vec3, radius=ti.f32)
89
+ >>> sphere = sphere_pack(center=vec3([0, 0, 0]), radius=1.0)
90
+ >>> sphere.keys
91
+ ['center', 'radius']
92
+ """
93
+ return list(self.__entries.keys())
94
+
95
+ @property
96
+ def _members(self):
97
+ return list(self.__entries.values())
98
+
99
+ @property
100
+ def _annotations(self):
101
+ return list(self.__annotations.values())
102
+
103
+ @property
104
+ def items(self):
105
+ """Returns the items in this argument pack.
106
+
107
+ Example::
108
+
109
+ >>> vec3 = ti.types.vector(3, ti.f32)
110
+ >>> sphere_pack = ti.ArgPackType(center=vec3, radius=ti.f32)
111
+ >>> sphere = sphere_pack(center=vec3([0, 0, 0]), radius=1.0)
112
+ >>> sphere.items
113
+ dict_items([('center', 2), ('radius', 1.0)])
114
+ """
115
+ return self.__entries.items()
116
+
117
+ def __getitem__(self, key):
118
+ ret = self.__entries[key]
119
+ return ret
120
+
121
+ def __setitem__(self, key, value):
122
+ self.__entries[key] = value
123
+ index = self._calc_element_true_index(list(self.__annotations).index(key))
124
+ self._write_to_device(self.__annotations[key], type(value), value, index)
125
+
126
+ def _set_entries(self, value):
127
+ if isinstance(value, dict):
128
+ value = ArgPack(self.__annotations, value)
129
+ for k in self.keys:
130
+ self[k] = value[k]
131
+
132
+ @staticmethod
133
+ def _make_getter(key):
134
+ def getter(self):
135
+ """Get an entry from custom argument pack by name."""
136
+ return self[key]
137
+
138
+ return getter
139
+
140
+ @staticmethod
141
+ def _make_setter(key):
142
+ @python_scope
143
+ def setter(self, value):
144
+ self[key] = value
145
+
146
+ return setter
147
+
148
+ def _register_members(self):
149
+ # https://stackoverflow.com/questions/48448074/adding-a-property-to-an-existing-object-instance
150
+ cls = self.__class__
151
+ new_cls_name = cls.__name__ + str(cls._instance_count)
152
+ cls._instance_count += 1
153
+ properties = {k: property(cls._make_getter(k), cls._make_setter(k)) for k in self.keys}
154
+ self.__class__ = type(new_cls_name, (cls,), properties)
155
+
156
+ def __len__(self):
157
+ """Get the number of entries in a custom argument pack."""
158
+ return len(self.__entries)
159
+
160
+ def __iter__(self):
161
+ return self.__entries.values()
162
+
163
+ def __str__(self):
164
+ """Python scope argument pack array print support."""
165
+ if impl.inside_kernel():
166
+ item_str = ", ".join([str(k) + "=" + str(v) for k, v in self.items])
167
+ return f"<ti.ArgPack {item_str}>"
168
+ return str(self.to_dict())
169
+
170
+ def __repr__(self):
171
+ return str(self.to_dict())
172
+
173
+ def to_dict(self):
174
+ """Converts the ArgPack to a dictionary.
175
+
176
+ Returns:
177
+ Dict: The result dictionary.
178
+ """
179
+ res_dict = {
180
+ k: v.to_dict() if isinstance(v, ArgPack) else v.to_list() if isinstance(v, Matrix) else v
181
+ for k, v in self.__entries.items()
182
+ }
183
+ return res_dict
184
+
185
+ def _calc_element_true_index(self, old_index):
186
+ for i in range(old_index):
187
+ anno = list(self.__annotations.values())[i]
188
+ if (
189
+ isinstance(anno, sparse_matrix_builder)
190
+ or isinstance(anno, ndarray_type.NdarrayType)
191
+ or isinstance(anno, texture_type.TextureType)
192
+ or isinstance(anno, texture_type.RWTextureType)
193
+ or isinstance(anno, ndarray_type.NdarrayType)
194
+ ):
195
+ old_index -= 1
196
+ return old_index
197
+
198
+ def _write_to_device(self, needed, provided, v, index):
199
+ if isinstance(needed, ArgPackType):
200
+ if not isinstance(v, ArgPack):
201
+ raise GsTaichiRuntimeTypeError.get(index, str(needed), str(provided))
202
+ self.__argpack.set_arg_nested_argpack(index, v.__argpack)
203
+ else:
204
+ # Note: do not use sth like "needed == f32". That would be slow.
205
+ if id(needed) in primitive_types.real_type_ids:
206
+ if not isinstance(v, (float, int, np.floating, np.integer)):
207
+ raise GsTaichiRuntimeTypeError.get(index, needed.to_string(), provided)
208
+ self.__argpack.set_arg_float((index,), float(v))
209
+ elif id(needed) in primitive_types.integer_type_ids:
210
+ if not isinstance(v, (int, np.integer)):
211
+ raise GsTaichiRuntimeTypeError.get(index, needed.to_string(), provided)
212
+ if is_signed(cook_dtype(needed)):
213
+ self.__argpack.set_arg_int((index,), int(v))
214
+ else:
215
+ self.__argpack.set_arg_uint((index,), int(v))
216
+ elif isinstance(needed, sparse_matrix_builder):
217
+ pass
218
+ elif isinstance(needed, ndarray_type.NdarrayType) and isinstance(v, gstaichi.lang._ndarray.Ndarray):
219
+ pass
220
+ elif isinstance(needed, texture_type.TextureType) and isinstance(v, gstaichi.lang._texture.Texture):
221
+ pass
222
+ elif isinstance(needed, texture_type.RWTextureType) and isinstance(v, gstaichi.lang._texture.Texture):
223
+ pass
224
+ elif isinstance(needed, ndarray_type.NdarrayType):
225
+ pass
226
+ elif isinstance(needed, MatrixType):
227
+ if needed.dtype in primitive_types.real_types:
228
+
229
+ def cast_func(x):
230
+ if not isinstance(x, (int, float, np.integer, np.floating)):
231
+ raise GsTaichiRuntimeTypeError.get(index, needed.dtype.to_string(), type(x))
232
+ return float(x)
233
+
234
+ elif needed.dtype in primitive_types.integer_types:
235
+
236
+ def cast_func(x):
237
+ if not isinstance(x, (int, np.integer)):
238
+ raise GsTaichiRuntimeTypeError.get(index, needed.dtype.to_string(), type(x))
239
+ return int(x)
240
+
241
+ else:
242
+ raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.")
243
+
244
+ if needed.ndim == 2:
245
+ v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)]
246
+ else:
247
+ v = [cast_func(v[i]) for i in range(needed.n)]
248
+ v = needed(*v)
249
+ needed.set_argpack_struct_args(v, self.__argpack, (index,))
250
+ elif isinstance(needed, StructType):
251
+ if not isinstance(v, needed):
252
+ raise GsTaichiRuntimeTypeError.get(index, str(needed), provided)
253
+ needed.set_argpack_struct_args(v, self.__argpack, (index,))
254
+ else:
255
+ raise ValueError(f"Argument type mismatch. Expecting {needed}, got {type(v)}.")
256
+
257
+
258
+ class _IntermediateArgPack(ArgPack):
259
+ """Intermediate argument pack class for compiler internal use only.
260
+
261
+ Args:
262
+ annotations (Dict[str, Union[Expr, Matrix, Struct]]): keys and types for struct members.
263
+ entries (Dict[str, Union[Expr, Matrix, Struct]]): keys and values for struct members.
264
+ """
265
+
266
+ def __init__(self, annotations, dtype, *args, **kwargs):
267
+ # converts dicts to argument packs
268
+ if len(args) == 1 and kwargs == {} and isinstance(args[0], dict):
269
+ self._ArgPack__entries = args[0]
270
+ elif len(args) == 0:
271
+ self._ArgPack__entries = kwargs
272
+ else:
273
+ raise GsTaichiSyntaxError(
274
+ "Custom argument packs need to be initialized using either dictionary or keyword arguments"
275
+ )
276
+ if annotations.keys() != self._ArgPack__entries.keys():
277
+ raise GsTaichiSyntaxError("ArgPack annotations keys not equals to entries keys.")
278
+ self._ArgPack__annotations = annotations
279
+ self._register_members()
280
+ self._ArgPack__dtype = dtype
281
+ self._ArgPack__argpack = impl.get_runtime().prog.create_argpack(dtype)
282
+
283
+ def __del__(self):
284
+ pass
285
+
286
+
287
+ class ArgPackType(CompoundType):
288
+ def __init__(self, **kwargs):
289
+ self.members = {}
290
+ elements = []
291
+ for k, dtype in kwargs.items():
292
+ if isinstance(dtype, StructType):
293
+ self.members[k] = dtype
294
+ elements.append([dtype.dtype, k])
295
+ elif isinstance(dtype, ArgPackType):
296
+ self.members[k] = dtype
297
+ elements.append(
298
+ [
299
+ _ti_core.DataTypeCxx(
300
+ _ti_core.get_type_factory_instance().get_struct_type_for_argpack_ptr(dtype.dtype)
301
+ ),
302
+ k,
303
+ ]
304
+ )
305
+ elif isinstance(dtype, MatrixType):
306
+ # Convert MatrixType to StructType
307
+ if dtype.ndim == 1:
308
+ elements_ = [(dtype.dtype, f"{k}_{i}") for i in range(dtype.n)]
309
+ else:
310
+ elements_ = [(dtype.dtype, f"{k}_{i}_{j}") for i in range(dtype.n) for j in range(dtype.m)]
311
+ self.members[k] = dtype
312
+ elements.append([_ti_core.get_type_factory_instance().get_struct_type(elements_), k])
313
+ elif isinstance(dtype, sparse_matrix_builder):
314
+ self.members[k] = dtype
315
+ elif isinstance(dtype, ndarray_type.NdarrayType):
316
+ self.members[k] = dtype
317
+ elif isinstance(dtype, texture_type.RWTextureType):
318
+ self.members[k] = dtype
319
+ elif isinstance(dtype, texture_type.TextureType):
320
+ self.members[k] = dtype
321
+ else:
322
+ dtype = cook_dtype(dtype)
323
+ self.members[k] = dtype
324
+ elements.append([dtype, k])
325
+ if len(elements) == 0:
326
+ # Use i32 as a placeholder for empty argpacks
327
+ elements.append([primitive_types.i32, k])
328
+ self.dtype = _ti_core.get_type_factory_instance().get_argpack_type(elements)
329
+
330
+ def __call__(self, *args, **kwargs):
331
+ """Create an instance of this argument pack type."""
332
+ d = {}
333
+ items = self.members.items()
334
+ # iterate over the members of this argument pack
335
+ for index, pair in enumerate(items):
336
+ name, dtype = pair # (member name, member type))
337
+ if index < len(args): # set from args
338
+ data = args[index]
339
+ else: # set from kwargs
340
+ data = kwargs.get(name, None)
341
+
342
+ # If dtype is CompoundType and data is a scalar, it cannot be
343
+ # casted in the self.cast call later. We need an initialization here.
344
+ if isinstance(dtype, CompoundType) and not isinstance(data, (dict, ArgPack, Struct)):
345
+ data = dtype(data)
346
+
347
+ d[name] = data
348
+
349
+ entries = ArgPack(self.members, self.dtype, d)
350
+ pack = self.cast(entries)
351
+ return pack
352
+
353
+ def __instancecheck__(self, instance):
354
+ if not isinstance(instance, ArgPack):
355
+ return False
356
+ if list(self.members.keys()) != list(instance._ArgPack__entries.keys()):
357
+ return False
358
+ for k, v in self.members.items():
359
+ if isinstance(v, ArgPackType):
360
+ if not isinstance(instance._ArgPack__entries[k], v):
361
+ return False
362
+ elif instance._ArgPack__annotations[k] != v:
363
+ return False
364
+ return True
365
+
366
+ def cast(self, pack):
367
+ # sanity check members
368
+ if self.members.keys() != pack._ArgPack__entries.keys():
369
+ raise GsTaichiSyntaxError("Incompatible arguments for custom argument pack members!")
370
+ entries = {}
371
+ for k, dtype in self.members.items():
372
+ if isinstance(dtype, MatrixType):
373
+ entries[k] = dtype(pack._ArgPack__entries[k])
374
+ elif isinstance(dtype, CompoundType):
375
+ entries[k] = dtype.cast(pack._ArgPack__entries[k])
376
+ elif isinstance(dtype, ArgPackType):
377
+ entries[k] = dtype.cast(pack._ArgPack__entries[k])
378
+ elif isinstance(dtype, ndarray_type.NdarrayType):
379
+ entries[k] = pack._ArgPack__entries[k]
380
+ elif isinstance(dtype, texture_type.RWTextureType):
381
+ entries[k] = pack._ArgPack__entries[k]
382
+ elif isinstance(dtype, texture_type.TextureType):
383
+ entries[k] = pack._ArgPack__entries[k]
384
+ elif isinstance(dtype, sparse_matrix_builder):
385
+ entries[k] = pack._ArgPack__entries[k]
386
+ else:
387
+ if in_python_scope():
388
+ v = pack._ArgPack__entries[k]
389
+ entries[k] = int(v) if dtype in primitive_types.integer_types else float(v)
390
+ else:
391
+ entries[k] = ops.cast(pack._ArgPack__entries[k], dtype)
392
+ pack = ArgPack(self.members, self.dtype, entries)
393
+ return pack
394
+
395
+ def from_gstaichi_object(self, arg_load_dict: dict):
396
+ d = {}
397
+ items = self.members.items()
398
+ for index, pair in enumerate(items):
399
+ name, dtype = pair
400
+ d[name] = arg_load_dict[name]
401
+ pack = _IntermediateArgPack(self.members, self.dtype, d)
402
+ pack._ArgPack__dtype = self.dtype
403
+ return pack
404
+
405
+ def __str__(self):
406
+ """Python scope argpack type print support."""
407
+ item_str = ", ".join([str(k) + "=" + str(v) for k, v in self.members.items()])
408
+ return f"<ti.ArgPackType {item_str}>"
409
+
410
+
411
+ __all__ = ["ArgPack"]
@@ -0,0 +1,5 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang.ast.ast_transformer_utils import ASTTransformerContext
4
+ from gstaichi.lang.ast.checkers import KernelSimplicityASTChecker
5
+ from gstaichi.lang.ast.transform import transform_tree