gstaichi 0.1.18.dev1__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  2. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  3. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  4. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  5. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  6. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  7. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  8. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  9. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  10. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  11. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  12. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  13. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  14. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  15. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  16. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  17. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  18. gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  19. gstaichi-0.1.18.dev1.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  20. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
  21. gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
  22. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
  23. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
  24. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  25. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
  26. gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  27. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  28. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-link.lib +0 -0
  29. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  30. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  31. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  32. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  33. gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools.lib +0 -0
  34. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  35. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  36. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  37. gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  38. gstaichi-0.1.18.dev1.data/data/lib/glfw3.lib +0 -0
  39. gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
  40. gstaichi-0.1.18.dev1.dist-info/RECORD +198 -0
  41. gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
  42. gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
  43. gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
  44. gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
  45. taichi/CHANGELOG.md +15 -0
  46. taichi/__init__.py +44 -0
  47. taichi/__main__.py +5 -0
  48. taichi/_funcs.py +706 -0
  49. taichi/_kernels.py +420 -0
  50. taichi/_lib/__init__.py +3 -0
  51. taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
  52. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
  53. taichi/_lib/c_api/include/taichi/taichi.h +29 -0
  54. taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
  55. taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
  56. taichi/_lib/c_api/include/taichi/taichi_cuda.h +36 -0
  57. taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
  58. taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
  59. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
  60. taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
  61. taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
  62. taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
  63. taichi/_lib/c_api/runtime/slim_libdevice.10.bc +0 -0
  64. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
  65. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
  66. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
  67. taichi/_lib/core/__init__.py +0 -0
  68. taichi/_lib/core/py.typed +0 -0
  69. taichi/_lib/core/taichi_python.cp310-win_amd64.pyd +0 -0
  70. taichi/_lib/core/taichi_python.pyi +3077 -0
  71. taichi/_lib/runtime/runtime_cuda.bc +0 -0
  72. taichi/_lib/runtime/runtime_x64.bc +0 -0
  73. taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  74. taichi/_lib/utils.py +249 -0
  75. taichi/_logging.py +131 -0
  76. taichi/_main.py +552 -0
  77. taichi/_snode/__init__.py +5 -0
  78. taichi/_snode/fields_builder.py +189 -0
  79. taichi/_snode/snode_tree.py +34 -0
  80. taichi/_ti_module/__init__.py +3 -0
  81. taichi/_ti_module/cppgen.py +309 -0
  82. taichi/_ti_module/module.py +145 -0
  83. taichi/_version.py +1 -0
  84. taichi/_version_check.py +100 -0
  85. taichi/ad/__init__.py +3 -0
  86. taichi/ad/_ad.py +530 -0
  87. taichi/algorithms/__init__.py +3 -0
  88. taichi/algorithms/_algorithms.py +117 -0
  89. taichi/aot/__init__.py +12 -0
  90. taichi/aot/_export.py +28 -0
  91. taichi/aot/conventions/__init__.py +3 -0
  92. taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
  93. taichi/aot/conventions/gfxruntime140/dr.py +244 -0
  94. taichi/aot/conventions/gfxruntime140/sr.py +613 -0
  95. taichi/aot/module.py +253 -0
  96. taichi/aot/utils.py +151 -0
  97. taichi/assets/.git +1 -0
  98. taichi/assets/Go-Regular.ttf +0 -0
  99. taichi/assets/static/imgs/ti_gallery.png +0 -0
  100. taichi/examples/minimal.py +28 -0
  101. taichi/experimental.py +16 -0
  102. taichi/graph/__init__.py +3 -0
  103. taichi/graph/_graph.py +292 -0
  104. taichi/lang/__init__.py +50 -0
  105. taichi/lang/_ndarray.py +348 -0
  106. taichi/lang/_ndrange.py +152 -0
  107. taichi/lang/_texture.py +172 -0
  108. taichi/lang/_wrap_inspect.py +189 -0
  109. taichi/lang/any_array.py +99 -0
  110. taichi/lang/argpack.py +411 -0
  111. taichi/lang/ast/__init__.py +5 -0
  112. taichi/lang/ast/ast_transformer.py +1806 -0
  113. taichi/lang/ast/ast_transformer_utils.py +328 -0
  114. taichi/lang/ast/checkers.py +106 -0
  115. taichi/lang/ast/symbol_resolver.py +57 -0
  116. taichi/lang/ast/transform.py +9 -0
  117. taichi/lang/common_ops.py +310 -0
  118. taichi/lang/exception.py +80 -0
  119. taichi/lang/expr.py +180 -0
  120. taichi/lang/field.py +464 -0
  121. taichi/lang/impl.py +1246 -0
  122. taichi/lang/kernel_arguments.py +157 -0
  123. taichi/lang/kernel_impl.py +1415 -0
  124. taichi/lang/matrix.py +1877 -0
  125. taichi/lang/matrix_ops.py +341 -0
  126. taichi/lang/matrix_ops_utils.py +190 -0
  127. taichi/lang/mesh.py +687 -0
  128. taichi/lang/misc.py +807 -0
  129. taichi/lang/ops.py +1489 -0
  130. taichi/lang/runtime_ops.py +13 -0
  131. taichi/lang/shell.py +35 -0
  132. taichi/lang/simt/__init__.py +5 -0
  133. taichi/lang/simt/block.py +94 -0
  134. taichi/lang/simt/grid.py +7 -0
  135. taichi/lang/simt/subgroup.py +191 -0
  136. taichi/lang/simt/warp.py +96 -0
  137. taichi/lang/snode.py +487 -0
  138. taichi/lang/source_builder.py +150 -0
  139. taichi/lang/struct.py +855 -0
  140. taichi/lang/util.py +381 -0
  141. taichi/linalg/__init__.py +8 -0
  142. taichi/linalg/matrixfree_cg.py +310 -0
  143. taichi/linalg/sparse_cg.py +59 -0
  144. taichi/linalg/sparse_matrix.py +303 -0
  145. taichi/linalg/sparse_solver.py +123 -0
  146. taichi/math/__init__.py +11 -0
  147. taichi/math/_complex.py +204 -0
  148. taichi/math/mathimpl.py +886 -0
  149. taichi/profiler/__init__.py +6 -0
  150. taichi/profiler/kernel_metrics.py +260 -0
  151. taichi/profiler/kernel_profiler.py +592 -0
  152. taichi/profiler/memory_profiler.py +15 -0
  153. taichi/profiler/scoped_profiler.py +36 -0
  154. taichi/shaders/Circles_vk.frag +29 -0
  155. taichi/shaders/Circles_vk.vert +45 -0
  156. taichi/shaders/Circles_vk_frag.spv +0 -0
  157. taichi/shaders/Circles_vk_vert.spv +0 -0
  158. taichi/shaders/Lines_vk.frag +9 -0
  159. taichi/shaders/Lines_vk.vert +11 -0
  160. taichi/shaders/Lines_vk_frag.spv +0 -0
  161. taichi/shaders/Lines_vk_vert.spv +0 -0
  162. taichi/shaders/Mesh_vk.frag +71 -0
  163. taichi/shaders/Mesh_vk.vert +68 -0
  164. taichi/shaders/Mesh_vk_frag.spv +0 -0
  165. taichi/shaders/Mesh_vk_vert.spv +0 -0
  166. taichi/shaders/Particles_vk.frag +95 -0
  167. taichi/shaders/Particles_vk.vert +73 -0
  168. taichi/shaders/Particles_vk_frag.spv +0 -0
  169. taichi/shaders/Particles_vk_vert.spv +0 -0
  170. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  171. taichi/shaders/SceneLines_vk.frag +9 -0
  172. taichi/shaders/SceneLines_vk.vert +12 -0
  173. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  174. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  175. taichi/shaders/SetImage_vk.frag +21 -0
  176. taichi/shaders/SetImage_vk.vert +15 -0
  177. taichi/shaders/SetImage_vk_frag.spv +0 -0
  178. taichi/shaders/SetImage_vk_vert.spv +0 -0
  179. taichi/shaders/Triangles_vk.frag +16 -0
  180. taichi/shaders/Triangles_vk.vert +29 -0
  181. taichi/shaders/Triangles_vk_frag.spv +0 -0
  182. taichi/shaders/Triangles_vk_vert.spv +0 -0
  183. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  184. taichi/sparse/__init__.py +3 -0
  185. taichi/sparse/_sparse_grid.py +77 -0
  186. taichi/tools/__init__.py +12 -0
  187. taichi/tools/diagnose.py +124 -0
  188. taichi/tools/np2ply.py +364 -0
  189. taichi/tools/vtk.py +38 -0
  190. taichi/types/__init__.py +19 -0
  191. taichi/types/annotations.py +47 -0
  192. taichi/types/compound_types.py +90 -0
  193. taichi/types/enums.py +49 -0
  194. taichi/types/ndarray_type.py +147 -0
  195. taichi/types/primitive_types.py +203 -0
  196. taichi/types/quant.py +88 -0
  197. taichi/types/texture_type.py +85 -0
  198. taichi/types/utils.py +13 -0
@@ -0,0 +1,1415 @@
1
+ # type: ignore
2
+
3
+ import ast
4
+ import dataclasses
5
+ import functools
6
+ import inspect
7
+ import json
8
+ import operator
9
+ import os
10
+ import pathlib
11
+ import re
12
+ import sys
13
+ import textwrap
14
+ import time
15
+ import types
16
+ import typing
17
+ import warnings
18
+ import weakref
19
+ from typing import Any, Callable, Type, Union
20
+
21
+ import numpy as np
22
+
23
+ import taichi.lang
24
+ import taichi.lang._ndarray
25
+ import taichi.lang._texture
26
+ import taichi.lang.expr
27
+ import taichi.lang.snode
28
+ import taichi.types.annotations
29
+ from taichi import _logging
30
+ from taichi._lib import core as _ti_core
31
+ from taichi._lib.core.taichi_python import ASTBuilder
32
+ from taichi.lang import impl, ops, runtime_ops
33
+ from taichi.lang._wrap_inspect import getsourcefile, getsourcelines
34
+ from taichi.lang.any_array import AnyArray
35
+ from taichi.lang.argpack import ArgPack, ArgPackType
36
+ from taichi.lang.ast import (
37
+ ASTTransformerContext,
38
+ KernelSimplicityASTChecker,
39
+ transform_tree,
40
+ )
41
+ from taichi.lang.ast.ast_transformer_utils import ReturnStatus
42
+ from taichi.lang.exception import (
43
+ TaichiCompilationError,
44
+ TaichiRuntimeError,
45
+ TaichiRuntimeTypeError,
46
+ TaichiSyntaxError,
47
+ TaichiTypeError,
48
+ handle_exception_from_cpp,
49
+ )
50
+ from taichi.lang.expr import Expr
51
+ from taichi.lang.kernel_arguments import KernelArgument
52
+ from taichi.lang.matrix import MatrixType
53
+ from taichi.lang.shell import _shell_pop_print
54
+ from taichi.lang.struct import StructType
55
+ from taichi.lang.util import cook_dtype, has_paddle, has_pytorch, to_taichi_type
56
+ from taichi.types import (
57
+ ndarray_type,
58
+ primitive_types,
59
+ sparse_matrix_builder,
60
+ template,
61
+ texture_type,
62
+ )
63
+ from taichi.types.compound_types import CompoundType
64
+ from taichi.types.enums import AutodiffMode, Layout
65
+ from taichi.types.utils import is_signed
66
+
67
+
68
+ def func(fn: Callable, is_real_function: bool = False):
69
+ """Marks a function as callable in Taichi-scope.
70
+
71
+ This decorator transforms a Python function into a Taichi one. Taichi
72
+ will JIT compile it into native instructions.
73
+
74
+ Args:
75
+ fn (Callable): The Python function to be decorated
76
+ is_real_function (bool): Whether the function is a real function
77
+
78
+ Returns:
79
+ Callable: The decorated function
80
+
81
+ Example::
82
+
83
+ >>> @ti.func
84
+ >>> def foo(x):
85
+ >>> return x + 2
86
+ >>>
87
+ >>> @ti.kernel
88
+ >>> def run():
89
+ >>> print(foo(40)) # 42
90
+ """
91
+ is_classfunc = _inside_class(level_of_class_stackframe=3 + is_real_function)
92
+
93
+ fun = Func(fn, _classfunc=is_classfunc, is_real_function=is_real_function)
94
+
95
+ @functools.wraps(fn)
96
+ def decorated(*args, **kwargs):
97
+ return fun.__call__(*args, **kwargs)
98
+
99
+ decorated._is_taichi_function = True
100
+ decorated._is_real_function = is_real_function
101
+ decorated.func = fun
102
+ return decorated
103
+
104
+
105
+ def real_func(fn: Callable):
106
+ return func(fn, is_real_function=True)
107
+
108
+
109
+ def pyfunc(fn: Callable):
110
+ """Marks a function as callable in both Taichi and Python scopes.
111
+
112
+ When called inside the Taichi scope, Taichi will JIT compile it into
113
+ native instructions. Otherwise it will be invoked directly as a
114
+ Python function.
115
+
116
+ See also :func:`~taichi.lang.kernel_impl.func`.
117
+
118
+ Args:
119
+ fn (Callable): The Python function to be decorated
120
+
121
+ Returns:
122
+ Callable: The decorated function
123
+ """
124
+ is_classfunc = _inside_class(level_of_class_stackframe=3)
125
+ fun = Func(fn, _classfunc=is_classfunc, _pyfunc=True)
126
+
127
+ @functools.wraps(fn)
128
+ def decorated(*args, **kwargs):
129
+ return fun.__call__(*args, **kwargs)
130
+
131
+ decorated._is_taichi_function = True
132
+ decorated._is_real_function = False
133
+ decorated.func = fun
134
+ return decorated
135
+
136
+
137
+ def _get_tree_and_ctx(
138
+ self: "Func | Kernel",
139
+ excluded_parameters=(),
140
+ is_kernel: bool = True,
141
+ arg_features=None,
142
+ args=None,
143
+ ast_builder: ASTBuilder | None = None,
144
+ is_real_function: bool = False,
145
+ ):
146
+ file = getsourcefile(self.func)
147
+ src, start_lineno = getsourcelines(self.func)
148
+ src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
149
+ tree = ast.parse(textwrap.dedent("\n".join(src)))
150
+
151
+ func_body = tree.body[0]
152
+ func_body.decorator_list = []
153
+
154
+ global_vars = _get_global_vars(self.func)
155
+
156
+ if is_kernel or is_real_function:
157
+ # inject template parameters into globals
158
+ for i in self.template_slot_locations:
159
+ template_var_name = self.arguments[i].name
160
+ global_vars[template_var_name] = args[i]
161
+ parameters = inspect.signature(self.func).parameters
162
+ for arg_i, (param_name, param) in enumerate(parameters.items()):
163
+ if dataclasses.is_dataclass(param.annotation):
164
+ for member_field in dataclasses.fields(param.annotation):
165
+ child_value = getattr(args[arg_i], member_field.name)
166
+ flat_name = f"__ti_{param_name}_{member_field.name}"
167
+ global_vars[flat_name] = child_value
168
+
169
+ return tree, ASTTransformerContext(
170
+ excluded_parameters=excluded_parameters,
171
+ is_kernel=is_kernel,
172
+ func=self,
173
+ arg_features=arg_features,
174
+ global_vars=global_vars,
175
+ argument_data=args,
176
+ src=src,
177
+ start_lineno=start_lineno,
178
+ file=file,
179
+ ast_builder=ast_builder,
180
+ is_real_function=is_real_function,
181
+ )
182
+
183
+
184
+ def expand_func_arguments(arguments: list[KernelArgument]) -> list[KernelArgument]:
185
+ new_arguments = []
186
+ for argument in arguments:
187
+ if dataclasses.is_dataclass(argument.annotation):
188
+ for field in dataclasses.fields(argument.annotation):
189
+ new_argument = KernelArgument(
190
+ _annotation=field.type,
191
+ _name=f"__ti_{argument.name}_{field.name}",
192
+ )
193
+ new_arguments.append(new_argument)
194
+ else:
195
+ new_arguments.append(argument)
196
+ return new_arguments
197
+
198
+
199
+ def _process_args(self: "Func | Kernel", is_func: bool, args, kwargs):
200
+ if is_func:
201
+ self.arguments = expand_func_arguments(self.arguments)
202
+ fused_args = [argument.default for argument in self.arguments]
203
+ len_args = len(args)
204
+
205
+ if len_args > len(fused_args):
206
+ arg_str = ", ".join([str(arg) for arg in args])
207
+ expected_str = ", ".join([f"{arg.name} : {arg.annotation}" for arg in self.arguments])
208
+ msg = f"Too many arguments. Expected ({expected_str}), got ({arg_str})."
209
+ raise TaichiSyntaxError(msg)
210
+
211
+ for i, arg in enumerate(args):
212
+ fused_args[i] = arg
213
+
214
+ for key, value in kwargs.items():
215
+ found = False
216
+ for i, arg in enumerate(self.arguments):
217
+ if key == arg.name:
218
+ if i < len_args:
219
+ raise TaichiSyntaxError(f"Multiple values for argument '{key}'.")
220
+ fused_args[i] = value
221
+ found = True
222
+ break
223
+ if not found:
224
+ raise TaichiSyntaxError(f"Unexpected argument '{key}'.")
225
+
226
+ for i, arg in enumerate(fused_args):
227
+ if arg is inspect.Parameter.empty:
228
+ if self.arguments[i].annotation is inspect._empty:
229
+ raise TaichiSyntaxError(f"Parameter `{self.arguments[i].name}` missing.")
230
+ else:
231
+ raise TaichiSyntaxError(
232
+ f"Parameter `{self.arguments[i].name} : {self.arguments[i].annotation}` missing."
233
+ )
234
+
235
+ return tuple(fused_args)
236
+
237
+
238
+ def unpack_ndarray_struct(tree: ast.Module, struct_locals: set[str]) -> ast.Module:
239
+ class AttributeToNameTransformer(ast.NodeTransformer):
240
+ def visit_Attribute(self, node: ast.AST):
241
+ if isinstance(node.value, ast.Attribute):
242
+ return node
243
+ if not isinstance(node.value, ast.Name):
244
+ return node
245
+ base_id = node.value.id
246
+ attr_name = node.attr
247
+ new_id = f"__ti_{base_id}_{attr_name}"
248
+ if new_id not in struct_locals:
249
+ return node
250
+ return ast.copy_location(ast.Name(id=new_id, ctx=node.ctx), node)
251
+
252
+ transformer = AttributeToNameTransformer()
253
+ new_tree = transformer.visit(tree)
254
+ ast.fix_missing_locations(new_tree)
255
+ return new_tree
256
+
257
+
258
+ def extract_struct_locals_from_context(ctx: ASTTransformerContext):
259
+ """
260
+ - Uses ctx.func.func to get the function signature.
261
+ - Searches this for any dataclasses:
262
+ - If it finds any dataclasses, then converts them into expanded names.
263
+ - E.g. my_struct: MyStruct, and MyStruct contains a, b, c would become:
264
+ {"__ti_my_struct_a", "__ti_my_struct_b, "__ti_my_struct_c"}
265
+ """
266
+ assert ctx.func is not None
267
+ sig = inspect.signature(ctx.func.func)
268
+ parameters = sig.parameters
269
+ struct_locals = set()
270
+ for param_name, parameter in parameters.items():
271
+ if dataclasses.is_dataclass(parameter.annotation):
272
+ for field in dataclasses.fields(parameter.annotation):
273
+ child_name = f"__ti_{param_name}_{field.name}"
274
+ struct_locals.add(child_name)
275
+ return struct_locals
276
+
277
+
278
+ class Func:
279
+ function_counter = 0
280
+
281
+ def __init__(self, _func: Callable, _classfunc=False, _pyfunc=False, is_real_function=False):
282
+ self.func = _func
283
+ self.func_id = Func.function_counter
284
+ Func.function_counter += 1
285
+ self.compiled = {}
286
+ self.classfunc = _classfunc
287
+ self.pyfunc = _pyfunc
288
+ self.is_real_function = is_real_function
289
+ self.arguments: list[KernelArgument] = []
290
+ self.orig_arguments: list[KernelArgument] = []
291
+ self.return_type: tuple[Type, ...] | None = None
292
+ self.extract_arguments()
293
+ self.template_slot_locations: list[int] = []
294
+ for i, arg in enumerate(self.arguments):
295
+ if arg.annotation == template or isinstance(arg.annotation, template):
296
+ self.template_slot_locations.append(i)
297
+ self.mapper = TaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
298
+ self.taichi_functions = {} # The |Function| class in C++
299
+ self.has_print = False
300
+
301
+ def __call__(self, *args, **kwargs):
302
+ args = _process_args(self, is_func=True, args=args, kwargs=kwargs)
303
+
304
+ if not impl.inside_kernel():
305
+ if not self.pyfunc:
306
+ raise TaichiSyntaxError("Taichi functions cannot be called from Python-scope.")
307
+ return self.func(*args)
308
+
309
+ current_kernel = impl.get_runtime().current_kernel
310
+ if self.is_real_function:
311
+ if current_kernel.autodiff_mode != AutodiffMode.NONE:
312
+ raise TaichiSyntaxError("Real function in gradient kernels unsupported.")
313
+ instance_id, arg_features = self.mapper.lookup(args)
314
+ key = _ti_core.FunctionKey(self.func.__name__, self.func_id, instance_id)
315
+ if key.instance_id not in self.compiled:
316
+ self.do_compile(key=key, args=args, arg_features=arg_features)
317
+ return self.func_call_rvalue(key=key, args=args)
318
+ tree, ctx = _get_tree_and_ctx(
319
+ self,
320
+ is_kernel=False,
321
+ args=args,
322
+ ast_builder=current_kernel.ast_builder(),
323
+ is_real_function=self.is_real_function,
324
+ )
325
+
326
+ struct_locals = extract_struct_locals_from_context(ctx)
327
+ tree = unpack_ndarray_struct(tree, struct_locals=struct_locals)
328
+ ret = transform_tree(tree, ctx)
329
+ if not self.is_real_function:
330
+ if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
331
+ raise TaichiSyntaxError("Function has a return type but does not have a return statement")
332
+ return ret
333
+
334
+ def func_call_rvalue(self, key, args):
335
+ # Skip the template args, e.g., |self|
336
+ assert self.is_real_function
337
+ non_template_args = []
338
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
339
+ for i, kernel_arg in enumerate(self.arguments):
340
+ anno = kernel_arg.annotation
341
+ if not isinstance(anno, template):
342
+ if id(anno) in primitive_types.type_ids:
343
+ non_template_args.append(ops.cast(args[i], anno))
344
+ elif isinstance(anno, primitive_types.RefType):
345
+ non_template_args.append(_ti_core.make_reference(args[i].ptr, dbg_info))
346
+ elif isinstance(anno, ndarray_type.NdarrayType):
347
+ if not isinstance(args[i], AnyArray):
348
+ raise TaichiTypeError(
349
+ f"Expected ndarray in the kernel argument for argument {kernel_arg.name}, got {args[i]}"
350
+ )
351
+ non_template_args += _ti_core.get_external_tensor_real_func_args(args[i].ptr, dbg_info)
352
+ else:
353
+ non_template_args.append(args[i])
354
+ non_template_args = impl.make_expr_group(non_template_args)
355
+ compiling_callable = impl.get_runtime().compiling_callable
356
+ assert compiling_callable is not None
357
+ func_call = compiling_callable.ast_builder().insert_func_call(
358
+ self.taichi_functions[key.instance_id], non_template_args, dbg_info
359
+ )
360
+ if self.return_type is None:
361
+ return None
362
+ func_call = Expr(func_call)
363
+ ret = []
364
+
365
+ for i, return_type in enumerate(self.return_type):
366
+ if id(return_type) in primitive_types.type_ids:
367
+ ret.append(
368
+ Expr(
369
+ _ti_core.make_get_element_expr(
370
+ func_call.ptr, (i,), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
371
+ )
372
+ )
373
+ )
374
+ elif isinstance(return_type, (StructType, MatrixType)):
375
+ ret.append(return_type.from_taichi_object(func_call, (i,)))
376
+ else:
377
+ raise TaichiTypeError(f"Unsupported return type for return value {i}: {return_type}")
378
+ if len(ret) == 1:
379
+ return ret[0]
380
+ return tuple(ret)
381
+
382
+ def do_compile(self, key, args, arg_features):
383
+ tree, ctx = _get_tree_and_ctx(
384
+ self, is_kernel=False, args=args, arg_features=arg_features, is_real_function=self.is_real_function
385
+ )
386
+ fn = impl.get_runtime().prog.create_function(key)
387
+
388
+ def func_body():
389
+ old_callable = impl.get_runtime().compiling_callable
390
+ impl.get_runtime().compiling_callable = fn
391
+ ctx.ast_builder = fn.ast_builder()
392
+ transform_tree(tree, ctx)
393
+ impl.get_runtime().compiling_callable = old_callable
394
+
395
+ self.taichi_functions[key.instance_id] = fn
396
+ self.compiled[key.instance_id] = func_body
397
+ self.taichi_functions[key.instance_id].set_function_body(func_body)
398
+
399
+ def extract_arguments(self) -> None:
400
+ sig = inspect.signature(self.func)
401
+ if sig.return_annotation not in (inspect.Signature.empty, None):
402
+ self.return_type = sig.return_annotation
403
+ if (
404
+ isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias))
405
+ and self.return_type.__origin__ is tuple
406
+ ):
407
+ self.return_type = self.return_type.__args__
408
+ if not isinstance(self.return_type, (list, tuple)):
409
+ self.return_type = (self.return_type,)
410
+ for i, return_type in enumerate(self.return_type):
411
+ if return_type is Ellipsis:
412
+ raise TaichiSyntaxError("Ellipsis is not supported in return type annotations")
413
+ params = sig.parameters
414
+ arg_names = params.keys()
415
+ for i, arg_name in enumerate(arg_names):
416
+ param = params[arg_name]
417
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
418
+ raise TaichiSyntaxError("Taichi functions do not support variable keyword parameters (i.e., **kwargs)")
419
+ if param.kind == inspect.Parameter.VAR_POSITIONAL:
420
+ raise TaichiSyntaxError("Taichi functions do not support variable positional parameters (i.e., *args)")
421
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
422
+ raise TaichiSyntaxError("Taichi functions do not support keyword parameters")
423
+ if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
424
+ raise TaichiSyntaxError('Taichi functions only support "positional or keyword" parameters')
425
+ annotation = param.annotation
426
+ if annotation is inspect.Parameter.empty:
427
+ if i == 0 and self.classfunc:
428
+ annotation = template()
429
+ # TODO: pyfunc also need type annotation check when real function is enabled,
430
+ # but that has to happen at runtime when we know which scope it's called from.
431
+ elif not self.pyfunc and self.is_real_function:
432
+ raise TaichiSyntaxError(
433
+ f"Taichi function `{self.func.__name__}` parameter `{arg_name}` must be type annotated"
434
+ )
435
+ else:
436
+ if isinstance(annotation, ndarray_type.NdarrayType):
437
+ pass
438
+ elif isinstance(annotation, MatrixType):
439
+ pass
440
+ elif isinstance(annotation, StructType):
441
+ pass
442
+ elif id(annotation) in primitive_types.type_ids:
443
+ pass
444
+ elif type(annotation) == taichi.types.annotations.Template:
445
+ pass
446
+ elif isinstance(annotation, template) or annotation == taichi.types.annotations.Template:
447
+ pass
448
+ elif isinstance(annotation, primitive_types.RefType):
449
+ pass
450
+ elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
451
+ pass
452
+ else:
453
+ raise TaichiSyntaxError(f"Invalid type annotation (argument {i}) of Taichi function: {annotation}")
454
+ self.arguments.append(KernelArgument(annotation, param.name, param.default))
455
+ self.orig_arguments.append(KernelArgument(annotation, param.name, param.default))
456
+
457
+
458
+ AnnotationType = Union[
459
+ template,
460
+ ArgPackType,
461
+ "texture_type.TextureType",
462
+ "texture_type.RWTextureType",
463
+ ndarray_type.NdarrayType,
464
+ sparse_matrix_builder,
465
+ Any,
466
+ ]
467
+
468
+
469
+ class TaichiCallableTemplateMapper:
470
+ """
471
+ This should probably be renamed to sometihng like FeatureMapper, or
472
+ FeatureExtractor, since:
473
+ - it's not specific to templates
474
+ - it extracts what are later called 'features', for example for ndarray this includes:
475
+ - element type
476
+ - number dimensions
477
+ - needs grad (or not)
478
+ - these are returned as a heterogeneous tuple, whose contents depends on the type
479
+ """
480
+
481
+ def __init__(self, arguments: list[KernelArgument], template_slot_locations: list[int]) -> None:
482
+ self.arguments = arguments
483
+ self.num_args = len(arguments)
484
+ self.template_slot_locations = template_slot_locations
485
+ self.mapping = {}
486
+
487
+ @staticmethod
488
+ def extract_arg(arg, annotation: AnnotationType, arg_name: str):
489
+ if annotation == template or isinstance(annotation, template):
490
+ if isinstance(arg, taichi.lang.snode.SNode):
491
+ return arg.ptr
492
+ if isinstance(arg, taichi.lang.expr.Expr):
493
+ return arg.ptr.get_underlying_ptr_address()
494
+ if isinstance(arg, _ti_core.Expr):
495
+ return arg.get_underlying_ptr_address()
496
+ if isinstance(arg, tuple):
497
+ return tuple(TaichiCallableTemplateMapper.extract_arg(item, annotation, arg_name) for item in arg)
498
+ if isinstance(arg, taichi.lang._ndarray.Ndarray):
499
+ raise TaichiRuntimeTypeError(
500
+ "Ndarray shouldn't be passed in via `ti.template()`, please annotate your kernel using `ti.types.ndarray(...)` instead"
501
+ )
502
+
503
+ if isinstance(arg, (list, tuple, dict, set)) or hasattr(arg, "_data_oriented"):
504
+ # [Composite arguments] Return weak reference to the object
505
+ # Taichi kernel will cache the extracted arguments, thus we can't simply return the original argument.
506
+ # Instead, a weak reference to the original value is returned to avoid memory leak.
507
+
508
+ # TODO(zhanlue): replacing "tuple(args)" with "hash of argument values"
509
+ # This can resolve the following issues:
510
+ # 1. Invalid weak-ref will leave a dead(dangling) entry in both caches: "self.mapping" and "self.compiled_functions"
511
+ # 2. Different argument instances with same type and same value, will get templatized into seperate kernels.
512
+ return weakref.ref(arg)
513
+
514
+ # [Primitive arguments] Return the value
515
+ return arg
516
+ if isinstance(annotation, ArgPackType):
517
+ if not isinstance(arg, ArgPack):
518
+ raise TaichiRuntimeTypeError(f"Argument {arg_name} must be a argument pack, got {type(arg)}")
519
+ return tuple(
520
+ TaichiCallableTemplateMapper.extract_arg(arg[name], dtype, arg_name)
521
+ for index, (name, dtype) in enumerate(annotation.members.items())
522
+ )
523
+ if dataclasses.is_dataclass(annotation):
524
+ _res_l = []
525
+ for field in dataclasses.fields(annotation):
526
+ field_value = getattr(arg, field.name)
527
+ arg_name = f"__ti_{arg_name}_{field.name}"
528
+ field_extracted = TaichiCallableTemplateMapper.extract_arg(field_value, field.type, arg_name)
529
+ _res_l.append(field_extracted)
530
+ return tuple(_res_l)
531
+ if isinstance(annotation, texture_type.TextureType):
532
+ if not isinstance(arg, taichi.lang._texture.Texture):
533
+ raise TaichiRuntimeTypeError(f"Argument {arg_name} must be a texture, got {type(arg)}")
534
+ if arg.num_dims != annotation.num_dimensions:
535
+ raise TaichiRuntimeTypeError(
536
+ f"TextureType dimension mismatch for argument {arg_name}: expected {annotation.num_dimensions}, got {arg.num_dims}"
537
+ )
538
+ return (arg.num_dims,)
539
+ if isinstance(annotation, texture_type.RWTextureType):
540
+ if not isinstance(arg, taichi.lang._texture.Texture):
541
+ raise TaichiRuntimeTypeError(f"Argument {arg_name} must be a texture, got {type(arg)}")
542
+ if arg.num_dims != annotation.num_dimensions:
543
+ raise TaichiRuntimeTypeError(
544
+ f"RWTextureType dimension mismatch for argument {arg_name}: expected {annotation.num_dimensions}, got {arg.num_dims}"
545
+ )
546
+ if arg.fmt != annotation.fmt:
547
+ raise TaichiRuntimeTypeError(
548
+ f"RWTextureType format mismatch for argument {arg_name}: expected {annotation.fmt}, got {arg.fmt}"
549
+ )
550
+ # (penguinliong) '0' is the assumed LOD level. We currently don't
551
+ # support mip-mapping.
552
+ return arg.num_dims, arg.fmt, 0
553
+ if isinstance(annotation, ndarray_type.NdarrayType):
554
+ if isinstance(arg, taichi.lang._ndarray.Ndarray):
555
+ annotation.check_matched(arg.get_type(), arg_name)
556
+ needs_grad = (arg.grad is not None) if annotation.needs_grad is None else annotation.needs_grad
557
+ assert arg.shape is not None
558
+ return arg.element_type, len(arg.shape), needs_grad, annotation.boundary
559
+ if isinstance(arg, AnyArray):
560
+ ty = arg.get_type()
561
+ annotation.check_matched(arg.get_type(), arg_name)
562
+ return ty.element_type, len(arg.shape), ty.needs_grad, annotation.boundary
563
+ # external arrays
564
+ shape = getattr(arg, "shape", None)
565
+ if shape is None:
566
+ raise TaichiRuntimeTypeError(f"Invalid type for argument {arg_name}, got {arg}")
567
+ shape = tuple(shape)
568
+ element_shape: tuple[int, ...] = ()
569
+ dtype = to_taichi_type(arg.dtype)
570
+ if isinstance(annotation.dtype, MatrixType):
571
+ if annotation.ndim is not None:
572
+ if len(shape) != annotation.dtype.ndim + annotation.ndim:
573
+ raise ValueError(
574
+ f"Invalid value for argument {arg_name} - required array has ndim={annotation.ndim} element_dim={annotation.dtype.ndim}, "
575
+ f"array with {len(shape)} dimensions is provided"
576
+ )
577
+ else:
578
+ if len(shape) < annotation.dtype.ndim:
579
+ raise ValueError(
580
+ f"Invalid value for argument {arg_name} - required element_dim={annotation.dtype.ndim}, "
581
+ f"array with {len(shape)} dimensions is provided"
582
+ )
583
+ element_shape = shape[-annotation.dtype.ndim :]
584
+ anno_element_shape = annotation.dtype.get_shape()
585
+ if None not in anno_element_shape and element_shape != anno_element_shape:
586
+ raise ValueError(
587
+ f"Invalid value for argument {arg_name} - required element_shape={anno_element_shape}, "
588
+ f"array with element shape of {element_shape} is provided"
589
+ )
590
+ elif annotation.dtype is not None:
591
+ # User specified scalar dtype
592
+ if annotation.dtype != dtype:
593
+ raise ValueError(
594
+ f"Invalid value for argument {arg_name} - required array has dtype={annotation.dtype.to_string()}, "
595
+ f"array with dtype={dtype.to_string()} is provided"
596
+ )
597
+
598
+ if annotation.ndim is not None and len(shape) != annotation.ndim:
599
+ raise ValueError(
600
+ f"Invalid value for argument {arg_name} - required array has ndim={annotation.ndim}, "
601
+ f"array with {len(shape)} dimensions is provided"
602
+ )
603
+ needs_grad = (
604
+ getattr(arg, "requires_grad", False) if annotation.needs_grad is None else annotation.needs_grad
605
+ )
606
+ element_type = (
607
+ _ti_core.get_type_factory_instance().get_tensor_type(element_shape, dtype)
608
+ if len(element_shape) != 0
609
+ else arg.dtype
610
+ )
611
+ return element_type, len(shape) - len(element_shape), needs_grad, annotation.boundary
612
+ if isinstance(annotation, sparse_matrix_builder):
613
+ return arg.dtype
614
+ # Use '#' as a placeholder because other kinds of arguments are not involved in template instantiation
615
+ return "#"
616
+
617
+ def extract(self, args):
618
+ extracted = []
619
+ for arg, kernel_arg in zip(args, self.arguments):
620
+ extracted.append(self.extract_arg(arg, kernel_arg.annotation, kernel_arg.name))
621
+ return tuple(extracted)
622
+
623
+ def lookup(self, args):
624
+ if len(args) != self.num_args:
625
+ raise TypeError(f"{self.num_args} argument(s) needed but {len(args)} provided.")
626
+
627
+ key = self.extract(args)
628
+ if key not in self.mapping:
629
+ count = len(self.mapping)
630
+ self.mapping[key] = count
631
+ return self.mapping[key], key
632
+
633
+
634
+ def _get_global_vars(_func):
635
+ # Discussions: https://github.com/taichi-dev/taichi/issues/282
636
+ global_vars = _func.__globals__.copy()
637
+
638
+ freevar_names = _func.__code__.co_freevars
639
+ closure = _func.__closure__
640
+ if closure:
641
+ freevar_values = list(map(lambda x: x.cell_contents, closure))
642
+ for name, value in zip(freevar_names, freevar_values):
643
+ global_vars[name] = value
644
+
645
+ return global_vars
646
+
647
+
648
+ class Kernel:
649
+ counter = 0
650
+
651
+ def __init__(self, _func: Callable, autodiff_mode, _classkernel=False):
652
+ self.func = _func
653
+ self.kernel_counter = Kernel.counter
654
+ Kernel.counter += 1
655
+ assert autodiff_mode in (
656
+ AutodiffMode.NONE,
657
+ AutodiffMode.VALIDATION,
658
+ AutodiffMode.FORWARD,
659
+ AutodiffMode.REVERSE,
660
+ )
661
+ self.autodiff_mode = autodiff_mode
662
+ self.grad: Kernel | None = None
663
+ self.arguments: list[KernelArgument] = []
664
+ self.return_type = None
665
+ self.classkernel = _classkernel
666
+ self.extract_arguments()
667
+ self.template_slot_locations = []
668
+ for i, arg in enumerate(self.arguments):
669
+ if arg.annotation == template or isinstance(arg.annotation, template):
670
+ self.template_slot_locations.append(i)
671
+ self.mapper = TaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
672
+ impl.get_runtime().kernels.append(self)
673
+ self.reset()
674
+ self.kernel_cpp = None
675
+ self.compiled_kernels = {}
676
+ self.has_print = False
677
+
678
+ def ast_builder(self) -> ASTBuilder:
679
+ assert self.kernel_cpp is not None
680
+ return self.kernel_cpp.ast_builder()
681
+
682
+ def reset(self):
683
+ self.runtime = impl.get_runtime()
684
+ self.compiled_kernels = {}
685
+
686
+ def extract_arguments(self):
687
+ sig = inspect.signature(self.func)
688
+ if sig.return_annotation not in (inspect._empty, None):
689
+ self.return_type = sig.return_annotation
690
+ if (
691
+ isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias))
692
+ and self.return_type.__origin__ is tuple
693
+ ):
694
+ self.return_type = self.return_type.__args__
695
+ if not isinstance(self.return_type, (list, tuple)):
696
+ self.return_type = (self.return_type,)
697
+ for return_type in self.return_type:
698
+ if return_type is Ellipsis:
699
+ raise TaichiSyntaxError("Ellipsis is not supported in return type annotations")
700
+ params = sig.parameters
701
+ arg_names = params.keys()
702
+ for i, arg_name in enumerate(arg_names):
703
+ param = params[arg_name]
704
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
705
+ raise TaichiSyntaxError("Taichi kernels do not support variable keyword parameters (i.e., **kwargs)")
706
+ if param.kind == inspect.Parameter.VAR_POSITIONAL:
707
+ raise TaichiSyntaxError("Taichi kernels do not support variable positional parameters (i.e., *args)")
708
+ if param.default is not inspect.Parameter.empty:
709
+ raise TaichiSyntaxError("Taichi kernels do not support default values for arguments")
710
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
711
+ raise TaichiSyntaxError("Taichi kernels do not support keyword parameters")
712
+ if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
713
+ raise TaichiSyntaxError('Taichi kernels only support "positional or keyword" parameters')
714
+ annotation = param.annotation
715
+ if param.annotation is inspect.Parameter.empty:
716
+ if i == 0 and self.classkernel: # The |self| parameter
717
+ annotation = template()
718
+ else:
719
+ raise TaichiSyntaxError("Taichi kernels parameters must be type annotated")
720
+ else:
721
+ if isinstance(
722
+ annotation,
723
+ (
724
+ template,
725
+ ndarray_type.NdarrayType,
726
+ texture_type.TextureType,
727
+ texture_type.RWTextureType,
728
+ ),
729
+ ):
730
+ pass
731
+ elif id(annotation) in primitive_types.type_ids:
732
+ pass
733
+ elif isinstance(annotation, sparse_matrix_builder):
734
+ pass
735
+ elif isinstance(annotation, MatrixType):
736
+ pass
737
+ elif isinstance(annotation, StructType):
738
+ pass
739
+ elif isinstance(annotation, ArgPackType):
740
+ pass
741
+ elif annotation == template:
742
+ pass
743
+ elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
744
+ pass
745
+ else:
746
+ raise TaichiSyntaxError(f"Invalid type annotation (argument {i}) of Taichi kernel: {annotation}")
747
+ self.arguments.append(KernelArgument(annotation, param.name, param.default))
748
+
749
+ def materialize(self, key, args: list[Any], arg_features):
750
+ if key is None:
751
+ key = (self.func, 0, self.autodiff_mode)
752
+ self.runtime.materialize()
753
+
754
+ if key in self.compiled_kernels:
755
+ return
756
+
757
+ kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}"
758
+ _logging.trace(f"Compiling kernel {kernel_name} in {self.autodiff_mode}...")
759
+
760
+ tree, ctx = _get_tree_and_ctx(
761
+ self,
762
+ args=args,
763
+ excluded_parameters=self.template_slot_locations,
764
+ arg_features=arg_features,
765
+ )
766
+
767
+ if self.autodiff_mode != AutodiffMode.NONE:
768
+ KernelSimplicityASTChecker(self.func).visit(tree)
769
+
770
+ # Do not change the name of 'taichi_ast_generator'
771
+ # The warning system needs this identifier to remove unnecessary messages
772
+ def taichi_ast_generator(kernel_cxx: Kernel): # not sure if this type is correct, seems doubtful
773
+ nonlocal tree
774
+ if self.runtime.inside_kernel:
775
+ raise TaichiSyntaxError(
776
+ "Kernels cannot call other kernels. I.e., nested kernels are not allowed. "
777
+ "Please check if you have direct/indirect invocation of kernels within kernels. "
778
+ "Note that some methods provided by the Taichi standard library may invoke kernels, "
779
+ "and please move their invocations to Python-scope."
780
+ )
781
+ self.kernel_cpp = kernel_cxx
782
+ self.runtime.inside_kernel = True
783
+ self.runtime._current_kernel = self
784
+ assert self.runtime.compiling_callable is None
785
+ self.runtime.compiling_callable = kernel_cxx
786
+ try:
787
+ ctx.ast_builder = kernel_cxx.ast_builder()
788
+
789
+ def ast_to_dict(node):
790
+ if isinstance(node, ast.AST):
791
+ fields = {k: ast_to_dict(v) for k, v in ast.iter_fields(node)}
792
+ return {
793
+ "type": node.__class__.__name__,
794
+ "fields": fields,
795
+ "lineno": getattr(node, "lineno", None),
796
+ "col_offset": getattr(node, "col_offset", None),
797
+ }
798
+ if isinstance(node, list):
799
+ return [ast_to_dict(x) for x in node]
800
+ return node # Basic types (str, int, None, etc.)
801
+
802
+ if os.environ.get("TI_DUMP_AST", "") == "1":
803
+ target_dir = pathlib.Path("/tmp/ast")
804
+ target_dir.mkdir(parents=True, exist_ok=True)
805
+
806
+ start = time.time()
807
+ ast_str = ast.dump(tree, indent=2)
808
+ output_file = target_dir / f"{kernel_name}_ast.txt"
809
+ output_file.write_text(ast_str)
810
+ elapsed_txt = time.time() - start
811
+
812
+ start = time.time()
813
+ json_str = json.dumps(ast_to_dict(tree), indent=2)
814
+ output_file = target_dir / f"{kernel_name}_ast.json"
815
+ output_file.write_text(json_str)
816
+ elapsed_json = time.time() - start
817
+
818
+ output_file = target_dir / f"{kernel_name}_gen_time.json"
819
+ output_file.write_text(
820
+ json.dumps({"elapsed_txt": elapsed_txt, "elapsed_json": elapsed_json}, indent=2)
821
+ )
822
+ struct_locals = extract_struct_locals_from_context(ctx)
823
+ tree = unpack_ndarray_struct(tree, struct_locals=struct_locals)
824
+ transform_tree(tree, ctx)
825
+ if not ctx.is_real_function:
826
+ if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
827
+ raise TaichiSyntaxError("Kernel has a return type but does not have a return statement")
828
+ finally:
829
+ self.runtime.inside_kernel = False
830
+ self.runtime._current_kernel = None
831
+ self.runtime.compiling_callable = None
832
+
833
+ taichi_kernel = impl.get_runtime().prog.create_kernel(taichi_ast_generator, kernel_name, self.autodiff_mode)
834
+ assert key not in self.compiled_kernels
835
+ self.compiled_kernels[key] = taichi_kernel
836
+
837
+ def launch_kernel(self, t_kernel, *args):
838
+ assert len(args) == len(self.arguments), f"{len(self.arguments)} arguments needed but {len(args)} provided"
839
+
840
+ tmps = []
841
+ callbacks = []
842
+
843
+ actual_argument_slot = 0
844
+ launch_ctx = t_kernel.make_launch_context()
845
+ max_arg_num = 64
846
+ exceed_max_arg_num = False
847
+
848
+ def set_arg_ndarray(indices, v):
849
+ v_primal = v.arr
850
+ v_grad = v.grad.arr if v.grad else None
851
+ if v_grad is None:
852
+ launch_ctx.set_arg_ndarray(indices, v_primal)
853
+ else:
854
+ launch_ctx.set_arg_ndarray_with_grad(indices, v_primal, v_grad)
855
+
856
+ def set_arg_texture(indices, v):
857
+ launch_ctx.set_arg_texture(indices, v.tex)
858
+
859
+ def set_arg_rw_texture(indices, v):
860
+ launch_ctx.set_arg_rw_texture(indices, v.tex)
861
+
862
+ def set_arg_ext_array(indices, v, needed):
863
+ # Element shapes are already specialized in Taichi codegen.
864
+ # The shape information for element dims are no longer needed.
865
+ # Therefore we strip the element shapes from the shape vector,
866
+ # so that it only holds "real" array shapes.
867
+ is_soa = needed.layout == Layout.SOA
868
+ array_shape = v.shape
869
+ if functools.reduce(operator.mul, array_shape, 1) > np.iinfo(np.int32).max:
870
+ warnings.warn("Ndarray index might be out of int32 boundary but int64 indexing is not supported yet.")
871
+ if needed.dtype is None or id(needed.dtype) in primitive_types.type_ids:
872
+ element_dim = 0
873
+ else:
874
+ element_dim = needed.dtype.ndim
875
+ array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim]
876
+ if isinstance(v, np.ndarray):
877
+ # numpy
878
+ if v.flags.c_contiguous:
879
+ launch_ctx.set_arg_external_array_with_shape(indices, int(v.ctypes.data), v.nbytes, array_shape, 0)
880
+ elif v.flags.f_contiguous:
881
+ # TODO: A better way that avoids copying is saving strides info.
882
+ tmp = np.ascontiguousarray(v)
883
+ # Purpose: DO NOT GC |tmp|!
884
+ tmps.append(tmp)
885
+
886
+ def callback(original, updated):
887
+ np.copyto(original, np.asfortranarray(updated))
888
+
889
+ callbacks.append(functools.partial(callback, v, tmp))
890
+ launch_ctx.set_arg_external_array_with_shape(
891
+ indices, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0
892
+ )
893
+ else:
894
+ raise ValueError(
895
+ "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) "
896
+ "before passing it into taichi kernel."
897
+ )
898
+ elif has_pytorch():
899
+ import torch # pylint: disable=C0415
900
+
901
+ if isinstance(v, torch.Tensor):
902
+ if not v.is_contiguous():
903
+ raise ValueError(
904
+ "Non contiguous tensors are not supported, please call tensor.contiguous() before "
905
+ "passing it into taichi kernel."
906
+ )
907
+ taichi_arch = self.runtime.prog.config().arch
908
+
909
+ def get_call_back(u, v):
910
+ def call_back():
911
+ u.copy_(v)
912
+
913
+ return call_back
914
+
915
+ # FIXME: only allocate when launching grad kernel
916
+ if v.requires_grad and v.grad is None:
917
+ v.grad = torch.zeros_like(v)
918
+
919
+ if v.requires_grad:
920
+ if not isinstance(v.grad, torch.Tensor):
921
+ raise ValueError(
922
+ f"Expecting torch.Tensor for gradient tensor, but getting {v.grad.__class__.__name__} instead"
923
+ )
924
+ if not v.grad.is_contiguous():
925
+ raise ValueError(
926
+ "Non contiguous gradient tensors are not supported, please call tensor.grad.contiguous() before passing it into taichi kernel."
927
+ )
928
+
929
+ tmp = v
930
+ if (str(v.device) != "cpu") and not (
931
+ str(v.device).startswith("cuda") and taichi_arch == _ti_core.Arch.cuda
932
+ ):
933
+ # Getting a torch CUDA tensor on Taichi non-cuda arch:
934
+ # We just replace it with a CPU tensor and by the end of kernel execution we'll use the
935
+ # callback to copy the values back to the original CUDA tensor.
936
+ host_v = v.to(device="cpu", copy=True)
937
+ tmp = host_v
938
+ callbacks.append(get_call_back(v, host_v))
939
+
940
+ launch_ctx.set_arg_external_array_with_shape(
941
+ indices,
942
+ int(tmp.data_ptr()),
943
+ tmp.element_size() * tmp.nelement(),
944
+ array_shape,
945
+ int(v.grad.data_ptr()) if v.grad is not None else 0,
946
+ )
947
+ else:
948
+ raise TaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {type(v)}")
949
+ elif has_paddle():
950
+ import paddle # pylint: disable=C0415 # type: ignore
951
+
952
+ if isinstance(v, paddle.Tensor):
953
+ # For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch
954
+ def get_call_back(u, v):
955
+ def call_back():
956
+ u.copy_(v, False)
957
+
958
+ return call_back
959
+
960
+ tmp = v.value().get_tensor()
961
+ taichi_arch = self.runtime.prog.config().arch
962
+ if v.place.is_gpu_place():
963
+ if taichi_arch != _ti_core.Arch.cuda:
964
+ # Paddle cuda tensor on Taichi non-cuda arch
965
+ host_v = v.cpu()
966
+ tmp = host_v.value().get_tensor()
967
+ callbacks.append(get_call_back(v, host_v))
968
+ elif v.place.is_cpu_place():
969
+ if taichi_arch == _ti_core.Arch.cuda:
970
+ # Paddle cpu tensor on Taichi cuda arch
971
+ gpu_v = v.cuda()
972
+ tmp = gpu_v.value().get_tensor()
973
+ callbacks.append(get_call_back(v, gpu_v))
974
+ else:
975
+ # Paddle do support many other backends like XPU, NPU, MLU, IPU
976
+ raise TaichiRuntimeTypeError(f"Taichi do not support backend {v.place} that Paddle support")
977
+ launch_ctx.set_arg_external_array_with_shape(
978
+ indices, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0
979
+ )
980
+ else:
981
+ raise TaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
982
+ else:
983
+ raise TaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
984
+
985
+ def set_arg_matrix(indices, v, needed):
986
+ def cast_float(x):
987
+ if not isinstance(x, (int, float, np.integer, np.floating)):
988
+ raise TaichiRuntimeTypeError(
989
+ f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
990
+ )
991
+ return float(x)
992
+
993
+ def cast_int(x):
994
+ if not isinstance(x, (int, np.integer)):
995
+ raise TaichiRuntimeTypeError(
996
+ f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
997
+ )
998
+ return int(x)
999
+
1000
+ cast_func = None
1001
+ if needed.dtype in primitive_types.real_types:
1002
+ cast_func = cast_float
1003
+ elif needed.dtype in primitive_types.integer_types:
1004
+ cast_func = cast_int
1005
+ else:
1006
+ raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.")
1007
+
1008
+ if needed.ndim == 2:
1009
+ v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)]
1010
+ else:
1011
+ v = [cast_func(v[i]) for i in range(needed.n)]
1012
+ v = needed(*v)
1013
+ needed.set_kernel_struct_args(v, launch_ctx, indices)
1014
+
1015
+ def set_arg_sparse_matrix_builder(indices, v):
1016
+ # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument
1017
+ launch_ctx.set_arg_uint(indices, v._get_ndarray_addr())
1018
+
1019
+ set_later_list = []
1020
+
1021
+ def recursive_set_args(needed_arg_type, provided_arg_type, v, indices):
1022
+ """
1023
+ Returns the number of kernel args set
1024
+ e.g. templates don't set kernel args, so returns 0
1025
+ a single ndarray is 1 kernel arg, so returns 1
1026
+ a struct of 3 ndarrays would set 3 kernel args, so return 3
1027
+ """
1028
+ in_argpack = len(indices) > 1
1029
+ nonlocal actual_argument_slot, exceed_max_arg_num, set_later_list
1030
+ if actual_argument_slot >= max_arg_num:
1031
+ exceed_max_arg_num = True
1032
+ return 0
1033
+ actual_argument_slot += 1
1034
+ if isinstance(needed_arg_type, ArgPackType):
1035
+ if not isinstance(v, ArgPack):
1036
+ raise TaichiRuntimeTypeError.get(indices, str(needed_arg_type), str(provided_arg_type))
1037
+ idx_new = 0
1038
+ for j, (name, anno) in enumerate(needed_arg_type.members.items()):
1039
+ idx_new += recursive_set_args(anno, type(v[name]), v[name], indices + (idx_new,))
1040
+ launch_ctx.set_arg_argpack(indices, v._ArgPack__argpack) # type: ignore
1041
+ return 1
1042
+ # Note: do not use sth like "needed == f32". That would be slow.
1043
+ if id(needed_arg_type) in primitive_types.real_type_ids:
1044
+ if not isinstance(v, (float, int, np.floating, np.integer)):
1045
+ raise TaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
1046
+ if in_argpack:
1047
+ return 1
1048
+ launch_ctx.set_arg_float(indices, float(v))
1049
+ return 1
1050
+ if id(needed_arg_type) in primitive_types.integer_type_ids:
1051
+ if not isinstance(v, (int, np.integer)):
1052
+ raise TaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
1053
+ if in_argpack:
1054
+ return 1
1055
+ if is_signed(cook_dtype(needed_arg_type)):
1056
+ launch_ctx.set_arg_int(indices, int(v))
1057
+ else:
1058
+ launch_ctx.set_arg_uint(indices, int(v))
1059
+ return 1
1060
+ if isinstance(needed_arg_type, sparse_matrix_builder):
1061
+ if in_argpack:
1062
+ set_later_list.append((set_arg_sparse_matrix_builder, (v,)))
1063
+ return 0
1064
+ set_arg_sparse_matrix_builder(indices, v)
1065
+ return 1
1066
+ if dataclasses.is_dataclass(needed_arg_type):
1067
+ assert provided_arg_type == needed_arg_type
1068
+ idx = 0
1069
+ for j, field in enumerate(dataclasses.fields(needed_arg_type)):
1070
+ assert not isinstance(field.type, str)
1071
+ field_value = getattr(v, field.name)
1072
+ idx += recursive_set_args(field.type, field.type, field_value, (indices[0] + idx,))
1073
+ return idx
1074
+ if isinstance(needed_arg_type, ndarray_type.NdarrayType) and isinstance(v, taichi.lang._ndarray.Ndarray):
1075
+ if in_argpack:
1076
+ set_later_list.append((set_arg_ndarray, (v,)))
1077
+ return 0
1078
+ set_arg_ndarray(indices, v)
1079
+ return 1
1080
+ if isinstance(needed_arg_type, texture_type.TextureType) and isinstance(v, taichi.lang._texture.Texture):
1081
+ if in_argpack:
1082
+ set_later_list.append((set_arg_texture, (v,)))
1083
+ return 0
1084
+ set_arg_texture(indices, v)
1085
+ return 1
1086
+ if isinstance(needed_arg_type, texture_type.RWTextureType) and isinstance(v, taichi.lang._texture.Texture):
1087
+ if in_argpack:
1088
+ set_later_list.append((set_arg_rw_texture, (v,)))
1089
+ return 0
1090
+ set_arg_rw_texture(indices, v)
1091
+ return 1
1092
+ if isinstance(needed_arg_type, ndarray_type.NdarrayType):
1093
+ if in_argpack:
1094
+ set_later_list.append((set_arg_ext_array, (v, needed_arg_type)))
1095
+ return 0
1096
+ set_arg_ext_array(indices, v, needed_arg_type)
1097
+ return 1
1098
+ if isinstance(needed_arg_type, MatrixType):
1099
+ if in_argpack:
1100
+ return 1
1101
+ set_arg_matrix(indices, v, needed_arg_type)
1102
+ return 1
1103
+ if isinstance(needed_arg_type, StructType):
1104
+ if in_argpack:
1105
+ return 1
1106
+ if not isinstance(v, needed_arg_type):
1107
+ raise TaichiRuntimeTypeError(
1108
+ f"Argument {provided_arg_type} cannot be converted into required type {needed_arg_type}"
1109
+ )
1110
+ needed_arg_type.set_kernel_struct_args(v, launch_ctx, indices)
1111
+ return 1
1112
+ if needed_arg_type == template or isinstance(needed_arg_type, template):
1113
+ return 0
1114
+ raise ValueError(f"Argument type mismatch. Expecting {needed_arg_type}, got {type(v)}.")
1115
+
1116
+ template_num = 0
1117
+ i_out = 0
1118
+ for i_in, val in enumerate(args):
1119
+ needed_ = self.arguments[i_in].annotation
1120
+ if needed_ == template or isinstance(needed_, template):
1121
+ template_num += 1
1122
+ i_out += 1
1123
+ continue
1124
+ i_out += recursive_set_args(needed_, type(val), val, (i_out - template_num,))
1125
+
1126
+ for i, (set_arg_func, params) in enumerate(set_later_list):
1127
+ set_arg_func((len(args) - template_num + i,), *params)
1128
+
1129
+ if exceed_max_arg_num:
1130
+ raise TaichiRuntimeError(
1131
+ f"The number of elements in kernel arguments is too big! Do not exceed {max_arg_num} on {_ti_core.arch_name(impl.current_cfg().arch)} backend."
1132
+ )
1133
+
1134
+ try:
1135
+ prog = impl.get_runtime().prog
1136
+ # Compile kernel (& Online Cache & Offline Cache)
1137
+ compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel)
1138
+ # Launch kernel
1139
+ prog.launch_kernel(compiled_kernel_data, launch_ctx)
1140
+ except Exception as e:
1141
+ e = handle_exception_from_cpp(e)
1142
+ if impl.get_runtime().print_full_traceback:
1143
+ raise e
1144
+ raise e from None
1145
+
1146
+ ret = None
1147
+ ret_dt = self.return_type
1148
+ has_ret = ret_dt is not None
1149
+
1150
+ if has_ret or self.has_print:
1151
+ runtime_ops.sync()
1152
+
1153
+ if has_ret:
1154
+ ret = []
1155
+ for i, ret_type in enumerate(ret_dt):
1156
+ ret.append(self.construct_kernel_ret(launch_ctx, ret_type, (i,)))
1157
+ if len(ret_dt) == 1:
1158
+ ret = ret[0]
1159
+ if callbacks:
1160
+ for c in callbacks:
1161
+ c()
1162
+
1163
+ return ret
1164
+
1165
+ def construct_kernel_ret(self, launch_ctx, ret_type, index=()):
1166
+ if isinstance(ret_type, CompoundType):
1167
+ return ret_type.from_kernel_struct_ret(launch_ctx, index)
1168
+ if ret_type in primitive_types.integer_types:
1169
+ if is_signed(cook_dtype(ret_type)):
1170
+ return launch_ctx.get_struct_ret_int(index)
1171
+ return launch_ctx.get_struct_ret_uint(index)
1172
+ if ret_type in primitive_types.real_types:
1173
+ return launch_ctx.get_struct_ret_float(index)
1174
+ raise TaichiRuntimeTypeError(f"Invalid return type on index={index}")
1175
+
1176
+ def ensure_compiled(self, *args):
1177
+ instance_id, arg_features = self.mapper.lookup(args)
1178
+ key = (self.func, instance_id, self.autodiff_mode)
1179
+ self.materialize(key=key, args=args, arg_features=arg_features)
1180
+ return key
1181
+
1182
+ # For small kernels (< 3us), the performance can be pretty sensitive to overhead in __call__
1183
+ # Thus this part needs to be fast. (i.e. < 3us on a 4 GHz x64 CPU)
1184
+ @_shell_pop_print
1185
+ def __call__(self, *args, **kwargs):
1186
+ args = _process_args(self, is_func=False, args=args, kwargs=kwargs)
1187
+
1188
+ # Transform the primal kernel to forward mode grad kernel
1189
+ # then recover to primal when exiting the forward mode manager
1190
+ if self.runtime.fwd_mode_manager and not self.runtime.grad_replaced:
1191
+ # TODO: if we would like to compute 2nd-order derivatives by forward-on-reverse in a nested context manager fashion,
1192
+ # i.e., a `Tape` nested in the `FwdMode`, we can transform the kernels with `mode_original == AutodiffMode.REVERSE` only,
1193
+ # to avoid duplicate computation for 1st-order derivatives
1194
+ self.runtime.fwd_mode_manager.insert(self)
1195
+
1196
+ # Both the class kernels and the plain-function kernels are unified now.
1197
+ # In both cases, |self.grad| is another Kernel instance that computes the
1198
+ # gradient. For class kernels, args[0] is always the kernel owner.
1199
+
1200
+ # No need to capture grad kernels because they are already bound with their primal kernels
1201
+ if (
1202
+ self.autodiff_mode in (AutodiffMode.NONE, AutodiffMode.VALIDATION)
1203
+ and self.runtime.target_tape
1204
+ and not self.runtime.grad_replaced
1205
+ ):
1206
+ self.runtime.target_tape.insert(self, args)
1207
+
1208
+ if self.autodiff_mode != AutodiffMode.NONE and impl.current_cfg().opt_level == 0:
1209
+ _logging.warn("""opt_level = 1 is enforced to enable gradient computation.""")
1210
+ impl.current_cfg().opt_level = 1
1211
+ key = self.ensure_compiled(*args)
1212
+ kernel_cpp = self.compiled_kernels[key]
1213
+ return self.launch_kernel(kernel_cpp, *args)
1214
+
1215
+
1216
+ # For a Taichi class definition like below:
1217
+ #
1218
+ # @ti.data_oriented
1219
+ # class X:
1220
+ # @ti.kernel
1221
+ # def foo(self):
1222
+ # ...
1223
+ #
1224
+ # When ti.kernel runs, the stackframe's |code_context| of Python 3.8(+) is
1225
+ # different from that of Python 3.7 and below. In 3.8+, it is 'class X:',
1226
+ # whereas in <=3.7, it is '@ti.data_oriented'. More interestingly, if the class
1227
+ # inherits, i.e. class X(object):, then in both versions, |code_context| is
1228
+ # 'class X(object):'...
1229
+ _KERNEL_CLASS_STACKFRAME_STMT_RES = [
1230
+ re.compile(r"@(\w+\.)?data_oriented"),
1231
+ re.compile(r"class "),
1232
+ ]
1233
+
1234
+
1235
+ def _inside_class(level_of_class_stackframe):
1236
+ try:
1237
+ maybe_class_frame = sys._getframe(level_of_class_stackframe)
1238
+ statement_list = inspect.getframeinfo(maybe_class_frame)[3]
1239
+ if statement_list is None:
1240
+ return False
1241
+ first_statment = statement_list[0].strip()
1242
+ for pat in _KERNEL_CLASS_STACKFRAME_STMT_RES:
1243
+ if pat.match(first_statment):
1244
+ return True
1245
+ except:
1246
+ pass
1247
+ return False
1248
+
1249
+
1250
+ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool = False):
1251
+ # Can decorators determine if a function is being defined inside a class?
1252
+ # https://stackoverflow.com/a/8793684/12003165
1253
+ is_classkernel = _inside_class(level_of_class_stackframe + 1)
1254
+
1255
+ if verbose:
1256
+ print(f"kernel={_func.__name__} is_classkernel={is_classkernel}")
1257
+ primal = Kernel(_func, autodiff_mode=AutodiffMode.NONE, _classkernel=is_classkernel)
1258
+ adjoint = Kernel(_func, autodiff_mode=AutodiffMode.REVERSE, _classkernel=is_classkernel)
1259
+ # Having |primal| contains |grad| makes the tape work.
1260
+ primal.grad = adjoint
1261
+
1262
+ if is_classkernel:
1263
+ # For class kernels, their primal/adjoint callables are constructed
1264
+ # when the kernel is accessed via the instance inside
1265
+ # _BoundedDifferentiableMethod.
1266
+ # This is because we need to bind the kernel or |grad| to the instance
1267
+ # owning the kernel, which is not known until the kernel is accessed.
1268
+ #
1269
+ # See also: _BoundedDifferentiableMethod, data_oriented.
1270
+ @functools.wraps(_func)
1271
+ def wrapped(*args, **kwargs):
1272
+ # If we reach here (we should never), it means the class is not decorated
1273
+ # with @ti.data_oriented, otherwise getattr would have intercepted the call.
1274
+ clsobj = type(args[0])
1275
+ assert not hasattr(clsobj, "_data_oriented")
1276
+ raise TaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
1277
+
1278
+ else:
1279
+
1280
+ @functools.wraps(_func)
1281
+ def wrapped(*args, **kwargs):
1282
+ try:
1283
+ return primal(*args, **kwargs)
1284
+ except (TaichiCompilationError, TaichiRuntimeError) as e:
1285
+ if impl.get_runtime().print_full_traceback:
1286
+ raise e
1287
+ raise type(e)("\n" + str(e)) from None
1288
+
1289
+ wrapped.grad = adjoint
1290
+
1291
+ wrapped._is_wrapped_kernel = True
1292
+ wrapped._is_classkernel = is_classkernel
1293
+ wrapped._primal = primal
1294
+ wrapped._adjoint = adjoint
1295
+ return wrapped
1296
+
1297
+
1298
+ def kernel(fn: Callable):
1299
+ """Marks a function as a Taichi kernel.
1300
+
1301
+ A Taichi kernel is a function written in Python, and gets JIT compiled by
1302
+ Taichi into native CPU/GPU instructions (e.g. a series of CUDA kernels).
1303
+ The top-level ``for`` loops are automatically parallelized, and distributed
1304
+ to either a CPU thread pool or massively parallel GPUs.
1305
+
1306
+ Kernel's gradient kernel would be generated automatically by the AutoDiff system.
1307
+
1308
+ See also https://docs.taichi-lang.org/docs/syntax#kernel.
1309
+
1310
+ Args:
1311
+ fn (Callable): the Python function to be decorated
1312
+
1313
+ Returns:
1314
+ Callable: The decorated function
1315
+
1316
+ Example::
1317
+
1318
+ >>> x = ti.field(ti.i32, shape=(4, 8))
1319
+ >>>
1320
+ >>> @ti.kernel
1321
+ >>> def run():
1322
+ >>> # Assigns all the elements of `x` in parallel.
1323
+ >>> for i in x:
1324
+ >>> x[i] = i
1325
+ """
1326
+ return _kernel_impl(fn, level_of_class_stackframe=3)
1327
+
1328
+
1329
+ class _BoundedDifferentiableMethod:
1330
+ def __init__(self, kernel_owner, wrapped_kernel_func):
1331
+ clsobj = type(kernel_owner)
1332
+ if not getattr(clsobj, "_data_oriented", False):
1333
+ raise TaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
1334
+ self._kernel_owner = kernel_owner
1335
+ self._primal = wrapped_kernel_func._primal
1336
+ self._adjoint = wrapped_kernel_func._adjoint
1337
+ self._is_staticmethod = wrapped_kernel_func._is_staticmethod
1338
+ self.__name__: str | None = None
1339
+
1340
+ def __call__(self, *args, **kwargs):
1341
+ try:
1342
+ if self._is_staticmethod:
1343
+ return self._primal(*args, **kwargs)
1344
+ return self._primal(self._kernel_owner, *args, **kwargs)
1345
+ except (TaichiCompilationError, TaichiRuntimeError) as e:
1346
+ if impl.get_runtime().print_full_traceback:
1347
+ raise e
1348
+ raise type(e)("\n" + str(e)) from None
1349
+
1350
+ def grad(self, *args, **kwargs):
1351
+ return self._adjoint(self._kernel_owner, *args, **kwargs)
1352
+
1353
+
1354
+ def data_oriented(cls):
1355
+ """Marks a class as Taichi compatible.
1356
+
1357
+ To allow for modularized code, Taichi provides this decorator so that
1358
+ Taichi kernels can be defined inside a class.
1359
+
1360
+ See also https://docs.taichi-lang.org/docs/odop
1361
+
1362
+ Example::
1363
+
1364
+ >>> @ti.data_oriented
1365
+ >>> class TiArray:
1366
+ >>> def __init__(self, n):
1367
+ >>> self.x = ti.field(ti.f32, shape=n)
1368
+ >>>
1369
+ >>> @ti.kernel
1370
+ >>> def inc(self):
1371
+ >>> for i in self.x:
1372
+ >>> self.x[i] += 1.0
1373
+ >>>
1374
+ >>> a = TiArray(32)
1375
+ >>> a.inc()
1376
+
1377
+ Args:
1378
+ cls (Class): the class to be decorated
1379
+
1380
+ Returns:
1381
+ The decorated class.
1382
+ """
1383
+
1384
+ def _getattr(self, item):
1385
+ method = cls.__dict__.get(item, None)
1386
+ is_property = method.__class__ == property
1387
+ is_staticmethod = method.__class__ == staticmethod
1388
+ if is_property:
1389
+ x = method.fget
1390
+ else:
1391
+ x = super(cls, self).__getattribute__(item)
1392
+ if hasattr(x, "_is_wrapped_kernel"):
1393
+ if inspect.ismethod(x):
1394
+ wrapped = x.__func__
1395
+ else:
1396
+ wrapped = x
1397
+ wrapped._is_staticmethod = is_staticmethod
1398
+ assert inspect.isfunction(wrapped)
1399
+ if wrapped._is_classkernel:
1400
+ ret = _BoundedDifferentiableMethod(self, wrapped)
1401
+ ret.__name__ = wrapped.__name__
1402
+ if is_property:
1403
+ return ret()
1404
+ return ret
1405
+ if is_property:
1406
+ return x(self)
1407
+ return x
1408
+
1409
+ cls.__getattribute__ = _getattr
1410
+ cls._data_oriented = True
1411
+
1412
+ return cls
1413
+
1414
+
1415
+ __all__ = ["data_oriented", "func", "kernel", "pyfunc", "real_func"]