gstaichi 0.1.23.dev0__cp310-cp310-win_amd64.whl → 0.1.25.dev0__cp310-cp310-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. gstaichi/CHANGELOG.md +9 -0
  2. {taichi → gstaichi}/__init__.py +9 -13
  3. {taichi → gstaichi}/_funcs.py +8 -8
  4. {taichi → gstaichi}/_kernels.py +19 -19
  5. gstaichi/_lib/__init__.py +3 -0
  6. taichi/_lib/core/taichi_python.cp310-win_amd64.pyd → gstaichi/_lib/core/gstaichi_python.cp310-win_amd64.pyd +0 -0
  7. taichi/_lib/core/taichi_python.pyi → gstaichi/_lib/core/gstaichi_python.pyi +382 -522
  8. {taichi → gstaichi}/_lib/runtime/runtime_cuda.bc +0 -0
  9. {taichi → gstaichi}/_lib/runtime/runtime_x64.bc +0 -0
  10. {taichi → gstaichi}/_lib/utils.py +15 -15
  11. {taichi → gstaichi}/_logging.py +1 -1
  12. {taichi → gstaichi}/_main.py +24 -31
  13. gstaichi/_snode/__init__.py +5 -0
  14. {taichi → gstaichi}/_snode/fields_builder.py +27 -29
  15. {taichi → gstaichi}/_snode/snode_tree.py +5 -5
  16. gstaichi/_test_tools/__init__.py +0 -0
  17. gstaichi/_test_tools/load_kernel_string.py +30 -0
  18. gstaichi/_version.py +1 -0
  19. {taichi → gstaichi}/_version_check.py +8 -5
  20. gstaichi/ad/__init__.py +3 -0
  21. {taichi → gstaichi}/ad/_ad.py +26 -26
  22. {taichi → gstaichi}/algorithms/_algorithms.py +7 -7
  23. {taichi → gstaichi}/examples/minimal.py +1 -1
  24. {taichi → gstaichi}/experimental.py +1 -1
  25. gstaichi/lang/__init__.py +50 -0
  26. {taichi → gstaichi}/lang/_ndarray.py +30 -26
  27. {taichi → gstaichi}/lang/_ndrange.py +8 -8
  28. gstaichi/lang/_template_mapper.py +199 -0
  29. {taichi → gstaichi}/lang/_texture.py +19 -19
  30. {taichi → gstaichi}/lang/_wrap_inspect.py +7 -7
  31. {taichi → gstaichi}/lang/any_array.py +13 -13
  32. {taichi → gstaichi}/lang/argpack.py +29 -29
  33. gstaichi/lang/ast/__init__.py +5 -0
  34. {taichi → gstaichi}/lang/ast/ast_transformer.py +94 -582
  35. {taichi → gstaichi}/lang/ast/ast_transformer_utils.py +54 -41
  36. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  37. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  38. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  39. {taichi → gstaichi}/lang/ast/checkers.py +5 -5
  40. gstaichi/lang/ast/transform.py +9 -0
  41. {taichi → gstaichi}/lang/common_ops.py +12 -12
  42. gstaichi/lang/exception.py +80 -0
  43. {taichi → gstaichi}/lang/expr.py +22 -22
  44. {taichi → gstaichi}/lang/field.py +29 -27
  45. {taichi → gstaichi}/lang/impl.py +116 -121
  46. {taichi → gstaichi}/lang/kernel_arguments.py +16 -16
  47. {taichi → gstaichi}/lang/kernel_impl.py +330 -363
  48. {taichi → gstaichi}/lang/matrix.py +119 -115
  49. {taichi → gstaichi}/lang/matrix_ops.py +6 -6
  50. {taichi → gstaichi}/lang/matrix_ops_utils.py +4 -4
  51. {taichi → gstaichi}/lang/mesh.py +22 -22
  52. {taichi → gstaichi}/lang/misc.py +39 -68
  53. {taichi → gstaichi}/lang/ops.py +146 -141
  54. {taichi → gstaichi}/lang/runtime_ops.py +2 -2
  55. {taichi → gstaichi}/lang/shell.py +3 -3
  56. {taichi → gstaichi}/lang/simt/__init__.py +1 -1
  57. {taichi → gstaichi}/lang/simt/block.py +7 -7
  58. {taichi → gstaichi}/lang/simt/grid.py +1 -1
  59. {taichi → gstaichi}/lang/simt/subgroup.py +1 -1
  60. {taichi → gstaichi}/lang/simt/warp.py +1 -1
  61. {taichi → gstaichi}/lang/snode.py +46 -44
  62. {taichi → gstaichi}/lang/source_builder.py +13 -13
  63. {taichi → gstaichi}/lang/struct.py +33 -33
  64. {taichi → gstaichi}/lang/util.py +24 -24
  65. gstaichi/linalg/__init__.py +8 -0
  66. {taichi → gstaichi}/linalg/matrixfree_cg.py +14 -14
  67. {taichi → gstaichi}/linalg/sparse_cg.py +10 -10
  68. {taichi → gstaichi}/linalg/sparse_matrix.py +23 -23
  69. {taichi → gstaichi}/linalg/sparse_solver.py +21 -21
  70. {taichi → gstaichi}/math/__init__.py +1 -1
  71. {taichi → gstaichi}/math/_complex.py +21 -20
  72. {taichi → gstaichi}/math/mathimpl.py +56 -56
  73. gstaichi/profiler/__init__.py +6 -0
  74. {taichi → gstaichi}/profiler/kernel_metrics.py +11 -11
  75. {taichi → gstaichi}/profiler/kernel_profiler.py +30 -36
  76. {taichi → gstaichi}/profiler/memory_profiler.py +1 -1
  77. {taichi → gstaichi}/profiler/scoped_profiler.py +2 -2
  78. {taichi → gstaichi}/sparse/_sparse_grid.py +7 -7
  79. {taichi → gstaichi}/tools/__init__.py +4 -4
  80. {taichi → gstaichi}/tools/diagnose.py +10 -17
  81. gstaichi/types/__init__.py +19 -0
  82. {taichi → gstaichi}/types/annotations.py +1 -1
  83. {taichi → gstaichi}/types/compound_types.py +8 -8
  84. {taichi → gstaichi}/types/enums.py +1 -1
  85. {taichi → gstaichi}/types/ndarray_type.py +7 -7
  86. {taichi → gstaichi}/types/primitive_types.py +17 -14
  87. {taichi → gstaichi}/types/quant.py +9 -9
  88. {taichi → gstaichi}/types/texture_type.py +5 -5
  89. {taichi → gstaichi}/types/utils.py +1 -1
  90. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/bin/SPIRV-Tools-shared.dll +0 -0
  91. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-diff.lib +0 -0
  92. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-link.lib +0 -0
  93. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-lint.lib +0 -0
  94. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-opt.lib +0 -0
  95. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-reduce.lib +0 -0
  96. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools-shared.lib +0 -0
  97. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/lib/SPIRV-Tools.lib +0 -0
  98. {gstaichi-0.1.23.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/METADATA +13 -16
  99. gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
  100. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  101. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
  102. gstaichi-0.1.23.dev0.data/data/include/GLFW/glfw3.h +0 -6389
  103. gstaichi-0.1.23.dev0.data/data/include/GLFW/glfw3native.h +0 -594
  104. gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Config.cmake +0 -3
  105. gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +0 -65
  106. gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +0 -19
  107. gstaichi-0.1.23.dev0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +0 -107
  108. gstaichi-0.1.23.dev0.data/data/lib/glfw3.lib +0 -0
  109. gstaichi-0.1.23.dev0.dist-info/RECORD +0 -198
  110. gstaichi-0.1.23.dev0.dist-info/entry_points.txt +0 -2
  111. gstaichi-0.1.23.dev0.dist-info/top_level.txt +0 -1
  112. taichi/CHANGELOG.md +0 -20
  113. taichi/_lib/__init__.py +0 -3
  114. taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
  115. taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +0 -1401
  116. taichi/_lib/c_api/include/taichi/taichi.h +0 -29
  117. taichi/_lib/c_api/include/taichi/taichi_core.h +0 -1111
  118. taichi/_lib/c_api/include/taichi/taichi_cpu.h +0 -29
  119. taichi/_lib/c_api/include/taichi/taichi_cuda.h +0 -36
  120. taichi/_lib/c_api/include/taichi/taichi_platform.h +0 -55
  121. taichi/_lib/c_api/include/taichi/taichi_unity.h +0 -64
  122. taichi/_lib/c_api/include/taichi/taichi_vulkan.h +0 -151
  123. taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
  124. taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
  125. taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
  126. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +0 -29
  127. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +0 -65
  128. taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +0 -121
  129. taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  130. taichi/_snode/__init__.py +0 -5
  131. taichi/_ti_module/__init__.py +0 -3
  132. taichi/_ti_module/cppgen.py +0 -309
  133. taichi/_ti_module/module.py +0 -145
  134. taichi/_version.py +0 -1
  135. taichi/ad/__init__.py +0 -3
  136. taichi/aot/__init__.py +0 -12
  137. taichi/aot/_export.py +0 -28
  138. taichi/aot/conventions/__init__.py +0 -3
  139. taichi/aot/conventions/gfxruntime140/__init__.py +0 -38
  140. taichi/aot/conventions/gfxruntime140/dr.py +0 -244
  141. taichi/aot/conventions/gfxruntime140/sr.py +0 -613
  142. taichi/aot/module.py +0 -253
  143. taichi/aot/utils.py +0 -151
  144. taichi/graph/__init__.py +0 -3
  145. taichi/graph/_graph.py +0 -292
  146. taichi/lang/__init__.py +0 -50
  147. taichi/lang/ast/__init__.py +0 -5
  148. taichi/lang/ast/transform.py +0 -9
  149. taichi/lang/exception.py +0 -80
  150. taichi/linalg/__init__.py +0 -8
  151. taichi/profiler/__init__.py +0 -6
  152. taichi/shaders/Circles_vk.frag +0 -29
  153. taichi/shaders/Circles_vk.vert +0 -45
  154. taichi/shaders/Circles_vk_frag.spv +0 -0
  155. taichi/shaders/Circles_vk_vert.spv +0 -0
  156. taichi/shaders/Lines_vk.frag +0 -9
  157. taichi/shaders/Lines_vk.vert +0 -11
  158. taichi/shaders/Lines_vk_frag.spv +0 -0
  159. taichi/shaders/Lines_vk_vert.spv +0 -0
  160. taichi/shaders/Mesh_vk.frag +0 -71
  161. taichi/shaders/Mesh_vk.vert +0 -68
  162. taichi/shaders/Mesh_vk_frag.spv +0 -0
  163. taichi/shaders/Mesh_vk_vert.spv +0 -0
  164. taichi/shaders/Particles_vk.frag +0 -95
  165. taichi/shaders/Particles_vk.vert +0 -73
  166. taichi/shaders/Particles_vk_frag.spv +0 -0
  167. taichi/shaders/Particles_vk_vert.spv +0 -0
  168. taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
  169. taichi/shaders/SceneLines_vk.frag +0 -9
  170. taichi/shaders/SceneLines_vk.vert +0 -12
  171. taichi/shaders/SceneLines_vk_frag.spv +0 -0
  172. taichi/shaders/SceneLines_vk_vert.spv +0 -0
  173. taichi/shaders/SetImage_vk.frag +0 -21
  174. taichi/shaders/SetImage_vk.vert +0 -15
  175. taichi/shaders/SetImage_vk_frag.spv +0 -0
  176. taichi/shaders/SetImage_vk_vert.spv +0 -0
  177. taichi/shaders/Triangles_vk.frag +0 -16
  178. taichi/shaders/Triangles_vk.vert +0 -29
  179. taichi/shaders/Triangles_vk_frag.spv +0 -0
  180. taichi/shaders/Triangles_vk_vert.spv +0 -0
  181. taichi/shaders/lines2quad_vk_comp.spv +0 -0
  182. taichi/types/__init__.py +0 -19
  183. {taichi → gstaichi}/__main__.py +0 -0
  184. {taichi → gstaichi}/_lib/core/__init__.py +0 -0
  185. {taichi → gstaichi}/_lib/core/py.typed +0 -0
  186. {taichi/_lib/c_api → gstaichi/_lib}/runtime/slim_libdevice.10.bc +0 -0
  187. {taichi → gstaichi}/algorithms/__init__.py +0 -0
  188. {taichi → gstaichi}/assets/.git +0 -0
  189. {taichi → gstaichi}/assets/Go-Regular.ttf +0 -0
  190. {taichi → gstaichi}/assets/static/imgs/ti_gallery.png +0 -0
  191. {taichi → gstaichi}/lang/ast/symbol_resolver.py +0 -0
  192. {taichi → gstaichi}/sparse/__init__.py +0 -0
  193. {taichi → gstaichi}/tools/np2ply.py +0 -0
  194. {taichi → gstaichi}/tools/vtk.py +0 -0
  195. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +0 -0
  196. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +0 -0
  197. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +0 -0
  198. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +0 -0
  199. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +0 -0
  200. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +0 -0
  201. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +0 -0
  202. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +0 -0
  203. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +0 -0
  204. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +0 -0
  205. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +0 -0
  206. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +0 -0
  207. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +0 -0
  208. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +0 -0
  209. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +0 -0
  210. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +0 -0
  211. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +0 -0
  212. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +0 -0
  213. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/instrument.hpp +0 -0
  214. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.h +0 -0
  215. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
  216. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/linker.hpp +0 -0
  217. {gstaichi-0.1.23.dev0.data → gstaichi-0.1.25.dev0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
  218. {gstaichi-0.1.23.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/WHEEL +0 -0
  219. {gstaichi-0.1.23.dev0.dist-info → gstaichi-0.1.25.dev0.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,3 @@
1
- # type: ignore
2
-
3
1
  import ast
4
2
  import dataclasses
5
3
  import functools
@@ -15,60 +13,181 @@ import time
15
13
  import types
16
14
  import typing
17
15
  import warnings
18
- import weakref
19
- from typing import Any, Callable, Type, Union
16
+ from typing import Any, Callable, Type
20
17
 
21
18
  import numpy as np
22
19
 
23
- import taichi.lang
24
- import taichi.lang._ndarray
25
- import taichi.lang._texture
26
- import taichi.lang.expr
27
- import taichi.lang.snode
28
- import taichi.types.annotations
29
- from taichi import _logging
30
- from taichi._lib import core as _ti_core
31
- from taichi._lib.core.taichi_python import ASTBuilder
32
- from taichi.lang import impl, ops, runtime_ops
33
- from taichi.lang._wrap_inspect import getsourcefile, getsourcelines
34
- from taichi.lang.any_array import AnyArray
35
- from taichi.lang.argpack import ArgPack, ArgPackType
36
- from taichi.lang.ast import (
20
+ import gstaichi.lang
21
+ import gstaichi.lang._ndarray
22
+ import gstaichi.lang._texture
23
+ import gstaichi.types.annotations
24
+ from gstaichi import _logging
25
+ from gstaichi._lib import core as _ti_core
26
+ from gstaichi._lib.core.gstaichi_python import (
27
+ ASTBuilder,
28
+ FunctionKey,
29
+ KernelCxx,
30
+ KernelLaunchContext,
31
+ )
32
+ from gstaichi.lang import impl, ops, runtime_ops
33
+ from gstaichi.lang._template_mapper import GsTaichiCallableTemplateMapper
34
+ from gstaichi.lang._wrap_inspect import getsourcefile, getsourcelines
35
+ from gstaichi.lang.any_array import AnyArray
36
+ from gstaichi.lang.argpack import ArgPack, ArgPackType
37
+ from gstaichi.lang.ast import (
37
38
  ASTTransformerContext,
38
39
  KernelSimplicityASTChecker,
39
40
  transform_tree,
40
41
  )
41
- from taichi.lang.ast.ast_transformer_utils import ReturnStatus
42
- from taichi.lang.exception import (
43
- TaichiCompilationError,
44
- TaichiRuntimeError,
45
- TaichiRuntimeTypeError,
46
- TaichiSyntaxError,
47
- TaichiTypeError,
42
+ from gstaichi.lang.ast.ast_transformer_utils import ReturnStatus
43
+ from gstaichi.lang.exception import (
44
+ GsTaichiCompilationError,
45
+ GsTaichiRuntimeError,
46
+ GsTaichiRuntimeTypeError,
47
+ GsTaichiSyntaxError,
48
+ GsTaichiTypeError,
48
49
  handle_exception_from_cpp,
49
50
  )
50
- from taichi.lang.expr import Expr
51
- from taichi.lang.kernel_arguments import KernelArgument
52
- from taichi.lang.matrix import MatrixType
53
- from taichi.lang.shell import _shell_pop_print
54
- from taichi.lang.struct import StructType
55
- from taichi.lang.util import cook_dtype, has_paddle, has_pytorch, to_taichi_type
56
- from taichi.types import (
51
+ from gstaichi.lang.expr import Expr
52
+ from gstaichi.lang.kernel_arguments import KernelArgument
53
+ from gstaichi.lang.matrix import MatrixType
54
+ from gstaichi.lang.shell import _shell_pop_print
55
+ from gstaichi.lang.struct import StructType
56
+ from gstaichi.lang.util import cook_dtype, has_paddle, has_pytorch
57
+ from gstaichi.types import (
57
58
  ndarray_type,
58
59
  primitive_types,
59
60
  sparse_matrix_builder,
60
61
  template,
61
62
  texture_type,
62
63
  )
63
- from taichi.types.compound_types import CompoundType
64
- from taichi.types.enums import AutodiffMode, Layout
65
- from taichi.types.utils import is_signed
64
+ from gstaichi.types.compound_types import CompoundType
65
+ from gstaichi.types.enums import AutodiffMode, Layout
66
+ from gstaichi.types.utils import is_signed
67
+
68
+ CompiledKernelKeyType = tuple[Callable, int, AutodiffMode]
66
69
 
67
70
 
68
- def func(fn: Callable, is_real_function: bool = False):
69
- """Marks a function as callable in Taichi-scope.
71
+ class GsTaichiCallable:
72
+ """
73
+ BoundGsTaichiCallable is used to enable wrapping a bindable function with a class.
74
+
75
+ Design requirements for GsTaichiCallable:
76
+ - wrap/contain a reference to a class Func instance, and allow (the GsTaichiCallable) being passed around
77
+ like normal function pointer
78
+ - expose attributes of the wrapped class Func, such as `_if_real_function`, `_primal`, etc
79
+ - allow for (now limited) strong typing, and enable type checkers, such as pyright/mypy
80
+ - currently GsTaichiCallable is a shared type used for all functions marked with @ti.func, @ti.kernel,
81
+ python functions (?)
82
+ - note: current type-checking implementation does not distinguish between different type flavors of
83
+ GsTaichiCallable, with different values of `_if_real_function`, `_primal`, etc
84
+ - handle not only class-less functions, but also class-instance methods (where determining the `self`
85
+ reference is a challenge)
86
+
87
+ Let's take the following example:
88
+
89
+ def test_ptr_class_func():
90
+ @ti.data_oriented
91
+ class MyClass:
92
+ def __init__(self):
93
+ self.a = ti.field(dtype=ti.f32, shape=(3))
94
+
95
+ def add2numbers_py(self, x, y):
96
+ return x + y
97
+
98
+ @ti.func
99
+ def add2numbers_func(self, x, y):
100
+ return x + y
101
+
102
+ @ti.kernel
103
+ def func(self):
104
+ a, add_py, add_func = ti.static(self.a, self.add2numbers_py, self.add2numbers_func)
105
+ a[0] = add_py(2, 3)
106
+ a[1] = add_func(3, 7)
107
+
108
+ (taken from test_ptr_assign.py).
109
+
110
+ When the @ti.func decorator is parsed, the function `add2numbers_func` exists, but there is not yet any `self`
111
+ - it is not possible for the method to be bound, to a `self` instance
112
+ - however, the @ti.func annotation, runs the kernel_imp.py::func function --- it is at this point
113
+ that GsTaichi's original code creates a class Func instance (that wraps the add2numbers_func)
114
+ and immediately we create a GsTaichiCallable instance that wraps the Func instance.
115
+ - effectively, we have two layers of wrapping GsTaichiCallable->Func->function pointer
116
+ (actual function definition)
117
+ - later on, when we call self.add2numbers_py, here:
118
+
119
+ a, add_py, add_func = ti.static(self.a, self.add2numbers_py, self.add2numbers_func)
120
+
121
+ ... we want to call the bound method, `self.add2numbers_py`.
122
+ - an actual python function reference, created by doing somevar = MyClass.add2numbers, can automatically
123
+ binds to self, when called from self in this way (however, add2numbers_py is actually a class
124
+ Func instance, wrapping python function reference -- now also all wrapped by a GsTaichiCallable
125
+ instance -- returned by the kernel_impl.py::func function, run by @ti.func)
126
+ - however, in order to be able to add strongly typed attributes to the wrapped python function, we need
127
+ to wrap the wrapped python function in a class
128
+ - the wrapped python function, wrapped in a GsTaichiCallable class (which is callable, and will
129
+ execute the underlying double-wrapped python function), will NOT automatically bind
130
+ - when we invoke GsTaichiCallable, the wrapped function is invoked. The wrapped function is unbound, and
131
+ so `self` is not automatically passed in, as an argument, and things break
132
+
133
+ To address this we need to use the `__get__` method, in our function wrapper, ie GsTaichiCallable,
134
+ and have the `__get__` method return the `BoundGsTaichiCallable` object. The `__get__` method handles
135
+ running the binding for us, and effectively binds `BoundFunc` object to `self` object, by passing
136
+ in the instance, as an argument into `BoundGsTaichiCallable.__init__`.
137
+
138
+ `BoundFunc` can then be used as a normal bound func - even though it's just an object instance -
139
+ using its `__call__` method. Effectively, at the time of actually invoking the underlying python
140
+ function, we have 3 layers of wrapper instances:
141
+ BoundGsTaichiCallabe -> GsTaichiCallable -> Func -> python function reference/definition
142
+ """
143
+
144
+ def __init__(self, fn: Callable, wrapper: Callable) -> None:
145
+ self.fn: Callable = fn
146
+ self.wrapper: Callable = wrapper
147
+ self._is_real_function: bool = False
148
+ self._is_gstaichi_function: bool = False
149
+ self._is_wrapped_kernel: bool = False
150
+ self._is_classkernel: bool = False
151
+ self._primal: Kernel | None = None
152
+ self._adjoint: Kernel | None = None
153
+ self.grad: Kernel | None = None
154
+ self._is_staticmethod: bool = False
155
+ functools.update_wrapper(self, fn)
156
+
157
+ def __call__(self, *args, **kwargs):
158
+ return self.wrapper.__call__(*args, **kwargs)
159
+
160
+ def __get__(self, instance, owner):
161
+ if instance is None:
162
+ return self
163
+ return BoundGsTaichiCallable(instance, self)
164
+
165
+
166
+ class BoundGsTaichiCallable:
167
+ def __init__(self, instance: Any, gstaichi_callable: "GsTaichiCallable"):
168
+ self.wrapper = gstaichi_callable.wrapper
169
+ self.instance = instance
170
+ self.gstaichi_callable = gstaichi_callable
171
+
172
+ def __call__(self, *args, **kwargs):
173
+ return self.wrapper(self.instance, *args, **kwargs)
174
+
175
+ def __getattr__(self, k: str) -> Any:
176
+ res = getattr(self.gstaichi_callable, k)
177
+ return res
178
+
179
+ def __setattr__(self, k: str, v: Any) -> None:
180
+ # Note: these have to match the name of any attributes on this class.
181
+ if k in ("wrapper", "instance", "gstaichi_callable"):
182
+ object.__setattr__(self, k, v)
183
+ else:
184
+ setattr(self.gstaichi_callable, k, v)
70
185
 
71
- This decorator transforms a Python function into a Taichi one. Taichi
186
+
187
+ def func(fn: Callable, is_real_function: bool = False) -> GsTaichiCallable:
188
+ """Marks a function as callable in GsTaichi-scope.
189
+
190
+ This decorator transforms a Python function into a GsTaichi one. GsTaichi
72
191
  will JIT compile it into native instructions.
73
192
 
74
193
  Args:
@@ -91,29 +210,24 @@ def func(fn: Callable, is_real_function: bool = False):
91
210
  is_classfunc = _inside_class(level_of_class_stackframe=3 + is_real_function)
92
211
 
93
212
  fun = Func(fn, _classfunc=is_classfunc, is_real_function=is_real_function)
94
-
95
- @functools.wraps(fn)
96
- def decorated(*args, **kwargs):
97
- return fun.__call__(*args, **kwargs)
98
-
99
- decorated._is_taichi_function = True
100
- decorated._is_real_function = is_real_function
101
- decorated.func = fun
102
- return decorated
213
+ gstaichi_callable = GsTaichiCallable(fn, fun)
214
+ gstaichi_callable._is_gstaichi_function = True
215
+ gstaichi_callable._is_real_function = is_real_function
216
+ return gstaichi_callable
103
217
 
104
218
 
105
- def real_func(fn: Callable):
219
+ def real_func(fn: Callable) -> GsTaichiCallable:
106
220
  return func(fn, is_real_function=True)
107
221
 
108
222
 
109
- def pyfunc(fn: Callable):
110
- """Marks a function as callable in both Taichi and Python scopes.
223
+ def pyfunc(fn: Callable) -> GsTaichiCallable:
224
+ """Marks a function as callable in both GsTaichi and Python scopes.
111
225
 
112
- When called inside the Taichi scope, Taichi will JIT compile it into
226
+ When called inside the GsTaichi scope, GsTaichi will JIT compile it into
113
227
  native instructions. Otherwise it will be invoked directly as a
114
228
  Python function.
115
229
 
116
- See also :func:`~taichi.lang.kernel_impl.func`.
230
+ See also :func:`~gstaichi.lang.kernel_impl.func`.
117
231
 
118
232
  Args:
119
233
  fn (Callable): The Python function to be decorated
@@ -123,33 +237,28 @@ def pyfunc(fn: Callable):
123
237
  """
124
238
  is_classfunc = _inside_class(level_of_class_stackframe=3)
125
239
  fun = Func(fn, _classfunc=is_classfunc, _pyfunc=True)
126
-
127
- @functools.wraps(fn)
128
- def decorated(*args, **kwargs):
129
- return fun.__call__(*args, **kwargs)
130
-
131
- decorated._is_taichi_function = True
132
- decorated._is_real_function = False
133
- decorated.func = fun
134
- return decorated
240
+ gstaichi_callable = GsTaichiCallable(fn, fun)
241
+ gstaichi_callable._is_gstaichi_function = True
242
+ gstaichi_callable._is_real_function = False
243
+ return gstaichi_callable
135
244
 
136
245
 
137
246
  def _get_tree_and_ctx(
138
247
  self: "Func | Kernel",
248
+ args: tuple[Any, ...],
139
249
  excluded_parameters=(),
140
250
  is_kernel: bool = True,
141
251
  arg_features=None,
142
- args=None,
143
252
  ast_builder: ASTBuilder | None = None,
144
253
  is_real_function: bool = False,
145
- ):
254
+ ) -> tuple[ast.Module, ASTTransformerContext]:
146
255
  file = getsourcefile(self.func)
147
256
  src, start_lineno = getsourcelines(self.func)
148
257
  src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
149
258
  tree = ast.parse(textwrap.dedent("\n".join(src)))
150
259
 
151
260
  func_body = tree.body[0]
152
- func_body.decorator_list = []
261
+ func_body.decorator_list = [] # type: ignore , kick that can down the road...
153
262
 
154
263
  global_vars = _get_global_vars(self.func)
155
264
 
@@ -196,17 +305,18 @@ def expand_func_arguments(arguments: list[KernelArgument]) -> list[KernelArgumen
196
305
  return new_arguments
197
306
 
198
307
 
199
- def _process_args(self: "Func | Kernel", is_func: bool, args, kwargs):
308
+ def _process_args(self: "Func | Kernel", is_func: bool, args: tuple[Any, ...], kwargs) -> tuple[Any, ...]:
200
309
  if is_func:
201
310
  self.arguments = expand_func_arguments(self.arguments)
202
311
  fused_args = [argument.default for argument in self.arguments]
312
+ ret: list[Any] = [argument.default for argument in self.arguments]
203
313
  len_args = len(args)
204
314
 
205
315
  if len_args > len(fused_args):
206
316
  arg_str = ", ".join([str(arg) for arg in args])
207
317
  expected_str = ", ".join([f"{arg.name} : {arg.annotation}" for arg in self.arguments])
208
318
  msg = f"Too many arguments. Expected ({expected_str}), got ({arg_str})."
209
- raise TaichiSyntaxError(msg)
319
+ raise GsTaichiSyntaxError(msg)
210
320
 
211
321
  for i, arg in enumerate(args):
212
322
  fused_args[i] = arg
@@ -216,19 +326,19 @@ def _process_args(self: "Func | Kernel", is_func: bool, args, kwargs):
216
326
  for i, arg in enumerate(self.arguments):
217
327
  if key == arg.name:
218
328
  if i < len_args:
219
- raise TaichiSyntaxError(f"Multiple values for argument '{key}'.")
329
+ raise GsTaichiSyntaxError(f"Multiple values for argument '{key}'.")
220
330
  fused_args[i] = value
221
331
  found = True
222
332
  break
223
333
  if not found:
224
- raise TaichiSyntaxError(f"Unexpected argument '{key}'.")
334
+ raise GsTaichiSyntaxError(f"Unexpected argument '{key}'.")
225
335
 
226
336
  for i, arg in enumerate(fused_args):
227
337
  if arg is inspect.Parameter.empty:
228
338
  if self.arguments[i].annotation is inspect._empty:
229
- raise TaichiSyntaxError(f"Parameter `{self.arguments[i].name}` missing.")
339
+ raise GsTaichiSyntaxError(f"Parameter `{self.arguments[i].name}` missing.")
230
340
  else:
231
- raise TaichiSyntaxError(
341
+ raise GsTaichiSyntaxError(
232
342
  f"Parameter `{self.arguments[i].name} : {self.arguments[i].annotation}` missing."
233
343
  )
234
344
 
@@ -237,7 +347,7 @@ def _process_args(self: "Func | Kernel", is_func: bool, args, kwargs):
237
347
 
238
348
  def unpack_ndarray_struct(tree: ast.Module, struct_locals: set[str]) -> ast.Module:
239
349
  class AttributeToNameTransformer(ast.NodeTransformer):
240
- def visit_Attribute(self, node: ast.AST):
350
+ def visit_Attribute(self, node: ast.Attribute):
241
351
  if isinstance(node.value, ast.Attribute):
242
352
  return node
243
353
  if not isinstance(node.value, ast.Name):
@@ -278,7 +388,7 @@ def extract_struct_locals_from_context(ctx: ASTTransformerContext):
278
388
  class Func:
279
389
  function_counter = 0
280
390
 
281
- def __init__(self, _func: Callable, _classfunc=False, _pyfunc=False, is_real_function=False):
391
+ def __init__(self, _func: Callable, _classfunc=False, _pyfunc=False, is_real_function=False) -> None:
282
392
  self.func = _func
283
393
  self.func_id = Func.function_counter
284
394
  Func.function_counter += 1
@@ -294,22 +404,22 @@ class Func:
294
404
  for i, arg in enumerate(self.arguments):
295
405
  if arg.annotation == template or isinstance(arg.annotation, template):
296
406
  self.template_slot_locations.append(i)
297
- self.mapper = TaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
298
- self.taichi_functions = {} # The |Function| class in C++
407
+ self.mapper = GsTaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
408
+ self.gstaichi_functions = {} # The |Function| class in C++
299
409
  self.has_print = False
300
410
 
301
- def __call__(self, *args, **kwargs):
411
+ def __call__(self, *args, **kwargs) -> Any:
302
412
  args = _process_args(self, is_func=True, args=args, kwargs=kwargs)
303
413
 
304
414
  if not impl.inside_kernel():
305
415
  if not self.pyfunc:
306
- raise TaichiSyntaxError("Taichi functions cannot be called from Python-scope.")
416
+ raise GsTaichiSyntaxError("GsTaichi functions cannot be called from Python-scope.")
307
417
  return self.func(*args)
308
418
 
309
419
  current_kernel = impl.get_runtime().current_kernel
310
420
  if self.is_real_function:
311
421
  if current_kernel.autodiff_mode != AutodiffMode.NONE:
312
- raise TaichiSyntaxError("Real function in gradient kernels unsupported.")
422
+ raise GsTaichiSyntaxError("Real function in gradient kernels unsupported.")
313
423
  instance_id, arg_features = self.mapper.lookup(args)
314
424
  key = _ti_core.FunctionKey(self.func.__name__, self.func_id, instance_id)
315
425
  if key.instance_id not in self.compiled:
@@ -328,10 +438,10 @@ class Func:
328
438
  ret = transform_tree(tree, ctx)
329
439
  if not self.is_real_function:
330
440
  if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
331
- raise TaichiSyntaxError("Function has a return type but does not have a return statement")
441
+ raise GsTaichiSyntaxError("Function has a return type but does not have a return statement")
332
442
  return ret
333
443
 
334
- def func_call_rvalue(self, key, args):
444
+ def func_call_rvalue(self, key: FunctionKey, args: tuple[Any, ...]) -> Any:
335
445
  # Skip the template args, e.g., |self|
336
446
  assert self.is_real_function
337
447
  non_template_args = []
@@ -345,7 +455,7 @@ class Func:
345
455
  non_template_args.append(_ti_core.make_reference(args[i].ptr, dbg_info))
346
456
  elif isinstance(anno, ndarray_type.NdarrayType):
347
457
  if not isinstance(args[i], AnyArray):
348
- raise TaichiTypeError(
458
+ raise GsTaichiTypeError(
349
459
  f"Expected ndarray in the kernel argument for argument {kernel_arg.name}, got {args[i]}"
350
460
  )
351
461
  non_template_args += _ti_core.get_external_tensor_real_func_args(args[i].ptr, dbg_info)
@@ -355,7 +465,7 @@ class Func:
355
465
  compiling_callable = impl.get_runtime().compiling_callable
356
466
  assert compiling_callable is not None
357
467
  func_call = compiling_callable.ast_builder().insert_func_call(
358
- self.taichi_functions[key.instance_id], non_template_args, dbg_info
468
+ self.gstaichi_functions[key.instance_id], non_template_args, dbg_info
359
469
  )
360
470
  if self.return_type is None:
361
471
  return None
@@ -372,14 +482,14 @@ class Func:
372
482
  )
373
483
  )
374
484
  elif isinstance(return_type, (StructType, MatrixType)):
375
- ret.append(return_type.from_taichi_object(func_call, (i,)))
485
+ ret.append(return_type.from_gstaichi_object(func_call, (i,)))
376
486
  else:
377
- raise TaichiTypeError(f"Unsupported return type for return value {i}: {return_type}")
487
+ raise GsTaichiTypeError(f"Unsupported return type for return value {i}: {return_type}")
378
488
  if len(ret) == 1:
379
489
  return ret[0]
380
490
  return tuple(ret)
381
491
 
382
- def do_compile(self, key, args, arg_features):
492
+ def do_compile(self, key: FunctionKey, args: tuple[Any, ...], arg_features: tuple[Any, ...]) -> None:
383
493
  tree, ctx = _get_tree_and_ctx(
384
494
  self, is_kernel=False, args=args, arg_features=arg_features, is_real_function=self.is_real_function
385
495
  )
@@ -392,36 +502,42 @@ class Func:
392
502
  transform_tree(tree, ctx)
393
503
  impl.get_runtime().compiling_callable = old_callable
394
504
 
395
- self.taichi_functions[key.instance_id] = fn
505
+ self.gstaichi_functions[key.instance_id] = fn
396
506
  self.compiled[key.instance_id] = func_body
397
- self.taichi_functions[key.instance_id].set_function_body(func_body)
507
+ self.gstaichi_functions[key.instance_id].set_function_body(func_body)
398
508
 
399
509
  def extract_arguments(self) -> None:
400
510
  sig = inspect.signature(self.func)
401
511
  if sig.return_annotation not in (inspect.Signature.empty, None):
402
512
  self.return_type = sig.return_annotation
403
513
  if (
404
- isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias))
405
- and self.return_type.__origin__ is tuple
514
+ isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias)) # type: ignore
515
+ and self.return_type.__origin__ is tuple # type: ignore
406
516
  ):
407
- self.return_type = self.return_type.__args__
517
+ self.return_type = self.return_type.__args__ # type: ignore
518
+ if self.return_type is None:
519
+ return
408
520
  if not isinstance(self.return_type, (list, tuple)):
409
521
  self.return_type = (self.return_type,)
410
522
  for i, return_type in enumerate(self.return_type):
411
523
  if return_type is Ellipsis:
412
- raise TaichiSyntaxError("Ellipsis is not supported in return type annotations")
524
+ raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
413
525
  params = sig.parameters
414
526
  arg_names = params.keys()
415
527
  for i, arg_name in enumerate(arg_names):
416
528
  param = params[arg_name]
417
529
  if param.kind == inspect.Parameter.VAR_KEYWORD:
418
- raise TaichiSyntaxError("Taichi functions do not support variable keyword parameters (i.e., **kwargs)")
530
+ raise GsTaichiSyntaxError(
531
+ "GsTaichi functions do not support variable keyword parameters (i.e., **kwargs)"
532
+ )
419
533
  if param.kind == inspect.Parameter.VAR_POSITIONAL:
420
- raise TaichiSyntaxError("Taichi functions do not support variable positional parameters (i.e., *args)")
534
+ raise GsTaichiSyntaxError(
535
+ "GsTaichi functions do not support variable positional parameters (i.e., *args)"
536
+ )
421
537
  if param.kind == inspect.Parameter.KEYWORD_ONLY:
422
- raise TaichiSyntaxError("Taichi functions do not support keyword parameters")
538
+ raise GsTaichiSyntaxError("GsTaichi functions do not support keyword parameters")
423
539
  if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
424
- raise TaichiSyntaxError('Taichi functions only support "positional or keyword" parameters')
540
+ raise GsTaichiSyntaxError('GsTaichi functions only support "positional or keyword" parameters')
425
541
  annotation = param.annotation
426
542
  if annotation is inspect.Parameter.empty:
427
543
  if i == 0 and self.classfunc:
@@ -429,8 +545,8 @@ class Func:
429
545
  # TODO: pyfunc also need type annotation check when real function is enabled,
430
546
  # but that has to happen at runtime when we know which scope it's called from.
431
547
  elif not self.pyfunc and self.is_real_function:
432
- raise TaichiSyntaxError(
433
- f"Taichi function `{self.func.__name__}` parameter `{arg_name}` must be type annotated"
548
+ raise GsTaichiSyntaxError(
549
+ f"GsTaichi function `{self.func.__name__}` parameter `{arg_name}` must be type annotated"
434
550
  )
435
551
  else:
436
552
  if isinstance(annotation, ndarray_type.NdarrayType):
@@ -441,198 +557,24 @@ class Func:
441
557
  pass
442
558
  elif id(annotation) in primitive_types.type_ids:
443
559
  pass
444
- elif type(annotation) == taichi.types.annotations.Template:
560
+ elif type(annotation) == gstaichi.types.annotations.Template:
445
561
  pass
446
- elif isinstance(annotation, template) or annotation == taichi.types.annotations.Template:
562
+ elif isinstance(annotation, template) or annotation == gstaichi.types.annotations.Template:
447
563
  pass
448
564
  elif isinstance(annotation, primitive_types.RefType):
449
565
  pass
450
566
  elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
451
567
  pass
452
568
  else:
453
- raise TaichiSyntaxError(f"Invalid type annotation (argument {i}) of Taichi function: {annotation}")
569
+ raise GsTaichiSyntaxError(
570
+ f"Invalid type annotation (argument {i}) of GsTaichi function: {annotation}"
571
+ )
454
572
  self.arguments.append(KernelArgument(annotation, param.name, param.default))
455
573
  self.orig_arguments.append(KernelArgument(annotation, param.name, param.default))
456
574
 
457
575
 
458
- AnnotationType = Union[
459
- template,
460
- ArgPackType,
461
- "texture_type.TextureType",
462
- "texture_type.RWTextureType",
463
- ndarray_type.NdarrayType,
464
- sparse_matrix_builder,
465
- Any,
466
- ]
467
-
468
-
469
- class TaichiCallableTemplateMapper:
470
- """
471
- This should probably be renamed to sometihng like FeatureMapper, or
472
- FeatureExtractor, since:
473
- - it's not specific to templates
474
- - it extracts what are later called 'features', for example for ndarray this includes:
475
- - element type
476
- - number dimensions
477
- - needs grad (or not)
478
- - these are returned as a heterogeneous tuple, whose contents depends on the type
479
- """
480
-
481
- def __init__(self, arguments: list[KernelArgument], template_slot_locations: list[int]) -> None:
482
- self.arguments = arguments
483
- self.num_args = len(arguments)
484
- self.template_slot_locations = template_slot_locations
485
- self.mapping = {}
486
-
487
- @staticmethod
488
- def extract_arg(arg, annotation: AnnotationType, arg_name: str):
489
- if annotation == template or isinstance(annotation, template):
490
- if isinstance(arg, taichi.lang.snode.SNode):
491
- return arg.ptr
492
- if isinstance(arg, taichi.lang.expr.Expr):
493
- return arg.ptr.get_underlying_ptr_address()
494
- if isinstance(arg, _ti_core.Expr):
495
- return arg.get_underlying_ptr_address()
496
- if isinstance(arg, tuple):
497
- return tuple(TaichiCallableTemplateMapper.extract_arg(item, annotation, arg_name) for item in arg)
498
- if isinstance(arg, taichi.lang._ndarray.Ndarray):
499
- raise TaichiRuntimeTypeError(
500
- "Ndarray shouldn't be passed in via `ti.template()`, please annotate your kernel using `ti.types.ndarray(...)` instead"
501
- )
502
-
503
- if isinstance(arg, (list, tuple, dict, set)) or hasattr(arg, "_data_oriented"):
504
- # [Composite arguments] Return weak reference to the object
505
- # Taichi kernel will cache the extracted arguments, thus we can't simply return the original argument.
506
- # Instead, a weak reference to the original value is returned to avoid memory leak.
507
-
508
- # TODO(zhanlue): replacing "tuple(args)" with "hash of argument values"
509
- # This can resolve the following issues:
510
- # 1. Invalid weak-ref will leave a dead(dangling) entry in both caches: "self.mapping" and "self.compiled_functions"
511
- # 2. Different argument instances with same type and same value, will get templatized into seperate kernels.
512
- return weakref.ref(arg)
513
-
514
- # [Primitive arguments] Return the value
515
- return arg
516
- if isinstance(annotation, ArgPackType):
517
- if not isinstance(arg, ArgPack):
518
- raise TaichiRuntimeTypeError(f"Argument {arg_name} must be a argument pack, got {type(arg)}")
519
- return tuple(
520
- TaichiCallableTemplateMapper.extract_arg(arg[name], dtype, arg_name)
521
- for index, (name, dtype) in enumerate(annotation.members.items())
522
- )
523
- if dataclasses.is_dataclass(annotation):
524
- _res_l = []
525
- for field in dataclasses.fields(annotation):
526
- field_value = getattr(arg, field.name)
527
- arg_name = f"__ti_{arg_name}_{field.name}"
528
- field_extracted = TaichiCallableTemplateMapper.extract_arg(field_value, field.type, arg_name)
529
- _res_l.append(field_extracted)
530
- return tuple(_res_l)
531
- if isinstance(annotation, texture_type.TextureType):
532
- if not isinstance(arg, taichi.lang._texture.Texture):
533
- raise TaichiRuntimeTypeError(f"Argument {arg_name} must be a texture, got {type(arg)}")
534
- if arg.num_dims != annotation.num_dimensions:
535
- raise TaichiRuntimeTypeError(
536
- f"TextureType dimension mismatch for argument {arg_name}: expected {annotation.num_dimensions}, got {arg.num_dims}"
537
- )
538
- return (arg.num_dims,)
539
- if isinstance(annotation, texture_type.RWTextureType):
540
- if not isinstance(arg, taichi.lang._texture.Texture):
541
- raise TaichiRuntimeTypeError(f"Argument {arg_name} must be a texture, got {type(arg)}")
542
- if arg.num_dims != annotation.num_dimensions:
543
- raise TaichiRuntimeTypeError(
544
- f"RWTextureType dimension mismatch for argument {arg_name}: expected {annotation.num_dimensions}, got {arg.num_dims}"
545
- )
546
- if arg.fmt != annotation.fmt:
547
- raise TaichiRuntimeTypeError(
548
- f"RWTextureType format mismatch for argument {arg_name}: expected {annotation.fmt}, got {arg.fmt}"
549
- )
550
- # (penguinliong) '0' is the assumed LOD level. We currently don't
551
- # support mip-mapping.
552
- return arg.num_dims, arg.fmt, 0
553
- if isinstance(annotation, ndarray_type.NdarrayType):
554
- if isinstance(arg, taichi.lang._ndarray.Ndarray):
555
- annotation.check_matched(arg.get_type(), arg_name)
556
- needs_grad = (arg.grad is not None) if annotation.needs_grad is None else annotation.needs_grad
557
- assert arg.shape is not None
558
- return arg.element_type, len(arg.shape), needs_grad, annotation.boundary
559
- if isinstance(arg, AnyArray):
560
- ty = arg.get_type()
561
- annotation.check_matched(arg.get_type(), arg_name)
562
- return ty.element_type, len(arg.shape), ty.needs_grad, annotation.boundary
563
- # external arrays
564
- shape = getattr(arg, "shape", None)
565
- if shape is None:
566
- raise TaichiRuntimeTypeError(f"Invalid type for argument {arg_name}, got {arg}")
567
- shape = tuple(shape)
568
- element_shape: tuple[int, ...] = ()
569
- dtype = to_taichi_type(arg.dtype)
570
- if isinstance(annotation.dtype, MatrixType):
571
- if annotation.ndim is not None:
572
- if len(shape) != annotation.dtype.ndim + annotation.ndim:
573
- raise ValueError(
574
- f"Invalid value for argument {arg_name} - required array has ndim={annotation.ndim} element_dim={annotation.dtype.ndim}, "
575
- f"array with {len(shape)} dimensions is provided"
576
- )
577
- else:
578
- if len(shape) < annotation.dtype.ndim:
579
- raise ValueError(
580
- f"Invalid value for argument {arg_name} - required element_dim={annotation.dtype.ndim}, "
581
- f"array with {len(shape)} dimensions is provided"
582
- )
583
- element_shape = shape[-annotation.dtype.ndim :]
584
- anno_element_shape = annotation.dtype.get_shape()
585
- if None not in anno_element_shape and element_shape != anno_element_shape:
586
- raise ValueError(
587
- f"Invalid value for argument {arg_name} - required element_shape={anno_element_shape}, "
588
- f"array with element shape of {element_shape} is provided"
589
- )
590
- elif annotation.dtype is not None:
591
- # User specified scalar dtype
592
- if annotation.dtype != dtype:
593
- raise ValueError(
594
- f"Invalid value for argument {arg_name} - required array has dtype={annotation.dtype.to_string()}, "
595
- f"array with dtype={dtype.to_string()} is provided"
596
- )
597
-
598
- if annotation.ndim is not None and len(shape) != annotation.ndim:
599
- raise ValueError(
600
- f"Invalid value for argument {arg_name} - required array has ndim={annotation.ndim}, "
601
- f"array with {len(shape)} dimensions is provided"
602
- )
603
- needs_grad = (
604
- getattr(arg, "requires_grad", False) if annotation.needs_grad is None else annotation.needs_grad
605
- )
606
- element_type = (
607
- _ti_core.get_type_factory_instance().get_tensor_type(element_shape, dtype)
608
- if len(element_shape) != 0
609
- else arg.dtype
610
- )
611
- return element_type, len(shape) - len(element_shape), needs_grad, annotation.boundary
612
- if isinstance(annotation, sparse_matrix_builder):
613
- return arg.dtype
614
- # Use '#' as a placeholder because other kinds of arguments are not involved in template instantiation
615
- return "#"
616
-
617
- def extract(self, args):
618
- extracted = []
619
- for arg, kernel_arg in zip(args, self.arguments):
620
- extracted.append(self.extract_arg(arg, kernel_arg.annotation, kernel_arg.name))
621
- return tuple(extracted)
622
-
623
- def lookup(self, args):
624
- if len(args) != self.num_args:
625
- raise TypeError(f"{self.num_args} argument(s) needed but {len(args)} provided.")
626
-
627
- key = self.extract(args)
628
- if key not in self.mapping:
629
- count = len(self.mapping)
630
- self.mapping[key] = count
631
- return self.mapping[key], key
632
-
633
-
634
- def _get_global_vars(_func):
635
- # Discussions: https://github.com/taichi-dev/taichi/issues/282
576
+ def _get_global_vars(_func: Callable) -> dict[str, Any]:
577
+ # Discussions: https://github.com/taichi-dev/gstaichi/issues/282
636
578
  global_vars = _func.__globals__.copy()
637
579
 
638
580
  freevar_names = _func.__code__.co_freevars
@@ -648,7 +590,7 @@ def _get_global_vars(_func):
648
590
  class Kernel:
649
591
  counter = 0
650
592
 
651
- def __init__(self, _func: Callable, autodiff_mode, _classkernel=False):
593
+ def __init__(self, _func: Callable, autodiff_mode: AutodiffMode, _classkernel=False) -> None:
652
594
  self.func = _func
653
595
  self.kernel_counter = Kernel.counter
654
596
  Kernel.counter += 1
@@ -668,27 +610,27 @@ class Kernel:
668
610
  for i, arg in enumerate(self.arguments):
669
611
  if arg.annotation == template or isinstance(arg.annotation, template):
670
612
  self.template_slot_locations.append(i)
671
- self.mapper = TaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
613
+ self.mapper = GsTaichiCallableTemplateMapper(self.arguments, self.template_slot_locations)
672
614
  impl.get_runtime().kernels.append(self)
673
615
  self.reset()
674
616
  self.kernel_cpp = None
675
- self.compiled_kernels = {}
617
+ self.compiled_kernels: dict[CompiledKernelKeyType, KernelCxx] = {}
676
618
  self.has_print = False
677
619
 
678
620
  def ast_builder(self) -> ASTBuilder:
679
621
  assert self.kernel_cpp is not None
680
622
  return self.kernel_cpp.ast_builder()
681
623
 
682
- def reset(self):
624
+ def reset(self) -> None:
683
625
  self.runtime = impl.get_runtime()
684
626
  self.compiled_kernels = {}
685
627
 
686
- def extract_arguments(self):
628
+ def extract_arguments(self) -> None:
687
629
  sig = inspect.signature(self.func)
688
630
  if sig.return_annotation not in (inspect._empty, None):
689
631
  self.return_type = sig.return_annotation
690
632
  if (
691
- isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias))
633
+ isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias)) # type: ignore
692
634
  and self.return_type.__origin__ is tuple
693
635
  ):
694
636
  self.return_type = self.return_type.__args__
@@ -696,27 +638,31 @@ class Kernel:
696
638
  self.return_type = (self.return_type,)
697
639
  for return_type in self.return_type:
698
640
  if return_type is Ellipsis:
699
- raise TaichiSyntaxError("Ellipsis is not supported in return type annotations")
641
+ raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
700
642
  params = sig.parameters
701
643
  arg_names = params.keys()
702
644
  for i, arg_name in enumerate(arg_names):
703
645
  param = params[arg_name]
704
646
  if param.kind == inspect.Parameter.VAR_KEYWORD:
705
- raise TaichiSyntaxError("Taichi kernels do not support variable keyword parameters (i.e., **kwargs)")
647
+ raise GsTaichiSyntaxError(
648
+ "GsTaichi kernels do not support variable keyword parameters (i.e., **kwargs)"
649
+ )
706
650
  if param.kind == inspect.Parameter.VAR_POSITIONAL:
707
- raise TaichiSyntaxError("Taichi kernels do not support variable positional parameters (i.e., *args)")
651
+ raise GsTaichiSyntaxError(
652
+ "GsTaichi kernels do not support variable positional parameters (i.e., *args)"
653
+ )
708
654
  if param.default is not inspect.Parameter.empty:
709
- raise TaichiSyntaxError("Taichi kernels do not support default values for arguments")
655
+ raise GsTaichiSyntaxError("GsTaichi kernels do not support default values for arguments")
710
656
  if param.kind == inspect.Parameter.KEYWORD_ONLY:
711
- raise TaichiSyntaxError("Taichi kernels do not support keyword parameters")
657
+ raise GsTaichiSyntaxError("GsTaichi kernels do not support keyword parameters")
712
658
  if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
713
- raise TaichiSyntaxError('Taichi kernels only support "positional or keyword" parameters')
659
+ raise GsTaichiSyntaxError('GsTaichi kernels only support "positional or keyword" parameters')
714
660
  annotation = param.annotation
715
661
  if param.annotation is inspect.Parameter.empty:
716
662
  if i == 0 and self.classkernel: # The |self| parameter
717
663
  annotation = template()
718
664
  else:
719
- raise TaichiSyntaxError("Taichi kernels parameters must be type annotated")
665
+ raise GsTaichiSyntaxError("GsTaichi kernels parameters must be type annotated")
720
666
  else:
721
667
  if isinstance(
722
668
  annotation,
@@ -743,10 +689,12 @@ class Kernel:
743
689
  elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
744
690
  pass
745
691
  else:
746
- raise TaichiSyntaxError(f"Invalid type annotation (argument {i}) of Taichi kernel: {annotation}")
692
+ raise GsTaichiSyntaxError(
693
+ f"Invalid type annotation (argument {i}) of GsTaichi kernel: {annotation}"
694
+ )
747
695
  self.arguments.append(KernelArgument(annotation, param.name, param.default))
748
696
 
749
- def materialize(self, key, args: list[Any], arg_features):
697
+ def materialize(self, key: CompiledKernelKeyType | None, args: tuple[Any, ...], arg_features):
750
698
  if key is None:
751
699
  key = (self.func, 0, self.autodiff_mode)
752
700
  self.runtime.materialize()
@@ -767,15 +715,15 @@ class Kernel:
767
715
  if self.autodiff_mode != AutodiffMode.NONE:
768
716
  KernelSimplicityASTChecker(self.func).visit(tree)
769
717
 
770
- # Do not change the name of 'taichi_ast_generator'
718
+ # Do not change the name of 'gstaichi_ast_generator'
771
719
  # The warning system needs this identifier to remove unnecessary messages
772
- def taichi_ast_generator(kernel_cxx: Kernel): # not sure if this type is correct, seems doubtful
720
+ def gstaichi_ast_generator(kernel_cxx: Kernel): # not sure if this type is correct, seems doubtful
773
721
  nonlocal tree
774
722
  if self.runtime.inside_kernel:
775
- raise TaichiSyntaxError(
723
+ raise GsTaichiSyntaxError(
776
724
  "Kernels cannot call other kernels. I.e., nested kernels are not allowed. "
777
725
  "Please check if you have direct/indirect invocation of kernels within kernels. "
778
- "Note that some methods provided by the Taichi standard library may invoke kernels, "
726
+ "Note that some methods provided by the GsTaichi standard library may invoke kernels, "
779
727
  "and please move their invocations to Python-scope."
780
728
  )
781
729
  self.kernel_cpp = kernel_cxx
@@ -786,7 +734,7 @@ class Kernel:
786
734
  try:
787
735
  ctx.ast_builder = kernel_cxx.ast_builder()
788
736
 
789
- def ast_to_dict(node):
737
+ def ast_to_dict(node: ast.AST | list | primitive_types._python_primitive_types):
790
738
  if isinstance(node, ast.AST):
791
739
  fields = {k: ast_to_dict(v) for k, v in ast.iter_fields(node)}
792
740
  return {
@@ -824,17 +772,17 @@ class Kernel:
824
772
  transform_tree(tree, ctx)
825
773
  if not ctx.is_real_function:
826
774
  if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
827
- raise TaichiSyntaxError("Kernel has a return type but does not have a return statement")
775
+ raise GsTaichiSyntaxError("Kernel has a return type but does not have a return statement")
828
776
  finally:
829
777
  self.runtime.inside_kernel = False
830
778
  self.runtime._current_kernel = None
831
779
  self.runtime.compiling_callable = None
832
780
 
833
- taichi_kernel = impl.get_runtime().prog.create_kernel(taichi_ast_generator, kernel_name, self.autodiff_mode)
781
+ gstaichi_kernel = impl.get_runtime().prog.create_kernel(gstaichi_ast_generator, kernel_name, self.autodiff_mode)
834
782
  assert key not in self.compiled_kernels
835
- self.compiled_kernels[key] = taichi_kernel
783
+ self.compiled_kernels[key] = gstaichi_kernel
836
784
 
837
- def launch_kernel(self, t_kernel, *args):
785
+ def launch_kernel(self, t_kernel: KernelCxx, *args) -> Any:
838
786
  assert len(args) == len(self.arguments), f"{len(self.arguments)} arguments needed but {len(args)} provided"
839
787
 
840
788
  tmps = []
@@ -842,25 +790,28 @@ class Kernel:
842
790
 
843
791
  actual_argument_slot = 0
844
792
  launch_ctx = t_kernel.make_launch_context()
845
- max_arg_num = 64
793
+ max_arg_num = 512
846
794
  exceed_max_arg_num = False
847
795
 
848
- def set_arg_ndarray(indices, v):
796
+ def set_arg_ndarray(indices: tuple[int, ...], v: gstaichi.lang._ndarray.Ndarray) -> None:
849
797
  v_primal = v.arr
850
798
  v_grad = v.grad.arr if v.grad else None
851
799
  if v_grad is None:
852
- launch_ctx.set_arg_ndarray(indices, v_primal)
800
+ launch_ctx.set_arg_ndarray(indices, v_primal) # type: ignore , solvable probably, just not today
853
801
  else:
854
- launch_ctx.set_arg_ndarray_with_grad(indices, v_primal, v_grad)
802
+ launch_ctx.set_arg_ndarray_with_grad(indices, v_primal, v_grad) # type: ignore
855
803
 
856
- def set_arg_texture(indices, v):
804
+ def set_arg_texture(indices: tuple[int, ...], v: gstaichi.lang._texture.Texture) -> None:
857
805
  launch_ctx.set_arg_texture(indices, v.tex)
858
806
 
859
- def set_arg_rw_texture(indices, v):
807
+ def set_arg_rw_texture(indices: tuple[int, ...], v: gstaichi.lang._texture.Texture) -> None:
860
808
  launch_ctx.set_arg_rw_texture(indices, v.tex)
861
809
 
862
- def set_arg_ext_array(indices, v, needed):
863
- # Element shapes are already specialized in Taichi codegen.
810
+ def set_arg_ext_array(indices: tuple[int, ...], v: Any, needed: ndarray_type.NdarrayType) -> None:
811
+ # v is things like torch Tensor and numpy array
812
+ # Not adding type for this, since adds additional dependencies
813
+ #
814
+ # Element shapes are already specialized in GsTaichi codegen.
864
815
  # The shape information for element dims are no longer needed.
865
816
  # Therefore we strip the element shapes from the shape vector,
866
817
  # so that it only holds "real" array shapes.
@@ -893,7 +844,7 @@ class Kernel:
893
844
  else:
894
845
  raise ValueError(
895
846
  "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) "
896
- "before passing it into taichi kernel."
847
+ "before passing it into gstaichi kernel."
897
848
  )
898
849
  elif has_pytorch():
899
850
  import torch # pylint: disable=C0415
@@ -902,9 +853,9 @@ class Kernel:
902
853
  if not v.is_contiguous():
903
854
  raise ValueError(
904
855
  "Non contiguous tensors are not supported, please call tensor.contiguous() before "
905
- "passing it into taichi kernel."
856
+ "passing it into gstaichi kernel."
906
857
  )
907
- taichi_arch = self.runtime.prog.config().arch
858
+ gstaichi_arch = self.runtime.prog.config().arch
908
859
 
909
860
  def get_call_back(u, v):
910
861
  def call_back():
@@ -923,14 +874,14 @@ class Kernel:
923
874
  )
924
875
  if not v.grad.is_contiguous():
925
876
  raise ValueError(
926
- "Non contiguous gradient tensors are not supported, please call tensor.grad.contiguous() before passing it into taichi kernel."
877
+ "Non contiguous gradient tensors are not supported, please call tensor.grad.contiguous() before passing it into gstaichi kernel."
927
878
  )
928
879
 
929
880
  tmp = v
930
881
  if (str(v.device) != "cpu") and not (
931
- str(v.device).startswith("cuda") and taichi_arch == _ti_core.Arch.cuda
882
+ str(v.device).startswith("cuda") and gstaichi_arch == _ti_core.Arch.cuda
932
883
  ):
933
- # Getting a torch CUDA tensor on Taichi non-cuda arch:
884
+ # Getting a torch CUDA tensor on GsTaichi non-cuda arch:
934
885
  # We just replace it with a CPU tensor and by the end of kernel execution we'll use the
935
886
  # callback to copy the values back to the original CUDA tensor.
936
887
  host_v = v.to(device="cpu", copy=True)
@@ -945,8 +896,12 @@ class Kernel:
945
896
  int(v.grad.data_ptr()) if v.grad is not None else 0,
946
897
  )
947
898
  else:
948
- raise TaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {type(v)}")
899
+ raise GsTaichiRuntimeTypeError(
900
+ f"Argument {needed} cannot be converted into required type {type(v)}"
901
+ )
949
902
  elif has_paddle():
903
+ # Do we want to continue to support paddle? :thinking_face:
904
+ # #maybeprunable
950
905
  import paddle # pylint: disable=C0415 # type: ignore
951
906
 
952
907
  if isinstance(v, paddle.Tensor):
@@ -958,41 +913,41 @@ class Kernel:
958
913
  return call_back
959
914
 
960
915
  tmp = v.value().get_tensor()
961
- taichi_arch = self.runtime.prog.config().arch
916
+ gstaichi_arch = self.runtime.prog.config().arch
962
917
  if v.place.is_gpu_place():
963
- if taichi_arch != _ti_core.Arch.cuda:
964
- # Paddle cuda tensor on Taichi non-cuda arch
918
+ if gstaichi_arch != _ti_core.Arch.cuda:
919
+ # Paddle cuda tensor on GsTaichi non-cuda arch
965
920
  host_v = v.cpu()
966
921
  tmp = host_v.value().get_tensor()
967
922
  callbacks.append(get_call_back(v, host_v))
968
923
  elif v.place.is_cpu_place():
969
- if taichi_arch == _ti_core.Arch.cuda:
970
- # Paddle cpu tensor on Taichi cuda arch
924
+ if gstaichi_arch == _ti_core.Arch.cuda:
925
+ # Paddle cpu tensor on GsTaichi cuda arch
971
926
  gpu_v = v.cuda()
972
927
  tmp = gpu_v.value().get_tensor()
973
928
  callbacks.append(get_call_back(v, gpu_v))
974
929
  else:
975
930
  # Paddle do support many other backends like XPU, NPU, MLU, IPU
976
- raise TaichiRuntimeTypeError(f"Taichi do not support backend {v.place} that Paddle support")
931
+ raise GsTaichiRuntimeTypeError(f"GsTaichi do not support backend {v.place} that Paddle support")
977
932
  launch_ctx.set_arg_external_array_with_shape(
978
933
  indices, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0
979
934
  )
980
935
  else:
981
- raise TaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
936
+ raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
982
937
  else:
983
- raise TaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
938
+ raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
984
939
 
985
- def set_arg_matrix(indices, v, needed):
986
- def cast_float(x):
940
+ def set_arg_matrix(indices: tuple[int, ...], v, needed) -> None:
941
+ def cast_float(x: float | np.floating | np.integer | int) -> float:
987
942
  if not isinstance(x, (int, float, np.integer, np.floating)):
988
- raise TaichiRuntimeTypeError(
943
+ raise GsTaichiRuntimeTypeError(
989
944
  f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
990
945
  )
991
946
  return float(x)
992
947
 
993
- def cast_int(x):
948
+ def cast_int(x: int | np.integer) -> int:
994
949
  if not isinstance(x, (int, np.integer)):
995
- raise TaichiRuntimeTypeError(
950
+ raise GsTaichiRuntimeTypeError(
996
951
  f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
997
952
  )
998
953
  return int(x)
@@ -1012,13 +967,13 @@ class Kernel:
1012
967
  v = needed(*v)
1013
968
  needed.set_kernel_struct_args(v, launch_ctx, indices)
1014
969
 
1015
- def set_arg_sparse_matrix_builder(indices, v):
970
+ def set_arg_sparse_matrix_builder(indices: tuple[int, ...], v) -> None:
1016
971
  # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument
1017
972
  launch_ctx.set_arg_uint(indices, v._get_ndarray_addr())
1018
973
 
1019
974
  set_later_list = []
1020
975
 
1021
- def recursive_set_args(needed_arg_type, provided_arg_type, v, indices):
976
+ def recursive_set_args(needed_arg_type: Type, provided_arg_type: Type, v: Any, indices: tuple[int, ...]) -> int:
1022
977
  """
1023
978
  Returns the number of kernel args set
1024
979
  e.g. templates don't set kernel args, so returns 0
@@ -1033,7 +988,7 @@ class Kernel:
1033
988
  actual_argument_slot += 1
1034
989
  if isinstance(needed_arg_type, ArgPackType):
1035
990
  if not isinstance(v, ArgPack):
1036
- raise TaichiRuntimeTypeError.get(indices, str(needed_arg_type), str(provided_arg_type))
991
+ raise GsTaichiRuntimeTypeError.get(indices, str(needed_arg_type), str(provided_arg_type))
1037
992
  idx_new = 0
1038
993
  for j, (name, anno) in enumerate(needed_arg_type.members.items()):
1039
994
  idx_new += recursive_set_args(anno, type(v[name]), v[name], indices + (idx_new,))
@@ -1042,14 +997,14 @@ class Kernel:
1042
997
  # Note: do not use sth like "needed == f32". That would be slow.
1043
998
  if id(needed_arg_type) in primitive_types.real_type_ids:
1044
999
  if not isinstance(v, (float, int, np.floating, np.integer)):
1045
- raise TaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
1000
+ raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
1046
1001
  if in_argpack:
1047
1002
  return 1
1048
1003
  launch_ctx.set_arg_float(indices, float(v))
1049
1004
  return 1
1050
1005
  if id(needed_arg_type) in primitive_types.integer_type_ids:
1051
1006
  if not isinstance(v, (int, np.integer)):
1052
- raise TaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
1007
+ raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
1053
1008
  if in_argpack:
1054
1009
  return 1
1055
1010
  if is_signed(cook_dtype(needed_arg_type)):
@@ -1071,19 +1026,21 @@ class Kernel:
1071
1026
  field_value = getattr(v, field.name)
1072
1027
  idx += recursive_set_args(field.type, field.type, field_value, (indices[0] + idx,))
1073
1028
  return idx
1074
- if isinstance(needed_arg_type, ndarray_type.NdarrayType) and isinstance(v, taichi.lang._ndarray.Ndarray):
1029
+ if isinstance(needed_arg_type, ndarray_type.NdarrayType) and isinstance(v, gstaichi.lang._ndarray.Ndarray):
1075
1030
  if in_argpack:
1076
1031
  set_later_list.append((set_arg_ndarray, (v,)))
1077
1032
  return 0
1078
1033
  set_arg_ndarray(indices, v)
1079
1034
  return 1
1080
- if isinstance(needed_arg_type, texture_type.TextureType) and isinstance(v, taichi.lang._texture.Texture):
1035
+ if isinstance(needed_arg_type, texture_type.TextureType) and isinstance(v, gstaichi.lang._texture.Texture):
1081
1036
  if in_argpack:
1082
1037
  set_later_list.append((set_arg_texture, (v,)))
1083
1038
  return 0
1084
1039
  set_arg_texture(indices, v)
1085
1040
  return 1
1086
- if isinstance(needed_arg_type, texture_type.RWTextureType) and isinstance(v, taichi.lang._texture.Texture):
1041
+ if isinstance(needed_arg_type, texture_type.RWTextureType) and isinstance(
1042
+ v, gstaichi.lang._texture.Texture
1043
+ ):
1087
1044
  if in_argpack:
1088
1045
  set_later_list.append((set_arg_rw_texture, (v,)))
1089
1046
  return 0
@@ -1103,8 +1060,12 @@ class Kernel:
1103
1060
  if isinstance(needed_arg_type, StructType):
1104
1061
  if in_argpack:
1105
1062
  return 1
1106
- if not isinstance(v, needed_arg_type):
1107
- raise TaichiRuntimeTypeError(
1063
+ # Unclear how to make the following pass typing checks
1064
+ # StructType implements __instancecheck__, which should be a classmethod, but
1065
+ # is currently an instance method
1066
+ # TODO: look into this more deeply at some point
1067
+ if not isinstance(v, needed_arg_type): # type: ignore
1068
+ raise GsTaichiRuntimeTypeError(
1108
1069
  f"Argument {provided_arg_type} cannot be converted into required type {needed_arg_type}"
1109
1070
  )
1110
1071
  needed_arg_type.set_kernel_struct_args(v, launch_ctx, indices)
@@ -1127,7 +1088,7 @@ class Kernel:
1127
1088
  set_arg_func((len(args) - template_num + i,), *params)
1128
1089
 
1129
1090
  if exceed_max_arg_num:
1130
- raise TaichiRuntimeError(
1091
+ raise GsTaichiRuntimeError(
1131
1092
  f"The number of elements in kernel arguments is too big! Do not exceed {max_arg_num} on {_ti_core.arch_name(impl.current_cfg().arch)} backend."
1132
1093
  )
1133
1094
 
@@ -1162,7 +1123,7 @@ class Kernel:
1162
1123
 
1163
1124
  return ret
1164
1125
 
1165
- def construct_kernel_ret(self, launch_ctx, ret_type, index=()):
1126
+ def construct_kernel_ret(self, launch_ctx: KernelLaunchContext, ret_type: Any, index: tuple[int, ...] = ()):
1166
1127
  if isinstance(ret_type, CompoundType):
1167
1128
  return ret_type.from_kernel_struct_ret(launch_ctx, index)
1168
1129
  if ret_type in primitive_types.integer_types:
@@ -1171,9 +1132,9 @@ class Kernel:
1171
1132
  return launch_ctx.get_struct_ret_uint(index)
1172
1133
  if ret_type in primitive_types.real_types:
1173
1134
  return launch_ctx.get_struct_ret_float(index)
1174
- raise TaichiRuntimeTypeError(f"Invalid return type on index={index}")
1135
+ raise GsTaichiRuntimeTypeError(f"Invalid return type on index={index}")
1175
1136
 
1176
- def ensure_compiled(self, *args):
1137
+ def ensure_compiled(self, *args: tuple[Any, ...]) -> tuple[Callable, int, AutodiffMode]:
1177
1138
  instance_id, arg_features = self.mapper.lookup(args)
1178
1139
  key = (self.func, instance_id, self.autodiff_mode)
1179
1140
  self.materialize(key=key, args=args, arg_features=arg_features)
@@ -1182,7 +1143,7 @@ class Kernel:
1182
1143
  # For small kernels (< 3us), the performance can be pretty sensitive to overhead in __call__
1183
1144
  # Thus this part needs to be fast. (i.e. < 3us on a 4 GHz x64 CPU)
1184
1145
  @_shell_pop_print
1185
- def __call__(self, *args, **kwargs):
1146
+ def __call__(self, *args, **kwargs) -> Any:
1186
1147
  args = _process_args(self, is_func=False, args=args, kwargs=kwargs)
1187
1148
 
1188
1149
  # Transform the primal kernel to forward mode grad kernel
@@ -1213,7 +1174,7 @@ class Kernel:
1213
1174
  return self.launch_kernel(kernel_cpp, *args)
1214
1175
 
1215
1176
 
1216
- # For a Taichi class definition like below:
1177
+ # For a GsTaichi class definition like below:
1217
1178
  #
1218
1179
  # @ti.data_oriented
1219
1180
  # class X:
@@ -1232,7 +1193,7 @@ _KERNEL_CLASS_STACKFRAME_STMT_RES = [
1232
1193
  ]
1233
1194
 
1234
1195
 
1235
- def _inside_class(level_of_class_stackframe):
1196
+ def _inside_class(level_of_class_stackframe: int) -> bool:
1236
1197
  try:
1237
1198
  maybe_class_frame = sys._getframe(level_of_class_stackframe)
1238
1199
  statement_list = inspect.getframeinfo(maybe_class_frame)[3]
@@ -1247,7 +1208,7 @@ def _inside_class(level_of_class_stackframe):
1247
1208
  return False
1248
1209
 
1249
1210
 
1250
- def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool = False):
1211
+ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool = False) -> GsTaichiCallable:
1251
1212
  # Can decorators determine if a function is being defined inside a class?
1252
1213
  # https://stackoverflow.com/a/8793684/12003165
1253
1214
  is_classkernel = _inside_class(level_of_class_stackframe + 1)
@@ -1259,6 +1220,7 @@ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool
1259
1220
  # Having |primal| contains |grad| makes the tape work.
1260
1221
  primal.grad = adjoint
1261
1222
 
1223
+ wrapped: GsTaichiCallable
1262
1224
  if is_classkernel:
1263
1225
  # For class kernels, their primal/adjoint callables are constructed
1264
1226
  # when the kernel is accessed via the instance inside
@@ -1268,24 +1230,26 @@ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool
1268
1230
  #
1269
1231
  # See also: _BoundedDifferentiableMethod, data_oriented.
1270
1232
  @functools.wraps(_func)
1271
- def wrapped(*args, **kwargs):
1233
+ def wrapped_classkernel(*args, **kwargs):
1272
1234
  # If we reach here (we should never), it means the class is not decorated
1273
1235
  # with @ti.data_oriented, otherwise getattr would have intercepted the call.
1274
1236
  clsobj = type(args[0])
1275
1237
  assert not hasattr(clsobj, "_data_oriented")
1276
- raise TaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
1238
+ raise GsTaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
1277
1239
 
1240
+ wrapped = GsTaichiCallable(_func, wrapped_classkernel)
1278
1241
  else:
1279
1242
 
1280
1243
  @functools.wraps(_func)
1281
- def wrapped(*args, **kwargs):
1244
+ def wrapped_func(*args, **kwargs):
1282
1245
  try:
1283
1246
  return primal(*args, **kwargs)
1284
- except (TaichiCompilationError, TaichiRuntimeError) as e:
1247
+ except (GsTaichiCompilationError, GsTaichiRuntimeError) as e:
1285
1248
  if impl.get_runtime().print_full_traceback:
1286
1249
  raise e
1287
1250
  raise type(e)("\n" + str(e)) from None
1288
1251
 
1252
+ wrapped = GsTaichiCallable(_func, wrapped_func)
1289
1253
  wrapped.grad = adjoint
1290
1254
 
1291
1255
  wrapped._is_wrapped_kernel = True
@@ -1296,10 +1260,10 @@ def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool
1296
1260
 
1297
1261
 
1298
1262
  def kernel(fn: Callable):
1299
- """Marks a function as a Taichi kernel.
1263
+ """Marks a function as a GsTaichi kernel.
1300
1264
 
1301
- A Taichi kernel is a function written in Python, and gets JIT compiled by
1302
- Taichi into native CPU/GPU instructions (e.g. a series of CUDA kernels).
1265
+ A GsTaichi kernel is a function written in Python, and gets JIT compiled by
1266
+ GsTaichi into native CPU/GPU instructions (e.g. a series of CUDA kernels).
1303
1267
  The top-level ``for`` loops are automatically parallelized, and distributed
1304
1268
  to either a CPU thread pool or massively parallel GPUs.
1305
1269
 
@@ -1327,10 +1291,10 @@ def kernel(fn: Callable):
1327
1291
 
1328
1292
 
1329
1293
  class _BoundedDifferentiableMethod:
1330
- def __init__(self, kernel_owner, wrapped_kernel_func):
1294
+ def __init__(self, kernel_owner: Any, wrapped_kernel_func: GsTaichiCallable | BoundGsTaichiCallable):
1331
1295
  clsobj = type(kernel_owner)
1332
1296
  if not getattr(clsobj, "_data_oriented", False):
1333
- raise TaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
1297
+ raise GsTaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
1334
1298
  self._kernel_owner = kernel_owner
1335
1299
  self._primal = wrapped_kernel_func._primal
1336
1300
  self._adjoint = wrapped_kernel_func._adjoint
@@ -1339,23 +1303,26 @@ class _BoundedDifferentiableMethod:
1339
1303
 
1340
1304
  def __call__(self, *args, **kwargs):
1341
1305
  try:
1306
+ assert self._primal is not None
1342
1307
  if self._is_staticmethod:
1343
1308
  return self._primal(*args, **kwargs)
1344
1309
  return self._primal(self._kernel_owner, *args, **kwargs)
1345
- except (TaichiCompilationError, TaichiRuntimeError) as e:
1310
+
1311
+ except (GsTaichiCompilationError, GsTaichiRuntimeError) as e:
1346
1312
  if impl.get_runtime().print_full_traceback:
1347
1313
  raise e
1348
1314
  raise type(e)("\n" + str(e)) from None
1349
1315
 
1350
- def grad(self, *args, **kwargs):
1316
+ def grad(self, *args, **kwargs) -> Kernel:
1317
+ assert self._adjoint is not None
1351
1318
  return self._adjoint(self._kernel_owner, *args, **kwargs)
1352
1319
 
1353
1320
 
1354
1321
  def data_oriented(cls):
1355
- """Marks a class as Taichi compatible.
1322
+ """Marks a class as GsTaichi compatible.
1356
1323
 
1357
- To allow for modularized code, Taichi provides this decorator so that
1358
- Taichi kernels can be defined inside a class.
1324
+ To allow for modularized code, GsTaichi provides this decorator so that
1325
+ GsTaichi kernels can be defined inside a class.
1359
1326
 
1360
1327
  See also https://docs.taichi-lang.org/docs/odop
1361
1328
 
@@ -1394,11 +1361,11 @@ def data_oriented(cls):
1394
1361
  wrapped = x.__func__
1395
1362
  else:
1396
1363
  wrapped = x
1364
+ assert isinstance(wrapped, (BoundGsTaichiCallable, GsTaichiCallable))
1397
1365
  wrapped._is_staticmethod = is_staticmethod
1398
- assert inspect.isfunction(wrapped)
1399
1366
  if wrapped._is_classkernel:
1400
1367
  ret = _BoundedDifferentiableMethod(self, wrapped)
1401
- ret.__name__ = wrapped.__name__
1368
+ ret.__name__ = wrapped.__name__ # type: ignore
1402
1369
  if is_property:
1403
1370
  return ret()
1404
1371
  return ret