gstaichi 2.1.1rc3__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1341 @@
1
+ import ast
2
+ import dataclasses
3
+ import functools
4
+ import inspect
5
+ import json
6
+ import operator
7
+ import os
8
+ import pathlib
9
+ import re
10
+ import sys
11
+ import textwrap
12
+ import time
13
+ import types
14
+ import typing
15
+ import warnings
16
+ from typing import Any, Callable, Type
17
+
18
+ import numpy as np
19
+
20
+ import gstaichi.lang
21
+ import gstaichi.lang._ndarray
22
+ import gstaichi.lang._texture
23
+ import gstaichi.types.annotations
24
+ from gstaichi import _logging
25
+ from gstaichi._lib import core as _ti_core
26
+ from gstaichi._lib.core.gstaichi_python import (
27
+ ASTBuilder,
28
+ CompiledKernelData,
29
+ FunctionKey,
30
+ KernelCxx,
31
+ KernelLaunchContext,
32
+ )
33
+ from gstaichi.lang import _kernel_impl_dataclass, impl, ops, runtime_ops
34
+ from gstaichi.lang._fast_caching import src_hasher
35
+ from gstaichi.lang._template_mapper import TemplateMapper
36
+ from gstaichi.lang._wrap_inspect import FunctionSourceInfo, get_source_info_and_src
37
+ from gstaichi.lang.any_array import AnyArray
38
+ from gstaichi.lang.ast import (
39
+ ASTTransformerContext,
40
+ KernelSimplicityASTChecker,
41
+ transform_tree,
42
+ )
43
+ from gstaichi.lang.ast.ast_transformer_utils import ReturnStatus
44
+ from gstaichi.lang.exception import (
45
+ GsTaichiCompilationError,
46
+ GsTaichiRuntimeError,
47
+ GsTaichiRuntimeTypeError,
48
+ GsTaichiSyntaxError,
49
+ GsTaichiTypeError,
50
+ handle_exception_from_cpp,
51
+ )
52
+ from gstaichi.lang.expr import Expr
53
+ from gstaichi.lang.kernel_arguments import ArgMetadata
54
+ from gstaichi.lang.matrix import MatrixType
55
+ from gstaichi.lang.shell import _shell_pop_print
56
+ from gstaichi.lang.struct import StructType
57
+ from gstaichi.lang.util import cook_dtype, has_pytorch
58
+ from gstaichi.types import (
59
+ ndarray_type,
60
+ primitive_types,
61
+ sparse_matrix_builder,
62
+ template,
63
+ texture_type,
64
+ )
65
+ from gstaichi.types.compound_types import CompoundType
66
+ from gstaichi.types.enums import AutodiffMode, Layout
67
+ from gstaichi.types.utils import is_signed
68
+
69
+ CompiledKernelKeyType = tuple[Callable, int, AutodiffMode]
70
+
71
+
72
+ class GsTaichiCallable:
73
+ """
74
+ BoundGsTaichiCallable is used to enable wrapping a bindable function with a class.
75
+
76
+ Design requirements for GsTaichiCallable:
77
+ - wrap/contain a reference to a class Func instance, and allow (the GsTaichiCallable) being passed around
78
+ like normal function pointer
79
+ - expose attributes of the wrapped class Func, such as `_if_real_function`, `_primal`, etc
80
+ - allow for (now limited) strong typing, and enable type checkers, such as pyright/mypy
81
+ - currently GsTaichiCallable is a shared type used for all functions marked with @ti.func, @ti.kernel,
82
+ python functions (?)
83
+ - note: current type-checking implementation does not distinguish between different type flavors of
84
+ GsTaichiCallable, with different values of `_if_real_function`, `_primal`, etc
85
+ - handle not only class-less functions, but also class-instance methods (where determining the `self`
86
+ reference is a challenge)
87
+
88
+ Let's take the following example:
89
+
90
+ def test_ptr_class_func():
91
+ @ti.data_oriented
92
+ class MyClass:
93
+ def __init__(self):
94
+ self.a = ti.field(dtype=ti.f32, shape=(3))
95
+
96
+ def add2numbers_py(self, x, y):
97
+ return x + y
98
+
99
+ @ti.func
100
+ def add2numbers_func(self, x, y):
101
+ return x + y
102
+
103
+ @ti.kernel
104
+ def func(self):
105
+ a, add_py, add_func = ti.static(self.a, self.add2numbers_py, self.add2numbers_func)
106
+ a[0] = add_py(2, 3)
107
+ a[1] = add_func(3, 7)
108
+
109
+ (taken from test_ptr_assign.py).
110
+
111
+ When the @ti.func decorator is parsed, the function `add2numbers_func` exists, but there is not yet any `self`
112
+ - it is not possible for the method to be bound, to a `self` instance
113
+ - however, the @ti.func annotation, runs the kernel_imp.py::func function --- it is at this point
114
+ that GsTaichi's original code creates a class Func instance (that wraps the add2numbers_func)
115
+ and immediately we create a GsTaichiCallable instance that wraps the Func instance.
116
+ - effectively, we have two layers of wrapping GsTaichiCallable->Func->function pointer
117
+ (actual function definition)
118
+ - later on, when we call self.add2numbers_py, here:
119
+
120
+ a, add_py, add_func = ti.static(self.a, self.add2numbers_py, self.add2numbers_func)
121
+
122
+ ... we want to call the bound method, `self.add2numbers_py`.
123
+ - an actual python function reference, created by doing somevar = MyClass.add2numbers, can automatically
124
+ binds to self, when called from self in this way (however, add2numbers_py is actually a class
125
+ Func instance, wrapping python function reference -- now also all wrapped by a GsTaichiCallable
126
+ instance -- returned by the kernel_impl.py::func function, run by @ti.func)
127
+ - however, in order to be able to add strongly typed attributes to the wrapped python function, we need
128
+ to wrap the wrapped python function in a class
129
+ - the wrapped python function, wrapped in a GsTaichiCallable class (which is callable, and will
130
+ execute the underlying double-wrapped python function), will NOT automatically bind
131
+ - when we invoke GsTaichiCallable, the wrapped function is invoked. The wrapped function is unbound, and
132
+ so `self` is not automatically passed in, as an argument, and things break
133
+
134
+ To address this we need to use the `__get__` method, in our function wrapper, ie GsTaichiCallable,
135
+ and have the `__get__` method return the `BoundGsTaichiCallable` object. The `__get__` method handles
136
+ running the binding for us, and effectively binds `BoundFunc` object to `self` object, by passing
137
+ in the instance, as an argument into `BoundGsTaichiCallable.__init__`.
138
+
139
+ `BoundFunc` can then be used as a normal bound func - even though it's just an object instance -
140
+ using its `__call__` method. Effectively, at the time of actually invoking the underlying python
141
+ function, we have 3 layers of wrapper instances:
142
+ BoundGsTaichiCallabe -> GsTaichiCallable -> Func -> python function reference/definition
143
+ """
144
+
145
+ def __init__(self, fn: Callable, wrapper: Callable) -> None:
146
+ self.fn: Callable = fn
147
+ self.wrapper: Callable = wrapper
148
+ self._is_real_function: bool = False
149
+ self._is_gstaichi_function: bool = False
150
+ self._is_wrapped_kernel: bool = False
151
+ self._is_classkernel: bool = False
152
+ self._primal: Kernel | None = None
153
+ self._adjoint: Kernel | None = None
154
+ self.grad: Kernel | None = None
155
+ self._is_staticmethod: bool = False
156
+ self.is_pure = False
157
+ functools.update_wrapper(self, fn)
158
+
159
+ def __call__(self, *args, **kwargs):
160
+ return self.wrapper.__call__(*args, **kwargs)
161
+
162
+ def __get__(self, instance, owner):
163
+ if instance is None:
164
+ return self
165
+ return BoundGsTaichiCallable(instance, self)
166
+
167
+
168
+ class BoundGsTaichiCallable:
169
+ def __init__(self, instance: Any, gstaichi_callable: "GsTaichiCallable"):
170
+ self.wrapper = gstaichi_callable.wrapper
171
+ self.instance = instance
172
+ self.gstaichi_callable = gstaichi_callable
173
+
174
+ def __call__(self, *args, **kwargs):
175
+ return self.wrapper(self.instance, *args, **kwargs)
176
+
177
+ def __getattr__(self, k: str) -> Any:
178
+ res = getattr(self.gstaichi_callable, k)
179
+ return res
180
+
181
+ def __setattr__(self, k: str, v: Any) -> None:
182
+ # Note: these have to match the name of any attributes on this class.
183
+ if k in ("wrapper", "instance", "gstaichi_callable"):
184
+ object.__setattr__(self, k, v)
185
+ else:
186
+ setattr(self.gstaichi_callable, k, v)
187
+
188
+
189
+ def func(fn: Callable, is_real_function: bool = False) -> GsTaichiCallable:
190
+ """Marks a function as callable in GsTaichi-scope.
191
+
192
+ This decorator transforms a Python function into a GsTaichi one. GsTaichi
193
+ will JIT compile it into native instructions.
194
+
195
+ Args:
196
+ fn (Callable): The Python function to be decorated
197
+ is_real_function (bool): Whether the function is a real function
198
+
199
+ Returns:
200
+ Callable: The decorated function
201
+
202
+ Example::
203
+
204
+ >>> @ti.func
205
+ >>> def foo(x):
206
+ >>> return x + 2
207
+ >>>
208
+ >>> @ti.kernel
209
+ >>> def run():
210
+ >>> print(foo(40)) # 42
211
+ """
212
+ is_classfunc = _inside_class(level_of_class_stackframe=3 + is_real_function)
213
+
214
+ fun = Func(fn, _classfunc=is_classfunc, is_real_function=is_real_function)
215
+ gstaichi_callable = GsTaichiCallable(fn, fun)
216
+ gstaichi_callable._is_gstaichi_function = True
217
+ gstaichi_callable._is_real_function = is_real_function
218
+ return gstaichi_callable
219
+
220
+
221
+ def real_func(fn: Callable) -> GsTaichiCallable:
222
+ return func(fn, is_real_function=True)
223
+
224
+
225
+ def pyfunc(fn: Callable) -> GsTaichiCallable:
226
+ """Marks a function as callable in both GsTaichi and Python scopes.
227
+
228
+ When called inside the GsTaichi scope, GsTaichi will JIT compile it into
229
+ native instructions. Otherwise it will be invoked directly as a
230
+ Python function.
231
+
232
+ See also :func:`~gstaichi.lang.kernel_impl.func`.
233
+
234
+ Args:
235
+ fn (Callable): The Python function to be decorated
236
+
237
+ Returns:
238
+ Callable: The decorated function
239
+ """
240
+ is_classfunc = _inside_class(level_of_class_stackframe=3)
241
+ fun = Func(fn, _classfunc=is_classfunc, _pyfunc=True)
242
+ gstaichi_callable = GsTaichiCallable(fn, fun)
243
+ gstaichi_callable._is_gstaichi_function = True
244
+ gstaichi_callable._is_real_function = False
245
+ return gstaichi_callable
246
+
247
+
248
+ def _populate_global_vars_for_templates(
249
+ template_slot_locations: list[int],
250
+ argument_metas: list[ArgMetadata],
251
+ global_vars: dict[str, Any],
252
+ fn: Callable,
253
+ py_args: tuple[Any, ...],
254
+ ):
255
+ """
256
+ Inject template parameters into globals
257
+
258
+ Globals are being abused to store the python objects associated
259
+ with templates. We continue this approach, and in addition this function
260
+ handles injecting expanded python variables from dataclasses.
261
+ """
262
+ for i in template_slot_locations:
263
+ template_var_name = argument_metas[i].name
264
+ global_vars[template_var_name] = py_args[i]
265
+ parameters = inspect.signature(fn).parameters
266
+ for i, (parameter_name, parameter) in enumerate(parameters.items()):
267
+ if dataclasses.is_dataclass(parameter.annotation):
268
+ _kernel_impl_dataclass.populate_global_vars_from_dataclass(
269
+ parameter_name,
270
+ parameter.annotation,
271
+ py_args[i],
272
+ global_vars=global_vars,
273
+ )
274
+
275
+
276
+ def _get_tree_and_ctx(
277
+ self: "Func | Kernel",
278
+ args: tuple[Any, ...],
279
+ excluded_parameters=(),
280
+ is_kernel: bool = True,
281
+ arg_features=None,
282
+ ast_builder: "ASTBuilder | None" = None,
283
+ is_real_function: bool = False,
284
+ current_kernel: "Kernel | None" = None,
285
+ ) -> tuple[ast.Module, ASTTransformerContext]:
286
+ function_source_info, src = get_source_info_and_src(self.func)
287
+ src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
288
+ tree = ast.parse(textwrap.dedent("\n".join(src)))
289
+
290
+ func_body = tree.body[0]
291
+ func_body.decorator_list = [] # type: ignore , kick that can down the road...
292
+
293
+ global_vars = _get_global_vars(self.func)
294
+
295
+ if is_kernel or is_real_function:
296
+ _populate_global_vars_for_templates(
297
+ template_slot_locations=self.template_slot_locations,
298
+ argument_metas=self.arg_metas,
299
+ global_vars=global_vars,
300
+ fn=self.func,
301
+ py_args=args,
302
+ )
303
+
304
+ if current_kernel is not None: # Kernel
305
+ current_kernel.kernel_function_info = function_source_info
306
+ if current_kernel is None:
307
+ current_kernel = impl.get_runtime()._current_kernel
308
+ assert current_kernel is not None
309
+ current_kernel.visited_functions.add(function_source_info)
310
+
311
+ return tree, ASTTransformerContext(
312
+ excluded_parameters=excluded_parameters,
313
+ is_kernel=is_kernel,
314
+ func=self,
315
+ arg_features=arg_features,
316
+ global_vars=global_vars,
317
+ argument_data=args,
318
+ src=src,
319
+ start_lineno=function_source_info.start_lineno,
320
+ end_lineno=function_source_info.end_lineno,
321
+ file=function_source_info.filepath,
322
+ ast_builder=ast_builder,
323
+ is_real_function=is_real_function,
324
+ )
325
+
326
+
327
+ def _process_args(self: "Func | Kernel", is_func: bool, args: tuple[Any, ...], kwargs) -> tuple[Any, ...]:
328
+ if is_func:
329
+ self.arg_metas = _kernel_impl_dataclass.expand_func_arguments(self.arg_metas)
330
+
331
+ fused_args: list[Any] = [arg_meta.default for arg_meta in self.arg_metas]
332
+ len_args = len(args)
333
+
334
+ if len_args > len(fused_args):
335
+ arg_str = ", ".join(map(str, args))
336
+ expected_str = ", ".join(f"{arg.name} : {arg.annotation}" for arg in self.arg_metas)
337
+ msg = f"Too many arguments. Expected ({expected_str}), got ({arg_str})."
338
+ raise GsTaichiSyntaxError(msg)
339
+
340
+ for i, arg in enumerate(args):
341
+ fused_args[i] = arg
342
+
343
+ for key, value in kwargs.items():
344
+ for i, arg in enumerate(self.arg_metas):
345
+ if key == arg.name:
346
+ if i < len_args:
347
+ raise GsTaichiSyntaxError(f"Multiple values for argument '{key}'.")
348
+ fused_args[i] = value
349
+ break
350
+ else:
351
+ raise GsTaichiSyntaxError(f"Unexpected argument '{key}'.")
352
+
353
+ for i, arg in enumerate(fused_args):
354
+ if arg is inspect.Parameter.empty:
355
+ if self.arg_metas[i].annotation is inspect._empty:
356
+ raise GsTaichiSyntaxError(f"Parameter `{self.arg_metas[i].name}` missing.")
357
+ else:
358
+ raise GsTaichiSyntaxError(
359
+ f"Parameter `{self.arg_metas[i].name} : {self.arg_metas[i].annotation}` missing."
360
+ )
361
+
362
+ return tuple(fused_args)
363
+
364
+
365
+ class Func:
366
+ function_counter = 0
367
+
368
+ def __init__(self, _func: Callable, _classfunc=False, _pyfunc=False, is_real_function=False) -> None:
369
+ self.func = _func
370
+ self.func_id = Func.function_counter
371
+ Func.function_counter += 1
372
+ self.compiled = {}
373
+ self.classfunc = _classfunc
374
+ self.pyfunc = _pyfunc
375
+ self.is_real_function = is_real_function
376
+ self.arg_metas: list[ArgMetadata] = []
377
+ self.orig_arguments: list[ArgMetadata] = []
378
+ self.return_type: tuple[Type, ...] | None = None
379
+ self.extract_arguments()
380
+ self.template_slot_locations: list[int] = []
381
+ for i, arg in enumerate(self.arg_metas):
382
+ if arg.annotation == template or isinstance(arg.annotation, template):
383
+ self.template_slot_locations.append(i)
384
+ self.mapper = TemplateMapper(self.arg_metas, self.template_slot_locations)
385
+ self.gstaichi_functions = {} # The |Function| class in C++
386
+ self.has_print = False
387
+
388
+ def __call__(self: "Func", *args, **kwargs) -> Any:
389
+ args = _process_args(self, is_func=True, args=args, kwargs=kwargs)
390
+
391
+ if not impl.inside_kernel():
392
+ if not self.pyfunc:
393
+ raise GsTaichiSyntaxError("GsTaichi functions cannot be called from Python-scope.")
394
+ return self.func(*args)
395
+
396
+ current_kernel = impl.get_runtime().current_kernel
397
+ if self.is_real_function:
398
+ if current_kernel.autodiff_mode != AutodiffMode.NONE:
399
+ raise GsTaichiSyntaxError("Real function in gradient kernels unsupported.")
400
+ instance_id, arg_features = self.mapper.lookup(args)
401
+ key = _ti_core.FunctionKey(self.func.__name__, self.func_id, instance_id)
402
+ if key.instance_id not in self.compiled:
403
+ self.do_compile(key=key, args=args, arg_features=arg_features)
404
+ return self.func_call_rvalue(key=key, args=args)
405
+ tree, ctx = _get_tree_and_ctx(
406
+ self,
407
+ is_kernel=False,
408
+ args=args,
409
+ ast_builder=current_kernel.ast_builder(),
410
+ is_real_function=self.is_real_function,
411
+ )
412
+
413
+ struct_locals = _kernel_impl_dataclass.extract_struct_locals_from_context(ctx)
414
+
415
+ tree = _kernel_impl_dataclass.unpack_ast_struct_expressions(tree, struct_locals=struct_locals)
416
+ ret = transform_tree(tree, ctx)
417
+ if not self.is_real_function:
418
+ if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
419
+ raise GsTaichiSyntaxError("Function has a return type but does not have a return statement")
420
+ return ret
421
+
422
+ def func_call_rvalue(self, key: FunctionKey, args: tuple[Any, ...]) -> Any:
423
+ # Skip the template args, e.g., |self|
424
+ assert self.is_real_function
425
+ non_template_args = []
426
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
427
+ for i, kernel_arg in enumerate(self.arg_metas):
428
+ anno = kernel_arg.annotation
429
+ if not isinstance(anno, template):
430
+ if id(anno) in primitive_types.type_ids:
431
+ non_template_args.append(ops.cast(args[i], anno))
432
+ elif isinstance(anno, primitive_types.RefType):
433
+ non_template_args.append(_ti_core.make_reference(args[i].ptr, dbg_info))
434
+ elif isinstance(anno, ndarray_type.NdarrayType):
435
+ if not isinstance(args[i], AnyArray):
436
+ raise GsTaichiTypeError(
437
+ f"Expected ndarray in the kernel argument for argument {kernel_arg.name}, got {args[i]}"
438
+ )
439
+ non_template_args += _ti_core.get_external_tensor_real_func_args(args[i].ptr, dbg_info)
440
+ else:
441
+ non_template_args.append(args[i])
442
+ non_template_args = impl.make_expr_group(non_template_args)
443
+ compiling_callable = impl.get_runtime().compiling_callable
444
+ assert compiling_callable is not None
445
+ func_call = compiling_callable.ast_builder().insert_func_call(
446
+ self.gstaichi_functions[key.instance_id], non_template_args, dbg_info
447
+ )
448
+ if self.return_type is None:
449
+ return None
450
+ func_call = Expr(func_call)
451
+ ret = []
452
+
453
+ for i, return_type in enumerate(self.return_type):
454
+ if id(return_type) in primitive_types.type_ids:
455
+ ret.append(
456
+ Expr(
457
+ _ti_core.make_get_element_expr(
458
+ func_call.ptr, (i,), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
459
+ )
460
+ )
461
+ )
462
+ elif isinstance(return_type, (StructType, MatrixType)):
463
+ ret.append(return_type.from_gstaichi_object(func_call, (i,)))
464
+ else:
465
+ raise GsTaichiTypeError(f"Unsupported return type for return value {i}: {return_type}")
466
+ if len(ret) == 1:
467
+ return ret[0]
468
+ return tuple(ret)
469
+
470
+ def do_compile(self, key: FunctionKey, args: tuple[Any, ...], arg_features: tuple[Any, ...]) -> None:
471
+ tree, ctx = _get_tree_and_ctx(
472
+ self, is_kernel=False, args=args, arg_features=arg_features, is_real_function=self.is_real_function
473
+ )
474
+ fn = impl.get_runtime().prog.create_function(key)
475
+
476
+ def func_body():
477
+ old_callable = impl.get_runtime().compiling_callable
478
+ impl.get_runtime()._compiling_callable = fn
479
+ ctx.ast_builder = fn.ast_builder()
480
+ transform_tree(tree, ctx)
481
+ impl.get_runtime()._compiling_callable = old_callable
482
+
483
+ self.gstaichi_functions[key.instance_id] = fn
484
+ self.compiled[key.instance_id] = func_body
485
+ self.gstaichi_functions[key.instance_id].set_function_body(func_body)
486
+
487
+ def extract_arguments(self) -> None:
488
+ sig = inspect.signature(self.func)
489
+ if sig.return_annotation not in (inspect.Signature.empty, None):
490
+ self.return_type = sig.return_annotation
491
+ if (
492
+ isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias)) # type: ignore
493
+ and self.return_type.__origin__ is tuple # type: ignore
494
+ ):
495
+ self.return_type = self.return_type.__args__ # type: ignore
496
+ if self.return_type is None:
497
+ return
498
+ if not isinstance(self.return_type, (list, tuple)):
499
+ self.return_type = (self.return_type,)
500
+ for i, return_type in enumerate(self.return_type):
501
+ if return_type is Ellipsis:
502
+ raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
503
+ params = sig.parameters
504
+ arg_names = params.keys()
505
+ for i, arg_name in enumerate(arg_names):
506
+ param = params[arg_name]
507
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
508
+ raise GsTaichiSyntaxError(
509
+ "GsTaichi functions do not support variable keyword parameters (i.e., **kwargs)"
510
+ )
511
+ if param.kind == inspect.Parameter.VAR_POSITIONAL:
512
+ raise GsTaichiSyntaxError(
513
+ "GsTaichi functions do not support variable positional parameters (i.e., *args)"
514
+ )
515
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
516
+ raise GsTaichiSyntaxError("GsTaichi functions do not support keyword parameters")
517
+ if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
518
+ raise GsTaichiSyntaxError('GsTaichi functions only support "positional or keyword" parameters')
519
+ annotation = param.annotation
520
+ if annotation is inspect.Parameter.empty:
521
+ if i == 0 and self.classfunc:
522
+ annotation = template()
523
+ # TODO: pyfunc also need type annotation check when real function is enabled,
524
+ # but that has to happen at runtime when we know which scope it's called from.
525
+ elif not self.pyfunc and self.is_real_function:
526
+ raise GsTaichiSyntaxError(
527
+ f"GsTaichi function `{self.func.__name__}` parameter `{arg_name}` must be type annotated"
528
+ )
529
+ else:
530
+ if isinstance(annotation, ndarray_type.NdarrayType):
531
+ pass
532
+ elif isinstance(annotation, MatrixType):
533
+ pass
534
+ elif isinstance(annotation, StructType):
535
+ pass
536
+ elif id(annotation) in primitive_types.type_ids:
537
+ pass
538
+ elif type(annotation) == gstaichi.types.annotations.Template:
539
+ pass
540
+ elif isinstance(annotation, template) or annotation == gstaichi.types.annotations.Template:
541
+ pass
542
+ elif isinstance(annotation, primitive_types.RefType):
543
+ pass
544
+ elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
545
+ pass
546
+ else:
547
+ raise GsTaichiSyntaxError(
548
+ f"Invalid type annotation (argument {i}) of GsTaichi function: {annotation}"
549
+ )
550
+ self.arg_metas.append(ArgMetadata(annotation, param.name, param.default))
551
+ self.orig_arguments.append(ArgMetadata(annotation, param.name, param.default))
552
+
553
+
554
+ def _get_global_vars(_func: Callable) -> dict[str, Any]:
555
+ # Discussions: https://github.com/taichi-dev/gstaichi/issues/282
556
+ global_vars = _func.__globals__.copy()
557
+
558
+ freevar_names = _func.__code__.co_freevars
559
+ closure = _func.__closure__
560
+ if closure:
561
+ freevar_values = list(map(lambda x: x.cell_contents, closure))
562
+ for name, value in zip(freevar_names, freevar_values):
563
+ global_vars[name] = value
564
+
565
+ return global_vars
566
+
567
+
568
+ @dataclasses.dataclass
569
+ class SrcLlCacheObservations:
570
+ cache_key_generated: bool = False
571
+ cache_validated: bool = False
572
+ cache_loaded: bool = False
573
+ cache_stored: bool = False
574
+
575
+
576
+ class Kernel:
577
+ counter = 0
578
+
579
+ def __init__(self, _func: Callable, autodiff_mode: AutodiffMode, _classkernel=False) -> None:
580
+ self.func = _func
581
+ self.kernel_counter = Kernel.counter
582
+ Kernel.counter += 1
583
+ assert autodiff_mode in (
584
+ AutodiffMode.NONE,
585
+ AutodiffMode.VALIDATION,
586
+ AutodiffMode.FORWARD,
587
+ AutodiffMode.REVERSE,
588
+ )
589
+ self.autodiff_mode = autodiff_mode
590
+ self.grad: "Kernel | None" = None
591
+ self.arg_metas: list[ArgMetadata] = []
592
+ self.return_type = None
593
+ self.classkernel = _classkernel
594
+ self.extract_arguments()
595
+ self.template_slot_locations = []
596
+ for i, arg in enumerate(self.arg_metas):
597
+ if arg.annotation == template or isinstance(arg.annotation, template):
598
+ self.template_slot_locations.append(i)
599
+ self.mapper = TemplateMapper(self.arg_metas, self.template_slot_locations)
600
+ impl.get_runtime().kernels.append(self)
601
+ self.reset()
602
+ self.kernel_cpp = None
603
+ # A materialized kernel is a KernelCxx object which may or may not have
604
+ # been compiled. It generally has been converted at least as far as AST
605
+ # and front-end IR, but not necessarily any further.
606
+ self.materialized_kernels: dict[CompiledKernelKeyType, KernelCxx] = {}
607
+ self.has_print = False
608
+ self.gstaichi_callable: GsTaichiCallable | None = None
609
+ self.visited_functions: set[FunctionSourceInfo] = set()
610
+ self.kernel_function_info: FunctionSourceInfo | None = None
611
+ self.compiled_kernel_data_by_key: dict[CompiledKernelKeyType, CompiledKernelData] = {}
612
+
613
+ self.src_ll_cache_observations: SrcLlCacheObservations = SrcLlCacheObservations()
614
+
615
+ def ast_builder(self) -> ASTBuilder:
616
+ assert self.kernel_cpp is not None
617
+ return self.kernel_cpp.ast_builder()
618
+
619
+ def reset(self) -> None:
620
+ self.runtime = impl.get_runtime()
621
+ self.materialized_kernels = {}
622
+
623
+ def extract_arguments(self) -> None:
624
+ sig = inspect.signature(self.func)
625
+ if sig.return_annotation not in (inspect._empty, None):
626
+ self.return_type = sig.return_annotation
627
+ if (
628
+ isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias)) # type: ignore
629
+ and self.return_type.__origin__ is tuple
630
+ ):
631
+ self.return_type = self.return_type.__args__
632
+ if not isinstance(self.return_type, (list, tuple)):
633
+ self.return_type = (self.return_type,)
634
+ for return_type in self.return_type:
635
+ if return_type is Ellipsis:
636
+ raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
637
+ params = dict(sig.parameters)
638
+ arg_names = params.keys()
639
+ for i, arg_name in enumerate(arg_names):
640
+ param = params[arg_name]
641
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
642
+ raise GsTaichiSyntaxError(
643
+ "GsTaichi kernels do not support variable keyword parameters (i.e., **kwargs)"
644
+ )
645
+ if param.kind == inspect.Parameter.VAR_POSITIONAL:
646
+ raise GsTaichiSyntaxError(
647
+ "GsTaichi kernels do not support variable positional parameters (i.e., *args)"
648
+ )
649
+ if param.default is not inspect.Parameter.empty:
650
+ raise GsTaichiSyntaxError("GsTaichi kernels do not support default values for arguments")
651
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
652
+ raise GsTaichiSyntaxError("GsTaichi kernels do not support keyword parameters")
653
+ if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
654
+ raise GsTaichiSyntaxError('GsTaichi kernels only support "positional or keyword" parameters')
655
+ annotation = param.annotation
656
+ if param.annotation is inspect.Parameter.empty:
657
+ if i == 0 and self.classkernel: # The |self| parameter
658
+ annotation = template()
659
+ else:
660
+ raise GsTaichiSyntaxError("GsTaichi kernels parameters must be type annotated")
661
+ else:
662
+ if isinstance(
663
+ annotation,
664
+ (
665
+ template,
666
+ ndarray_type.NdarrayType,
667
+ texture_type.TextureType,
668
+ texture_type.RWTextureType,
669
+ ),
670
+ ):
671
+ pass
672
+ elif annotation is ndarray_type.NdarrayType:
673
+ # convert from ti.types.NDArray into ti.types.NDArray()
674
+ annotation = annotation()
675
+ elif id(annotation) in primitive_types.type_ids:
676
+ pass
677
+ elif isinstance(annotation, sparse_matrix_builder):
678
+ pass
679
+ elif isinstance(annotation, MatrixType):
680
+ pass
681
+ elif isinstance(annotation, StructType):
682
+ pass
683
+ elif annotation == template:
684
+ pass
685
+ elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
686
+ pass
687
+ else:
688
+ raise GsTaichiSyntaxError(f"Invalid type annotation (argument {i}) of Taichi kernel: {annotation}")
689
+ self.arg_metas.append(ArgMetadata(annotation, param.name, param.default))
690
+
691
+ def materialize(self, key: CompiledKernelKeyType | None, args: tuple[Any, ...], arg_features=None):
692
+ if key is None:
693
+ key = (self.func, 0, self.autodiff_mode)
694
+ self.runtime.materialize()
695
+ self.fast_checksum = None
696
+
697
+ if key in self.materialized_kernels:
698
+ return
699
+
700
+ if self.runtime.src_ll_cache and self.gstaichi_callable and self.gstaichi_callable.is_pure:
701
+ kernel_source_info, _src = get_source_info_and_src(self.func)
702
+ self.fast_checksum = src_hasher.create_cache_key(kernel_source_info, args)
703
+ if self.fast_checksum:
704
+ self.src_ll_cache_observations.cache_key_generated = True
705
+ if self.fast_checksum and src_hasher.validate_cache_key(self.fast_checksum):
706
+ self.src_ll_cache_observations.cache_validated = True
707
+ prog = impl.get_runtime().prog
708
+ self.compiled_kernel_data_by_key[key] = prog.load_fast_cache(
709
+ self.fast_checksum,
710
+ self.func.__name__,
711
+ prog.config(),
712
+ prog.get_device_caps(),
713
+ )
714
+ if self.compiled_kernel_data_by_key[key]:
715
+ self.src_ll_cache_observations.cache_loaded = True
716
+
717
+ kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}"
718
+ _logging.trace(f"Materializing kernel {kernel_name} in {self.autodiff_mode}...")
719
+
720
+ tree, ctx = _get_tree_and_ctx(
721
+ self,
722
+ args=args,
723
+ excluded_parameters=self.template_slot_locations,
724
+ arg_features=arg_features,
725
+ current_kernel=self,
726
+ )
727
+
728
+ if self.autodiff_mode != AutodiffMode.NONE:
729
+ KernelSimplicityASTChecker(self.func).visit(tree)
730
+
731
+ # Do not change the name of 'gstaichi_ast_generator'
732
+ # The warning system needs this identifier to remove unnecessary messages
733
+ def gstaichi_ast_generator(kernel_cxx: KernelCxx):
734
+ nonlocal tree
735
+ if self.runtime.inside_kernel:
736
+ raise GsTaichiSyntaxError(
737
+ "Kernels cannot call other kernels. I.e., nested kernels are not allowed. "
738
+ "Please check if you have direct/indirect invocation of kernels within kernels. "
739
+ "Note that some methods provided by the GsTaichi standard library may invoke kernels, "
740
+ "and please move their invocations to Python-scope."
741
+ )
742
+ self.kernel_cpp = kernel_cxx
743
+ self.runtime.inside_kernel = True
744
+ self.runtime._current_kernel = self
745
+ assert self.runtime._compiling_callable is None
746
+ self.runtime._compiling_callable = kernel_cxx
747
+ try:
748
+ ctx.ast_builder = kernel_cxx.ast_builder()
749
+
750
+ def ast_to_dict(node: ast.AST | list | primitive_types._python_primitive_types):
751
+ if isinstance(node, ast.AST):
752
+ fields = {k: ast_to_dict(v) for k, v in ast.iter_fields(node)}
753
+ return {
754
+ "type": node.__class__.__name__,
755
+ "fields": fields,
756
+ "lineno": getattr(node, "lineno", None),
757
+ "col_offset": getattr(node, "col_offset", None),
758
+ }
759
+ if isinstance(node, list):
760
+ return [ast_to_dict(x) for x in node]
761
+ return node # Basic types (str, int, None, etc.)
762
+
763
+ if os.environ.get("TI_DUMP_AST", "") == "1":
764
+ target_dir = pathlib.Path("/tmp/ast")
765
+ target_dir.mkdir(parents=True, exist_ok=True)
766
+
767
+ start = time.time()
768
+ ast_str = ast.dump(tree, indent=2)
769
+ output_file = target_dir / f"{kernel_name}_ast.txt"
770
+ output_file.write_text(ast_str)
771
+ elapsed_txt = time.time() - start
772
+
773
+ start = time.time()
774
+ json_str = json.dumps(ast_to_dict(tree), indent=2)
775
+ output_file = target_dir / f"{kernel_name}_ast.json"
776
+ output_file.write_text(json_str)
777
+ elapsed_json = time.time() - start
778
+
779
+ output_file = target_dir / f"{kernel_name}_gen_time.json"
780
+ output_file.write_text(
781
+ json.dumps({"elapsed_txt": elapsed_txt, "elapsed_json": elapsed_json}, indent=2)
782
+ )
783
+ struct_locals = _kernel_impl_dataclass.extract_struct_locals_from_context(ctx)
784
+ tree = _kernel_impl_dataclass.unpack_ast_struct_expressions(tree, struct_locals=struct_locals)
785
+ ctx.only_parse_function_def = self.compiled_kernel_data_by_key.get(key) is not None
786
+ transform_tree(tree, ctx)
787
+ if not ctx.is_real_function:
788
+ if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
789
+ raise GsTaichiSyntaxError("Kernel has a return type but does not have a return statement")
790
+ finally:
791
+ self.runtime.inside_kernel = False
792
+ self.runtime._current_kernel = None
793
+ self.runtime._compiling_callable = None
794
+
795
+ gstaichi_kernel = impl.get_runtime().prog.create_kernel(gstaichi_ast_generator, kernel_name, self.autodiff_mode)
796
+ assert key not in self.materialized_kernels
797
+ self.materialized_kernels[key] = gstaichi_kernel
798
+
799
+ def launch_kernel(self, t_kernel: KernelCxx, compiled_kernel_data: CompiledKernelData | None, *args) -> Any:
800
+ assert len(args) == len(self.arg_metas), f"{len(self.arg_metas)} arguments needed but {len(args)} provided"
801
+
802
+ tmps = []
803
+ callbacks = []
804
+
805
+ actual_argument_slot = 0
806
+ launch_ctx = t_kernel.make_launch_context()
807
+ max_arg_num = 512
808
+ exceed_max_arg_num = False
809
+
810
+ def set_arg_ndarray(indices: tuple[int, ...], v: gstaichi.lang._ndarray.Ndarray) -> None:
811
+ v_primal = v.arr
812
+ v_grad = v.grad.arr if v.grad else None
813
+ if v_grad is None:
814
+ launch_ctx.set_arg_ndarray(indices, v_primal) # type: ignore , solvable probably, just not today
815
+ else:
816
+ launch_ctx.set_arg_ndarray_with_grad(indices, v_primal, v_grad) # type: ignore
817
+
818
+ def set_arg_texture(indices: tuple[int, ...], v: gstaichi.lang._texture.Texture) -> None:
819
+ launch_ctx.set_arg_texture(indices, v.tex)
820
+
821
+ def set_arg_rw_texture(indices: tuple[int, ...], v: gstaichi.lang._texture.Texture) -> None:
822
+ launch_ctx.set_arg_rw_texture(indices, v.tex)
823
+
824
+ def set_arg_ext_array(indices: tuple[int, ...], v: Any, needed: ndarray_type.NdarrayType) -> None:
825
+ # v is things like torch Tensor and numpy array
826
+ # Not adding type for this, since adds additional dependencies
827
+ #
828
+ # Element shapes are already specialized in GsTaichi codegen.
829
+ # The shape information for element dims are no longer needed.
830
+ # Therefore we strip the element shapes from the shape vector,
831
+ # so that it only holds "real" array shapes.
832
+ is_soa = needed.layout == Layout.SOA
833
+ array_shape = v.shape
834
+ if functools.reduce(operator.mul, array_shape, 1) > np.iinfo(np.int32).max:
835
+ warnings.warn("Ndarray index might be out of int32 boundary but int64 indexing is not supported yet.")
836
+ if needed.dtype is None or id(needed.dtype) in primitive_types.type_ids:
837
+ element_dim = 0
838
+ else:
839
+ element_dim = needed.dtype.ndim
840
+ array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim]
841
+ if isinstance(v, np.ndarray):
842
+ # numpy
843
+ if v.flags.c_contiguous:
844
+ launch_ctx.set_arg_external_array_with_shape(indices, int(v.ctypes.data), v.nbytes, array_shape, 0)
845
+ elif v.flags.f_contiguous:
846
+ # TODO: A better way that avoids copying is saving strides info.
847
+ tmp = np.ascontiguousarray(v)
848
+ # Purpose: DO NOT GC |tmp|!
849
+ tmps.append(tmp)
850
+
851
+ def callback(original, updated):
852
+ np.copyto(original, np.asfortranarray(updated))
853
+
854
+ callbacks.append(functools.partial(callback, v, tmp))
855
+ launch_ctx.set_arg_external_array_with_shape(
856
+ indices, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0
857
+ )
858
+ else:
859
+ raise ValueError(
860
+ "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) "
861
+ "before passing it into gstaichi kernel."
862
+ )
863
+ elif has_pytorch():
864
+ import torch # pylint: disable=C0415
865
+
866
+ if isinstance(v, torch.Tensor):
867
+ if not v.is_contiguous():
868
+ raise ValueError(
869
+ "Non contiguous tensors are not supported, please call tensor.contiguous() before "
870
+ "passing it into gstaichi kernel."
871
+ )
872
+ gstaichi_arch = self.runtime.prog.config().arch
873
+
874
+ def get_call_back(u, v):
875
+ def call_back():
876
+ u.copy_(v)
877
+
878
+ return call_back
879
+
880
+ # FIXME: only allocate when launching grad kernel
881
+ if v.requires_grad and v.grad is None:
882
+ v.grad = torch.zeros_like(v)
883
+
884
+ if v.requires_grad:
885
+ if not isinstance(v.grad, torch.Tensor):
886
+ raise ValueError(
887
+ f"Expecting torch.Tensor for gradient tensor, but getting {v.grad.__class__.__name__} instead"
888
+ )
889
+ if not v.grad.is_contiguous():
890
+ raise ValueError(
891
+ "Non contiguous gradient tensors are not supported, please call tensor.grad.contiguous() before passing it into gstaichi kernel."
892
+ )
893
+
894
+ tmp = v
895
+ if (str(v.device) != "cpu") and not (
896
+ str(v.device).startswith("cuda") and gstaichi_arch == _ti_core.Arch.cuda
897
+ ):
898
+ # Getting a torch CUDA tensor on GsTaichi non-cuda arch:
899
+ # We just replace it with a CPU tensor and by the end of kernel execution we'll use the
900
+ # callback to copy the values back to the original CUDA tensor.
901
+ host_v = v.to(device="cpu", copy=True)
902
+ tmp = host_v
903
+ callbacks.append(get_call_back(v, host_v))
904
+
905
+ launch_ctx.set_arg_external_array_with_shape(
906
+ indices,
907
+ int(tmp.data_ptr()),
908
+ tmp.element_size() * tmp.nelement(),
909
+ array_shape,
910
+ int(v.grad.data_ptr()) if v.grad is not None else 0,
911
+ )
912
+ else:
913
+ raise GsTaichiRuntimeTypeError(
914
+ f"Argument of type {type(v)} cannot be converted into required type {needed}"
915
+ )
916
+ else:
917
+ raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
918
+
919
+ def set_arg_matrix(indices: tuple[int, ...], v, needed) -> None:
920
+ def cast_float(x: float | np.floating | np.integer | int) -> float:
921
+ if not isinstance(x, (int, float, np.integer, np.floating)):
922
+ raise GsTaichiRuntimeTypeError(
923
+ f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
924
+ )
925
+ return float(x)
926
+
927
+ def cast_int(x: int | np.integer) -> int:
928
+ if not isinstance(x, (int, np.integer)):
929
+ raise GsTaichiRuntimeTypeError(
930
+ f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
931
+ )
932
+ return int(x)
933
+
934
+ cast_func = None
935
+ if needed.dtype in primitive_types.real_types:
936
+ cast_func = cast_float
937
+ elif needed.dtype in primitive_types.integer_types:
938
+ cast_func = cast_int
939
+ else:
940
+ raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.")
941
+
942
+ if needed.ndim == 2:
943
+ v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)]
944
+ else:
945
+ v = [cast_func(v[i]) for i in range(needed.n)]
946
+ v = needed(*v)
947
+ needed.set_kernel_struct_args(v, launch_ctx, indices)
948
+
949
+ def set_arg_sparse_matrix_builder(indices: tuple[int, ...], v) -> None:
950
+ # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument
951
+ launch_ctx.set_arg_uint(indices, v._get_ndarray_addr())
952
+
953
+ set_later_list = []
954
+
955
+ def recursive_set_args(needed_arg_type: Type, provided_arg_type: Type, v: Any, indices: tuple[int, ...]) -> int:
956
+ """
957
+ Returns the number of kernel args set
958
+ e.g. templates don't set kernel args, so returns 0
959
+ a single ndarray is 1 kernel arg, so returns 1
960
+ a struct of 3 ndarrays would set 3 kernel args, so return 3
961
+ note: len(indices) > 1 only happens with argpack (which we are removing support for)
962
+ """
963
+ nonlocal actual_argument_slot, exceed_max_arg_num, set_later_list
964
+ if actual_argument_slot >= max_arg_num:
965
+ exceed_max_arg_num = True
966
+ return 0
967
+ actual_argument_slot += 1
968
+ # Note: do not use sth like "needed == f32". That would be slow.
969
+ if id(needed_arg_type) in primitive_types.real_type_ids:
970
+ if not isinstance(v, (float, int, np.floating, np.integer)):
971
+ raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
972
+ launch_ctx.set_arg_float(indices, float(v))
973
+ return 1
974
+ if id(needed_arg_type) in primitive_types.integer_type_ids:
975
+ if not isinstance(v, (int, np.integer)):
976
+ raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
977
+ if is_signed(cook_dtype(needed_arg_type)):
978
+ launch_ctx.set_arg_int(indices, int(v))
979
+ else:
980
+ launch_ctx.set_arg_uint(indices, int(v))
981
+ return 1
982
+ if isinstance(needed_arg_type, sparse_matrix_builder):
983
+ set_arg_sparse_matrix_builder(indices, v)
984
+ return 1
985
+ if dataclasses.is_dataclass(needed_arg_type):
986
+ assert provided_arg_type == needed_arg_type
987
+ idx = 0
988
+ for j, field in enumerate(dataclasses.fields(needed_arg_type)):
989
+ assert not isinstance(field.type, str)
990
+ field_value = getattr(v, field.name)
991
+ idx += recursive_set_args(field.type, field.type, field_value, (indices[0] + idx,))
992
+ return idx
993
+ if isinstance(needed_arg_type, ndarray_type.NdarrayType) and isinstance(v, gstaichi.lang._ndarray.Ndarray):
994
+ set_arg_ndarray(indices, v)
995
+ return 1
996
+ if isinstance(needed_arg_type, texture_type.TextureType) and isinstance(v, gstaichi.lang._texture.Texture):
997
+ set_arg_texture(indices, v)
998
+ return 1
999
+ if isinstance(needed_arg_type, texture_type.RWTextureType) and isinstance(
1000
+ v, gstaichi.lang._texture.Texture
1001
+ ):
1002
+ set_arg_rw_texture(indices, v)
1003
+ return 1
1004
+ if isinstance(needed_arg_type, ndarray_type.NdarrayType):
1005
+ set_arg_ext_array(indices, v, needed_arg_type)
1006
+ return 1
1007
+ if isinstance(needed_arg_type, MatrixType):
1008
+ set_arg_matrix(indices, v, needed_arg_type)
1009
+ return 1
1010
+ if isinstance(needed_arg_type, StructType):
1011
+ # Unclear how to make the following pass typing checks
1012
+ # StructType implements __instancecheck__, which should be a classmethod, but
1013
+ # is currently an instance method
1014
+ # TODO: look into this more deeply at some point
1015
+ if not isinstance(v, needed_arg_type): # type: ignore
1016
+ raise GsTaichiRuntimeTypeError(
1017
+ f"Argument {provided_arg_type} cannot be converted into required type {needed_arg_type}"
1018
+ )
1019
+ needed_arg_type.set_kernel_struct_args(v, launch_ctx, indices)
1020
+ return 1
1021
+ if needed_arg_type == template or isinstance(needed_arg_type, template):
1022
+ return 0
1023
+ raise ValueError(f"Argument type mismatch. Expecting {needed_arg_type}, got {type(v)}.")
1024
+
1025
+ template_num = 0
1026
+ i_out = 0
1027
+ for i_in, val in enumerate(args):
1028
+ needed_ = self.arg_metas[i_in].annotation
1029
+ if needed_ == template or isinstance(needed_, template):
1030
+ template_num += 1
1031
+ i_out += 1
1032
+ continue
1033
+ i_out += recursive_set_args(needed_, type(val), val, (i_out - template_num,))
1034
+
1035
+ for i, (set_arg_func, params) in enumerate(set_later_list):
1036
+ set_arg_func((len(args) - template_num + i,), *params)
1037
+
1038
+ if exceed_max_arg_num:
1039
+ raise GsTaichiRuntimeError(
1040
+ f"The number of elements in kernel arguments is too big! Do not exceed {max_arg_num} on {_ti_core.arch_name(impl.current_cfg().arch)} backend."
1041
+ )
1042
+
1043
+ try:
1044
+ prog = impl.get_runtime().prog
1045
+ if not compiled_kernel_data:
1046
+ compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel)
1047
+ if self.fast_checksum:
1048
+ src_hasher.store(self.fast_checksum, self.visited_functions)
1049
+ prog.store_fast_cache(
1050
+ self.fast_checksum,
1051
+ self.kernel_cpp,
1052
+ prog.config(),
1053
+ prog.get_device_caps(),
1054
+ compiled_kernel_data,
1055
+ )
1056
+ self.src_ll_cache_observations.cache_stored = True
1057
+ prog.launch_kernel(compiled_kernel_data, launch_ctx)
1058
+ except Exception as e:
1059
+ e = handle_exception_from_cpp(e)
1060
+ if impl.get_runtime().print_full_traceback:
1061
+ raise e
1062
+ raise e from None
1063
+
1064
+ ret = None
1065
+ ret_dt = self.return_type
1066
+ has_ret = ret_dt is not None
1067
+
1068
+ if has_ret or self.has_print:
1069
+ runtime_ops.sync()
1070
+
1071
+ if has_ret:
1072
+ ret = []
1073
+ for i, ret_type in enumerate(ret_dt):
1074
+ ret.append(self.construct_kernel_ret(launch_ctx, ret_type, (i,)))
1075
+ if len(ret_dt) == 1:
1076
+ ret = ret[0]
1077
+ if callbacks:
1078
+ for c in callbacks:
1079
+ c()
1080
+
1081
+ return ret
1082
+
1083
+ def construct_kernel_ret(self, launch_ctx: KernelLaunchContext, ret_type: Any, index: tuple[int, ...] = ()):
1084
+ if isinstance(ret_type, CompoundType):
1085
+ return ret_type.from_kernel_struct_ret(launch_ctx, index)
1086
+ if ret_type in primitive_types.integer_types:
1087
+ if is_signed(cook_dtype(ret_type)):
1088
+ return launch_ctx.get_struct_ret_int(index)
1089
+ return launch_ctx.get_struct_ret_uint(index)
1090
+ if ret_type in primitive_types.real_types:
1091
+ return launch_ctx.get_struct_ret_float(index)
1092
+ raise GsTaichiRuntimeTypeError(f"Invalid return type on index={index}")
1093
+
1094
+ def ensure_compiled(self, *args: tuple[Any, ...]) -> tuple[Callable, int, AutodiffMode]:
1095
+ instance_id, arg_features = self.mapper.lookup(args)
1096
+ key = (self.func, instance_id, self.autodiff_mode)
1097
+ self.materialize(key=key, args=args, arg_features=arg_features)
1098
+ return key
1099
+
1100
+ # For small kernels (< 3us), the performance can be pretty sensitive to overhead in __call__
1101
+ # Thus this part needs to be fast. (i.e. < 3us on a 4 GHz x64 CPU)
1102
+ @_shell_pop_print
1103
+ def __call__(self, *args, **kwargs) -> Any:
1104
+ args = _process_args(self, is_func=False, args=args, kwargs=kwargs)
1105
+
1106
+ # Transform the primal kernel to forward mode grad kernel
1107
+ # then recover to primal when exiting the forward mode manager
1108
+ if self.runtime.fwd_mode_manager and not self.runtime.grad_replaced:
1109
+ # TODO: if we would like to compute 2nd-order derivatives by forward-on-reverse in a nested context manager fashion,
1110
+ # i.e., a `Tape` nested in the `FwdMode`, we can transform the kernels with `mode_original == AutodiffMode.REVERSE` only,
1111
+ # to avoid duplicate computation for 1st-order derivatives
1112
+ self.runtime.fwd_mode_manager.insert(self)
1113
+
1114
+ # Both the class kernels and the plain-function kernels are unified now.
1115
+ # In both cases, |self.grad| is another Kernel instance that computes the
1116
+ # gradient. For class kernels, args[0] is always the kernel owner.
1117
+
1118
+ # No need to capture grad kernels because they are already bound with their primal kernels
1119
+ if (
1120
+ self.autodiff_mode in (AutodiffMode.NONE, AutodiffMode.VALIDATION)
1121
+ and self.runtime.target_tape
1122
+ and not self.runtime.grad_replaced
1123
+ ):
1124
+ self.runtime.target_tape.insert(self, args)
1125
+
1126
+ if self.autodiff_mode != AutodiffMode.NONE and impl.current_cfg().opt_level == 0:
1127
+ _logging.warn("""opt_level = 1 is enforced to enable gradient computation.""")
1128
+ impl.current_cfg().opt_level = 1
1129
+ key = self.ensure_compiled(*args)
1130
+ kernel_cpp = self.materialized_kernels[key]
1131
+ compiled_kernel_data = self.compiled_kernel_data_by_key.get(key, None)
1132
+ return self.launch_kernel(kernel_cpp, compiled_kernel_data, *args)
1133
+
1134
+
1135
+ # For a GsTaichi class definition like below:
1136
+ #
1137
+ # @ti.data_oriented
1138
+ # class X:
1139
+ # @ti.kernel
1140
+ # def foo(self):
1141
+ # ...
1142
+ #
1143
+ # When ti.kernel runs, the stackframe's |code_context| of Python 3.8(+) is
1144
+ # different from that of Python 3.7 and below. In 3.8+, it is 'class X:',
1145
+ # whereas in <=3.7, it is '@ti.data_oriented'. More interestingly, if the class
1146
+ # inherits, i.e. class X(object):, then in both versions, |code_context| is
1147
+ # 'class X(object):'...
1148
+ _KERNEL_CLASS_STACKFRAME_STMT_RES = [
1149
+ re.compile(r"@(\w+\.)?data_oriented"),
1150
+ re.compile(r"class "),
1151
+ ]
1152
+
1153
+
1154
+ def _inside_class(level_of_class_stackframe: int) -> bool:
1155
+ try:
1156
+ maybe_class_frame = sys._getframe(level_of_class_stackframe)
1157
+ statement_list = inspect.getframeinfo(maybe_class_frame)[3]
1158
+ if statement_list is None:
1159
+ return False
1160
+ first_statment = statement_list[0].strip()
1161
+ for pat in _KERNEL_CLASS_STACKFRAME_STMT_RES:
1162
+ if pat.match(first_statment):
1163
+ return True
1164
+ except:
1165
+ pass
1166
+ return False
1167
+
1168
+
1169
+ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool = False) -> GsTaichiCallable:
1170
+ # Can decorators determine if a function is being defined inside a class?
1171
+ # https://stackoverflow.com/a/8793684/12003165
1172
+ is_classkernel = _inside_class(level_of_class_stackframe + 1)
1173
+
1174
+ if verbose:
1175
+ print(f"kernel={_func.__name__} is_classkernel={is_classkernel}")
1176
+ primal = Kernel(_func, autodiff_mode=AutodiffMode.NONE, _classkernel=is_classkernel)
1177
+ adjoint = Kernel(_func, autodiff_mode=AutodiffMode.REVERSE, _classkernel=is_classkernel)
1178
+ # Having |primal| contains |grad| makes the tape work.
1179
+ primal.grad = adjoint
1180
+
1181
+ wrapped: GsTaichiCallable
1182
+ if is_classkernel:
1183
+ # For class kernels, their primal/adjoint callables are constructed
1184
+ # when the kernel is accessed via the instance inside
1185
+ # _BoundedDifferentiableMethod.
1186
+ # This is because we need to bind the kernel or |grad| to the instance
1187
+ # owning the kernel, which is not known until the kernel is accessed.
1188
+ #
1189
+ # See also: _BoundedDifferentiableMethod, data_oriented.
1190
+ @functools.wraps(_func)
1191
+ def wrapped_classkernel(*args, **kwargs):
1192
+ # If we reach here (we should never), it means the class is not decorated
1193
+ # with @ti.data_oriented, otherwise getattr would have intercepted the call.
1194
+ clsobj = type(args[0])
1195
+ assert not hasattr(clsobj, "_data_oriented")
1196
+ raise GsTaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
1197
+
1198
+ wrapped = GsTaichiCallable(_func, wrapped_classkernel)
1199
+ else:
1200
+
1201
+ @functools.wraps(_func)
1202
+ def wrapped_func(*args, **kwargs):
1203
+ try:
1204
+ return primal(*args, **kwargs)
1205
+ except (GsTaichiCompilationError, GsTaichiRuntimeError) as e:
1206
+ if impl.get_runtime().print_full_traceback:
1207
+ raise e
1208
+ raise type(e)("\n" + str(e)) from None
1209
+
1210
+ wrapped = GsTaichiCallable(_func, wrapped_func)
1211
+ wrapped.grad = adjoint
1212
+
1213
+ wrapped._is_wrapped_kernel = True
1214
+ wrapped._is_classkernel = is_classkernel
1215
+ wrapped._primal = primal
1216
+ wrapped._adjoint = adjoint
1217
+ primal.gstaichi_callable = wrapped
1218
+ return wrapped
1219
+
1220
+
1221
+ def kernel(fn: Callable):
1222
+ """Marks a function as a GsTaichi kernel.
1223
+
1224
+ A GsTaichi kernel is a function written in Python, and gets JIT compiled by
1225
+ GsTaichi into native CPU/GPU instructions (e.g. a series of CUDA kernels).
1226
+ The top-level ``for`` loops are automatically parallelized, and distributed
1227
+ to either a CPU thread pool or massively parallel GPUs.
1228
+
1229
+ Kernel's gradient kernel would be generated automatically by the AutoDiff system.
1230
+
1231
+ See also https://docs.taichi-lang.org/docs/syntax#kernel.
1232
+
1233
+ Args:
1234
+ fn (Callable): the Python function to be decorated
1235
+
1236
+ Returns:
1237
+ Callable: The decorated function
1238
+
1239
+ Example::
1240
+
1241
+ >>> x = ti.field(ti.i32, shape=(4, 8))
1242
+ >>>
1243
+ >>> @ti.kernel
1244
+ >>> def run():
1245
+ >>> # Assigns all the elements of `x` in parallel.
1246
+ >>> for i in x:
1247
+ >>> x[i] = i
1248
+ """
1249
+ return _kernel_impl(fn, level_of_class_stackframe=3)
1250
+
1251
+
1252
+ class _BoundedDifferentiableMethod:
1253
+ def __init__(self, kernel_owner: Any, wrapped_kernel_func: GsTaichiCallable | BoundGsTaichiCallable):
1254
+ clsobj = type(kernel_owner)
1255
+ if not getattr(clsobj, "_data_oriented", False):
1256
+ raise GsTaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
1257
+ self._kernel_owner = kernel_owner
1258
+ self._primal = wrapped_kernel_func._primal
1259
+ self._adjoint = wrapped_kernel_func._adjoint
1260
+ self._is_staticmethod = wrapped_kernel_func._is_staticmethod
1261
+ self.__name__: str | None = None
1262
+
1263
+ def __call__(self, *args, **kwargs):
1264
+ try:
1265
+ assert self._primal is not None
1266
+ if self._is_staticmethod:
1267
+ return self._primal(*args, **kwargs)
1268
+ return self._primal(self._kernel_owner, *args, **kwargs)
1269
+
1270
+ except (GsTaichiCompilationError, GsTaichiRuntimeError) as e:
1271
+ if impl.get_runtime().print_full_traceback:
1272
+ raise e
1273
+ raise type(e)("\n" + str(e)) from None
1274
+
1275
+ def grad(self, *args, **kwargs) -> Kernel:
1276
+ assert self._adjoint is not None
1277
+ return self._adjoint(self._kernel_owner, *args, **kwargs)
1278
+
1279
+
1280
+ def data_oriented(cls):
1281
+ """Marks a class as GsTaichi compatible.
1282
+
1283
+ To allow for modularized code, GsTaichi provides this decorator so that
1284
+ GsTaichi kernels can be defined inside a class.
1285
+
1286
+ See also https://docs.taichi-lang.org/docs/odop
1287
+
1288
+ Example::
1289
+
1290
+ >>> @ti.data_oriented
1291
+ >>> class TiArray:
1292
+ >>> def __init__(self, n):
1293
+ >>> self.x = ti.field(ti.f32, shape=n)
1294
+ >>>
1295
+ >>> @ti.kernel
1296
+ >>> def inc(self):
1297
+ >>> for i in self.x:
1298
+ >>> self.x[i] += 1.0
1299
+ >>>
1300
+ >>> a = TiArray(32)
1301
+ >>> a.inc()
1302
+
1303
+ Args:
1304
+ cls (Class): the class to be decorated
1305
+
1306
+ Returns:
1307
+ The decorated class.
1308
+ """
1309
+
1310
+ def _getattr(self, item):
1311
+ method = cls.__dict__.get(item, None)
1312
+ is_property = method.__class__ == property
1313
+ is_staticmethod = method.__class__ == staticmethod
1314
+ if is_property:
1315
+ x = method.fget
1316
+ else:
1317
+ x = super(cls, self).__getattribute__(item)
1318
+ if hasattr(x, "_is_wrapped_kernel"):
1319
+ if inspect.ismethod(x):
1320
+ wrapped = x.__func__
1321
+ else:
1322
+ wrapped = x
1323
+ assert isinstance(wrapped, (BoundGsTaichiCallable, GsTaichiCallable))
1324
+ wrapped._is_staticmethod = is_staticmethod
1325
+ if wrapped._is_classkernel:
1326
+ ret = _BoundedDifferentiableMethod(self, wrapped)
1327
+ ret.__name__ = wrapped.__name__ # type: ignore
1328
+ if is_property:
1329
+ return ret()
1330
+ return ret
1331
+ if is_property:
1332
+ return x(self)
1333
+ return x
1334
+
1335
+ cls.__getattribute__ = _getattr
1336
+ cls._data_oriented = True
1337
+
1338
+ return cls
1339
+
1340
+
1341
+ __all__ = ["data_oriented", "func", "kernel", "pyfunc", "real_func"]