gstaichi 0.0.0__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +51 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +5 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2917 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version_check.py +100 -0
  22. gstaichi/ad/__init__.py +3 -0
  23. gstaichi/ad/_ad.py +530 -0
  24. gstaichi/algorithms/__init__.py +3 -0
  25. gstaichi/algorithms/_algorithms.py +117 -0
  26. gstaichi/assets/.git +1 -0
  27. gstaichi/assets/Go-Regular.ttf +0 -0
  28. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  29. gstaichi/examples/lcg_python.py +26 -0
  30. gstaichi/examples/lcg_taichi.py +34 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_dataclass_util.py +31 -0
  35. gstaichi/lang/_fast_caching/__init__.py +3 -0
  36. gstaichi/lang/_fast_caching/args_hasher.py +122 -0
  37. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  38. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  39. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  40. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  41. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  42. gstaichi/lang/_fast_caching/src_hasher.py +83 -0
  43. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  44. gstaichi/lang/_ndarray.py +366 -0
  45. gstaichi/lang/_ndrange.py +152 -0
  46. gstaichi/lang/_template_mapper.py +195 -0
  47. gstaichi/lang/_texture.py +172 -0
  48. gstaichi/lang/_wrap_inspect.py +215 -0
  49. gstaichi/lang/any_array.py +99 -0
  50. gstaichi/lang/ast/__init__.py +7 -0
  51. gstaichi/lang/ast/ast_transformer.py +1351 -0
  52. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  53. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  54. gstaichi/lang/ast/ast_transformers/call_transformer.py +327 -0
  55. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  56. gstaichi/lang/ast/checkers.py +106 -0
  57. gstaichi/lang/ast/symbol_resolver.py +57 -0
  58. gstaichi/lang/ast/transform.py +9 -0
  59. gstaichi/lang/common_ops.py +310 -0
  60. gstaichi/lang/exception.py +80 -0
  61. gstaichi/lang/expr.py +180 -0
  62. gstaichi/lang/field.py +428 -0
  63. gstaichi/lang/impl.py +1259 -0
  64. gstaichi/lang/kernel_arguments.py +155 -0
  65. gstaichi/lang/kernel_impl.py +1386 -0
  66. gstaichi/lang/matrix.py +1835 -0
  67. gstaichi/lang/matrix_ops.py +341 -0
  68. gstaichi/lang/matrix_ops_utils.py +190 -0
  69. gstaichi/lang/mesh.py +687 -0
  70. gstaichi/lang/misc.py +784 -0
  71. gstaichi/lang/ops.py +1494 -0
  72. gstaichi/lang/runtime_ops.py +13 -0
  73. gstaichi/lang/shell.py +35 -0
  74. gstaichi/lang/simt/__init__.py +5 -0
  75. gstaichi/lang/simt/block.py +94 -0
  76. gstaichi/lang/simt/grid.py +7 -0
  77. gstaichi/lang/simt/subgroup.py +191 -0
  78. gstaichi/lang/simt/warp.py +96 -0
  79. gstaichi/lang/snode.py +489 -0
  80. gstaichi/lang/source_builder.py +150 -0
  81. gstaichi/lang/struct.py +810 -0
  82. gstaichi/lang/util.py +312 -0
  83. gstaichi/linalg/__init__.py +10 -0
  84. gstaichi/linalg/matrixfree_cg.py +310 -0
  85. gstaichi/linalg/sparse_cg.py +59 -0
  86. gstaichi/linalg/sparse_matrix.py +303 -0
  87. gstaichi/linalg/sparse_solver.py +123 -0
  88. gstaichi/math/__init__.py +11 -0
  89. gstaichi/math/_complex.py +205 -0
  90. gstaichi/math/mathimpl.py +886 -0
  91. gstaichi/profiler/__init__.py +6 -0
  92. gstaichi/profiler/kernel_metrics.py +260 -0
  93. gstaichi/profiler/kernel_profiler.py +586 -0
  94. gstaichi/profiler/memory_profiler.py +15 -0
  95. gstaichi/profiler/scoped_profiler.py +36 -0
  96. gstaichi/sparse/__init__.py +3 -0
  97. gstaichi/sparse/_sparse_grid.py +77 -0
  98. gstaichi/tools/__init__.py +12 -0
  99. gstaichi/tools/diagnose.py +117 -0
  100. gstaichi/tools/np2ply.py +364 -0
  101. gstaichi/tools/vtk.py +38 -0
  102. gstaichi/types/__init__.py +21 -0
  103. gstaichi/types/annotations.py +52 -0
  104. gstaichi/types/compound_types.py +71 -0
  105. gstaichi/types/enums.py +49 -0
  106. gstaichi/types/ndarray_type.py +169 -0
  107. gstaichi/types/primitive_types.py +206 -0
  108. gstaichi/types/quant.py +88 -0
  109. gstaichi/types/texture_type.py +85 -0
  110. gstaichi/types/utils.py +11 -0
  111. gstaichi-0.0.0.data/data/include/GLFW/glfw3.h +6389 -0
  112. gstaichi-0.0.0.data/data/include/GLFW/glfw3native.h +594 -0
  113. gstaichi-0.0.0.data/data/include/spirv-tools/instrument.hpp +268 -0
  114. gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.h +907 -0
  115. gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  116. gstaichi-0.0.0.data/data/include/spirv-tools/linker.hpp +97 -0
  117. gstaichi-0.0.0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  118. gstaichi-0.0.0.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  119. gstaichi-0.0.0.data/data/include/spirv_cross/spirv.h +2568 -0
  120. gstaichi-0.0.0.data/data/include/spirv_cross/spirv.hpp +2579 -0
  121. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  122. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  123. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  124. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  125. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  126. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  127. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  128. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  129. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  130. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  131. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  132. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  133. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  134. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  135. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  136. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  137. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  138. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  139. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  140. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  141. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  142. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  143. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  144. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  145. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  146. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  147. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  148. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  149. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  150. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  151. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  152. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  153. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  154. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  155. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  156. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  157. gstaichi-0.0.0.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  158. gstaichi-0.0.0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  159. gstaichi-0.0.0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  160. gstaichi-0.0.0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  161. gstaichi-0.0.0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  162. gstaichi-0.0.0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  163. gstaichi-0.0.0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  164. gstaichi-0.0.0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  165. gstaichi-0.0.0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  166. gstaichi-0.0.0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  167. gstaichi-0.0.0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  168. gstaichi-0.0.0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  169. gstaichi-0.0.0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  170. gstaichi-0.0.0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  171. gstaichi-0.0.0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  172. gstaichi-0.0.0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  173. gstaichi-0.0.0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  174. gstaichi-0.0.0.dist-info/METADATA +97 -0
  175. gstaichi-0.0.0.dist-info/RECORD +178 -0
  176. gstaichi-0.0.0.dist-info/WHEEL +5 -0
  177. gstaichi-0.0.0.dist-info/licenses/LICENSE +201 -0
  178. gstaichi-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1386 @@
1
+ import ast
2
+ import dataclasses
3
+ import functools
4
+ import inspect
5
+ import json
6
+ import operator
7
+ import os
8
+ import pathlib
9
+ import re
10
+ import sys
11
+ import textwrap
12
+ import time
13
+ import types
14
+ import typing
15
+ import warnings
16
+ from typing import Any, Callable, Type
17
+
18
+ import numpy as np
19
+
20
+ import gstaichi.lang
21
+ import gstaichi.lang._ndarray
22
+ import gstaichi.lang._texture
23
+ import gstaichi.types.annotations
24
+ from gstaichi import _logging
25
+ from gstaichi._lib import core as _ti_core
26
+ from gstaichi._lib.core.gstaichi_python import (
27
+ ASTBuilder,
28
+ CompiledKernelData,
29
+ CompileResult,
30
+ FunctionKey,
31
+ KernelCxx,
32
+ KernelLaunchContext,
33
+ )
34
+ from gstaichi.lang import _kernel_impl_dataclass, impl, ops, runtime_ops
35
+ from gstaichi.lang._fast_caching import src_hasher
36
+ from gstaichi.lang._template_mapper import TemplateMapper
37
+ from gstaichi.lang._wrap_inspect import FunctionSourceInfo, get_source_info_and_src
38
+ from gstaichi.lang.any_array import AnyArray
39
+ from gstaichi.lang.ast import (
40
+ ASTTransformerContext,
41
+ KernelSimplicityASTChecker,
42
+ transform_tree,
43
+ )
44
+ from gstaichi.lang.ast.ast_transformer_utils import ReturnStatus
45
+ from gstaichi.lang.exception import (
46
+ GsTaichiCompilationError,
47
+ GsTaichiRuntimeError,
48
+ GsTaichiRuntimeTypeError,
49
+ GsTaichiSyntaxError,
50
+ GsTaichiTypeError,
51
+ handle_exception_from_cpp,
52
+ )
53
+ from gstaichi.lang.expr import Expr
54
+ from gstaichi.lang.kernel_arguments import ArgMetadata
55
+ from gstaichi.lang.matrix import MatrixType
56
+ from gstaichi.lang.shell import _shell_pop_print
57
+ from gstaichi.lang.struct import StructType
58
+ from gstaichi.lang.util import cook_dtype, has_pytorch
59
+ from gstaichi.types import (
60
+ ndarray_type,
61
+ primitive_types,
62
+ sparse_matrix_builder,
63
+ template,
64
+ texture_type,
65
+ )
66
+ from gstaichi.types.compound_types import CompoundType
67
+ from gstaichi.types.enums import AutodiffMode, Layout
68
+ from gstaichi.types.utils import is_signed
69
+
70
+ CompiledKernelKeyType = tuple[Callable, int, AutodiffMode]
71
+
72
+
73
+ class GsTaichiCallable:
74
+ """
75
+ BoundGsTaichiCallable is used to enable wrapping a bindable function with a class.
76
+
77
+ Design requirements for GsTaichiCallable:
78
+ - wrap/contain a reference to a class Func instance, and allow (the GsTaichiCallable) being passed around
79
+ like normal function pointer
80
+ - expose attributes of the wrapped class Func, such as `_if_real_function`, `_primal`, etc
81
+ - allow for (now limited) strong typing, and enable type checkers, such as pyright/mypy
82
+ - currently GsTaichiCallable is a shared type used for all functions marked with @ti.func, @ti.kernel,
83
+ python functions (?)
84
+ - note: current type-checking implementation does not distinguish between different type flavors of
85
+ GsTaichiCallable, with different values of `_if_real_function`, `_primal`, etc
86
+ - handle not only class-less functions, but also class-instance methods (where determining the `self`
87
+ reference is a challenge)
88
+
89
+ Let's take the following example:
90
+
91
+ def test_ptr_class_func():
92
+ @ti.data_oriented
93
+ class MyClass:
94
+ def __init__(self):
95
+ self.a = ti.field(dtype=ti.f32, shape=(3))
96
+
97
+ def add2numbers_py(self, x, y):
98
+ return x + y
99
+
100
+ @ti.func
101
+ def add2numbers_func(self, x, y):
102
+ return x + y
103
+
104
+ @ti.kernel
105
+ def func(self):
106
+ a, add_py, add_func = ti.static(self.a, self.add2numbers_py, self.add2numbers_func)
107
+ a[0] = add_py(2, 3)
108
+ a[1] = add_func(3, 7)
109
+
110
+ (taken from test_ptr_assign.py).
111
+
112
+ When the @ti.func decorator is parsed, the function `add2numbers_func` exists, but there is not yet any `self`
113
+ - it is not possible for the method to be bound, to a `self` instance
114
+ - however, the @ti.func annotation, runs the kernel_imp.py::func function --- it is at this point
115
+ that GsTaichi's original code creates a class Func instance (that wraps the add2numbers_func)
116
+ and immediately we create a GsTaichiCallable instance that wraps the Func instance.
117
+ - effectively, we have two layers of wrapping GsTaichiCallable->Func->function pointer
118
+ (actual function definition)
119
+ - later on, when we call self.add2numbers_py, here:
120
+
121
+ a, add_py, add_func = ti.static(self.a, self.add2numbers_py, self.add2numbers_func)
122
+
123
+ ... we want to call the bound method, `self.add2numbers_py`.
124
+ - an actual python function reference, created by doing somevar = MyClass.add2numbers, can automatically
125
+ binds to self, when called from self in this way (however, add2numbers_py is actually a class
126
+ Func instance, wrapping python function reference -- now also all wrapped by a GsTaichiCallable
127
+ instance -- returned by the kernel_impl.py::func function, run by @ti.func)
128
+ - however, in order to be able to add strongly typed attributes to the wrapped python function, we need
129
+ to wrap the wrapped python function in a class
130
+ - the wrapped python function, wrapped in a GsTaichiCallable class (which is callable, and will
131
+ execute the underlying double-wrapped python function), will NOT automatically bind
132
+ - when we invoke GsTaichiCallable, the wrapped function is invoked. The wrapped function is unbound, and
133
+ so `self` is not automatically passed in, as an argument, and things break
134
+
135
+ To address this we need to use the `__get__` method, in our function wrapper, ie GsTaichiCallable,
136
+ and have the `__get__` method return the `BoundGsTaichiCallable` object. The `__get__` method handles
137
+ running the binding for us, and effectively binds `BoundFunc` object to `self` object, by passing
138
+ in the instance, as an argument into `BoundGsTaichiCallable.__init__`.
139
+
140
+ `BoundFunc` can then be used as a normal bound func - even though it's just an object instance -
141
+ using its `__call__` method. Effectively, at the time of actually invoking the underlying python
142
+ function, we have 3 layers of wrapper instances:
143
+ BoundGsTaichiCallabe -> GsTaichiCallable -> Func -> python function reference/definition
144
+ """
145
+
146
+ def __init__(self, fn: Callable, wrapper: Callable) -> None:
147
+ self.fn: Callable = fn
148
+ self.wrapper: Callable = wrapper
149
+ self._is_real_function: bool = False
150
+ self._is_gstaichi_function: bool = False
151
+ self._is_wrapped_kernel: bool = False
152
+ self._is_classkernel: bool = False
153
+ self._primal: Kernel | None = None
154
+ self._adjoint: Kernel | None = None
155
+ self.grad: Kernel | None = None
156
+ self._is_staticmethod: bool = False
157
+ self.is_pure = False
158
+ functools.update_wrapper(self, fn)
159
+
160
+ def __call__(self, *args, **kwargs):
161
+ return self.wrapper.__call__(*args, **kwargs)
162
+
163
+ def __get__(self, instance, owner):
164
+ if instance is None:
165
+ return self
166
+ return BoundGsTaichiCallable(instance, self)
167
+
168
+
169
+ class BoundGsTaichiCallable:
170
+ def __init__(self, instance: Any, gstaichi_callable: "GsTaichiCallable"):
171
+ self.wrapper = gstaichi_callable.wrapper
172
+ self.instance = instance
173
+ self.gstaichi_callable = gstaichi_callable
174
+
175
+ def __call__(self, *args, **kwargs):
176
+ return self.wrapper(self.instance, *args, **kwargs)
177
+
178
+ def __getattr__(self, k: str) -> Any:
179
+ res = getattr(self.gstaichi_callable, k)
180
+ return res
181
+
182
+ def __setattr__(self, k: str, v: Any) -> None:
183
+ # Note: these have to match the name of any attributes on this class.
184
+ if k in ("wrapper", "instance", "gstaichi_callable"):
185
+ object.__setattr__(self, k, v)
186
+ else:
187
+ setattr(self.gstaichi_callable, k, v)
188
+
189
+
190
+ def func(fn: Callable, is_real_function: bool = False) -> GsTaichiCallable:
191
+ """Marks a function as callable in GsTaichi-scope.
192
+
193
+ This decorator transforms a Python function into a GsTaichi one. GsTaichi
194
+ will JIT compile it into native instructions.
195
+
196
+ Args:
197
+ fn (Callable): The Python function to be decorated
198
+ is_real_function (bool): Whether the function is a real function
199
+
200
+ Returns:
201
+ Callable: The decorated function
202
+
203
+ Example::
204
+
205
+ >>> @ti.func
206
+ >>> def foo(x):
207
+ >>> return x + 2
208
+ >>>
209
+ >>> @ti.kernel
210
+ >>> def run():
211
+ >>> print(foo(40)) # 42
212
+ """
213
+ is_classfunc = _inside_class(level_of_class_stackframe=3 + is_real_function)
214
+
215
+ fun = Func(fn, _classfunc=is_classfunc, is_real_function=is_real_function)
216
+ gstaichi_callable = GsTaichiCallable(fn, fun)
217
+ gstaichi_callable._is_gstaichi_function = True
218
+ gstaichi_callable._is_real_function = is_real_function
219
+ return gstaichi_callable
220
+
221
+
222
+ def real_func(fn: Callable) -> GsTaichiCallable:
223
+ return func(fn, is_real_function=True)
224
+
225
+
226
+ def pyfunc(fn: Callable) -> GsTaichiCallable:
227
+ """Marks a function as callable in both GsTaichi and Python scopes.
228
+
229
+ When called inside the GsTaichi scope, GsTaichi will JIT compile it into
230
+ native instructions. Otherwise it will be invoked directly as a
231
+ Python function.
232
+
233
+ See also :func:`~gstaichi.lang.kernel_impl.func`.
234
+
235
+ Args:
236
+ fn (Callable): The Python function to be decorated
237
+
238
+ Returns:
239
+ Callable: The decorated function
240
+ """
241
+ is_classfunc = _inside_class(level_of_class_stackframe=3)
242
+ fun = Func(fn, _classfunc=is_classfunc, _pyfunc=True)
243
+ gstaichi_callable = GsTaichiCallable(fn, fun)
244
+ gstaichi_callable._is_gstaichi_function = True
245
+ gstaichi_callable._is_real_function = False
246
+ return gstaichi_callable
247
+
248
+
249
+ def _populate_global_vars_for_templates(
250
+ template_slot_locations: list[int],
251
+ argument_metas: list[ArgMetadata],
252
+ global_vars: dict[str, Any],
253
+ fn: Callable,
254
+ py_args: tuple[Any, ...],
255
+ ):
256
+ """
257
+ Inject template parameters into globals
258
+
259
+ Globals are being abused to store the python objects associated
260
+ with templates. We continue this approach, and in addition this function
261
+ handles injecting expanded python variables from dataclasses.
262
+ """
263
+ for i in template_slot_locations:
264
+ template_var_name = argument_metas[i].name
265
+ global_vars[template_var_name] = py_args[i]
266
+ parameters = inspect.signature(fn).parameters
267
+ for i, (parameter_name, parameter) in enumerate(parameters.items()):
268
+ if dataclasses.is_dataclass(parameter.annotation):
269
+ _kernel_impl_dataclass.populate_global_vars_from_dataclass(
270
+ parameter_name,
271
+ parameter.annotation,
272
+ py_args[i],
273
+ global_vars=global_vars,
274
+ )
275
+
276
+
277
+ def _get_tree_and_ctx(
278
+ self: "Func | Kernel",
279
+ args: tuple[Any, ...],
280
+ excluded_parameters=(),
281
+ is_kernel: bool = True,
282
+ arg_features=None,
283
+ ast_builder: "ASTBuilder | None" = None,
284
+ is_real_function: bool = False,
285
+ current_kernel: "Kernel | None" = None,
286
+ ) -> tuple[ast.Module, ASTTransformerContext]:
287
+ function_source_info, src = get_source_info_and_src(self.func)
288
+ src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
289
+ tree = ast.parse(textwrap.dedent("\n".join(src)))
290
+
291
+ func_body = tree.body[0]
292
+ func_body.decorator_list = [] # type: ignore , kick that can down the road...
293
+
294
+ global_vars = _get_global_vars(self.func)
295
+
296
+ if is_kernel or is_real_function:
297
+ _populate_global_vars_for_templates(
298
+ template_slot_locations=self.template_slot_locations,
299
+ argument_metas=self.arg_metas,
300
+ global_vars=global_vars,
301
+ fn=self.func,
302
+ py_args=args,
303
+ )
304
+
305
+ if current_kernel is not None: # Kernel
306
+ current_kernel.kernel_function_info = function_source_info
307
+ if current_kernel is None:
308
+ current_kernel = impl.get_runtime()._current_kernel
309
+ assert current_kernel is not None
310
+ current_kernel.visited_functions.add(function_source_info)
311
+
312
+ return tree, ASTTransformerContext(
313
+ excluded_parameters=excluded_parameters,
314
+ is_kernel=is_kernel,
315
+ func=self,
316
+ arg_features=arg_features,
317
+ global_vars=global_vars,
318
+ argument_data=args,
319
+ src=src,
320
+ start_lineno=function_source_info.start_lineno,
321
+ end_lineno=function_source_info.end_lineno,
322
+ file=function_source_info.filepath,
323
+ ast_builder=ast_builder,
324
+ is_real_function=is_real_function,
325
+ )
326
+
327
+
328
+ def _process_args(self: "Func | Kernel", is_func: bool, args: tuple[Any, ...], kwargs) -> tuple[Any, ...]:
329
+ if is_func:
330
+ self.arg_metas = _kernel_impl_dataclass.expand_func_arguments(self.arg_metas)
331
+
332
+ fused_args: list[Any] = [arg_meta.default for arg_meta in self.arg_metas]
333
+ len_args = len(args)
334
+
335
+ if len_args > len(fused_args):
336
+ arg_str = ", ".join(map(str, args))
337
+ expected_str = ", ".join(f"{arg.name} : {arg.annotation}" for arg in self.arg_metas)
338
+ msg_l = []
339
+ msg_l.append(f"Too many arguments. Expected ({expected_str}), got ({arg_str}).")
340
+ for i in range(len_args):
341
+ if i < len(self.arg_metas):
342
+ msg_l.append(f" - {i} arg meta: {self.arg_metas[i].name} arg type: {type(args[i])}")
343
+ else:
344
+ msg_l.append(f" - {i} arg meta: <out of arg metas> arg type: {type(args[i])}")
345
+ msg_l.append(f"In function: {self.func}")
346
+ raise GsTaichiSyntaxError("\n".join(msg_l))
347
+
348
+ for i, arg in enumerate(args):
349
+ fused_args[i] = arg
350
+
351
+ for key, value in kwargs.items():
352
+ for i, arg in enumerate(self.arg_metas):
353
+ if key == arg.name:
354
+ if i < len_args:
355
+ raise GsTaichiSyntaxError(f"Multiple values for argument '{key}'.")
356
+ fused_args[i] = value
357
+ break
358
+ else:
359
+ raise GsTaichiSyntaxError(f"Unexpected argument '{key}'.")
360
+
361
+ missing_parameters = []
362
+ for i, arg in enumerate(fused_args):
363
+ if arg is inspect.Parameter.empty:
364
+ if self.arg_metas[i].annotation is inspect._empty:
365
+ missing_parameters.append(f"Parameter `{self.arg_metas[i].name}` missing.")
366
+ else:
367
+ missing_parameters.append(
368
+ f"Parameter `{self.arg_metas[i].name} : {self.arg_metas[i].annotation}` missing."
369
+ )
370
+ if len(missing_parameters) > 0:
371
+ msg_l = []
372
+ msg_l.append("Error: missing parameters.")
373
+ msg_l.extend(missing_parameters)
374
+ msg_l.append("")
375
+ msg_l.append("Debug info follows.")
376
+ msg_l.append("fused args:")
377
+ for i, arg in enumerate(fused_args):
378
+ msg_l.append(f" {i} {arg}")
379
+ msg_l.append("arg metas:")
380
+ for i, arg in enumerate(self.arg_metas):
381
+ msg_l.append(f" {i} {arg}")
382
+ raise GsTaichiSyntaxError("\n".join(msg_l))
383
+
384
+ return tuple(fused_args)
385
+
386
+
387
+ class Func:
388
+ function_counter = 0
389
+
390
+ def __init__(self, _func: Callable, _classfunc=False, _pyfunc=False, is_real_function=False) -> None:
391
+ self.func = _func
392
+ self.func_id = Func.function_counter
393
+ Func.function_counter += 1
394
+ self.compiled = {}
395
+ self.classfunc = _classfunc
396
+ self.pyfunc = _pyfunc
397
+ self.is_real_function = is_real_function
398
+ self.arg_metas: list[ArgMetadata] = []
399
+ self.orig_arguments: list[ArgMetadata] = []
400
+ self.return_type: tuple[Type, ...] | None = None
401
+ self.extract_arguments()
402
+ self.template_slot_locations: list[int] = []
403
+ for i, arg in enumerate(self.arg_metas):
404
+ if arg.annotation == template or isinstance(arg.annotation, template):
405
+ self.template_slot_locations.append(i)
406
+ self.mapper = TemplateMapper(self.arg_metas, self.template_slot_locations)
407
+ self.gstaichi_functions = {} # The |Function| class in C++
408
+ self.has_print = False
409
+
410
+ def __call__(self: "Func", *args, **kwargs) -> Any:
411
+ args = _process_args(self, is_func=True, args=args, kwargs=kwargs)
412
+
413
+ if not impl.inside_kernel():
414
+ if not self.pyfunc:
415
+ raise GsTaichiSyntaxError("GsTaichi functions cannot be called from Python-scope.")
416
+ return self.func(*args)
417
+
418
+ current_kernel = impl.get_runtime().current_kernel
419
+ if self.is_real_function:
420
+ if current_kernel.autodiff_mode != AutodiffMode.NONE:
421
+ raise GsTaichiSyntaxError("Real function in gradient kernels unsupported.")
422
+ instance_id, arg_features = self.mapper.lookup(args)
423
+ key = _ti_core.FunctionKey(self.func.__name__, self.func_id, instance_id)
424
+ if key.instance_id not in self.compiled:
425
+ self.do_compile(key=key, args=args, arg_features=arg_features)
426
+ return self.func_call_rvalue(key=key, args=args)
427
+ tree, ctx = _get_tree_and_ctx(
428
+ self,
429
+ is_kernel=False,
430
+ args=args,
431
+ ast_builder=current_kernel.ast_builder(),
432
+ is_real_function=self.is_real_function,
433
+ )
434
+
435
+ struct_locals = _kernel_impl_dataclass.extract_struct_locals_from_context(ctx)
436
+
437
+ tree = _kernel_impl_dataclass.unpack_ast_struct_expressions(tree, struct_locals=struct_locals)
438
+ ret = transform_tree(tree, ctx)
439
+ if not self.is_real_function:
440
+ if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
441
+ raise GsTaichiSyntaxError("Function has a return type but does not have a return statement")
442
+ return ret
443
+
444
+ def func_call_rvalue(self, key: FunctionKey, args: tuple[Any, ...]) -> Any:
445
+ # Skip the template args, e.g., |self|
446
+ assert self.is_real_function
447
+ non_template_args = []
448
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
449
+ for i, kernel_arg in enumerate(self.arg_metas):
450
+ anno = kernel_arg.annotation
451
+ if not isinstance(anno, template):
452
+ if id(anno) in primitive_types.type_ids:
453
+ non_template_args.append(ops.cast(args[i], anno))
454
+ elif isinstance(anno, primitive_types.RefType):
455
+ non_template_args.append(_ti_core.make_reference(args[i].ptr, dbg_info))
456
+ elif isinstance(anno, ndarray_type.NdarrayType):
457
+ if not isinstance(args[i], AnyArray):
458
+ raise GsTaichiTypeError(
459
+ f"Expected ndarray in the kernel argument for argument {kernel_arg.name}, got {args[i]}"
460
+ )
461
+ non_template_args += _ti_core.get_external_tensor_real_func_args(args[i].ptr, dbg_info)
462
+ else:
463
+ non_template_args.append(args[i])
464
+ non_template_args = impl.make_expr_group(non_template_args)
465
+ compiling_callable = impl.get_runtime().compiling_callable
466
+ assert compiling_callable is not None
467
+ func_call = compiling_callable.ast_builder().insert_func_call(
468
+ self.gstaichi_functions[key.instance_id], non_template_args, dbg_info
469
+ )
470
+ if self.return_type is None:
471
+ return None
472
+ func_call = Expr(func_call)
473
+ ret = []
474
+
475
+ for i, return_type in enumerate(self.return_type):
476
+ if id(return_type) in primitive_types.type_ids:
477
+ ret.append(
478
+ Expr(
479
+ _ti_core.make_get_element_expr(
480
+ func_call.ptr, (i,), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
481
+ )
482
+ )
483
+ )
484
+ elif isinstance(return_type, (StructType, MatrixType)):
485
+ ret.append(return_type.from_gstaichi_object(func_call, (i,)))
486
+ else:
487
+ raise GsTaichiTypeError(f"Unsupported return type for return value {i}: {return_type}")
488
+ if len(ret) == 1:
489
+ return ret[0]
490
+ return tuple(ret)
491
+
492
+ def do_compile(self, key: FunctionKey, args: tuple[Any, ...], arg_features: tuple[Any, ...]) -> None:
493
+ tree, ctx = _get_tree_and_ctx(
494
+ self, is_kernel=False, args=args, arg_features=arg_features, is_real_function=self.is_real_function
495
+ )
496
+ fn = impl.get_runtime().prog.create_function(key)
497
+
498
+ def func_body():
499
+ old_callable = impl.get_runtime().compiling_callable
500
+ impl.get_runtime()._compiling_callable = fn
501
+ ctx.ast_builder = fn.ast_builder()
502
+ transform_tree(tree, ctx)
503
+ impl.get_runtime()._compiling_callable = old_callable
504
+
505
+ self.gstaichi_functions[key.instance_id] = fn
506
+ self.compiled[key.instance_id] = func_body
507
+ self.gstaichi_functions[key.instance_id].set_function_body(func_body)
508
+
509
+ def extract_arguments(self) -> None:
510
+ sig = inspect.signature(self.func)
511
+ if sig.return_annotation not in (inspect.Signature.empty, None):
512
+ self.return_type = sig.return_annotation
513
+ if (
514
+ isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias)) # type: ignore
515
+ and self.return_type.__origin__ is tuple # type: ignore
516
+ ):
517
+ self.return_type = self.return_type.__args__ # type: ignore
518
+ if self.return_type is None:
519
+ return
520
+ if not isinstance(self.return_type, (list, tuple)):
521
+ self.return_type = (self.return_type,)
522
+ for i, return_type in enumerate(self.return_type):
523
+ if return_type is Ellipsis:
524
+ raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
525
+ params = sig.parameters
526
+ arg_names = params.keys()
527
+ for i, arg_name in enumerate(arg_names):
528
+ param = params[arg_name]
529
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
530
+ raise GsTaichiSyntaxError(
531
+ "GsTaichi functions do not support variable keyword parameters (i.e., **kwargs)"
532
+ )
533
+ if param.kind == inspect.Parameter.VAR_POSITIONAL:
534
+ raise GsTaichiSyntaxError(
535
+ "GsTaichi functions do not support variable positional parameters (i.e., *args)"
536
+ )
537
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
538
+ raise GsTaichiSyntaxError("GsTaichi functions do not support keyword parameters")
539
+ if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
540
+ raise GsTaichiSyntaxError('GsTaichi functions only support "positional or keyword" parameters')
541
+ annotation = param.annotation
542
+ if annotation is inspect.Parameter.empty:
543
+ if i == 0 and self.classfunc:
544
+ annotation = template()
545
+ # TODO: pyfunc also need type annotation check when real function is enabled,
546
+ # but that has to happen at runtime when we know which scope it's called from.
547
+ elif not self.pyfunc and self.is_real_function:
548
+ raise GsTaichiSyntaxError(
549
+ f"GsTaichi function `{self.func.__name__}` parameter `{arg_name}` must be type annotated"
550
+ )
551
+ else:
552
+ if isinstance(annotation, ndarray_type.NdarrayType):
553
+ pass
554
+ elif isinstance(annotation, MatrixType):
555
+ pass
556
+ elif isinstance(annotation, StructType):
557
+ pass
558
+ elif id(annotation) in primitive_types.type_ids:
559
+ pass
560
+ elif type(annotation) == gstaichi.types.annotations.Template:
561
+ pass
562
+ elif isinstance(annotation, template) or annotation == gstaichi.types.annotations.Template:
563
+ pass
564
+ elif isinstance(annotation, primitive_types.RefType):
565
+ pass
566
+ elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
567
+ pass
568
+ else:
569
+ raise GsTaichiSyntaxError(
570
+ f"Invalid type annotation (argument {i}) of GsTaichi function: {annotation}"
571
+ )
572
+ self.arg_metas.append(ArgMetadata(annotation, param.name, param.default))
573
+ self.orig_arguments.append(ArgMetadata(annotation, param.name, param.default))
574
+
575
+
576
+ def _get_global_vars(_func: Callable) -> dict[str, Any]:
577
+ # Discussions: https://github.com/taichi-dev/gstaichi/issues/282
578
+ global_vars = _func.__globals__.copy()
579
+
580
+ freevar_names = _func.__code__.co_freevars
581
+ closure = _func.__closure__
582
+ if closure:
583
+ freevar_values = list(map(lambda x: x.cell_contents, closure))
584
+ for name, value in zip(freevar_names, freevar_values):
585
+ global_vars[name] = value
586
+
587
+ return global_vars
588
+
589
+
590
+ @dataclasses.dataclass
591
+ class SrcLlCacheObservations:
592
+ cache_key_generated: bool = False
593
+ cache_validated: bool = False
594
+ cache_loaded: bool = False
595
+ cache_stored: bool = False
596
+
597
+
598
+ @dataclasses.dataclass
599
+ class FeLlCacheObservations:
600
+ cache_hit: bool = False
601
+
602
+
603
+ class Kernel:
604
+ counter = 0
605
+
606
+ def __init__(self, _func: Callable, autodiff_mode: AutodiffMode, _classkernel=False) -> None:
607
+ self.func = _func
608
+ self.kernel_counter = Kernel.counter
609
+ Kernel.counter += 1
610
+ assert autodiff_mode in (
611
+ AutodiffMode.NONE,
612
+ AutodiffMode.VALIDATION,
613
+ AutodiffMode.FORWARD,
614
+ AutodiffMode.REVERSE,
615
+ )
616
+ self.autodiff_mode = autodiff_mode
617
+ self.grad: "Kernel | None" = None
618
+ self.arg_metas: list[ArgMetadata] = []
619
+ self.return_type = None
620
+ self.classkernel = _classkernel
621
+ self.extract_arguments()
622
+ self.template_slot_locations = []
623
+ for i, arg in enumerate(self.arg_metas):
624
+ if arg.annotation == template or isinstance(arg.annotation, template):
625
+ self.template_slot_locations.append(i)
626
+ self.mapper = TemplateMapper(self.arg_metas, self.template_slot_locations)
627
+ impl.get_runtime().kernels.append(self)
628
+ self.reset()
629
+ self.kernel_cpp = None
630
+ # A materialized kernel is a KernelCxx object which may or may not have
631
+ # been compiled. It generally has been converted at least as far as AST
632
+ # and front-end IR, but not necessarily any further.
633
+ self.materialized_kernels: dict[CompiledKernelKeyType, KernelCxx] = {}
634
+ self.has_print = False
635
+ self.gstaichi_callable: GsTaichiCallable | None = None
636
+ self.visited_functions: set[FunctionSourceInfo] = set()
637
+ self.kernel_function_info: FunctionSourceInfo | None = None
638
+ self.compiled_kernel_data_by_key: dict[CompiledKernelKeyType, CompiledKernelData] = {}
639
+ self._last_compiled_kernel_data: CompiledKernelData | None = None # for dev/debug
640
+
641
+ self.src_ll_cache_observations: SrcLlCacheObservations = SrcLlCacheObservations()
642
+ self.fe_ll_cache_observations: FeLlCacheObservations = FeLlCacheObservations()
643
+
644
+ def ast_builder(self) -> ASTBuilder:
645
+ assert self.kernel_cpp is not None
646
+ return self.kernel_cpp.ast_builder()
647
+
648
+ def reset(self) -> None:
649
+ self.runtime = impl.get_runtime()
650
+ self.materialized_kernels = {}
651
+
652
+ def extract_arguments(self) -> None:
653
+ sig = inspect.signature(self.func)
654
+ if sig.return_annotation not in (inspect._empty, None):
655
+ self.return_type = sig.return_annotation
656
+ if (
657
+ isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias)) # type: ignore
658
+ and self.return_type.__origin__ is tuple
659
+ ):
660
+ self.return_type = self.return_type.__args__
661
+ if not isinstance(self.return_type, (list, tuple)):
662
+ self.return_type = (self.return_type,)
663
+ for return_type in self.return_type:
664
+ if return_type is Ellipsis:
665
+ raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
666
+ params = dict(sig.parameters)
667
+ arg_names = params.keys()
668
+ for i, arg_name in enumerate(arg_names):
669
+ param = params[arg_name]
670
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
671
+ raise GsTaichiSyntaxError(
672
+ "GsTaichi kernels do not support variable keyword parameters (i.e., **kwargs)"
673
+ )
674
+ if param.kind == inspect.Parameter.VAR_POSITIONAL:
675
+ raise GsTaichiSyntaxError(
676
+ "GsTaichi kernels do not support variable positional parameters (i.e., *args)"
677
+ )
678
+ if param.default is not inspect.Parameter.empty:
679
+ raise GsTaichiSyntaxError("GsTaichi kernels do not support default values for arguments")
680
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
681
+ raise GsTaichiSyntaxError("GsTaichi kernels do not support keyword parameters")
682
+ if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
683
+ raise GsTaichiSyntaxError('GsTaichi kernels only support "positional or keyword" parameters')
684
+ annotation = param.annotation
685
+ if param.annotation is inspect.Parameter.empty:
686
+ if i == 0 and self.classkernel: # The |self| parameter
687
+ annotation = template()
688
+ else:
689
+ raise GsTaichiSyntaxError("GsTaichi kernels parameters must be type annotated")
690
+ else:
691
+ if isinstance(
692
+ annotation,
693
+ (
694
+ template,
695
+ ndarray_type.NdarrayType,
696
+ texture_type.TextureType,
697
+ texture_type.RWTextureType,
698
+ ),
699
+ ):
700
+ pass
701
+ elif annotation is ndarray_type.NdarrayType:
702
+ # convert from ti.types.NDArray into ti.types.NDArray()
703
+ annotation = annotation()
704
+ elif id(annotation) in primitive_types.type_ids:
705
+ pass
706
+ elif isinstance(annotation, sparse_matrix_builder):
707
+ pass
708
+ elif isinstance(annotation, MatrixType):
709
+ pass
710
+ elif isinstance(annotation, StructType):
711
+ pass
712
+ elif annotation == template:
713
+ pass
714
+ elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
715
+ pass
716
+ else:
717
+ raise GsTaichiSyntaxError(f"Invalid type annotation (argument {i}) of Taichi kernel: {annotation}")
718
+ self.arg_metas.append(ArgMetadata(annotation, param.name, param.default))
719
+
720
+ def materialize(self, key: CompiledKernelKeyType | None, args: tuple[Any, ...], arg_features=None):
721
+ if key is None:
722
+ key = (self.func, 0, self.autodiff_mode)
723
+ self.runtime.materialize()
724
+ self.fast_checksum = None
725
+
726
+ if key in self.materialized_kernels:
727
+ return
728
+
729
+ if self.runtime.src_ll_cache and self.gstaichi_callable and self.gstaichi_callable.is_pure:
730
+ kernel_source_info, _src = get_source_info_and_src(self.func)
731
+ self.fast_checksum = src_hasher.create_cache_key(kernel_source_info, args)
732
+ if self.fast_checksum:
733
+ self.src_ll_cache_observations.cache_key_generated = True
734
+ if self.fast_checksum and src_hasher.validate_cache_key(self.fast_checksum):
735
+ self.src_ll_cache_observations.cache_validated = True
736
+ prog = impl.get_runtime().prog
737
+ self.compiled_kernel_data_by_key[key] = prog.load_fast_cache(
738
+ self.fast_checksum,
739
+ self.func.__name__,
740
+ prog.config(),
741
+ prog.get_device_caps(),
742
+ )
743
+ if self.compiled_kernel_data_by_key[key]:
744
+ self.src_ll_cache_observations.cache_loaded = True
745
+ elif self.gstaichi_callable and not self.gstaichi_callable.is_pure and self.runtime.print_non_pure:
746
+ # The bit in caps should not be modified without updating corresponding test
747
+ # freetext can be freely modified.
748
+ # As for why we are using `print` rather than eg logger.info, it is because
749
+ # this is only printed when ti.init(print_non_pure=..) is True. And it is
750
+ # confusing to set that to True, and see nothing printed.
751
+ print(f"[NOT_PURE] Debug information: not pure: {self.func.__name__}")
752
+
753
+ kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}"
754
+ _logging.trace(f"Materializing kernel {kernel_name} in {self.autodiff_mode}...")
755
+
756
+ tree, ctx = _get_tree_and_ctx(
757
+ self,
758
+ args=args,
759
+ excluded_parameters=self.template_slot_locations,
760
+ arg_features=arg_features,
761
+ current_kernel=self,
762
+ )
763
+
764
+ if self.autodiff_mode != AutodiffMode.NONE:
765
+ KernelSimplicityASTChecker(self.func).visit(tree)
766
+
767
+ # Do not change the name of 'gstaichi_ast_generator'
768
+ # The warning system needs this identifier to remove unnecessary messages
769
+ def gstaichi_ast_generator(kernel_cxx: KernelCxx):
770
+ nonlocal tree
771
+ if self.runtime.inside_kernel:
772
+ raise GsTaichiSyntaxError(
773
+ "Kernels cannot call other kernels. I.e., nested kernels are not allowed. "
774
+ "Please check if you have direct/indirect invocation of kernels within kernels. "
775
+ "Note that some methods provided by the GsTaichi standard library may invoke kernels, "
776
+ "and please move their invocations to Python-scope."
777
+ )
778
+ self.kernel_cpp = kernel_cxx
779
+ self.runtime.inside_kernel = True
780
+ self.runtime._current_kernel = self
781
+ assert self.runtime._compiling_callable is None
782
+ self.runtime._compiling_callable = kernel_cxx
783
+ try:
784
+ ctx.ast_builder = kernel_cxx.ast_builder()
785
+
786
+ def ast_to_dict(node: ast.AST | list | primitive_types._python_primitive_types):
787
+ if isinstance(node, ast.AST):
788
+ fields = {k: ast_to_dict(v) for k, v in ast.iter_fields(node)}
789
+ return {
790
+ "type": node.__class__.__name__,
791
+ "fields": fields,
792
+ "lineno": getattr(node, "lineno", None),
793
+ "col_offset": getattr(node, "col_offset", None),
794
+ }
795
+ if isinstance(node, list):
796
+ return [ast_to_dict(x) for x in node]
797
+ return node # Basic types (str, int, None, etc.)
798
+
799
+ if os.environ.get("TI_DUMP_AST", "") == "1":
800
+ target_dir = pathlib.Path("/tmp/ast")
801
+ target_dir.mkdir(parents=True, exist_ok=True)
802
+
803
+ start = time.time()
804
+ ast_str = ast.dump(tree, indent=2)
805
+ output_file = target_dir / f"{kernel_name}_ast.txt"
806
+ output_file.write_text(ast_str)
807
+ elapsed_txt = time.time() - start
808
+
809
+ start = time.time()
810
+ json_str = json.dumps(ast_to_dict(tree), indent=2)
811
+ output_file = target_dir / f"{kernel_name}_ast.json"
812
+ output_file.write_text(json_str)
813
+ elapsed_json = time.time() - start
814
+
815
+ output_file = target_dir / f"{kernel_name}_gen_time.json"
816
+ output_file.write_text(
817
+ json.dumps({"elapsed_txt": elapsed_txt, "elapsed_json": elapsed_json}, indent=2)
818
+ )
819
+ struct_locals = _kernel_impl_dataclass.extract_struct_locals_from_context(ctx)
820
+ tree = _kernel_impl_dataclass.unpack_ast_struct_expressions(tree, struct_locals=struct_locals)
821
+ ctx.only_parse_function_def = self.compiled_kernel_data_by_key.get(key) is not None
822
+ transform_tree(tree, ctx)
823
+ if not ctx.is_real_function:
824
+ if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
825
+ raise GsTaichiSyntaxError("Kernel has a return type but does not have a return statement")
826
+ finally:
827
+ self.runtime.inside_kernel = False
828
+ self.runtime._current_kernel = None
829
+ self.runtime._compiling_callable = None
830
+
831
+ gstaichi_kernel = impl.get_runtime().prog.create_kernel(gstaichi_ast_generator, kernel_name, self.autodiff_mode)
832
+ assert key not in self.materialized_kernels
833
+ self.materialized_kernels[key] = gstaichi_kernel
834
+
835
+ def launch_kernel(self, t_kernel: KernelCxx, compiled_kernel_data: CompiledKernelData | None, *args) -> Any:
836
+ assert len(args) == len(self.arg_metas), f"{len(self.arg_metas)} arguments needed but {len(args)} provided"
837
+
838
+ tmps = []
839
+ callbacks = []
840
+
841
+ actual_argument_slot = 0
842
+ launch_ctx = t_kernel.make_launch_context()
843
+ max_arg_num = 512
844
+ exceed_max_arg_num = False
845
+
846
+ def set_arg_ndarray(indices: tuple[int, ...], v: gstaichi.lang._ndarray.Ndarray) -> None:
847
+ v_primal = v.arr
848
+ v_grad = v.grad.arr if v.grad else None
849
+ if v_grad is None:
850
+ launch_ctx.set_arg_ndarray(indices, v_primal) # type: ignore , solvable probably, just not today
851
+ else:
852
+ launch_ctx.set_arg_ndarray_with_grad(indices, v_primal, v_grad) # type: ignore
853
+
854
+ def set_arg_texture(indices: tuple[int, ...], v: gstaichi.lang._texture.Texture) -> None:
855
+ launch_ctx.set_arg_texture(indices, v.tex)
856
+
857
+ def set_arg_rw_texture(indices: tuple[int, ...], v: gstaichi.lang._texture.Texture) -> None:
858
+ launch_ctx.set_arg_rw_texture(indices, v.tex)
859
+
860
+ def set_arg_ext_array(indices: tuple[int, ...], v: Any, needed: ndarray_type.NdarrayType) -> None:
861
+ # v is things like torch Tensor and numpy array
862
+ # Not adding type for this, since adds additional dependencies
863
+ #
864
+ # Element shapes are already specialized in GsTaichi codegen.
865
+ # The shape information for element dims are no longer needed.
866
+ # Therefore we strip the element shapes from the shape vector,
867
+ # so that it only holds "real" array shapes.
868
+ is_soa = needed.layout == Layout.SOA
869
+ array_shape = v.shape
870
+ if functools.reduce(operator.mul, array_shape, 1) > np.iinfo(np.int32).max:
871
+ warnings.warn("Ndarray index might be out of int32 boundary but int64 indexing is not supported yet.")
872
+ if needed.dtype is None or id(needed.dtype) in primitive_types.type_ids:
873
+ element_dim = 0
874
+ else:
875
+ element_dim = needed.dtype.ndim
876
+ array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim]
877
+ if isinstance(v, np.ndarray):
878
+ # numpy
879
+ if v.flags.c_contiguous:
880
+ launch_ctx.set_arg_external_array_with_shape(indices, int(v.ctypes.data), v.nbytes, array_shape, 0)
881
+ elif v.flags.f_contiguous:
882
+ # TODO: A better way that avoids copying is saving strides info.
883
+ tmp = np.ascontiguousarray(v)
884
+ # Purpose: DO NOT GC |tmp|!
885
+ tmps.append(tmp)
886
+
887
+ def callback(original, updated):
888
+ np.copyto(original, np.asfortranarray(updated))
889
+
890
+ callbacks.append(functools.partial(callback, v, tmp))
891
+ launch_ctx.set_arg_external_array_with_shape(
892
+ indices, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0
893
+ )
894
+ else:
895
+ raise ValueError(
896
+ "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) "
897
+ "before passing it into gstaichi kernel."
898
+ )
899
+ elif has_pytorch():
900
+ import torch # pylint: disable=C0415
901
+
902
+ if isinstance(v, torch.Tensor):
903
+ if not v.is_contiguous():
904
+ raise ValueError(
905
+ "Non contiguous tensors are not supported, please call tensor.contiguous() before "
906
+ "passing it into gstaichi kernel."
907
+ )
908
+ gstaichi_arch = self.runtime.prog.config().arch
909
+
910
+ def get_call_back(u, v):
911
+ def call_back():
912
+ u.copy_(v)
913
+
914
+ return call_back
915
+
916
+ # FIXME: only allocate when launching grad kernel
917
+ if v.requires_grad and v.grad is None:
918
+ v.grad = torch.zeros_like(v)
919
+
920
+ if v.requires_grad:
921
+ if not isinstance(v.grad, torch.Tensor):
922
+ raise ValueError(
923
+ f"Expecting torch.Tensor for gradient tensor, but getting {v.grad.__class__.__name__} instead"
924
+ )
925
+ if not v.grad.is_contiguous():
926
+ raise ValueError(
927
+ "Non contiguous gradient tensors are not supported, please call tensor.grad.contiguous() before passing it into gstaichi kernel."
928
+ )
929
+
930
+ tmp = v
931
+ if (str(v.device) != "cpu") and not (
932
+ str(v.device).startswith("cuda") and gstaichi_arch == _ti_core.Arch.cuda
933
+ ):
934
+ # Getting a torch CUDA tensor on GsTaichi non-cuda arch:
935
+ # We just replace it with a CPU tensor and by the end of kernel execution we'll use the
936
+ # callback to copy the values back to the original CUDA tensor.
937
+ host_v = v.to(device="cpu", copy=True)
938
+ tmp = host_v
939
+ callbacks.append(get_call_back(v, host_v))
940
+
941
+ launch_ctx.set_arg_external_array_with_shape(
942
+ indices,
943
+ int(tmp.data_ptr()),
944
+ tmp.element_size() * tmp.nelement(),
945
+ array_shape,
946
+ int(v.grad.data_ptr()) if v.grad is not None else 0,
947
+ )
948
+ else:
949
+ raise GsTaichiRuntimeTypeError(
950
+ f"Argument of type {type(v)} cannot be converted into required type {needed}"
951
+ )
952
+ else:
953
+ raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
954
+
955
+ def set_arg_matrix(indices: tuple[int, ...], v, needed) -> None:
956
+ def cast_float(x: float | np.floating | np.integer | int) -> float:
957
+ if not isinstance(x, (int, float, np.integer, np.floating)):
958
+ raise GsTaichiRuntimeTypeError(
959
+ f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
960
+ )
961
+ return float(x)
962
+
963
+ def cast_int(x: int | np.integer) -> int:
964
+ if not isinstance(x, (int, np.integer)):
965
+ raise GsTaichiRuntimeTypeError(
966
+ f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
967
+ )
968
+ return int(x)
969
+
970
+ cast_func = None
971
+ if needed.dtype in primitive_types.real_types:
972
+ cast_func = cast_float
973
+ elif needed.dtype in primitive_types.integer_types:
974
+ cast_func = cast_int
975
+ else:
976
+ raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.")
977
+
978
+ if needed.ndim == 2:
979
+ v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)]
980
+ else:
981
+ v = [cast_func(v[i]) for i in range(needed.n)]
982
+ v = needed(*v)
983
+ needed.set_kernel_struct_args(v, launch_ctx, indices)
984
+
985
+ def set_arg_sparse_matrix_builder(indices: tuple[int, ...], v) -> None:
986
+ # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument
987
+ launch_ctx.set_arg_uint(indices, v._get_ndarray_addr())
988
+
989
+ set_later_list = []
990
+
991
+ def recursive_set_args(needed_arg_type: Type, provided_arg_type: Type, v: Any, indices: tuple[int, ...]) -> int:
992
+ """
993
+ Returns the number of kernel args set
994
+ e.g. templates don't set kernel args, so returns 0
995
+ a single ndarray is 1 kernel arg, so returns 1
996
+ a struct of 3 ndarrays would set 3 kernel args, so return 3
997
+ note: len(indices) > 1 only happens with argpack (which we are removing support for)
998
+ """
999
+ nonlocal actual_argument_slot, exceed_max_arg_num, set_later_list
1000
+ if actual_argument_slot >= max_arg_num:
1001
+ exceed_max_arg_num = True
1002
+ return 0
1003
+ actual_argument_slot += 1
1004
+ # Note: do not use sth like "needed == f32". That would be slow.
1005
+ if id(needed_arg_type) in primitive_types.real_type_ids:
1006
+ if not isinstance(v, (float, int, np.floating, np.integer)):
1007
+ raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
1008
+ launch_ctx.set_arg_float(indices, float(v))
1009
+ return 1
1010
+ if id(needed_arg_type) in primitive_types.integer_type_ids:
1011
+ if not isinstance(v, (int, np.integer)):
1012
+ raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
1013
+ if is_signed(cook_dtype(needed_arg_type)):
1014
+ launch_ctx.set_arg_int(indices, int(v))
1015
+ else:
1016
+ launch_ctx.set_arg_uint(indices, int(v))
1017
+ return 1
1018
+ if isinstance(needed_arg_type, sparse_matrix_builder):
1019
+ set_arg_sparse_matrix_builder(indices, v)
1020
+ return 1
1021
+ if dataclasses.is_dataclass(needed_arg_type):
1022
+ if provided_arg_type != needed_arg_type:
1023
+ raise GsTaichiRuntimeError("needed", needed_arg_type, "!= provided", provided_arg_type)
1024
+ assert provided_arg_type == needed_arg_type
1025
+ idx = 0
1026
+ for j, field in enumerate(dataclasses.fields(needed_arg_type)):
1027
+ assert not isinstance(field.type, str)
1028
+ field_value = getattr(v, field.name)
1029
+ idx += recursive_set_args(field.type, field.type, field_value, (indices[0] + idx,))
1030
+ return idx
1031
+ if isinstance(needed_arg_type, ndarray_type.NdarrayType) and isinstance(v, gstaichi.lang._ndarray.Ndarray):
1032
+ set_arg_ndarray(indices, v)
1033
+ return 1
1034
+ if isinstance(needed_arg_type, texture_type.TextureType) and isinstance(v, gstaichi.lang._texture.Texture):
1035
+ set_arg_texture(indices, v)
1036
+ return 1
1037
+ if isinstance(needed_arg_type, texture_type.RWTextureType) and isinstance(
1038
+ v, gstaichi.lang._texture.Texture
1039
+ ):
1040
+ set_arg_rw_texture(indices, v)
1041
+ return 1
1042
+ if isinstance(needed_arg_type, ndarray_type.NdarrayType):
1043
+ set_arg_ext_array(indices, v, needed_arg_type)
1044
+ return 1
1045
+ if isinstance(needed_arg_type, MatrixType):
1046
+ set_arg_matrix(indices, v, needed_arg_type)
1047
+ return 1
1048
+ if isinstance(needed_arg_type, StructType):
1049
+ # Unclear how to make the following pass typing checks
1050
+ # StructType implements __instancecheck__, which should be a classmethod, but
1051
+ # is currently an instance method
1052
+ # TODO: look into this more deeply at some point
1053
+ if not isinstance(v, needed_arg_type): # type: ignore
1054
+ raise GsTaichiRuntimeTypeError(
1055
+ f"Argument {provided_arg_type} cannot be converted into required type {needed_arg_type}"
1056
+ )
1057
+ needed_arg_type.set_kernel_struct_args(v, launch_ctx, indices)
1058
+ return 1
1059
+ if needed_arg_type == template or isinstance(needed_arg_type, template):
1060
+ return 0
1061
+ raise ValueError(f"Argument type mismatch. Expecting {needed_arg_type}, got {type(v)}.")
1062
+
1063
+ template_num = 0
1064
+ i_out = 0
1065
+ for i_in, val in enumerate(args):
1066
+ needed_ = self.arg_metas[i_in].annotation
1067
+ if needed_ == template or isinstance(needed_, template):
1068
+ template_num += 1
1069
+ i_out += 1
1070
+ continue
1071
+ i_out += recursive_set_args(needed_, type(val), val, (i_out - template_num,))
1072
+
1073
+ for i, (set_arg_func, params) in enumerate(set_later_list):
1074
+ set_arg_func((len(args) - template_num + i,), *params)
1075
+
1076
+ if exceed_max_arg_num:
1077
+ raise GsTaichiRuntimeError(
1078
+ f"The number of elements in kernel arguments is too big! Do not exceed {max_arg_num} on {_ti_core.arch_name(impl.current_cfg().arch)} backend."
1079
+ )
1080
+
1081
+ try:
1082
+ prog = impl.get_runtime().prog
1083
+ if not compiled_kernel_data:
1084
+ compile_result: CompileResult = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel)
1085
+ compiled_kernel_data = compile_result.compiled_kernel_data
1086
+ if compile_result.cache_hit:
1087
+ self.fe_ll_cache_observations.cache_hit = True
1088
+ if self.fast_checksum:
1089
+ src_hasher.store(self.fast_checksum, self.visited_functions)
1090
+ prog.store_fast_cache(
1091
+ self.fast_checksum,
1092
+ self.kernel_cpp,
1093
+ prog.config(),
1094
+ prog.get_device_caps(),
1095
+ compiled_kernel_data,
1096
+ )
1097
+ self.src_ll_cache_observations.cache_stored = True
1098
+ self._last_compiled_kernel_data = compiled_kernel_data
1099
+ prog.launch_kernel(compiled_kernel_data, launch_ctx)
1100
+ except Exception as e:
1101
+ e = handle_exception_from_cpp(e)
1102
+ if impl.get_runtime().print_full_traceback:
1103
+ raise e
1104
+ raise e from None
1105
+
1106
+ ret = None
1107
+ ret_dt = self.return_type
1108
+ has_ret = ret_dt is not None
1109
+
1110
+ if has_ret or self.has_print:
1111
+ runtime_ops.sync()
1112
+
1113
+ if has_ret:
1114
+ ret = []
1115
+ for i, ret_type in enumerate(ret_dt):
1116
+ ret.append(self.construct_kernel_ret(launch_ctx, ret_type, (i,)))
1117
+ if len(ret_dt) == 1:
1118
+ ret = ret[0]
1119
+ if callbacks:
1120
+ for c in callbacks:
1121
+ c()
1122
+
1123
+ return ret
1124
+
1125
+ def construct_kernel_ret(self, launch_ctx: KernelLaunchContext, ret_type: Any, index: tuple[int, ...] = ()):
1126
+ if isinstance(ret_type, CompoundType):
1127
+ return ret_type.from_kernel_struct_ret(launch_ctx, index)
1128
+ if ret_type in primitive_types.integer_types:
1129
+ if is_signed(cook_dtype(ret_type)):
1130
+ return launch_ctx.get_struct_ret_int(index)
1131
+ return launch_ctx.get_struct_ret_uint(index)
1132
+ if ret_type in primitive_types.real_types:
1133
+ return launch_ctx.get_struct_ret_float(index)
1134
+ raise GsTaichiRuntimeTypeError(f"Invalid return type on index={index}")
1135
+
1136
+ def ensure_compiled(self, *args: tuple[Any, ...]) -> tuple[Callable, int, AutodiffMode]:
1137
+ try:
1138
+ instance_id, arg_features = self.mapper.lookup(args)
1139
+ except Exception as e:
1140
+ raise type(e)(f"exception while trying to ensure compiled {self.func}:\n{e}") from e
1141
+ key = (self.func, instance_id, self.autodiff_mode)
1142
+ self.materialize(key=key, args=args, arg_features=arg_features)
1143
+ return key
1144
+
1145
+ # For small kernels (< 3us), the performance can be pretty sensitive to overhead in __call__
1146
+ # Thus this part needs to be fast. (i.e. < 3us on a 4 GHz x64 CPU)
1147
+ @_shell_pop_print
1148
+ def __call__(self, *args, **kwargs) -> Any:
1149
+ args = _process_args(self, is_func=False, args=args, kwargs=kwargs)
1150
+
1151
+ # Transform the primal kernel to forward mode grad kernel
1152
+ # then recover to primal when exiting the forward mode manager
1153
+ if self.runtime.fwd_mode_manager and not self.runtime.grad_replaced:
1154
+ # TODO: if we would like to compute 2nd-order derivatives by forward-on-reverse in a nested context manager fashion,
1155
+ # i.e., a `Tape` nested in the `FwdMode`, we can transform the kernels with `mode_original == AutodiffMode.REVERSE` only,
1156
+ # to avoid duplicate computation for 1st-order derivatives
1157
+ self.runtime.fwd_mode_manager.insert(self)
1158
+
1159
+ # Both the class kernels and the plain-function kernels are unified now.
1160
+ # In both cases, |self.grad| is another Kernel instance that computes the
1161
+ # gradient. For class kernels, args[0] is always the kernel owner.
1162
+
1163
+ # No need to capture grad kernels because they are already bound with their primal kernels
1164
+ if (
1165
+ self.autodiff_mode in (AutodiffMode.NONE, AutodiffMode.VALIDATION)
1166
+ and self.runtime.target_tape
1167
+ and not self.runtime.grad_replaced
1168
+ ):
1169
+ self.runtime.target_tape.insert(self, args)
1170
+
1171
+ if self.autodiff_mode != AutodiffMode.NONE and impl.current_cfg().opt_level == 0:
1172
+ _logging.warn("""opt_level = 1 is enforced to enable gradient computation.""")
1173
+ impl.current_cfg().opt_level = 1
1174
+ key = self.ensure_compiled(*args)
1175
+ kernel_cpp = self.materialized_kernels[key]
1176
+ compiled_kernel_data = self.compiled_kernel_data_by_key.get(key, None)
1177
+ return self.launch_kernel(kernel_cpp, compiled_kernel_data, *args)
1178
+
1179
+
1180
+ # For a GsTaichi class definition like below:
1181
+ #
1182
+ # @ti.data_oriented
1183
+ # class X:
1184
+ # @ti.kernel
1185
+ # def foo(self):
1186
+ # ...
1187
+ #
1188
+ # When ti.kernel runs, the stackframe's |code_context| of Python 3.8(+) is
1189
+ # different from that of Python 3.7 and below. In 3.8+, it is 'class X:',
1190
+ # whereas in <=3.7, it is '@ti.data_oriented'. More interestingly, if the class
1191
+ # inherits, i.e. class X(object):, then in both versions, |code_context| is
1192
+ # 'class X(object):'...
1193
+ _KERNEL_CLASS_STACKFRAME_STMT_RES = [
1194
+ re.compile(r"@(\w+\.)?data_oriented"),
1195
+ re.compile(r"class "),
1196
+ ]
1197
+
1198
+
1199
+ def _inside_class(level_of_class_stackframe: int) -> bool:
1200
+ try:
1201
+ maybe_class_frame = sys._getframe(level_of_class_stackframe)
1202
+ statement_list = inspect.getframeinfo(maybe_class_frame)[3]
1203
+ if statement_list is None:
1204
+ return False
1205
+ first_statment = statement_list[0].strip()
1206
+ for pat in _KERNEL_CLASS_STACKFRAME_STMT_RES:
1207
+ if pat.match(first_statment):
1208
+ return True
1209
+ except:
1210
+ pass
1211
+ return False
1212
+
1213
+
1214
+ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool = False) -> GsTaichiCallable:
1215
+ # Can decorators determine if a function is being defined inside a class?
1216
+ # https://stackoverflow.com/a/8793684/12003165
1217
+ is_classkernel = _inside_class(level_of_class_stackframe + 1)
1218
+
1219
+ if verbose:
1220
+ print(f"kernel={_func.__name__} is_classkernel={is_classkernel}")
1221
+ primal = Kernel(_func, autodiff_mode=AutodiffMode.NONE, _classkernel=is_classkernel)
1222
+ adjoint = Kernel(_func, autodiff_mode=AutodiffMode.REVERSE, _classkernel=is_classkernel)
1223
+ # Having |primal| contains |grad| makes the tape work.
1224
+ primal.grad = adjoint
1225
+
1226
+ wrapped: GsTaichiCallable
1227
+ if is_classkernel:
1228
+ # For class kernels, their primal/adjoint callables are constructed
1229
+ # when the kernel is accessed via the instance inside
1230
+ # _BoundedDifferentiableMethod.
1231
+ # This is because we need to bind the kernel or |grad| to the instance
1232
+ # owning the kernel, which is not known until the kernel is accessed.
1233
+ #
1234
+ # See also: _BoundedDifferentiableMethod, data_oriented.
1235
+ @functools.wraps(_func)
1236
+ def wrapped_classkernel(*args, **kwargs):
1237
+ # If we reach here (we should never), it means the class is not decorated
1238
+ # with @ti.data_oriented, otherwise getattr would have intercepted the call.
1239
+ clsobj = type(args[0])
1240
+ assert not hasattr(clsobj, "_data_oriented")
1241
+ raise GsTaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
1242
+
1243
+ wrapped = GsTaichiCallable(_func, wrapped_classkernel)
1244
+ else:
1245
+
1246
+ @functools.wraps(_func)
1247
+ def wrapped_func(*args, **kwargs):
1248
+ try:
1249
+ return primal(*args, **kwargs)
1250
+ except (GsTaichiCompilationError, GsTaichiRuntimeError) as e:
1251
+ if impl.get_runtime().print_full_traceback:
1252
+ raise e
1253
+ raise type(e)("\n" + str(e)) from None
1254
+
1255
+ wrapped = GsTaichiCallable(_func, wrapped_func)
1256
+ wrapped.grad = adjoint
1257
+
1258
+ wrapped._is_wrapped_kernel = True
1259
+ wrapped._is_classkernel = is_classkernel
1260
+ wrapped._primal = primal
1261
+ wrapped._adjoint = adjoint
1262
+ primal.gstaichi_callable = wrapped
1263
+ return wrapped
1264
+
1265
+
1266
+ def kernel(fn: Callable):
1267
+ """Marks a function as a GsTaichi kernel.
1268
+
1269
+ A GsTaichi kernel is a function written in Python, and gets JIT compiled by
1270
+ GsTaichi into native CPU/GPU instructions (e.g. a series of CUDA kernels).
1271
+ The top-level ``for`` loops are automatically parallelized, and distributed
1272
+ to either a CPU thread pool or massively parallel GPUs.
1273
+
1274
+ Kernel's gradient kernel would be generated automatically by the AutoDiff system.
1275
+
1276
+ See also https://docs.taichi-lang.org/docs/syntax#kernel.
1277
+
1278
+ Args:
1279
+ fn (Callable): the Python function to be decorated
1280
+
1281
+ Returns:
1282
+ Callable: The decorated function
1283
+
1284
+ Example::
1285
+
1286
+ >>> x = ti.field(ti.i32, shape=(4, 8))
1287
+ >>>
1288
+ >>> @ti.kernel
1289
+ >>> def run():
1290
+ >>> # Assigns all the elements of `x` in parallel.
1291
+ >>> for i in x:
1292
+ >>> x[i] = i
1293
+ """
1294
+ return _kernel_impl(fn, level_of_class_stackframe=3)
1295
+
1296
+
1297
+ class _BoundedDifferentiableMethod:
1298
+ def __init__(self, kernel_owner: Any, wrapped_kernel_func: GsTaichiCallable | BoundGsTaichiCallable):
1299
+ clsobj = type(kernel_owner)
1300
+ if not getattr(clsobj, "_data_oriented", False):
1301
+ raise GsTaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
1302
+ self._kernel_owner = kernel_owner
1303
+ self._primal = wrapped_kernel_func._primal
1304
+ self._adjoint = wrapped_kernel_func._adjoint
1305
+ self._is_staticmethod = wrapped_kernel_func._is_staticmethod
1306
+ self.__name__: str | None = None
1307
+
1308
+ def __call__(self, *args, **kwargs):
1309
+ try:
1310
+ assert self._primal is not None
1311
+ if self._is_staticmethod:
1312
+ return self._primal(*args, **kwargs)
1313
+ return self._primal(self._kernel_owner, *args, **kwargs)
1314
+
1315
+ except (GsTaichiCompilationError, GsTaichiRuntimeError) as e:
1316
+ if impl.get_runtime().print_full_traceback:
1317
+ raise e
1318
+ raise type(e)("\n" + str(e)) from None
1319
+
1320
+ def grad(self, *args, **kwargs) -> Kernel:
1321
+ assert self._adjoint is not None
1322
+ return self._adjoint(self._kernel_owner, *args, **kwargs)
1323
+
1324
+
1325
+ def data_oriented(cls):
1326
+ """Marks a class as GsTaichi compatible.
1327
+
1328
+ To allow for modularized code, GsTaichi provides this decorator so that
1329
+ GsTaichi kernels can be defined inside a class.
1330
+
1331
+ See also https://docs.taichi-lang.org/docs/odop
1332
+
1333
+ Example::
1334
+
1335
+ >>> @ti.data_oriented
1336
+ >>> class TiArray:
1337
+ >>> def __init__(self, n):
1338
+ >>> self.x = ti.field(ti.f32, shape=n)
1339
+ >>>
1340
+ >>> @ti.kernel
1341
+ >>> def inc(self):
1342
+ >>> for i in self.x:
1343
+ >>> self.x[i] += 1.0
1344
+ >>>
1345
+ >>> a = TiArray(32)
1346
+ >>> a.inc()
1347
+
1348
+ Args:
1349
+ cls (Class): the class to be decorated
1350
+
1351
+ Returns:
1352
+ The decorated class.
1353
+ """
1354
+
1355
+ def _getattr(self, item):
1356
+ method = cls.__dict__.get(item, None)
1357
+ is_property = method.__class__ == property
1358
+ is_staticmethod = method.__class__ == staticmethod
1359
+ if is_property:
1360
+ x = method.fget
1361
+ else:
1362
+ x = super(cls, self).__getattribute__(item)
1363
+ if hasattr(x, "_is_wrapped_kernel"):
1364
+ if inspect.ismethod(x):
1365
+ wrapped = x.__func__
1366
+ else:
1367
+ wrapped = x
1368
+ assert isinstance(wrapped, (BoundGsTaichiCallable, GsTaichiCallable))
1369
+ wrapped._is_staticmethod = is_staticmethod
1370
+ if wrapped._is_classkernel:
1371
+ ret = _BoundedDifferentiableMethod(self, wrapped)
1372
+ ret.__name__ = wrapped.__name__ # type: ignore
1373
+ if is_property:
1374
+ return ret()
1375
+ return ret
1376
+ if is_property:
1377
+ return x(self)
1378
+ return x
1379
+
1380
+ cls.__getattribute__ = _getattr
1381
+ cls._data_oriented = True
1382
+
1383
+ return cls
1384
+
1385
+
1386
+ __all__ = ["data_oriented", "func", "kernel", "pyfunc", "real_func"]