gstaichi 0.1.25.dev0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. gstaichi/CHANGELOG.md +9 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/__main__.py +5 -0
  4. gstaichi/_funcs.py +706 -0
  5. gstaichi/_kernels.py +420 -0
  6. gstaichi/_lib/__init__.py +3 -0
  7. gstaichi/_lib/core/__init__.py +0 -0
  8. gstaichi/_lib/core/gstaichi_python.cp311-win_amd64.pyd +0 -0
  9. gstaichi/_lib/core/gstaichi_python.pyi +2937 -0
  10. gstaichi/_lib/core/py.typed +0 -0
  11. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  12. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  13. gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  14. gstaichi/_lib/utils.py +249 -0
  15. gstaichi/_logging.py +131 -0
  16. gstaichi/_main.py +545 -0
  17. gstaichi/_snode/__init__.py +5 -0
  18. gstaichi/_snode/fields_builder.py +187 -0
  19. gstaichi/_snode/snode_tree.py +34 -0
  20. gstaichi/_test_tools/__init__.py +0 -0
  21. gstaichi/_test_tools/load_kernel_string.py +30 -0
  22. gstaichi/_version.py +1 -0
  23. gstaichi/_version_check.py +103 -0
  24. gstaichi/ad/__init__.py +3 -0
  25. gstaichi/ad/_ad.py +530 -0
  26. gstaichi/algorithms/__init__.py +3 -0
  27. gstaichi/algorithms/_algorithms.py +117 -0
  28. gstaichi/assets/.git +1 -0
  29. gstaichi/assets/Go-Regular.ttf +0 -0
  30. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_ndarray.py +352 -0
  35. gstaichi/lang/_ndrange.py +152 -0
  36. gstaichi/lang/_template_mapper.py +199 -0
  37. gstaichi/lang/_texture.py +172 -0
  38. gstaichi/lang/_wrap_inspect.py +189 -0
  39. gstaichi/lang/any_array.py +99 -0
  40. gstaichi/lang/argpack.py +411 -0
  41. gstaichi/lang/ast/__init__.py +5 -0
  42. gstaichi/lang/ast/ast_transformer.py +1318 -0
  43. gstaichi/lang/ast/ast_transformer_utils.py +341 -0
  44. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  45. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  46. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  47. gstaichi/lang/ast/checkers.py +106 -0
  48. gstaichi/lang/ast/symbol_resolver.py +57 -0
  49. gstaichi/lang/ast/transform.py +9 -0
  50. gstaichi/lang/common_ops.py +310 -0
  51. gstaichi/lang/exception.py +80 -0
  52. gstaichi/lang/expr.py +180 -0
  53. gstaichi/lang/field.py +466 -0
  54. gstaichi/lang/impl.py +1241 -0
  55. gstaichi/lang/kernel_arguments.py +157 -0
  56. gstaichi/lang/kernel_impl.py +1382 -0
  57. gstaichi/lang/matrix.py +1881 -0
  58. gstaichi/lang/matrix_ops.py +341 -0
  59. gstaichi/lang/matrix_ops_utils.py +190 -0
  60. gstaichi/lang/mesh.py +687 -0
  61. gstaichi/lang/misc.py +778 -0
  62. gstaichi/lang/ops.py +1494 -0
  63. gstaichi/lang/runtime_ops.py +13 -0
  64. gstaichi/lang/shell.py +35 -0
  65. gstaichi/lang/simt/__init__.py +5 -0
  66. gstaichi/lang/simt/block.py +94 -0
  67. gstaichi/lang/simt/grid.py +7 -0
  68. gstaichi/lang/simt/subgroup.py +191 -0
  69. gstaichi/lang/simt/warp.py +96 -0
  70. gstaichi/lang/snode.py +489 -0
  71. gstaichi/lang/source_builder.py +150 -0
  72. gstaichi/lang/struct.py +855 -0
  73. gstaichi/lang/util.py +381 -0
  74. gstaichi/linalg/__init__.py +8 -0
  75. gstaichi/linalg/matrixfree_cg.py +310 -0
  76. gstaichi/linalg/sparse_cg.py +59 -0
  77. gstaichi/linalg/sparse_matrix.py +303 -0
  78. gstaichi/linalg/sparse_solver.py +123 -0
  79. gstaichi/math/__init__.py +11 -0
  80. gstaichi/math/_complex.py +205 -0
  81. gstaichi/math/mathimpl.py +886 -0
  82. gstaichi/profiler/__init__.py +6 -0
  83. gstaichi/profiler/kernel_metrics.py +260 -0
  84. gstaichi/profiler/kernel_profiler.py +586 -0
  85. gstaichi/profiler/memory_profiler.py +15 -0
  86. gstaichi/profiler/scoped_profiler.py +36 -0
  87. gstaichi/sparse/__init__.py +3 -0
  88. gstaichi/sparse/_sparse_grid.py +77 -0
  89. gstaichi/tools/__init__.py +12 -0
  90. gstaichi/tools/diagnose.py +117 -0
  91. gstaichi/tools/np2ply.py +364 -0
  92. gstaichi/tools/vtk.py +38 -0
  93. gstaichi/types/__init__.py +19 -0
  94. gstaichi/types/annotations.py +47 -0
  95. gstaichi/types/compound_types.py +90 -0
  96. gstaichi/types/enums.py +49 -0
  97. gstaichi/types/ndarray_type.py +147 -0
  98. gstaichi/types/primitive_types.py +206 -0
  99. gstaichi/types/quant.py +88 -0
  100. gstaichi/types/texture_type.py +85 -0
  101. gstaichi/types/utils.py +13 -0
  102. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  103. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  104. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  105. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  106. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  107. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  108. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  109. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  110. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  111. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  112. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  113. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  114. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  115. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  116. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  117. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  118. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  119. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  120. gstaichi-0.1.25.dev0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  121. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
  122. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
  123. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  124. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
  125. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  126. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  127. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-link.lib +0 -0
  128. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  129. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  130. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  131. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  132. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools.lib +0 -0
  133. gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
  134. gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
  135. gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
  136. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  137. gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
  138. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1382 @@
1
+ import ast
2
+ import dataclasses
3
+ import functools
4
+ import inspect
5
+ import json
6
+ import operator
7
+ import os
8
+ import pathlib
9
+ import re
10
+ import sys
11
+ import textwrap
12
+ import time
13
+ import types
14
+ import typing
15
+ import warnings
16
+ from typing import Any, Callable, Type
17
+
18
+ import numpy as np
19
+
20
+ import gstaichi.lang
21
+ import gstaichi.lang._ndarray
22
+ import gstaichi.lang._texture
23
+ import gstaichi.types.annotations
24
+ from gstaichi import _logging
25
+ from gstaichi._lib import core as _ti_core
26
+ from gstaichi._lib.core.gstaichi_python import (
27
+ ASTBuilder,
28
+ FunctionKey,
29
+ KernelCxx,
30
+ KernelLaunchContext,
31
+ )
32
+ from gstaichi.lang import impl, ops, runtime_ops
33
+ from gstaichi.lang._template_mapper import GsTaichiCallableTemplateMapper
34
+ from gstaichi.lang._wrap_inspect import getsourcefile, getsourcelines
35
+ from gstaichi.lang.any_array import AnyArray
36
+ from gstaichi.lang.argpack import ArgPack, ArgPackType
37
+ from gstaichi.lang.ast import (
38
+ ASTTransformerContext,
39
+ KernelSimplicityASTChecker,
40
+ transform_tree,
41
+ )
42
+ from gstaichi.lang.ast.ast_transformer_utils import ReturnStatus
43
+ from gstaichi.lang.exception import (
44
+ GsTaichiCompilationError,
45
+ GsTaichiRuntimeError,
46
+ GsTaichiRuntimeTypeError,
47
+ GsTaichiSyntaxError,
48
+ GsTaichiTypeError,
49
+ handle_exception_from_cpp,
50
+ )
51
+ from gstaichi.lang.expr import Expr
52
+ from gstaichi.lang.kernel_arguments import KernelArgument
53
+ from gstaichi.lang.matrix import MatrixType
54
+ from gstaichi.lang.shell import _shell_pop_print
55
+ from gstaichi.lang.struct import StructType
56
+ from gstaichi.lang.util import cook_dtype, has_paddle, has_pytorch
57
+ from gstaichi.types import (
58
+ ndarray_type,
59
+ primitive_types,
60
+ sparse_matrix_builder,
61
+ template,
62
+ texture_type,
63
+ )
64
+ from gstaichi.types.compound_types import CompoundType
65
+ from gstaichi.types.enums import AutodiffMode, Layout
66
+ from gstaichi.types.utils import is_signed
67
+
68
+ CompiledKernelKeyType = tuple[Callable, int, AutodiffMode]
69
+
70
+
71
+ class GsTaichiCallable:
72
+ """
73
+ BoundGsTaichiCallable is used to enable wrapping a bindable function with a class.
74
+
75
+ Design requirements for GsTaichiCallable:
76
+ - wrap/contain a reference to a class Func instance, and allow (the GsTaichiCallable) being passed around
77
+ like normal function pointer
78
+ - expose attributes of the wrapped class Func, such as `_if_real_function`, `_primal`, etc
79
+ - allow for (now limited) strong typing, and enable type checkers, such as pyright/mypy
80
+ - currently GsTaichiCallable is a shared type used for all functions marked with @ti.func, @ti.kernel,
81
+ python functions (?)
82
+ - note: current type-checking implementation does not distinguish between different type flavors of
83
+ GsTaichiCallable, with different values of `_if_real_function`, `_primal`, etc
84
+ - handle not only class-less functions, but also class-instance methods (where determining the `self`
85
+ reference is a challenge)
86
+
87
+ Let's take the following example:
88
+
89
+ def test_ptr_class_func():
90
+ @ti.data_oriented
91
+ class MyClass:
92
+ def __init__(self):
93
+ self.a = ti.field(dtype=ti.f32, shape=(3))
94
+
95
+ def add2numbers_py(self, x, y):
96
+ return x + y
97
+
98
+ @ti.func
99
+ def add2numbers_func(self, x, y):
100
+ return x + y
101
+
102
+ @ti.kernel
103
+ def func(self):
104
+ a, add_py, add_func = ti.static(self.a, self.add2numbers_py, self.add2numbers_func)
105
+ a[0] = add_py(2, 3)
106
+ a[1] = add_func(3, 7)
107
+
108
+ (taken from test_ptr_assign.py).
109
+
110
+ When the @ti.func decorator is parsed, the function `add2numbers_func` exists, but there is not yet any `self`
111
+ - it is not possible for the method to be bound, to a `self` instance
112
+ - however, the @ti.func annotation, runs the kernel_imp.py::func function --- it is at this point
113
+ that GsTaichi's original code creates a class Func instance (that wraps the add2numbers_func)
114
+ and immediately we create a GsTaichiCallable instance that wraps the Func instance.
115
+ - effectively, we have two layers of wrapping GsTaichiCallable->Func->function pointer
116
+ (actual function definition)
117
+ - later on, when we call self.add2numbers_py, here:
118
+
119
+ a, add_py, add_func = ti.static(self.a, self.add2numbers_py, self.add2numbers_func)
120
+
121
+ ... we want to call the bound method, `self.add2numbers_py`.
122
+ - an actual python function reference, created by doing somevar = MyClass.add2numbers, can automatically
123
+ binds to self, when called from self in this way (however, add2numbers_py is actually a class
124
+ Func instance, wrapping python function reference -- now also all wrapped by a GsTaichiCallable
125
+ instance -- returned by the kernel_impl.py::func function, run by @ti.func)
126
+ - however, in order to be able to add strongly typed attributes to the wrapped python function, we need
127
+ to wrap the wrapped python function in a class
128
+ - the wrapped python function, wrapped in a GsTaichiCallable class (which is callable, and will
129
+ execute the underlying double-wrapped python function), will NOT automatically bind
130
+ - when we invoke GsTaichiCallable, the wrapped function is invoked. The wrapped function is unbound, and
131
+ so `self` is not automatically passed in, as an argument, and things break
132
+
133
+ To address this we need to use the `__get__` method, in our function wrapper, ie GsTaichiCallable,
134
+ and have the `__get__` method return the `BoundGsTaichiCallable` object. The `__get__` method handles
135
+ running the binding for us, and effectively binds `BoundFunc` object to `self` object, by passing
136
+ in the instance, as an argument into `BoundGsTaichiCallable.__init__`.
137
+
138
+ `BoundFunc` can then be used as a normal bound func - even though it's just an object instance -
139
+ using its `__call__` method. Effectively, at the time of actually invoking the underlying python
140
+ function, we have 3 layers of wrapper instances:
141
+ BoundGsTaichiCallabe -> GsTaichiCallable -> Func -> python function reference/definition
142
+ """
143
+
144
+ def __init__(self, fn: Callable, wrapper: Callable) -> None:
145
+ self.fn: Callable = fn
146
+ self.wrapper: Callable = wrapper
147
+ self._is_real_function: bool = False
148
+ self._is_gstaichi_function: bool = False
149
+ self._is_wrapped_kernel: bool = False
150
+ self._is_classkernel: bool = False
151
+ self._primal: Kernel | None = None
152
+ self._adjoint: Kernel | None = None
153
+ self.grad: Kernel | None = None
154
+ self._is_staticmethod: bool = False
155
+ functools.update_wrapper(self, fn)
156
+
157
+ def __call__(self, *args, **kwargs):
158
+ return self.wrapper.__call__(*args, **kwargs)
159
+
160
+ def __get__(self, instance, owner):
161
+ if instance is None:
162
+ return self
163
+ return BoundGsTaichiCallable(instance, self)
164
+
165
+
166
+ class BoundGsTaichiCallable:
167
+ def __init__(self, instance: Any, gstaichi_callable: "GsTaichiCallable"):
168
+ self.wrapper = gstaichi_callable.wrapper
169
+ self.instance = instance
170
+ self.gstaichi_callable = gstaichi_callable
171
+
172
+ def __call__(self, *args, **kwargs):
173
+ return self.wrapper(self.instance, *args, **kwargs)
174
+
175
+ def __getattr__(self, k: str) -> Any:
176
+ res = getattr(self.gstaichi_callable, k)
177
+ return res
178
+
179
+ def __setattr__(self, k: str, v: Any) -> None:
180
+ # Note: these have to match the name of any attributes on this class.
181
+ if k in ("wrapper", "instance", "gstaichi_callable"):
182
+ object.__setattr__(self, k, v)
183
+ else:
184
+ setattr(self.gstaichi_callable, k, v)
185
+
186
+
187
+ def func(fn: Callable, is_real_function: bool = False) -> GsTaichiCallable:
188
+ """Marks a function as callable in GsTaichi-scope.
189
+
190
+ This decorator transforms a Python function into a GsTaichi one. GsTaichi
191
+ will JIT compile it into native instructions.
192
+
193
+ Args:
194
+ fn (Callable): The Python function to be decorated
195
+ is_real_function (bool): Whether the function is a real function
196
+
197
+ Returns:
198
+ Callable: The decorated function
199
+
200
+ Example::
201
+
202
+ >>> @ti.func
203
+ >>> def foo(x):
204
+ >>> return x + 2
205
+ >>>
206
+ >>> @ti.kernel
207
+ >>> def run():
208
+ >>> print(foo(40)) # 42
209
+ """
210
+ is_classfunc = _inside_class(level_of_class_stackframe=3 + is_real_function)
211
+
212
+ fun = Func(fn, _classfunc=is_classfunc, is_real_function=is_real_function)
213
+ gstaichi_callable = GsTaichiCallable(fn, fun)
214
+ gstaichi_callable._is_gstaichi_function = True
215
+ gstaichi_callable._is_real_function = is_real_function
216
+ return gstaichi_callable
217
+
218
+
219
+ def real_func(fn: Callable) -> GsTaichiCallable:
220
+ return func(fn, is_real_function=True)
221
+
222
+
223
+ def pyfunc(fn: Callable) -> GsTaichiCallable:
224
+ """Marks a function as callable in both GsTaichi and Python scopes.
225
+
226
+ When called inside the GsTaichi scope, GsTaichi will JIT compile it into
227
+ native instructions. Otherwise it will be invoked directly as a
228
+ Python function.
229
+
230
+ See also :func:`~gstaichi.lang.kernel_impl.func`.
231
+
232
+ Args:
233
+ fn (Callable): The Python function to be decorated
234
+
235
+ Returns:
236
+ Callable: The decorated function
237
+ """
238
+ is_classfunc = _inside_class(level_of_class_stackframe=3)
239
+ fun = Func(fn, _classfunc=is_classfunc, _pyfunc=True)
240
+ gstaichi_callable = GsTaichiCallable(fn, fun)
241
+ gstaichi_callable._is_gstaichi_function = True
242
+ gstaichi_callable._is_real_function = False
243
+ return gstaichi_callable
244
+
245
+
246
+ def _get_tree_and_ctx(
247
+ self: "Func | Kernel",
248
+ args: tuple[Any, ...],
249
+ excluded_parameters=(),
250
+ is_kernel: bool = True,
251
+ arg_features=None,
252
+ ast_builder: ASTBuilder | None = None,
253
+ is_real_function: bool = False,
254
+ ) -> tuple[ast.Module, ASTTransformerContext]:
255
+ file = getsourcefile(self.func)
256
+ src, start_lineno = getsourcelines(self.func)
257
+ src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
258
+ tree = ast.parse(textwrap.dedent("\n".join(src)))
259
+
260
+ func_body = tree.body[0]
261
+ func_body.decorator_list = [] # type: ignore , kick that can down the road...
262
+
263
+ global_vars = _get_global_vars(self.func)
264
+
265
+ if is_kernel or is_real_function:
266
+ # inject template parameters into globals
267
+ for i in self.template_slot_locations:
268
+ template_var_name = self.arguments[i].name
269
+ global_vars[template_var_name] = args[i]
270
+ parameters = inspect.signature(self.func).parameters
271
+ for arg_i, (param_name, param) in enumerate(parameters.items()):
272
+ if dataclasses.is_dataclass(param.annotation):
273
+ for member_field in dataclasses.fields(param.annotation):
274
+ child_value = getattr(args[arg_i], member_field.name)
275
+ flat_name = f"__ti_{param_name}_{member_field.name}"
276
+ global_vars[flat_name] = child_value
277
+
278
+ return tree, ASTTransformerContext(
279
+ excluded_parameters=excluded_parameters,
280
+ is_kernel=is_kernel,
281
+ func=self,
282
+ arg_features=arg_features,
283
+ global_vars=global_vars,
284
+ argument_data=args,
285
+ src=src,
286
+ start_lineno=start_lineno,
287
+ file=file,
288
+ ast_builder=ast_builder,
289
+ is_real_function=is_real_function,
290
+ )
291
+
292
+
293
+ def expand_func_arguments(arguments: list[KernelArgument]) -> list[KernelArgument]:
294
+ new_arguments = []
295
+ for argument in arguments:
296
+ if dataclasses.is_dataclass(argument.annotation):
297
+ for field in dataclasses.fields(argument.annotation):
298
+ new_argument = KernelArgument(
299
+ _annotation=field.type,
300
+ _name=f"__ti_{argument.name}_{field.name}",
301
+ )
302
+ new_arguments.append(new_argument)
303
+ else:
304
+ new_arguments.append(argument)
305
+ return new_arguments
306
+
307
+
308
+ def _process_args(self: "Func | Kernel", is_func: bool, args: tuple[Any, ...], kwargs) -> tuple[Any, ...]:
309
+ if is_func:
310
+ self.arguments = expand_func_arguments(self.arguments)
311
+ fused_args = [argument.default for argument in self.arguments]
312
+ ret: list[Any] = [argument.default for argument in self.arguments]
313
+ len_args = len(args)
314
+
315
+ if len_args > len(fused_args):
316
+ arg_str = ", ".join([str(arg) for arg in args])
317
+ expected_str = ", ".join([f"{arg.name} : {arg.annotation}" for arg in self.arguments])
318
+ msg = f"Too many arguments. Expected ({expected_str}), got ({arg_str})."
319
+ raise GsTaichiSyntaxError(msg)
320
+
321
+ for i, arg in enumerate(args):
322
+ fused_args[i] = arg
323
+
324
+ for key, value in kwargs.items():
325
+ found = False
326
+ for i, arg in enumerate(self.arguments):
327
+ if key == arg.name:
328
+ if i < len_args:
329
+ raise GsTaichiSyntaxError(f"Multiple values for argument '{key}'.")
330
+ fused_args[i] = value
331
+ found = True
332
+ break
333
+ if not found:
334
+ raise GsTaichiSyntaxError(f"Unexpected argument '{key}'.")
335
+
336
+ for i, arg in enumerate(fused_args):
337
+ if arg is inspect.Parameter.empty:
338
+ if self.arguments[i].annotation is inspect._empty:
339
+ raise GsTaichiSyntaxError(f"Parameter `{self.arguments[i].name}` missing.")
340
+ else:
341
+ raise GsTaichiSyntaxError(
342
+ f"Parameter `{self.arguments[i].name} : {self.arguments[i].annotation}` missing."
343
+ )
344
+
345
+ return tuple(fused_args)
346
+
347
+
348
+ def unpack_ndarray_struct(tree: ast.Module, struct_locals: set[str]) -> ast.Module:
349
+ class AttributeToNameTransformer(ast.NodeTransformer):
350
+ def visit_Attribute(self, node: ast.Attribute):
351
+ if isinstance(node.value, ast.Attribute):
352
+ return node
353
+ if not isinstance(node.value, ast.Name):
354
+ return node
355
+ base_id = node.value.id
356
+ attr_name = node.attr
357
+ new_id = f"__ti_{base_id}_{attr_name}"
358
+ if new_id not in struct_locals:
359
+ return node
360
+ return ast.copy_location(ast.Name(id=new_id, ctx=node.ctx), node)
361
+
362
+ transformer = AttributeToNameTransformer()
363
+ new_tree = transformer.visit(tree)
364
+ ast.fix_missing_locations(new_tree)
365
+ return new_tree
366
+
367
+
368
+ def extract_struct_locals_from_context(ctx: ASTTransformerContext):
369
+ """
370
+ - Uses ctx.func.func to get the function signature.
371
+ - Searches this for any dataclasses:
372
+ - If it finds any dataclasses, then converts them into expanded names.
373
+ - E.g. my_struct: MyStruct, and MyStruct contains a, b, c would become:
374
+ {"__ti_my_struct_a", "__ti_my_struct_b, "__ti_my_struct_c"}
375
+ """
376
+ assert ctx.func is not None
377
+ sig = inspect.signature(ctx.func.func)
378
+ parameters = sig.parameters
379
+ struct_locals = set()
380
+ for param_name, parameter in parameters.items():
381
+ if dataclasses.is_dataclass(parameter.annotation):
382
+ for field in dataclasses.fields(parameter.annotation):
383
+ child_name = f"__ti_{param_name}_{field.name}"
384
+ struct_locals.add(child_name)
385
+ return struct_locals
386
+
387
+
388
+ class Func:
389
+ function_counter = 0
390
+
391
+ def __init__(self, _func: Callable, _classfunc=False, _pyfunc=False, is_real_function=False) -> None:
392
+ self.func = _func
393
+ self.func_id = Func.function_counter
394
+ Func.function_counter += 1
395
+ self.compiled = {}
396
+ self.classfunc = _classfunc
397
+ self.pyfunc = _pyfunc
398
+ self.is_real_function = is_real_function
399
+ self.arguments: list[KernelArgument] = []
400
+ self.orig_arguments: list[KernelArgument] = []
401
+ self.return_type: tuple[Type, ...] | None = None
402
+ self.extract_arguments()
403
+ self.template_slot_locations: list[int] = []
404
+ for i, arg in enumerate(self.arguments):
405
+ if arg.annotation == template or isinstance(arg.annotation, template):
406
+ self.template_slot_locations.append(i)
407
+ self.mapper = GsTaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
408
+ self.gstaichi_functions = {} # The |Function| class in C++
409
+ self.has_print = False
410
+
411
+ def __call__(self, *args, **kwargs) -> Any:
412
+ args = _process_args(self, is_func=True, args=args, kwargs=kwargs)
413
+
414
+ if not impl.inside_kernel():
415
+ if not self.pyfunc:
416
+ raise GsTaichiSyntaxError("GsTaichi functions cannot be called from Python-scope.")
417
+ return self.func(*args)
418
+
419
+ current_kernel = impl.get_runtime().current_kernel
420
+ if self.is_real_function:
421
+ if current_kernel.autodiff_mode != AutodiffMode.NONE:
422
+ raise GsTaichiSyntaxError("Real function in gradient kernels unsupported.")
423
+ instance_id, arg_features = self.mapper.lookup(args)
424
+ key = _ti_core.FunctionKey(self.func.__name__, self.func_id, instance_id)
425
+ if key.instance_id not in self.compiled:
426
+ self.do_compile(key=key, args=args, arg_features=arg_features)
427
+ return self.func_call_rvalue(key=key, args=args)
428
+ tree, ctx = _get_tree_and_ctx(
429
+ self,
430
+ is_kernel=False,
431
+ args=args,
432
+ ast_builder=current_kernel.ast_builder(),
433
+ is_real_function=self.is_real_function,
434
+ )
435
+
436
+ struct_locals = extract_struct_locals_from_context(ctx)
437
+ tree = unpack_ndarray_struct(tree, struct_locals=struct_locals)
438
+ ret = transform_tree(tree, ctx)
439
+ if not self.is_real_function:
440
+ if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
441
+ raise GsTaichiSyntaxError("Function has a return type but does not have a return statement")
442
+ return ret
443
+
444
+ def func_call_rvalue(self, key: FunctionKey, args: tuple[Any, ...]) -> Any:
445
+ # Skip the template args, e.g., |self|
446
+ assert self.is_real_function
447
+ non_template_args = []
448
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
449
+ for i, kernel_arg in enumerate(self.arguments):
450
+ anno = kernel_arg.annotation
451
+ if not isinstance(anno, template):
452
+ if id(anno) in primitive_types.type_ids:
453
+ non_template_args.append(ops.cast(args[i], anno))
454
+ elif isinstance(anno, primitive_types.RefType):
455
+ non_template_args.append(_ti_core.make_reference(args[i].ptr, dbg_info))
456
+ elif isinstance(anno, ndarray_type.NdarrayType):
457
+ if not isinstance(args[i], AnyArray):
458
+ raise GsTaichiTypeError(
459
+ f"Expected ndarray in the kernel argument for argument {kernel_arg.name}, got {args[i]}"
460
+ )
461
+ non_template_args += _ti_core.get_external_tensor_real_func_args(args[i].ptr, dbg_info)
462
+ else:
463
+ non_template_args.append(args[i])
464
+ non_template_args = impl.make_expr_group(non_template_args)
465
+ compiling_callable = impl.get_runtime().compiling_callable
466
+ assert compiling_callable is not None
467
+ func_call = compiling_callable.ast_builder().insert_func_call(
468
+ self.gstaichi_functions[key.instance_id], non_template_args, dbg_info
469
+ )
470
+ if self.return_type is None:
471
+ return None
472
+ func_call = Expr(func_call)
473
+ ret = []
474
+
475
+ for i, return_type in enumerate(self.return_type):
476
+ if id(return_type) in primitive_types.type_ids:
477
+ ret.append(
478
+ Expr(
479
+ _ti_core.make_get_element_expr(
480
+ func_call.ptr, (i,), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
481
+ )
482
+ )
483
+ )
484
+ elif isinstance(return_type, (StructType, MatrixType)):
485
+ ret.append(return_type.from_gstaichi_object(func_call, (i,)))
486
+ else:
487
+ raise GsTaichiTypeError(f"Unsupported return type for return value {i}: {return_type}")
488
+ if len(ret) == 1:
489
+ return ret[0]
490
+ return tuple(ret)
491
+
492
+ def do_compile(self, key: FunctionKey, args: tuple[Any, ...], arg_features: tuple[Any, ...]) -> None:
493
+ tree, ctx = _get_tree_and_ctx(
494
+ self, is_kernel=False, args=args, arg_features=arg_features, is_real_function=self.is_real_function
495
+ )
496
+ fn = impl.get_runtime().prog.create_function(key)
497
+
498
+ def func_body():
499
+ old_callable = impl.get_runtime().compiling_callable
500
+ impl.get_runtime().compiling_callable = fn
501
+ ctx.ast_builder = fn.ast_builder()
502
+ transform_tree(tree, ctx)
503
+ impl.get_runtime().compiling_callable = old_callable
504
+
505
+ self.gstaichi_functions[key.instance_id] = fn
506
+ self.compiled[key.instance_id] = func_body
507
+ self.gstaichi_functions[key.instance_id].set_function_body(func_body)
508
+
509
+ def extract_arguments(self) -> None:
510
+ sig = inspect.signature(self.func)
511
+ if sig.return_annotation not in (inspect.Signature.empty, None):
512
+ self.return_type = sig.return_annotation
513
+ if (
514
+ isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias)) # type: ignore
515
+ and self.return_type.__origin__ is tuple # type: ignore
516
+ ):
517
+ self.return_type = self.return_type.__args__ # type: ignore
518
+ if self.return_type is None:
519
+ return
520
+ if not isinstance(self.return_type, (list, tuple)):
521
+ self.return_type = (self.return_type,)
522
+ for i, return_type in enumerate(self.return_type):
523
+ if return_type is Ellipsis:
524
+ raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
525
+ params = sig.parameters
526
+ arg_names = params.keys()
527
+ for i, arg_name in enumerate(arg_names):
528
+ param = params[arg_name]
529
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
530
+ raise GsTaichiSyntaxError(
531
+ "GsTaichi functions do not support variable keyword parameters (i.e., **kwargs)"
532
+ )
533
+ if param.kind == inspect.Parameter.VAR_POSITIONAL:
534
+ raise GsTaichiSyntaxError(
535
+ "GsTaichi functions do not support variable positional parameters (i.e., *args)"
536
+ )
537
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
538
+ raise GsTaichiSyntaxError("GsTaichi functions do not support keyword parameters")
539
+ if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
540
+ raise GsTaichiSyntaxError('GsTaichi functions only support "positional or keyword" parameters')
541
+ annotation = param.annotation
542
+ if annotation is inspect.Parameter.empty:
543
+ if i == 0 and self.classfunc:
544
+ annotation = template()
545
+ # TODO: pyfunc also need type annotation check when real function is enabled,
546
+ # but that has to happen at runtime when we know which scope it's called from.
547
+ elif not self.pyfunc and self.is_real_function:
548
+ raise GsTaichiSyntaxError(
549
+ f"GsTaichi function `{self.func.__name__}` parameter `{arg_name}` must be type annotated"
550
+ )
551
+ else:
552
+ if isinstance(annotation, ndarray_type.NdarrayType):
553
+ pass
554
+ elif isinstance(annotation, MatrixType):
555
+ pass
556
+ elif isinstance(annotation, StructType):
557
+ pass
558
+ elif id(annotation) in primitive_types.type_ids:
559
+ pass
560
+ elif type(annotation) == gstaichi.types.annotations.Template:
561
+ pass
562
+ elif isinstance(annotation, template) or annotation == gstaichi.types.annotations.Template:
563
+ pass
564
+ elif isinstance(annotation, primitive_types.RefType):
565
+ pass
566
+ elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
567
+ pass
568
+ else:
569
+ raise GsTaichiSyntaxError(
570
+ f"Invalid type annotation (argument {i}) of GsTaichi function: {annotation}"
571
+ )
572
+ self.arguments.append(KernelArgument(annotation, param.name, param.default))
573
+ self.orig_arguments.append(KernelArgument(annotation, param.name, param.default))
574
+
575
+
576
+ def _get_global_vars(_func: Callable) -> dict[str, Any]:
577
+ # Discussions: https://github.com/taichi-dev/gstaichi/issues/282
578
+ global_vars = _func.__globals__.copy()
579
+
580
+ freevar_names = _func.__code__.co_freevars
581
+ closure = _func.__closure__
582
+ if closure:
583
+ freevar_values = list(map(lambda x: x.cell_contents, closure))
584
+ for name, value in zip(freevar_names, freevar_values):
585
+ global_vars[name] = value
586
+
587
+ return global_vars
588
+
589
+
590
+ class Kernel:
591
+ counter = 0
592
+
593
+ def __init__(self, _func: Callable, autodiff_mode: AutodiffMode, _classkernel=False) -> None:
594
+ self.func = _func
595
+ self.kernel_counter = Kernel.counter
596
+ Kernel.counter += 1
597
+ assert autodiff_mode in (
598
+ AutodiffMode.NONE,
599
+ AutodiffMode.VALIDATION,
600
+ AutodiffMode.FORWARD,
601
+ AutodiffMode.REVERSE,
602
+ )
603
+ self.autodiff_mode = autodiff_mode
604
+ self.grad: Kernel | None = None
605
+ self.arguments: list[KernelArgument] = []
606
+ self.return_type = None
607
+ self.classkernel = _classkernel
608
+ self.extract_arguments()
609
+ self.template_slot_locations = []
610
+ for i, arg in enumerate(self.arguments):
611
+ if arg.annotation == template or isinstance(arg.annotation, template):
612
+ self.template_slot_locations.append(i)
613
+ self.mapper = GsTaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
614
+ impl.get_runtime().kernels.append(self)
615
+ self.reset()
616
+ self.kernel_cpp = None
617
+ self.compiled_kernels: dict[CompiledKernelKeyType, KernelCxx] = {}
618
+ self.has_print = False
619
+
620
+ def ast_builder(self) -> ASTBuilder:
621
+ assert self.kernel_cpp is not None
622
+ return self.kernel_cpp.ast_builder()
623
+
624
+ def reset(self) -> None:
625
+ self.runtime = impl.get_runtime()
626
+ self.compiled_kernels = {}
627
+
628
+ def extract_arguments(self) -> None:
629
+ sig = inspect.signature(self.func)
630
+ if sig.return_annotation not in (inspect._empty, None):
631
+ self.return_type = sig.return_annotation
632
+ if (
633
+ isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias)) # type: ignore
634
+ and self.return_type.__origin__ is tuple
635
+ ):
636
+ self.return_type = self.return_type.__args__
637
+ if not isinstance(self.return_type, (list, tuple)):
638
+ self.return_type = (self.return_type,)
639
+ for return_type in self.return_type:
640
+ if return_type is Ellipsis:
641
+ raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
642
+ params = sig.parameters
643
+ arg_names = params.keys()
644
+ for i, arg_name in enumerate(arg_names):
645
+ param = params[arg_name]
646
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
647
+ raise GsTaichiSyntaxError(
648
+ "GsTaichi kernels do not support variable keyword parameters (i.e., **kwargs)"
649
+ )
650
+ if param.kind == inspect.Parameter.VAR_POSITIONAL:
651
+ raise GsTaichiSyntaxError(
652
+ "GsTaichi kernels do not support variable positional parameters (i.e., *args)"
653
+ )
654
+ if param.default is not inspect.Parameter.empty:
655
+ raise GsTaichiSyntaxError("GsTaichi kernels do not support default values for arguments")
656
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
657
+ raise GsTaichiSyntaxError("GsTaichi kernels do not support keyword parameters")
658
+ if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
659
+ raise GsTaichiSyntaxError('GsTaichi kernels only support "positional or keyword" parameters')
660
+ annotation = param.annotation
661
+ if param.annotation is inspect.Parameter.empty:
662
+ if i == 0 and self.classkernel: # The |self| parameter
663
+ annotation = template()
664
+ else:
665
+ raise GsTaichiSyntaxError("GsTaichi kernels parameters must be type annotated")
666
+ else:
667
+ if isinstance(
668
+ annotation,
669
+ (
670
+ template,
671
+ ndarray_type.NdarrayType,
672
+ texture_type.TextureType,
673
+ texture_type.RWTextureType,
674
+ ),
675
+ ):
676
+ pass
677
+ elif id(annotation) in primitive_types.type_ids:
678
+ pass
679
+ elif isinstance(annotation, sparse_matrix_builder):
680
+ pass
681
+ elif isinstance(annotation, MatrixType):
682
+ pass
683
+ elif isinstance(annotation, StructType):
684
+ pass
685
+ elif isinstance(annotation, ArgPackType):
686
+ pass
687
+ elif annotation == template:
688
+ pass
689
+ elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
690
+ pass
691
+ else:
692
+ raise GsTaichiSyntaxError(
693
+ f"Invalid type annotation (argument {i}) of GsTaichi kernel: {annotation}"
694
+ )
695
+ self.arguments.append(KernelArgument(annotation, param.name, param.default))
696
+
697
+ def materialize(self, key: CompiledKernelKeyType | None, args: tuple[Any, ...], arg_features):
698
+ if key is None:
699
+ key = (self.func, 0, self.autodiff_mode)
700
+ self.runtime.materialize()
701
+
702
+ if key in self.compiled_kernels:
703
+ return
704
+
705
+ kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}"
706
+ _logging.trace(f"Compiling kernel {kernel_name} in {self.autodiff_mode}...")
707
+
708
+ tree, ctx = _get_tree_and_ctx(
709
+ self,
710
+ args=args,
711
+ excluded_parameters=self.template_slot_locations,
712
+ arg_features=arg_features,
713
+ )
714
+
715
+ if self.autodiff_mode != AutodiffMode.NONE:
716
+ KernelSimplicityASTChecker(self.func).visit(tree)
717
+
718
+ # Do not change the name of 'gstaichi_ast_generator'
719
+ # The warning system needs this identifier to remove unnecessary messages
720
+ def gstaichi_ast_generator(kernel_cxx: Kernel): # not sure if this type is correct, seems doubtful
721
+ nonlocal tree
722
+ if self.runtime.inside_kernel:
723
+ raise GsTaichiSyntaxError(
724
+ "Kernels cannot call other kernels. I.e., nested kernels are not allowed. "
725
+ "Please check if you have direct/indirect invocation of kernels within kernels. "
726
+ "Note that some methods provided by the GsTaichi standard library may invoke kernels, "
727
+ "and please move their invocations to Python-scope."
728
+ )
729
+ self.kernel_cpp = kernel_cxx
730
+ self.runtime.inside_kernel = True
731
+ self.runtime._current_kernel = self
732
+ assert self.runtime.compiling_callable is None
733
+ self.runtime.compiling_callable = kernel_cxx
734
+ try:
735
+ ctx.ast_builder = kernel_cxx.ast_builder()
736
+
737
+ def ast_to_dict(node: ast.AST | list | primitive_types._python_primitive_types):
738
+ if isinstance(node, ast.AST):
739
+ fields = {k: ast_to_dict(v) for k, v in ast.iter_fields(node)}
740
+ return {
741
+ "type": node.__class__.__name__,
742
+ "fields": fields,
743
+ "lineno": getattr(node, "lineno", None),
744
+ "col_offset": getattr(node, "col_offset", None),
745
+ }
746
+ if isinstance(node, list):
747
+ return [ast_to_dict(x) for x in node]
748
+ return node # Basic types (str, int, None, etc.)
749
+
750
+ if os.environ.get("TI_DUMP_AST", "") == "1":
751
+ target_dir = pathlib.Path("/tmp/ast")
752
+ target_dir.mkdir(parents=True, exist_ok=True)
753
+
754
+ start = time.time()
755
+ ast_str = ast.dump(tree, indent=2)
756
+ output_file = target_dir / f"{kernel_name}_ast.txt"
757
+ output_file.write_text(ast_str)
758
+ elapsed_txt = time.time() - start
759
+
760
+ start = time.time()
761
+ json_str = json.dumps(ast_to_dict(tree), indent=2)
762
+ output_file = target_dir / f"{kernel_name}_ast.json"
763
+ output_file.write_text(json_str)
764
+ elapsed_json = time.time() - start
765
+
766
+ output_file = target_dir / f"{kernel_name}_gen_time.json"
767
+ output_file.write_text(
768
+ json.dumps({"elapsed_txt": elapsed_txt, "elapsed_json": elapsed_json}, indent=2)
769
+ )
770
+ struct_locals = extract_struct_locals_from_context(ctx)
771
+ tree = unpack_ndarray_struct(tree, struct_locals=struct_locals)
772
+ transform_tree(tree, ctx)
773
+ if not ctx.is_real_function:
774
+ if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
775
+ raise GsTaichiSyntaxError("Kernel has a return type but does not have a return statement")
776
+ finally:
777
+ self.runtime.inside_kernel = False
778
+ self.runtime._current_kernel = None
779
+ self.runtime.compiling_callable = None
780
+
781
+ gstaichi_kernel = impl.get_runtime().prog.create_kernel(gstaichi_ast_generator, kernel_name, self.autodiff_mode)
782
+ assert key not in self.compiled_kernels
783
+ self.compiled_kernels[key] = gstaichi_kernel
784
+
785
+ def launch_kernel(self, t_kernel: KernelCxx, *args) -> Any:
786
+ assert len(args) == len(self.arguments), f"{len(self.arguments)} arguments needed but {len(args)} provided"
787
+
788
+ tmps = []
789
+ callbacks = []
790
+
791
+ actual_argument_slot = 0
792
+ launch_ctx = t_kernel.make_launch_context()
793
+ max_arg_num = 512
794
+ exceed_max_arg_num = False
795
+
796
+ def set_arg_ndarray(indices: tuple[int, ...], v: gstaichi.lang._ndarray.Ndarray) -> None:
797
+ v_primal = v.arr
798
+ v_grad = v.grad.arr if v.grad else None
799
+ if v_grad is None:
800
+ launch_ctx.set_arg_ndarray(indices, v_primal) # type: ignore , solvable probably, just not today
801
+ else:
802
+ launch_ctx.set_arg_ndarray_with_grad(indices, v_primal, v_grad) # type: ignore
803
+
804
+ def set_arg_texture(indices: tuple[int, ...], v: gstaichi.lang._texture.Texture) -> None:
805
+ launch_ctx.set_arg_texture(indices, v.tex)
806
+
807
+ def set_arg_rw_texture(indices: tuple[int, ...], v: gstaichi.lang._texture.Texture) -> None:
808
+ launch_ctx.set_arg_rw_texture(indices, v.tex)
809
+
810
+ def set_arg_ext_array(indices: tuple[int, ...], v: Any, needed: ndarray_type.NdarrayType) -> None:
811
+ # v is things like torch Tensor and numpy array
812
+ # Not adding type for this, since adds additional dependencies
813
+ #
814
+ # Element shapes are already specialized in GsTaichi codegen.
815
+ # The shape information for element dims are no longer needed.
816
+ # Therefore we strip the element shapes from the shape vector,
817
+ # so that it only holds "real" array shapes.
818
+ is_soa = needed.layout == Layout.SOA
819
+ array_shape = v.shape
820
+ if functools.reduce(operator.mul, array_shape, 1) > np.iinfo(np.int32).max:
821
+ warnings.warn("Ndarray index might be out of int32 boundary but int64 indexing is not supported yet.")
822
+ if needed.dtype is None or id(needed.dtype) in primitive_types.type_ids:
823
+ element_dim = 0
824
+ else:
825
+ element_dim = needed.dtype.ndim
826
+ array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim]
827
+ if isinstance(v, np.ndarray):
828
+ # numpy
829
+ if v.flags.c_contiguous:
830
+ launch_ctx.set_arg_external_array_with_shape(indices, int(v.ctypes.data), v.nbytes, array_shape, 0)
831
+ elif v.flags.f_contiguous:
832
+ # TODO: A better way that avoids copying is saving strides info.
833
+ tmp = np.ascontiguousarray(v)
834
+ # Purpose: DO NOT GC |tmp|!
835
+ tmps.append(tmp)
836
+
837
+ def callback(original, updated):
838
+ np.copyto(original, np.asfortranarray(updated))
839
+
840
+ callbacks.append(functools.partial(callback, v, tmp))
841
+ launch_ctx.set_arg_external_array_with_shape(
842
+ indices, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0
843
+ )
844
+ else:
845
+ raise ValueError(
846
+ "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) "
847
+ "before passing it into gstaichi kernel."
848
+ )
849
+ elif has_pytorch():
850
+ import torch # pylint: disable=C0415
851
+
852
+ if isinstance(v, torch.Tensor):
853
+ if not v.is_contiguous():
854
+ raise ValueError(
855
+ "Non contiguous tensors are not supported, please call tensor.contiguous() before "
856
+ "passing it into gstaichi kernel."
857
+ )
858
+ gstaichi_arch = self.runtime.prog.config().arch
859
+
860
+ def get_call_back(u, v):
861
+ def call_back():
862
+ u.copy_(v)
863
+
864
+ return call_back
865
+
866
+ # FIXME: only allocate when launching grad kernel
867
+ if v.requires_grad and v.grad is None:
868
+ v.grad = torch.zeros_like(v)
869
+
870
+ if v.requires_grad:
871
+ if not isinstance(v.grad, torch.Tensor):
872
+ raise ValueError(
873
+ f"Expecting torch.Tensor for gradient tensor, but getting {v.grad.__class__.__name__} instead"
874
+ )
875
+ if not v.grad.is_contiguous():
876
+ raise ValueError(
877
+ "Non contiguous gradient tensors are not supported, please call tensor.grad.contiguous() before passing it into gstaichi kernel."
878
+ )
879
+
880
+ tmp = v
881
+ if (str(v.device) != "cpu") and not (
882
+ str(v.device).startswith("cuda") and gstaichi_arch == _ti_core.Arch.cuda
883
+ ):
884
+ # Getting a torch CUDA tensor on GsTaichi non-cuda arch:
885
+ # We just replace it with a CPU tensor and by the end of kernel execution we'll use the
886
+ # callback to copy the values back to the original CUDA tensor.
887
+ host_v = v.to(device="cpu", copy=True)
888
+ tmp = host_v
889
+ callbacks.append(get_call_back(v, host_v))
890
+
891
+ launch_ctx.set_arg_external_array_with_shape(
892
+ indices,
893
+ int(tmp.data_ptr()),
894
+ tmp.element_size() * tmp.nelement(),
895
+ array_shape,
896
+ int(v.grad.data_ptr()) if v.grad is not None else 0,
897
+ )
898
+ else:
899
+ raise GsTaichiRuntimeTypeError(
900
+ f"Argument {needed} cannot be converted into required type {type(v)}"
901
+ )
902
+ elif has_paddle():
903
+ # Do we want to continue to support paddle? :thinking_face:
904
+ # #maybeprunable
905
+ import paddle # pylint: disable=C0415 # type: ignore
906
+
907
+ if isinstance(v, paddle.Tensor):
908
+ # For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch
909
+ def get_call_back(u, v):
910
+ def call_back():
911
+ u.copy_(v, False)
912
+
913
+ return call_back
914
+
915
+ tmp = v.value().get_tensor()
916
+ gstaichi_arch = self.runtime.prog.config().arch
917
+ if v.place.is_gpu_place():
918
+ if gstaichi_arch != _ti_core.Arch.cuda:
919
+ # Paddle cuda tensor on GsTaichi non-cuda arch
920
+ host_v = v.cpu()
921
+ tmp = host_v.value().get_tensor()
922
+ callbacks.append(get_call_back(v, host_v))
923
+ elif v.place.is_cpu_place():
924
+ if gstaichi_arch == _ti_core.Arch.cuda:
925
+ # Paddle cpu tensor on GsTaichi cuda arch
926
+ gpu_v = v.cuda()
927
+ tmp = gpu_v.value().get_tensor()
928
+ callbacks.append(get_call_back(v, gpu_v))
929
+ else:
930
+ # Paddle do support many other backends like XPU, NPU, MLU, IPU
931
+ raise GsTaichiRuntimeTypeError(f"GsTaichi do not support backend {v.place} that Paddle support")
932
+ launch_ctx.set_arg_external_array_with_shape(
933
+ indices, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0
934
+ )
935
+ else:
936
+ raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
937
+ else:
938
+ raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
939
+
940
+ def set_arg_matrix(indices: tuple[int, ...], v, needed) -> None:
941
+ def cast_float(x: float | np.floating | np.integer | int) -> float:
942
+ if not isinstance(x, (int, float, np.integer, np.floating)):
943
+ raise GsTaichiRuntimeTypeError(
944
+ f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
945
+ )
946
+ return float(x)
947
+
948
+ def cast_int(x: int | np.integer) -> int:
949
+ if not isinstance(x, (int, np.integer)):
950
+ raise GsTaichiRuntimeTypeError(
951
+ f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
952
+ )
953
+ return int(x)
954
+
955
+ cast_func = None
956
+ if needed.dtype in primitive_types.real_types:
957
+ cast_func = cast_float
958
+ elif needed.dtype in primitive_types.integer_types:
959
+ cast_func = cast_int
960
+ else:
961
+ raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.")
962
+
963
+ if needed.ndim == 2:
964
+ v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)]
965
+ else:
966
+ v = [cast_func(v[i]) for i in range(needed.n)]
967
+ v = needed(*v)
968
+ needed.set_kernel_struct_args(v, launch_ctx, indices)
969
+
970
+ def set_arg_sparse_matrix_builder(indices: tuple[int, ...], v) -> None:
971
+ # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument
972
+ launch_ctx.set_arg_uint(indices, v._get_ndarray_addr())
973
+
974
+ set_later_list = []
975
+
976
+ def recursive_set_args(needed_arg_type: Type, provided_arg_type: Type, v: Any, indices: tuple[int, ...]) -> int:
977
+ """
978
+ Returns the number of kernel args set
979
+ e.g. templates don't set kernel args, so returns 0
980
+ a single ndarray is 1 kernel arg, so returns 1
981
+ a struct of 3 ndarrays would set 3 kernel args, so return 3
982
+ """
983
+ in_argpack = len(indices) > 1
984
+ nonlocal actual_argument_slot, exceed_max_arg_num, set_later_list
985
+ if actual_argument_slot >= max_arg_num:
986
+ exceed_max_arg_num = True
987
+ return 0
988
+ actual_argument_slot += 1
989
+ if isinstance(needed_arg_type, ArgPackType):
990
+ if not isinstance(v, ArgPack):
991
+ raise GsTaichiRuntimeTypeError.get(indices, str(needed_arg_type), str(provided_arg_type))
992
+ idx_new = 0
993
+ for j, (name, anno) in enumerate(needed_arg_type.members.items()):
994
+ idx_new += recursive_set_args(anno, type(v[name]), v[name], indices + (idx_new,))
995
+ launch_ctx.set_arg_argpack(indices, v._ArgPack__argpack) # type: ignore
996
+ return 1
997
+ # Note: do not use sth like "needed == f32". That would be slow.
998
+ if id(needed_arg_type) in primitive_types.real_type_ids:
999
+ if not isinstance(v, (float, int, np.floating, np.integer)):
1000
+ raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
1001
+ if in_argpack:
1002
+ return 1
1003
+ launch_ctx.set_arg_float(indices, float(v))
1004
+ return 1
1005
+ if id(needed_arg_type) in primitive_types.integer_type_ids:
1006
+ if not isinstance(v, (int, np.integer)):
1007
+ raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
1008
+ if in_argpack:
1009
+ return 1
1010
+ if is_signed(cook_dtype(needed_arg_type)):
1011
+ launch_ctx.set_arg_int(indices, int(v))
1012
+ else:
1013
+ launch_ctx.set_arg_uint(indices, int(v))
1014
+ return 1
1015
+ if isinstance(needed_arg_type, sparse_matrix_builder):
1016
+ if in_argpack:
1017
+ set_later_list.append((set_arg_sparse_matrix_builder, (v,)))
1018
+ return 0
1019
+ set_arg_sparse_matrix_builder(indices, v)
1020
+ return 1
1021
+ if dataclasses.is_dataclass(needed_arg_type):
1022
+ assert provided_arg_type == needed_arg_type
1023
+ idx = 0
1024
+ for j, field in enumerate(dataclasses.fields(needed_arg_type)):
1025
+ assert not isinstance(field.type, str)
1026
+ field_value = getattr(v, field.name)
1027
+ idx += recursive_set_args(field.type, field.type, field_value, (indices[0] + idx,))
1028
+ return idx
1029
+ if isinstance(needed_arg_type, ndarray_type.NdarrayType) and isinstance(v, gstaichi.lang._ndarray.Ndarray):
1030
+ if in_argpack:
1031
+ set_later_list.append((set_arg_ndarray, (v,)))
1032
+ return 0
1033
+ set_arg_ndarray(indices, v)
1034
+ return 1
1035
+ if isinstance(needed_arg_type, texture_type.TextureType) and isinstance(v, gstaichi.lang._texture.Texture):
1036
+ if in_argpack:
1037
+ set_later_list.append((set_arg_texture, (v,)))
1038
+ return 0
1039
+ set_arg_texture(indices, v)
1040
+ return 1
1041
+ if isinstance(needed_arg_type, texture_type.RWTextureType) and isinstance(
1042
+ v, gstaichi.lang._texture.Texture
1043
+ ):
1044
+ if in_argpack:
1045
+ set_later_list.append((set_arg_rw_texture, (v,)))
1046
+ return 0
1047
+ set_arg_rw_texture(indices, v)
1048
+ return 1
1049
+ if isinstance(needed_arg_type, ndarray_type.NdarrayType):
1050
+ if in_argpack:
1051
+ set_later_list.append((set_arg_ext_array, (v, needed_arg_type)))
1052
+ return 0
1053
+ set_arg_ext_array(indices, v, needed_arg_type)
1054
+ return 1
1055
+ if isinstance(needed_arg_type, MatrixType):
1056
+ if in_argpack:
1057
+ return 1
1058
+ set_arg_matrix(indices, v, needed_arg_type)
1059
+ return 1
1060
+ if isinstance(needed_arg_type, StructType):
1061
+ if in_argpack:
1062
+ return 1
1063
+ # Unclear how to make the following pass typing checks
1064
+ # StructType implements __instancecheck__, which should be a classmethod, but
1065
+ # is currently an instance method
1066
+ # TODO: look into this more deeply at some point
1067
+ if not isinstance(v, needed_arg_type): # type: ignore
1068
+ raise GsTaichiRuntimeTypeError(
1069
+ f"Argument {provided_arg_type} cannot be converted into required type {needed_arg_type}"
1070
+ )
1071
+ needed_arg_type.set_kernel_struct_args(v, launch_ctx, indices)
1072
+ return 1
1073
+ if needed_arg_type == template or isinstance(needed_arg_type, template):
1074
+ return 0
1075
+ raise ValueError(f"Argument type mismatch. Expecting {needed_arg_type}, got {type(v)}.")
1076
+
1077
+ template_num = 0
1078
+ i_out = 0
1079
+ for i_in, val in enumerate(args):
1080
+ needed_ = self.arguments[i_in].annotation
1081
+ if needed_ == template or isinstance(needed_, template):
1082
+ template_num += 1
1083
+ i_out += 1
1084
+ continue
1085
+ i_out += recursive_set_args(needed_, type(val), val, (i_out - template_num,))
1086
+
1087
+ for i, (set_arg_func, params) in enumerate(set_later_list):
1088
+ set_arg_func((len(args) - template_num + i,), *params)
1089
+
1090
+ if exceed_max_arg_num:
1091
+ raise GsTaichiRuntimeError(
1092
+ f"The number of elements in kernel arguments is too big! Do not exceed {max_arg_num} on {_ti_core.arch_name(impl.current_cfg().arch)} backend."
1093
+ )
1094
+
1095
+ try:
1096
+ prog = impl.get_runtime().prog
1097
+ # Compile kernel (& Online Cache & Offline Cache)
1098
+ compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel)
1099
+ # Launch kernel
1100
+ prog.launch_kernel(compiled_kernel_data, launch_ctx)
1101
+ except Exception as e:
1102
+ e = handle_exception_from_cpp(e)
1103
+ if impl.get_runtime().print_full_traceback:
1104
+ raise e
1105
+ raise e from None
1106
+
1107
+ ret = None
1108
+ ret_dt = self.return_type
1109
+ has_ret = ret_dt is not None
1110
+
1111
+ if has_ret or self.has_print:
1112
+ runtime_ops.sync()
1113
+
1114
+ if has_ret:
1115
+ ret = []
1116
+ for i, ret_type in enumerate(ret_dt):
1117
+ ret.append(self.construct_kernel_ret(launch_ctx, ret_type, (i,)))
1118
+ if len(ret_dt) == 1:
1119
+ ret = ret[0]
1120
+ if callbacks:
1121
+ for c in callbacks:
1122
+ c()
1123
+
1124
+ return ret
1125
+
1126
+ def construct_kernel_ret(self, launch_ctx: KernelLaunchContext, ret_type: Any, index: tuple[int, ...] = ()):
1127
+ if isinstance(ret_type, CompoundType):
1128
+ return ret_type.from_kernel_struct_ret(launch_ctx, index)
1129
+ if ret_type in primitive_types.integer_types:
1130
+ if is_signed(cook_dtype(ret_type)):
1131
+ return launch_ctx.get_struct_ret_int(index)
1132
+ return launch_ctx.get_struct_ret_uint(index)
1133
+ if ret_type in primitive_types.real_types:
1134
+ return launch_ctx.get_struct_ret_float(index)
1135
+ raise GsTaichiRuntimeTypeError(f"Invalid return type on index={index}")
1136
+
1137
+ def ensure_compiled(self, *args: tuple[Any, ...]) -> tuple[Callable, int, AutodiffMode]:
1138
+ instance_id, arg_features = self.mapper.lookup(args)
1139
+ key = (self.func, instance_id, self.autodiff_mode)
1140
+ self.materialize(key=key, args=args, arg_features=arg_features)
1141
+ return key
1142
+
1143
+ # For small kernels (< 3us), the performance can be pretty sensitive to overhead in __call__
1144
+ # Thus this part needs to be fast. (i.e. < 3us on a 4 GHz x64 CPU)
1145
+ @_shell_pop_print
1146
+ def __call__(self, *args, **kwargs) -> Any:
1147
+ args = _process_args(self, is_func=False, args=args, kwargs=kwargs)
1148
+
1149
+ # Transform the primal kernel to forward mode grad kernel
1150
+ # then recover to primal when exiting the forward mode manager
1151
+ if self.runtime.fwd_mode_manager and not self.runtime.grad_replaced:
1152
+ # TODO: if we would like to compute 2nd-order derivatives by forward-on-reverse in a nested context manager fashion,
1153
+ # i.e., a `Tape` nested in the `FwdMode`, we can transform the kernels with `mode_original == AutodiffMode.REVERSE` only,
1154
+ # to avoid duplicate computation for 1st-order derivatives
1155
+ self.runtime.fwd_mode_manager.insert(self)
1156
+
1157
+ # Both the class kernels and the plain-function kernels are unified now.
1158
+ # In both cases, |self.grad| is another Kernel instance that computes the
1159
+ # gradient. For class kernels, args[0] is always the kernel owner.
1160
+
1161
+ # No need to capture grad kernels because they are already bound with their primal kernels
1162
+ if (
1163
+ self.autodiff_mode in (AutodiffMode.NONE, AutodiffMode.VALIDATION)
1164
+ and self.runtime.target_tape
1165
+ and not self.runtime.grad_replaced
1166
+ ):
1167
+ self.runtime.target_tape.insert(self, args)
1168
+
1169
+ if self.autodiff_mode != AutodiffMode.NONE and impl.current_cfg().opt_level == 0:
1170
+ _logging.warn("""opt_level = 1 is enforced to enable gradient computation.""")
1171
+ impl.current_cfg().opt_level = 1
1172
+ key = self.ensure_compiled(*args)
1173
+ kernel_cpp = self.compiled_kernels[key]
1174
+ return self.launch_kernel(kernel_cpp, *args)
1175
+
1176
+
1177
+ # For a GsTaichi class definition like below:
1178
+ #
1179
+ # @ti.data_oriented
1180
+ # class X:
1181
+ # @ti.kernel
1182
+ # def foo(self):
1183
+ # ...
1184
+ #
1185
+ # When ti.kernel runs, the stackframe's |code_context| of Python 3.8(+) is
1186
+ # different from that of Python 3.7 and below. In 3.8+, it is 'class X:',
1187
+ # whereas in <=3.7, it is '@ti.data_oriented'. More interestingly, if the class
1188
+ # inherits, i.e. class X(object):, then in both versions, |code_context| is
1189
+ # 'class X(object):'...
1190
+ _KERNEL_CLASS_STACKFRAME_STMT_RES = [
1191
+ re.compile(r"@(\w+\.)?data_oriented"),
1192
+ re.compile(r"class "),
1193
+ ]
1194
+
1195
+
1196
+ def _inside_class(level_of_class_stackframe: int) -> bool:
1197
+ try:
1198
+ maybe_class_frame = sys._getframe(level_of_class_stackframe)
1199
+ statement_list = inspect.getframeinfo(maybe_class_frame)[3]
1200
+ if statement_list is None:
1201
+ return False
1202
+ first_statment = statement_list[0].strip()
1203
+ for pat in _KERNEL_CLASS_STACKFRAME_STMT_RES:
1204
+ if pat.match(first_statment):
1205
+ return True
1206
+ except:
1207
+ pass
1208
+ return False
1209
+
1210
+
1211
+ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool = False) -> GsTaichiCallable:
1212
+ # Can decorators determine if a function is being defined inside a class?
1213
+ # https://stackoverflow.com/a/8793684/12003165
1214
+ is_classkernel = _inside_class(level_of_class_stackframe + 1)
1215
+
1216
+ if verbose:
1217
+ print(f"kernel={_func.__name__} is_classkernel={is_classkernel}")
1218
+ primal = Kernel(_func, autodiff_mode=AutodiffMode.NONE, _classkernel=is_classkernel)
1219
+ adjoint = Kernel(_func, autodiff_mode=AutodiffMode.REVERSE, _classkernel=is_classkernel)
1220
+ # Having |primal| contains |grad| makes the tape work.
1221
+ primal.grad = adjoint
1222
+
1223
+ wrapped: GsTaichiCallable
1224
+ if is_classkernel:
1225
+ # For class kernels, their primal/adjoint callables are constructed
1226
+ # when the kernel is accessed via the instance inside
1227
+ # _BoundedDifferentiableMethod.
1228
+ # This is because we need to bind the kernel or |grad| to the instance
1229
+ # owning the kernel, which is not known until the kernel is accessed.
1230
+ #
1231
+ # See also: _BoundedDifferentiableMethod, data_oriented.
1232
+ @functools.wraps(_func)
1233
+ def wrapped_classkernel(*args, **kwargs):
1234
+ # If we reach here (we should never), it means the class is not decorated
1235
+ # with @ti.data_oriented, otherwise getattr would have intercepted the call.
1236
+ clsobj = type(args[0])
1237
+ assert not hasattr(clsobj, "_data_oriented")
1238
+ raise GsTaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
1239
+
1240
+ wrapped = GsTaichiCallable(_func, wrapped_classkernel)
1241
+ else:
1242
+
1243
+ @functools.wraps(_func)
1244
+ def wrapped_func(*args, **kwargs):
1245
+ try:
1246
+ return primal(*args, **kwargs)
1247
+ except (GsTaichiCompilationError, GsTaichiRuntimeError) as e:
1248
+ if impl.get_runtime().print_full_traceback:
1249
+ raise e
1250
+ raise type(e)("\n" + str(e)) from None
1251
+
1252
+ wrapped = GsTaichiCallable(_func, wrapped_func)
1253
+ wrapped.grad = adjoint
1254
+
1255
+ wrapped._is_wrapped_kernel = True
1256
+ wrapped._is_classkernel = is_classkernel
1257
+ wrapped._primal = primal
1258
+ wrapped._adjoint = adjoint
1259
+ return wrapped
1260
+
1261
+
1262
+ def kernel(fn: Callable):
1263
+ """Marks a function as a GsTaichi kernel.
1264
+
1265
+ A GsTaichi kernel is a function written in Python, and gets JIT compiled by
1266
+ GsTaichi into native CPU/GPU instructions (e.g. a series of CUDA kernels).
1267
+ The top-level ``for`` loops are automatically parallelized, and distributed
1268
+ to either a CPU thread pool or massively parallel GPUs.
1269
+
1270
+ Kernel's gradient kernel would be generated automatically by the AutoDiff system.
1271
+
1272
+ See also https://docs.taichi-lang.org/docs/syntax#kernel.
1273
+
1274
+ Args:
1275
+ fn (Callable): the Python function to be decorated
1276
+
1277
+ Returns:
1278
+ Callable: The decorated function
1279
+
1280
+ Example::
1281
+
1282
+ >>> x = ti.field(ti.i32, shape=(4, 8))
1283
+ >>>
1284
+ >>> @ti.kernel
1285
+ >>> def run():
1286
+ >>> # Assigns all the elements of `x` in parallel.
1287
+ >>> for i in x:
1288
+ >>> x[i] = i
1289
+ """
1290
+ return _kernel_impl(fn, level_of_class_stackframe=3)
1291
+
1292
+
1293
+ class _BoundedDifferentiableMethod:
1294
+ def __init__(self, kernel_owner: Any, wrapped_kernel_func: GsTaichiCallable | BoundGsTaichiCallable):
1295
+ clsobj = type(kernel_owner)
1296
+ if not getattr(clsobj, "_data_oriented", False):
1297
+ raise GsTaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
1298
+ self._kernel_owner = kernel_owner
1299
+ self._primal = wrapped_kernel_func._primal
1300
+ self._adjoint = wrapped_kernel_func._adjoint
1301
+ self._is_staticmethod = wrapped_kernel_func._is_staticmethod
1302
+ self.__name__: str | None = None
1303
+
1304
+ def __call__(self, *args, **kwargs):
1305
+ try:
1306
+ assert self._primal is not None
1307
+ if self._is_staticmethod:
1308
+ return self._primal(*args, **kwargs)
1309
+ return self._primal(self._kernel_owner, *args, **kwargs)
1310
+
1311
+ except (GsTaichiCompilationError, GsTaichiRuntimeError) as e:
1312
+ if impl.get_runtime().print_full_traceback:
1313
+ raise e
1314
+ raise type(e)("\n" + str(e)) from None
1315
+
1316
+ def grad(self, *args, **kwargs) -> Kernel:
1317
+ assert self._adjoint is not None
1318
+ return self._adjoint(self._kernel_owner, *args, **kwargs)
1319
+
1320
+
1321
+ def data_oriented(cls):
1322
+ """Marks a class as GsTaichi compatible.
1323
+
1324
+ To allow for modularized code, GsTaichi provides this decorator so that
1325
+ GsTaichi kernels can be defined inside a class.
1326
+
1327
+ See also https://docs.taichi-lang.org/docs/odop
1328
+
1329
+ Example::
1330
+
1331
+ >>> @ti.data_oriented
1332
+ >>> class TiArray:
1333
+ >>> def __init__(self, n):
1334
+ >>> self.x = ti.field(ti.f32, shape=n)
1335
+ >>>
1336
+ >>> @ti.kernel
1337
+ >>> def inc(self):
1338
+ >>> for i in self.x:
1339
+ >>> self.x[i] += 1.0
1340
+ >>>
1341
+ >>> a = TiArray(32)
1342
+ >>> a.inc()
1343
+
1344
+ Args:
1345
+ cls (Class): the class to be decorated
1346
+
1347
+ Returns:
1348
+ The decorated class.
1349
+ """
1350
+
1351
+ def _getattr(self, item):
1352
+ method = cls.__dict__.get(item, None)
1353
+ is_property = method.__class__ == property
1354
+ is_staticmethod = method.__class__ == staticmethod
1355
+ if is_property:
1356
+ x = method.fget
1357
+ else:
1358
+ x = super(cls, self).__getattribute__(item)
1359
+ if hasattr(x, "_is_wrapped_kernel"):
1360
+ if inspect.ismethod(x):
1361
+ wrapped = x.__func__
1362
+ else:
1363
+ wrapped = x
1364
+ assert isinstance(wrapped, (BoundGsTaichiCallable, GsTaichiCallable))
1365
+ wrapped._is_staticmethod = is_staticmethod
1366
+ if wrapped._is_classkernel:
1367
+ ret = _BoundedDifferentiableMethod(self, wrapped)
1368
+ ret.__name__ = wrapped.__name__ # type: ignore
1369
+ if is_property:
1370
+ return ret()
1371
+ return ret
1372
+ if is_property:
1373
+ return x(self)
1374
+ return x
1375
+
1376
+ cls.__getattribute__ = _getattr
1377
+ cls._data_oriented = True
1378
+
1379
+ return cls
1380
+
1381
+
1382
+ __all__ = ["data_oriented", "func", "kernel", "pyfunc", "real_func"]