gstaichi 2.1.1rc3__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1341 @@
|
|
1
|
+
import ast
|
2
|
+
import dataclasses
|
3
|
+
import functools
|
4
|
+
import inspect
|
5
|
+
import json
|
6
|
+
import operator
|
7
|
+
import os
|
8
|
+
import pathlib
|
9
|
+
import re
|
10
|
+
import sys
|
11
|
+
import textwrap
|
12
|
+
import time
|
13
|
+
import types
|
14
|
+
import typing
|
15
|
+
import warnings
|
16
|
+
from typing import Any, Callable, Type
|
17
|
+
|
18
|
+
import numpy as np
|
19
|
+
|
20
|
+
import gstaichi.lang
|
21
|
+
import gstaichi.lang._ndarray
|
22
|
+
import gstaichi.lang._texture
|
23
|
+
import gstaichi.types.annotations
|
24
|
+
from gstaichi import _logging
|
25
|
+
from gstaichi._lib import core as _ti_core
|
26
|
+
from gstaichi._lib.core.gstaichi_python import (
|
27
|
+
ASTBuilder,
|
28
|
+
CompiledKernelData,
|
29
|
+
FunctionKey,
|
30
|
+
KernelCxx,
|
31
|
+
KernelLaunchContext,
|
32
|
+
)
|
33
|
+
from gstaichi.lang import _kernel_impl_dataclass, impl, ops, runtime_ops
|
34
|
+
from gstaichi.lang._fast_caching import src_hasher
|
35
|
+
from gstaichi.lang._template_mapper import TemplateMapper
|
36
|
+
from gstaichi.lang._wrap_inspect import FunctionSourceInfo, get_source_info_and_src
|
37
|
+
from gstaichi.lang.any_array import AnyArray
|
38
|
+
from gstaichi.lang.ast import (
|
39
|
+
ASTTransformerContext,
|
40
|
+
KernelSimplicityASTChecker,
|
41
|
+
transform_tree,
|
42
|
+
)
|
43
|
+
from gstaichi.lang.ast.ast_transformer_utils import ReturnStatus
|
44
|
+
from gstaichi.lang.exception import (
|
45
|
+
GsTaichiCompilationError,
|
46
|
+
GsTaichiRuntimeError,
|
47
|
+
GsTaichiRuntimeTypeError,
|
48
|
+
GsTaichiSyntaxError,
|
49
|
+
GsTaichiTypeError,
|
50
|
+
handle_exception_from_cpp,
|
51
|
+
)
|
52
|
+
from gstaichi.lang.expr import Expr
|
53
|
+
from gstaichi.lang.kernel_arguments import ArgMetadata
|
54
|
+
from gstaichi.lang.matrix import MatrixType
|
55
|
+
from gstaichi.lang.shell import _shell_pop_print
|
56
|
+
from gstaichi.lang.struct import StructType
|
57
|
+
from gstaichi.lang.util import cook_dtype, has_pytorch
|
58
|
+
from gstaichi.types import (
|
59
|
+
ndarray_type,
|
60
|
+
primitive_types,
|
61
|
+
sparse_matrix_builder,
|
62
|
+
template,
|
63
|
+
texture_type,
|
64
|
+
)
|
65
|
+
from gstaichi.types.compound_types import CompoundType
|
66
|
+
from gstaichi.types.enums import AutodiffMode, Layout
|
67
|
+
from gstaichi.types.utils import is_signed
|
68
|
+
|
69
|
+
CompiledKernelKeyType = tuple[Callable, int, AutodiffMode]
|
70
|
+
|
71
|
+
|
72
|
+
class GsTaichiCallable:
|
73
|
+
"""
|
74
|
+
BoundGsTaichiCallable is used to enable wrapping a bindable function with a class.
|
75
|
+
|
76
|
+
Design requirements for GsTaichiCallable:
|
77
|
+
- wrap/contain a reference to a class Func instance, and allow (the GsTaichiCallable) being passed around
|
78
|
+
like normal function pointer
|
79
|
+
- expose attributes of the wrapped class Func, such as `_if_real_function`, `_primal`, etc
|
80
|
+
- allow for (now limited) strong typing, and enable type checkers, such as pyright/mypy
|
81
|
+
- currently GsTaichiCallable is a shared type used for all functions marked with @ti.func, @ti.kernel,
|
82
|
+
python functions (?)
|
83
|
+
- note: current type-checking implementation does not distinguish between different type flavors of
|
84
|
+
GsTaichiCallable, with different values of `_if_real_function`, `_primal`, etc
|
85
|
+
- handle not only class-less functions, but also class-instance methods (where determining the `self`
|
86
|
+
reference is a challenge)
|
87
|
+
|
88
|
+
Let's take the following example:
|
89
|
+
|
90
|
+
def test_ptr_class_func():
|
91
|
+
@ti.data_oriented
|
92
|
+
class MyClass:
|
93
|
+
def __init__(self):
|
94
|
+
self.a = ti.field(dtype=ti.f32, shape=(3))
|
95
|
+
|
96
|
+
def add2numbers_py(self, x, y):
|
97
|
+
return x + y
|
98
|
+
|
99
|
+
@ti.func
|
100
|
+
def add2numbers_func(self, x, y):
|
101
|
+
return x + y
|
102
|
+
|
103
|
+
@ti.kernel
|
104
|
+
def func(self):
|
105
|
+
a, add_py, add_func = ti.static(self.a, self.add2numbers_py, self.add2numbers_func)
|
106
|
+
a[0] = add_py(2, 3)
|
107
|
+
a[1] = add_func(3, 7)
|
108
|
+
|
109
|
+
(taken from test_ptr_assign.py).
|
110
|
+
|
111
|
+
When the @ti.func decorator is parsed, the function `add2numbers_func` exists, but there is not yet any `self`
|
112
|
+
- it is not possible for the method to be bound, to a `self` instance
|
113
|
+
- however, the @ti.func annotation, runs the kernel_imp.py::func function --- it is at this point
|
114
|
+
that GsTaichi's original code creates a class Func instance (that wraps the add2numbers_func)
|
115
|
+
and immediately we create a GsTaichiCallable instance that wraps the Func instance.
|
116
|
+
- effectively, we have two layers of wrapping GsTaichiCallable->Func->function pointer
|
117
|
+
(actual function definition)
|
118
|
+
- later on, when we call self.add2numbers_py, here:
|
119
|
+
|
120
|
+
a, add_py, add_func = ti.static(self.a, self.add2numbers_py, self.add2numbers_func)
|
121
|
+
|
122
|
+
... we want to call the bound method, `self.add2numbers_py`.
|
123
|
+
- an actual python function reference, created by doing somevar = MyClass.add2numbers, can automatically
|
124
|
+
binds to self, when called from self in this way (however, add2numbers_py is actually a class
|
125
|
+
Func instance, wrapping python function reference -- now also all wrapped by a GsTaichiCallable
|
126
|
+
instance -- returned by the kernel_impl.py::func function, run by @ti.func)
|
127
|
+
- however, in order to be able to add strongly typed attributes to the wrapped python function, we need
|
128
|
+
to wrap the wrapped python function in a class
|
129
|
+
- the wrapped python function, wrapped in a GsTaichiCallable class (which is callable, and will
|
130
|
+
execute the underlying double-wrapped python function), will NOT automatically bind
|
131
|
+
- when we invoke GsTaichiCallable, the wrapped function is invoked. The wrapped function is unbound, and
|
132
|
+
so `self` is not automatically passed in, as an argument, and things break
|
133
|
+
|
134
|
+
To address this we need to use the `__get__` method, in our function wrapper, ie GsTaichiCallable,
|
135
|
+
and have the `__get__` method return the `BoundGsTaichiCallable` object. The `__get__` method handles
|
136
|
+
running the binding for us, and effectively binds `BoundFunc` object to `self` object, by passing
|
137
|
+
in the instance, as an argument into `BoundGsTaichiCallable.__init__`.
|
138
|
+
|
139
|
+
`BoundFunc` can then be used as a normal bound func - even though it's just an object instance -
|
140
|
+
using its `__call__` method. Effectively, at the time of actually invoking the underlying python
|
141
|
+
function, we have 3 layers of wrapper instances:
|
142
|
+
BoundGsTaichiCallabe -> GsTaichiCallable -> Func -> python function reference/definition
|
143
|
+
"""
|
144
|
+
|
145
|
+
def __init__(self, fn: Callable, wrapper: Callable) -> None:
|
146
|
+
self.fn: Callable = fn
|
147
|
+
self.wrapper: Callable = wrapper
|
148
|
+
self._is_real_function: bool = False
|
149
|
+
self._is_gstaichi_function: bool = False
|
150
|
+
self._is_wrapped_kernel: bool = False
|
151
|
+
self._is_classkernel: bool = False
|
152
|
+
self._primal: Kernel | None = None
|
153
|
+
self._adjoint: Kernel | None = None
|
154
|
+
self.grad: Kernel | None = None
|
155
|
+
self._is_staticmethod: bool = False
|
156
|
+
self.is_pure = False
|
157
|
+
functools.update_wrapper(self, fn)
|
158
|
+
|
159
|
+
def __call__(self, *args, **kwargs):
|
160
|
+
return self.wrapper.__call__(*args, **kwargs)
|
161
|
+
|
162
|
+
def __get__(self, instance, owner):
|
163
|
+
if instance is None:
|
164
|
+
return self
|
165
|
+
return BoundGsTaichiCallable(instance, self)
|
166
|
+
|
167
|
+
|
168
|
+
class BoundGsTaichiCallable:
|
169
|
+
def __init__(self, instance: Any, gstaichi_callable: "GsTaichiCallable"):
|
170
|
+
self.wrapper = gstaichi_callable.wrapper
|
171
|
+
self.instance = instance
|
172
|
+
self.gstaichi_callable = gstaichi_callable
|
173
|
+
|
174
|
+
def __call__(self, *args, **kwargs):
|
175
|
+
return self.wrapper(self.instance, *args, **kwargs)
|
176
|
+
|
177
|
+
def __getattr__(self, k: str) -> Any:
|
178
|
+
res = getattr(self.gstaichi_callable, k)
|
179
|
+
return res
|
180
|
+
|
181
|
+
def __setattr__(self, k: str, v: Any) -> None:
|
182
|
+
# Note: these have to match the name of any attributes on this class.
|
183
|
+
if k in ("wrapper", "instance", "gstaichi_callable"):
|
184
|
+
object.__setattr__(self, k, v)
|
185
|
+
else:
|
186
|
+
setattr(self.gstaichi_callable, k, v)
|
187
|
+
|
188
|
+
|
189
|
+
def func(fn: Callable, is_real_function: bool = False) -> GsTaichiCallable:
|
190
|
+
"""Marks a function as callable in GsTaichi-scope.
|
191
|
+
|
192
|
+
This decorator transforms a Python function into a GsTaichi one. GsTaichi
|
193
|
+
will JIT compile it into native instructions.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
fn (Callable): The Python function to be decorated
|
197
|
+
is_real_function (bool): Whether the function is a real function
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
Callable: The decorated function
|
201
|
+
|
202
|
+
Example::
|
203
|
+
|
204
|
+
>>> @ti.func
|
205
|
+
>>> def foo(x):
|
206
|
+
>>> return x + 2
|
207
|
+
>>>
|
208
|
+
>>> @ti.kernel
|
209
|
+
>>> def run():
|
210
|
+
>>> print(foo(40)) # 42
|
211
|
+
"""
|
212
|
+
is_classfunc = _inside_class(level_of_class_stackframe=3 + is_real_function)
|
213
|
+
|
214
|
+
fun = Func(fn, _classfunc=is_classfunc, is_real_function=is_real_function)
|
215
|
+
gstaichi_callable = GsTaichiCallable(fn, fun)
|
216
|
+
gstaichi_callable._is_gstaichi_function = True
|
217
|
+
gstaichi_callable._is_real_function = is_real_function
|
218
|
+
return gstaichi_callable
|
219
|
+
|
220
|
+
|
221
|
+
def real_func(fn: Callable) -> GsTaichiCallable:
|
222
|
+
return func(fn, is_real_function=True)
|
223
|
+
|
224
|
+
|
225
|
+
def pyfunc(fn: Callable) -> GsTaichiCallable:
|
226
|
+
"""Marks a function as callable in both GsTaichi and Python scopes.
|
227
|
+
|
228
|
+
When called inside the GsTaichi scope, GsTaichi will JIT compile it into
|
229
|
+
native instructions. Otherwise it will be invoked directly as a
|
230
|
+
Python function.
|
231
|
+
|
232
|
+
See also :func:`~gstaichi.lang.kernel_impl.func`.
|
233
|
+
|
234
|
+
Args:
|
235
|
+
fn (Callable): The Python function to be decorated
|
236
|
+
|
237
|
+
Returns:
|
238
|
+
Callable: The decorated function
|
239
|
+
"""
|
240
|
+
is_classfunc = _inside_class(level_of_class_stackframe=3)
|
241
|
+
fun = Func(fn, _classfunc=is_classfunc, _pyfunc=True)
|
242
|
+
gstaichi_callable = GsTaichiCallable(fn, fun)
|
243
|
+
gstaichi_callable._is_gstaichi_function = True
|
244
|
+
gstaichi_callable._is_real_function = False
|
245
|
+
return gstaichi_callable
|
246
|
+
|
247
|
+
|
248
|
+
def _populate_global_vars_for_templates(
|
249
|
+
template_slot_locations: list[int],
|
250
|
+
argument_metas: list[ArgMetadata],
|
251
|
+
global_vars: dict[str, Any],
|
252
|
+
fn: Callable,
|
253
|
+
py_args: tuple[Any, ...],
|
254
|
+
):
|
255
|
+
"""
|
256
|
+
Inject template parameters into globals
|
257
|
+
|
258
|
+
Globals are being abused to store the python objects associated
|
259
|
+
with templates. We continue this approach, and in addition this function
|
260
|
+
handles injecting expanded python variables from dataclasses.
|
261
|
+
"""
|
262
|
+
for i in template_slot_locations:
|
263
|
+
template_var_name = argument_metas[i].name
|
264
|
+
global_vars[template_var_name] = py_args[i]
|
265
|
+
parameters = inspect.signature(fn).parameters
|
266
|
+
for i, (parameter_name, parameter) in enumerate(parameters.items()):
|
267
|
+
if dataclasses.is_dataclass(parameter.annotation):
|
268
|
+
_kernel_impl_dataclass.populate_global_vars_from_dataclass(
|
269
|
+
parameter_name,
|
270
|
+
parameter.annotation,
|
271
|
+
py_args[i],
|
272
|
+
global_vars=global_vars,
|
273
|
+
)
|
274
|
+
|
275
|
+
|
276
|
+
def _get_tree_and_ctx(
|
277
|
+
self: "Func | Kernel",
|
278
|
+
args: tuple[Any, ...],
|
279
|
+
excluded_parameters=(),
|
280
|
+
is_kernel: bool = True,
|
281
|
+
arg_features=None,
|
282
|
+
ast_builder: "ASTBuilder | None" = None,
|
283
|
+
is_real_function: bool = False,
|
284
|
+
current_kernel: "Kernel | None" = None,
|
285
|
+
) -> tuple[ast.Module, ASTTransformerContext]:
|
286
|
+
function_source_info, src = get_source_info_and_src(self.func)
|
287
|
+
src = [textwrap.fill(line, tabsize=4, width=9999) for line in src]
|
288
|
+
tree = ast.parse(textwrap.dedent("\n".join(src)))
|
289
|
+
|
290
|
+
func_body = tree.body[0]
|
291
|
+
func_body.decorator_list = [] # type: ignore , kick that can down the road...
|
292
|
+
|
293
|
+
global_vars = _get_global_vars(self.func)
|
294
|
+
|
295
|
+
if is_kernel or is_real_function:
|
296
|
+
_populate_global_vars_for_templates(
|
297
|
+
template_slot_locations=self.template_slot_locations,
|
298
|
+
argument_metas=self.arg_metas,
|
299
|
+
global_vars=global_vars,
|
300
|
+
fn=self.func,
|
301
|
+
py_args=args,
|
302
|
+
)
|
303
|
+
|
304
|
+
if current_kernel is not None: # Kernel
|
305
|
+
current_kernel.kernel_function_info = function_source_info
|
306
|
+
if current_kernel is None:
|
307
|
+
current_kernel = impl.get_runtime()._current_kernel
|
308
|
+
assert current_kernel is not None
|
309
|
+
current_kernel.visited_functions.add(function_source_info)
|
310
|
+
|
311
|
+
return tree, ASTTransformerContext(
|
312
|
+
excluded_parameters=excluded_parameters,
|
313
|
+
is_kernel=is_kernel,
|
314
|
+
func=self,
|
315
|
+
arg_features=arg_features,
|
316
|
+
global_vars=global_vars,
|
317
|
+
argument_data=args,
|
318
|
+
src=src,
|
319
|
+
start_lineno=function_source_info.start_lineno,
|
320
|
+
end_lineno=function_source_info.end_lineno,
|
321
|
+
file=function_source_info.filepath,
|
322
|
+
ast_builder=ast_builder,
|
323
|
+
is_real_function=is_real_function,
|
324
|
+
)
|
325
|
+
|
326
|
+
|
327
|
+
def _process_args(self: "Func | Kernel", is_func: bool, args: tuple[Any, ...], kwargs) -> tuple[Any, ...]:
|
328
|
+
if is_func:
|
329
|
+
self.arg_metas = _kernel_impl_dataclass.expand_func_arguments(self.arg_metas)
|
330
|
+
|
331
|
+
fused_args: list[Any] = [arg_meta.default for arg_meta in self.arg_metas]
|
332
|
+
len_args = len(args)
|
333
|
+
|
334
|
+
if len_args > len(fused_args):
|
335
|
+
arg_str = ", ".join(map(str, args))
|
336
|
+
expected_str = ", ".join(f"{arg.name} : {arg.annotation}" for arg in self.arg_metas)
|
337
|
+
msg = f"Too many arguments. Expected ({expected_str}), got ({arg_str})."
|
338
|
+
raise GsTaichiSyntaxError(msg)
|
339
|
+
|
340
|
+
for i, arg in enumerate(args):
|
341
|
+
fused_args[i] = arg
|
342
|
+
|
343
|
+
for key, value in kwargs.items():
|
344
|
+
for i, arg in enumerate(self.arg_metas):
|
345
|
+
if key == arg.name:
|
346
|
+
if i < len_args:
|
347
|
+
raise GsTaichiSyntaxError(f"Multiple values for argument '{key}'.")
|
348
|
+
fused_args[i] = value
|
349
|
+
break
|
350
|
+
else:
|
351
|
+
raise GsTaichiSyntaxError(f"Unexpected argument '{key}'.")
|
352
|
+
|
353
|
+
for i, arg in enumerate(fused_args):
|
354
|
+
if arg is inspect.Parameter.empty:
|
355
|
+
if self.arg_metas[i].annotation is inspect._empty:
|
356
|
+
raise GsTaichiSyntaxError(f"Parameter `{self.arg_metas[i].name}` missing.")
|
357
|
+
else:
|
358
|
+
raise GsTaichiSyntaxError(
|
359
|
+
f"Parameter `{self.arg_metas[i].name} : {self.arg_metas[i].annotation}` missing."
|
360
|
+
)
|
361
|
+
|
362
|
+
return tuple(fused_args)
|
363
|
+
|
364
|
+
|
365
|
+
class Func:
|
366
|
+
function_counter = 0
|
367
|
+
|
368
|
+
def __init__(self, _func: Callable, _classfunc=False, _pyfunc=False, is_real_function=False) -> None:
|
369
|
+
self.func = _func
|
370
|
+
self.func_id = Func.function_counter
|
371
|
+
Func.function_counter += 1
|
372
|
+
self.compiled = {}
|
373
|
+
self.classfunc = _classfunc
|
374
|
+
self.pyfunc = _pyfunc
|
375
|
+
self.is_real_function = is_real_function
|
376
|
+
self.arg_metas: list[ArgMetadata] = []
|
377
|
+
self.orig_arguments: list[ArgMetadata] = []
|
378
|
+
self.return_type: tuple[Type, ...] | None = None
|
379
|
+
self.extract_arguments()
|
380
|
+
self.template_slot_locations: list[int] = []
|
381
|
+
for i, arg in enumerate(self.arg_metas):
|
382
|
+
if arg.annotation == template or isinstance(arg.annotation, template):
|
383
|
+
self.template_slot_locations.append(i)
|
384
|
+
self.mapper = TemplateMapper(self.arg_metas, self.template_slot_locations)
|
385
|
+
self.gstaichi_functions = {} # The |Function| class in C++
|
386
|
+
self.has_print = False
|
387
|
+
|
388
|
+
def __call__(self: "Func", *args, **kwargs) -> Any:
|
389
|
+
args = _process_args(self, is_func=True, args=args, kwargs=kwargs)
|
390
|
+
|
391
|
+
if not impl.inside_kernel():
|
392
|
+
if not self.pyfunc:
|
393
|
+
raise GsTaichiSyntaxError("GsTaichi functions cannot be called from Python-scope.")
|
394
|
+
return self.func(*args)
|
395
|
+
|
396
|
+
current_kernel = impl.get_runtime().current_kernel
|
397
|
+
if self.is_real_function:
|
398
|
+
if current_kernel.autodiff_mode != AutodiffMode.NONE:
|
399
|
+
raise GsTaichiSyntaxError("Real function in gradient kernels unsupported.")
|
400
|
+
instance_id, arg_features = self.mapper.lookup(args)
|
401
|
+
key = _ti_core.FunctionKey(self.func.__name__, self.func_id, instance_id)
|
402
|
+
if key.instance_id not in self.compiled:
|
403
|
+
self.do_compile(key=key, args=args, arg_features=arg_features)
|
404
|
+
return self.func_call_rvalue(key=key, args=args)
|
405
|
+
tree, ctx = _get_tree_and_ctx(
|
406
|
+
self,
|
407
|
+
is_kernel=False,
|
408
|
+
args=args,
|
409
|
+
ast_builder=current_kernel.ast_builder(),
|
410
|
+
is_real_function=self.is_real_function,
|
411
|
+
)
|
412
|
+
|
413
|
+
struct_locals = _kernel_impl_dataclass.extract_struct_locals_from_context(ctx)
|
414
|
+
|
415
|
+
tree = _kernel_impl_dataclass.unpack_ast_struct_expressions(tree, struct_locals=struct_locals)
|
416
|
+
ret = transform_tree(tree, ctx)
|
417
|
+
if not self.is_real_function:
|
418
|
+
if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
|
419
|
+
raise GsTaichiSyntaxError("Function has a return type but does not have a return statement")
|
420
|
+
return ret
|
421
|
+
|
422
|
+
def func_call_rvalue(self, key: FunctionKey, args: tuple[Any, ...]) -> Any:
|
423
|
+
# Skip the template args, e.g., |self|
|
424
|
+
assert self.is_real_function
|
425
|
+
non_template_args = []
|
426
|
+
dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
427
|
+
for i, kernel_arg in enumerate(self.arg_metas):
|
428
|
+
anno = kernel_arg.annotation
|
429
|
+
if not isinstance(anno, template):
|
430
|
+
if id(anno) in primitive_types.type_ids:
|
431
|
+
non_template_args.append(ops.cast(args[i], anno))
|
432
|
+
elif isinstance(anno, primitive_types.RefType):
|
433
|
+
non_template_args.append(_ti_core.make_reference(args[i].ptr, dbg_info))
|
434
|
+
elif isinstance(anno, ndarray_type.NdarrayType):
|
435
|
+
if not isinstance(args[i], AnyArray):
|
436
|
+
raise GsTaichiTypeError(
|
437
|
+
f"Expected ndarray in the kernel argument for argument {kernel_arg.name}, got {args[i]}"
|
438
|
+
)
|
439
|
+
non_template_args += _ti_core.get_external_tensor_real_func_args(args[i].ptr, dbg_info)
|
440
|
+
else:
|
441
|
+
non_template_args.append(args[i])
|
442
|
+
non_template_args = impl.make_expr_group(non_template_args)
|
443
|
+
compiling_callable = impl.get_runtime().compiling_callable
|
444
|
+
assert compiling_callable is not None
|
445
|
+
func_call = compiling_callable.ast_builder().insert_func_call(
|
446
|
+
self.gstaichi_functions[key.instance_id], non_template_args, dbg_info
|
447
|
+
)
|
448
|
+
if self.return_type is None:
|
449
|
+
return None
|
450
|
+
func_call = Expr(func_call)
|
451
|
+
ret = []
|
452
|
+
|
453
|
+
for i, return_type in enumerate(self.return_type):
|
454
|
+
if id(return_type) in primitive_types.type_ids:
|
455
|
+
ret.append(
|
456
|
+
Expr(
|
457
|
+
_ti_core.make_get_element_expr(
|
458
|
+
func_call.ptr, (i,), _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
|
459
|
+
)
|
460
|
+
)
|
461
|
+
)
|
462
|
+
elif isinstance(return_type, (StructType, MatrixType)):
|
463
|
+
ret.append(return_type.from_gstaichi_object(func_call, (i,)))
|
464
|
+
else:
|
465
|
+
raise GsTaichiTypeError(f"Unsupported return type for return value {i}: {return_type}")
|
466
|
+
if len(ret) == 1:
|
467
|
+
return ret[0]
|
468
|
+
return tuple(ret)
|
469
|
+
|
470
|
+
def do_compile(self, key: FunctionKey, args: tuple[Any, ...], arg_features: tuple[Any, ...]) -> None:
|
471
|
+
tree, ctx = _get_tree_and_ctx(
|
472
|
+
self, is_kernel=False, args=args, arg_features=arg_features, is_real_function=self.is_real_function
|
473
|
+
)
|
474
|
+
fn = impl.get_runtime().prog.create_function(key)
|
475
|
+
|
476
|
+
def func_body():
|
477
|
+
old_callable = impl.get_runtime().compiling_callable
|
478
|
+
impl.get_runtime()._compiling_callable = fn
|
479
|
+
ctx.ast_builder = fn.ast_builder()
|
480
|
+
transform_tree(tree, ctx)
|
481
|
+
impl.get_runtime()._compiling_callable = old_callable
|
482
|
+
|
483
|
+
self.gstaichi_functions[key.instance_id] = fn
|
484
|
+
self.compiled[key.instance_id] = func_body
|
485
|
+
self.gstaichi_functions[key.instance_id].set_function_body(func_body)
|
486
|
+
|
487
|
+
def extract_arguments(self) -> None:
|
488
|
+
sig = inspect.signature(self.func)
|
489
|
+
if sig.return_annotation not in (inspect.Signature.empty, None):
|
490
|
+
self.return_type = sig.return_annotation
|
491
|
+
if (
|
492
|
+
isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias)) # type: ignore
|
493
|
+
and self.return_type.__origin__ is tuple # type: ignore
|
494
|
+
):
|
495
|
+
self.return_type = self.return_type.__args__ # type: ignore
|
496
|
+
if self.return_type is None:
|
497
|
+
return
|
498
|
+
if not isinstance(self.return_type, (list, tuple)):
|
499
|
+
self.return_type = (self.return_type,)
|
500
|
+
for i, return_type in enumerate(self.return_type):
|
501
|
+
if return_type is Ellipsis:
|
502
|
+
raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
|
503
|
+
params = sig.parameters
|
504
|
+
arg_names = params.keys()
|
505
|
+
for i, arg_name in enumerate(arg_names):
|
506
|
+
param = params[arg_name]
|
507
|
+
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
508
|
+
raise GsTaichiSyntaxError(
|
509
|
+
"GsTaichi functions do not support variable keyword parameters (i.e., **kwargs)"
|
510
|
+
)
|
511
|
+
if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
512
|
+
raise GsTaichiSyntaxError(
|
513
|
+
"GsTaichi functions do not support variable positional parameters (i.e., *args)"
|
514
|
+
)
|
515
|
+
if param.kind == inspect.Parameter.KEYWORD_ONLY:
|
516
|
+
raise GsTaichiSyntaxError("GsTaichi functions do not support keyword parameters")
|
517
|
+
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
518
|
+
raise GsTaichiSyntaxError('GsTaichi functions only support "positional or keyword" parameters')
|
519
|
+
annotation = param.annotation
|
520
|
+
if annotation is inspect.Parameter.empty:
|
521
|
+
if i == 0 and self.classfunc:
|
522
|
+
annotation = template()
|
523
|
+
# TODO: pyfunc also need type annotation check when real function is enabled,
|
524
|
+
# but that has to happen at runtime when we know which scope it's called from.
|
525
|
+
elif not self.pyfunc and self.is_real_function:
|
526
|
+
raise GsTaichiSyntaxError(
|
527
|
+
f"GsTaichi function `{self.func.__name__}` parameter `{arg_name}` must be type annotated"
|
528
|
+
)
|
529
|
+
else:
|
530
|
+
if isinstance(annotation, ndarray_type.NdarrayType):
|
531
|
+
pass
|
532
|
+
elif isinstance(annotation, MatrixType):
|
533
|
+
pass
|
534
|
+
elif isinstance(annotation, StructType):
|
535
|
+
pass
|
536
|
+
elif id(annotation) in primitive_types.type_ids:
|
537
|
+
pass
|
538
|
+
elif type(annotation) == gstaichi.types.annotations.Template:
|
539
|
+
pass
|
540
|
+
elif isinstance(annotation, template) or annotation == gstaichi.types.annotations.Template:
|
541
|
+
pass
|
542
|
+
elif isinstance(annotation, primitive_types.RefType):
|
543
|
+
pass
|
544
|
+
elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
|
545
|
+
pass
|
546
|
+
else:
|
547
|
+
raise GsTaichiSyntaxError(
|
548
|
+
f"Invalid type annotation (argument {i}) of GsTaichi function: {annotation}"
|
549
|
+
)
|
550
|
+
self.arg_metas.append(ArgMetadata(annotation, param.name, param.default))
|
551
|
+
self.orig_arguments.append(ArgMetadata(annotation, param.name, param.default))
|
552
|
+
|
553
|
+
|
554
|
+
def _get_global_vars(_func: Callable) -> dict[str, Any]:
|
555
|
+
# Discussions: https://github.com/taichi-dev/gstaichi/issues/282
|
556
|
+
global_vars = _func.__globals__.copy()
|
557
|
+
|
558
|
+
freevar_names = _func.__code__.co_freevars
|
559
|
+
closure = _func.__closure__
|
560
|
+
if closure:
|
561
|
+
freevar_values = list(map(lambda x: x.cell_contents, closure))
|
562
|
+
for name, value in zip(freevar_names, freevar_values):
|
563
|
+
global_vars[name] = value
|
564
|
+
|
565
|
+
return global_vars
|
566
|
+
|
567
|
+
|
568
|
+
@dataclasses.dataclass
|
569
|
+
class SrcLlCacheObservations:
|
570
|
+
cache_key_generated: bool = False
|
571
|
+
cache_validated: bool = False
|
572
|
+
cache_loaded: bool = False
|
573
|
+
cache_stored: bool = False
|
574
|
+
|
575
|
+
|
576
|
+
class Kernel:
|
577
|
+
counter = 0
|
578
|
+
|
579
|
+
def __init__(self, _func: Callable, autodiff_mode: AutodiffMode, _classkernel=False) -> None:
|
580
|
+
self.func = _func
|
581
|
+
self.kernel_counter = Kernel.counter
|
582
|
+
Kernel.counter += 1
|
583
|
+
assert autodiff_mode in (
|
584
|
+
AutodiffMode.NONE,
|
585
|
+
AutodiffMode.VALIDATION,
|
586
|
+
AutodiffMode.FORWARD,
|
587
|
+
AutodiffMode.REVERSE,
|
588
|
+
)
|
589
|
+
self.autodiff_mode = autodiff_mode
|
590
|
+
self.grad: "Kernel | None" = None
|
591
|
+
self.arg_metas: list[ArgMetadata] = []
|
592
|
+
self.return_type = None
|
593
|
+
self.classkernel = _classkernel
|
594
|
+
self.extract_arguments()
|
595
|
+
self.template_slot_locations = []
|
596
|
+
for i, arg in enumerate(self.arg_metas):
|
597
|
+
if arg.annotation == template or isinstance(arg.annotation, template):
|
598
|
+
self.template_slot_locations.append(i)
|
599
|
+
self.mapper = TemplateMapper(self.arg_metas, self.template_slot_locations)
|
600
|
+
impl.get_runtime().kernels.append(self)
|
601
|
+
self.reset()
|
602
|
+
self.kernel_cpp = None
|
603
|
+
# A materialized kernel is a KernelCxx object which may or may not have
|
604
|
+
# been compiled. It generally has been converted at least as far as AST
|
605
|
+
# and front-end IR, but not necessarily any further.
|
606
|
+
self.materialized_kernels: dict[CompiledKernelKeyType, KernelCxx] = {}
|
607
|
+
self.has_print = False
|
608
|
+
self.gstaichi_callable: GsTaichiCallable | None = None
|
609
|
+
self.visited_functions: set[FunctionSourceInfo] = set()
|
610
|
+
self.kernel_function_info: FunctionSourceInfo | None = None
|
611
|
+
self.compiled_kernel_data_by_key: dict[CompiledKernelKeyType, CompiledKernelData] = {}
|
612
|
+
|
613
|
+
self.src_ll_cache_observations: SrcLlCacheObservations = SrcLlCacheObservations()
|
614
|
+
|
615
|
+
def ast_builder(self) -> ASTBuilder:
|
616
|
+
assert self.kernel_cpp is not None
|
617
|
+
return self.kernel_cpp.ast_builder()
|
618
|
+
|
619
|
+
def reset(self) -> None:
|
620
|
+
self.runtime = impl.get_runtime()
|
621
|
+
self.materialized_kernels = {}
|
622
|
+
|
623
|
+
def extract_arguments(self) -> None:
|
624
|
+
sig = inspect.signature(self.func)
|
625
|
+
if sig.return_annotation not in (inspect._empty, None):
|
626
|
+
self.return_type = sig.return_annotation
|
627
|
+
if (
|
628
|
+
isinstance(self.return_type, (types.GenericAlias, typing._GenericAlias)) # type: ignore
|
629
|
+
and self.return_type.__origin__ is tuple
|
630
|
+
):
|
631
|
+
self.return_type = self.return_type.__args__
|
632
|
+
if not isinstance(self.return_type, (list, tuple)):
|
633
|
+
self.return_type = (self.return_type,)
|
634
|
+
for return_type in self.return_type:
|
635
|
+
if return_type is Ellipsis:
|
636
|
+
raise GsTaichiSyntaxError("Ellipsis is not supported in return type annotations")
|
637
|
+
params = dict(sig.parameters)
|
638
|
+
arg_names = params.keys()
|
639
|
+
for i, arg_name in enumerate(arg_names):
|
640
|
+
param = params[arg_name]
|
641
|
+
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
642
|
+
raise GsTaichiSyntaxError(
|
643
|
+
"GsTaichi kernels do not support variable keyword parameters (i.e., **kwargs)"
|
644
|
+
)
|
645
|
+
if param.kind == inspect.Parameter.VAR_POSITIONAL:
|
646
|
+
raise GsTaichiSyntaxError(
|
647
|
+
"GsTaichi kernels do not support variable positional parameters (i.e., *args)"
|
648
|
+
)
|
649
|
+
if param.default is not inspect.Parameter.empty:
|
650
|
+
raise GsTaichiSyntaxError("GsTaichi kernels do not support default values for arguments")
|
651
|
+
if param.kind == inspect.Parameter.KEYWORD_ONLY:
|
652
|
+
raise GsTaichiSyntaxError("GsTaichi kernels do not support keyword parameters")
|
653
|
+
if param.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
|
654
|
+
raise GsTaichiSyntaxError('GsTaichi kernels only support "positional or keyword" parameters')
|
655
|
+
annotation = param.annotation
|
656
|
+
if param.annotation is inspect.Parameter.empty:
|
657
|
+
if i == 0 and self.classkernel: # The |self| parameter
|
658
|
+
annotation = template()
|
659
|
+
else:
|
660
|
+
raise GsTaichiSyntaxError("GsTaichi kernels parameters must be type annotated")
|
661
|
+
else:
|
662
|
+
if isinstance(
|
663
|
+
annotation,
|
664
|
+
(
|
665
|
+
template,
|
666
|
+
ndarray_type.NdarrayType,
|
667
|
+
texture_type.TextureType,
|
668
|
+
texture_type.RWTextureType,
|
669
|
+
),
|
670
|
+
):
|
671
|
+
pass
|
672
|
+
elif annotation is ndarray_type.NdarrayType:
|
673
|
+
# convert from ti.types.NDArray into ti.types.NDArray()
|
674
|
+
annotation = annotation()
|
675
|
+
elif id(annotation) in primitive_types.type_ids:
|
676
|
+
pass
|
677
|
+
elif isinstance(annotation, sparse_matrix_builder):
|
678
|
+
pass
|
679
|
+
elif isinstance(annotation, MatrixType):
|
680
|
+
pass
|
681
|
+
elif isinstance(annotation, StructType):
|
682
|
+
pass
|
683
|
+
elif annotation == template:
|
684
|
+
pass
|
685
|
+
elif isinstance(annotation, type) and dataclasses.is_dataclass(annotation):
|
686
|
+
pass
|
687
|
+
else:
|
688
|
+
raise GsTaichiSyntaxError(f"Invalid type annotation (argument {i}) of Taichi kernel: {annotation}")
|
689
|
+
self.arg_metas.append(ArgMetadata(annotation, param.name, param.default))
|
690
|
+
|
691
|
+
def materialize(self, key: CompiledKernelKeyType | None, args: tuple[Any, ...], arg_features=None):
|
692
|
+
if key is None:
|
693
|
+
key = (self.func, 0, self.autodiff_mode)
|
694
|
+
self.runtime.materialize()
|
695
|
+
self.fast_checksum = None
|
696
|
+
|
697
|
+
if key in self.materialized_kernels:
|
698
|
+
return
|
699
|
+
|
700
|
+
if self.runtime.src_ll_cache and self.gstaichi_callable and self.gstaichi_callable.is_pure:
|
701
|
+
kernel_source_info, _src = get_source_info_and_src(self.func)
|
702
|
+
self.fast_checksum = src_hasher.create_cache_key(kernel_source_info, args)
|
703
|
+
if self.fast_checksum:
|
704
|
+
self.src_ll_cache_observations.cache_key_generated = True
|
705
|
+
if self.fast_checksum and src_hasher.validate_cache_key(self.fast_checksum):
|
706
|
+
self.src_ll_cache_observations.cache_validated = True
|
707
|
+
prog = impl.get_runtime().prog
|
708
|
+
self.compiled_kernel_data_by_key[key] = prog.load_fast_cache(
|
709
|
+
self.fast_checksum,
|
710
|
+
self.func.__name__,
|
711
|
+
prog.config(),
|
712
|
+
prog.get_device_caps(),
|
713
|
+
)
|
714
|
+
if self.compiled_kernel_data_by_key[key]:
|
715
|
+
self.src_ll_cache_observations.cache_loaded = True
|
716
|
+
|
717
|
+
kernel_name = f"{self.func.__name__}_c{self.kernel_counter}_{key[1]}"
|
718
|
+
_logging.trace(f"Materializing kernel {kernel_name} in {self.autodiff_mode}...")
|
719
|
+
|
720
|
+
tree, ctx = _get_tree_and_ctx(
|
721
|
+
self,
|
722
|
+
args=args,
|
723
|
+
excluded_parameters=self.template_slot_locations,
|
724
|
+
arg_features=arg_features,
|
725
|
+
current_kernel=self,
|
726
|
+
)
|
727
|
+
|
728
|
+
if self.autodiff_mode != AutodiffMode.NONE:
|
729
|
+
KernelSimplicityASTChecker(self.func).visit(tree)
|
730
|
+
|
731
|
+
# Do not change the name of 'gstaichi_ast_generator'
|
732
|
+
# The warning system needs this identifier to remove unnecessary messages
|
733
|
+
def gstaichi_ast_generator(kernel_cxx: KernelCxx):
|
734
|
+
nonlocal tree
|
735
|
+
if self.runtime.inside_kernel:
|
736
|
+
raise GsTaichiSyntaxError(
|
737
|
+
"Kernels cannot call other kernels. I.e., nested kernels are not allowed. "
|
738
|
+
"Please check if you have direct/indirect invocation of kernels within kernels. "
|
739
|
+
"Note that some methods provided by the GsTaichi standard library may invoke kernels, "
|
740
|
+
"and please move their invocations to Python-scope."
|
741
|
+
)
|
742
|
+
self.kernel_cpp = kernel_cxx
|
743
|
+
self.runtime.inside_kernel = True
|
744
|
+
self.runtime._current_kernel = self
|
745
|
+
assert self.runtime._compiling_callable is None
|
746
|
+
self.runtime._compiling_callable = kernel_cxx
|
747
|
+
try:
|
748
|
+
ctx.ast_builder = kernel_cxx.ast_builder()
|
749
|
+
|
750
|
+
def ast_to_dict(node: ast.AST | list | primitive_types._python_primitive_types):
|
751
|
+
if isinstance(node, ast.AST):
|
752
|
+
fields = {k: ast_to_dict(v) for k, v in ast.iter_fields(node)}
|
753
|
+
return {
|
754
|
+
"type": node.__class__.__name__,
|
755
|
+
"fields": fields,
|
756
|
+
"lineno": getattr(node, "lineno", None),
|
757
|
+
"col_offset": getattr(node, "col_offset", None),
|
758
|
+
}
|
759
|
+
if isinstance(node, list):
|
760
|
+
return [ast_to_dict(x) for x in node]
|
761
|
+
return node # Basic types (str, int, None, etc.)
|
762
|
+
|
763
|
+
if os.environ.get("TI_DUMP_AST", "") == "1":
|
764
|
+
target_dir = pathlib.Path("/tmp/ast")
|
765
|
+
target_dir.mkdir(parents=True, exist_ok=True)
|
766
|
+
|
767
|
+
start = time.time()
|
768
|
+
ast_str = ast.dump(tree, indent=2)
|
769
|
+
output_file = target_dir / f"{kernel_name}_ast.txt"
|
770
|
+
output_file.write_text(ast_str)
|
771
|
+
elapsed_txt = time.time() - start
|
772
|
+
|
773
|
+
start = time.time()
|
774
|
+
json_str = json.dumps(ast_to_dict(tree), indent=2)
|
775
|
+
output_file = target_dir / f"{kernel_name}_ast.json"
|
776
|
+
output_file.write_text(json_str)
|
777
|
+
elapsed_json = time.time() - start
|
778
|
+
|
779
|
+
output_file = target_dir / f"{kernel_name}_gen_time.json"
|
780
|
+
output_file.write_text(
|
781
|
+
json.dumps({"elapsed_txt": elapsed_txt, "elapsed_json": elapsed_json}, indent=2)
|
782
|
+
)
|
783
|
+
struct_locals = _kernel_impl_dataclass.extract_struct_locals_from_context(ctx)
|
784
|
+
tree = _kernel_impl_dataclass.unpack_ast_struct_expressions(tree, struct_locals=struct_locals)
|
785
|
+
ctx.only_parse_function_def = self.compiled_kernel_data_by_key.get(key) is not None
|
786
|
+
transform_tree(tree, ctx)
|
787
|
+
if not ctx.is_real_function:
|
788
|
+
if self.return_type and ctx.returned != ReturnStatus.ReturnedValue:
|
789
|
+
raise GsTaichiSyntaxError("Kernel has a return type but does not have a return statement")
|
790
|
+
finally:
|
791
|
+
self.runtime.inside_kernel = False
|
792
|
+
self.runtime._current_kernel = None
|
793
|
+
self.runtime._compiling_callable = None
|
794
|
+
|
795
|
+
gstaichi_kernel = impl.get_runtime().prog.create_kernel(gstaichi_ast_generator, kernel_name, self.autodiff_mode)
|
796
|
+
assert key not in self.materialized_kernels
|
797
|
+
self.materialized_kernels[key] = gstaichi_kernel
|
798
|
+
|
799
|
+
def launch_kernel(self, t_kernel: KernelCxx, compiled_kernel_data: CompiledKernelData | None, *args) -> Any:
|
800
|
+
assert len(args) == len(self.arg_metas), f"{len(self.arg_metas)} arguments needed but {len(args)} provided"
|
801
|
+
|
802
|
+
tmps = []
|
803
|
+
callbacks = []
|
804
|
+
|
805
|
+
actual_argument_slot = 0
|
806
|
+
launch_ctx = t_kernel.make_launch_context()
|
807
|
+
max_arg_num = 512
|
808
|
+
exceed_max_arg_num = False
|
809
|
+
|
810
|
+
def set_arg_ndarray(indices: tuple[int, ...], v: gstaichi.lang._ndarray.Ndarray) -> None:
|
811
|
+
v_primal = v.arr
|
812
|
+
v_grad = v.grad.arr if v.grad else None
|
813
|
+
if v_grad is None:
|
814
|
+
launch_ctx.set_arg_ndarray(indices, v_primal) # type: ignore , solvable probably, just not today
|
815
|
+
else:
|
816
|
+
launch_ctx.set_arg_ndarray_with_grad(indices, v_primal, v_grad) # type: ignore
|
817
|
+
|
818
|
+
def set_arg_texture(indices: tuple[int, ...], v: gstaichi.lang._texture.Texture) -> None:
|
819
|
+
launch_ctx.set_arg_texture(indices, v.tex)
|
820
|
+
|
821
|
+
def set_arg_rw_texture(indices: tuple[int, ...], v: gstaichi.lang._texture.Texture) -> None:
|
822
|
+
launch_ctx.set_arg_rw_texture(indices, v.tex)
|
823
|
+
|
824
|
+
def set_arg_ext_array(indices: tuple[int, ...], v: Any, needed: ndarray_type.NdarrayType) -> None:
|
825
|
+
# v is things like torch Tensor and numpy array
|
826
|
+
# Not adding type for this, since adds additional dependencies
|
827
|
+
#
|
828
|
+
# Element shapes are already specialized in GsTaichi codegen.
|
829
|
+
# The shape information for element dims are no longer needed.
|
830
|
+
# Therefore we strip the element shapes from the shape vector,
|
831
|
+
# so that it only holds "real" array shapes.
|
832
|
+
is_soa = needed.layout == Layout.SOA
|
833
|
+
array_shape = v.shape
|
834
|
+
if functools.reduce(operator.mul, array_shape, 1) > np.iinfo(np.int32).max:
|
835
|
+
warnings.warn("Ndarray index might be out of int32 boundary but int64 indexing is not supported yet.")
|
836
|
+
if needed.dtype is None or id(needed.dtype) in primitive_types.type_ids:
|
837
|
+
element_dim = 0
|
838
|
+
else:
|
839
|
+
element_dim = needed.dtype.ndim
|
840
|
+
array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim]
|
841
|
+
if isinstance(v, np.ndarray):
|
842
|
+
# numpy
|
843
|
+
if v.flags.c_contiguous:
|
844
|
+
launch_ctx.set_arg_external_array_with_shape(indices, int(v.ctypes.data), v.nbytes, array_shape, 0)
|
845
|
+
elif v.flags.f_contiguous:
|
846
|
+
# TODO: A better way that avoids copying is saving strides info.
|
847
|
+
tmp = np.ascontiguousarray(v)
|
848
|
+
# Purpose: DO NOT GC |tmp|!
|
849
|
+
tmps.append(tmp)
|
850
|
+
|
851
|
+
def callback(original, updated):
|
852
|
+
np.copyto(original, np.asfortranarray(updated))
|
853
|
+
|
854
|
+
callbacks.append(functools.partial(callback, v, tmp))
|
855
|
+
launch_ctx.set_arg_external_array_with_shape(
|
856
|
+
indices, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0
|
857
|
+
)
|
858
|
+
else:
|
859
|
+
raise ValueError(
|
860
|
+
"Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) "
|
861
|
+
"before passing it into gstaichi kernel."
|
862
|
+
)
|
863
|
+
elif has_pytorch():
|
864
|
+
import torch # pylint: disable=C0415
|
865
|
+
|
866
|
+
if isinstance(v, torch.Tensor):
|
867
|
+
if not v.is_contiguous():
|
868
|
+
raise ValueError(
|
869
|
+
"Non contiguous tensors are not supported, please call tensor.contiguous() before "
|
870
|
+
"passing it into gstaichi kernel."
|
871
|
+
)
|
872
|
+
gstaichi_arch = self.runtime.prog.config().arch
|
873
|
+
|
874
|
+
def get_call_back(u, v):
|
875
|
+
def call_back():
|
876
|
+
u.copy_(v)
|
877
|
+
|
878
|
+
return call_back
|
879
|
+
|
880
|
+
# FIXME: only allocate when launching grad kernel
|
881
|
+
if v.requires_grad and v.grad is None:
|
882
|
+
v.grad = torch.zeros_like(v)
|
883
|
+
|
884
|
+
if v.requires_grad:
|
885
|
+
if not isinstance(v.grad, torch.Tensor):
|
886
|
+
raise ValueError(
|
887
|
+
f"Expecting torch.Tensor for gradient tensor, but getting {v.grad.__class__.__name__} instead"
|
888
|
+
)
|
889
|
+
if not v.grad.is_contiguous():
|
890
|
+
raise ValueError(
|
891
|
+
"Non contiguous gradient tensors are not supported, please call tensor.grad.contiguous() before passing it into gstaichi kernel."
|
892
|
+
)
|
893
|
+
|
894
|
+
tmp = v
|
895
|
+
if (str(v.device) != "cpu") and not (
|
896
|
+
str(v.device).startswith("cuda") and gstaichi_arch == _ti_core.Arch.cuda
|
897
|
+
):
|
898
|
+
# Getting a torch CUDA tensor on GsTaichi non-cuda arch:
|
899
|
+
# We just replace it with a CPU tensor and by the end of kernel execution we'll use the
|
900
|
+
# callback to copy the values back to the original CUDA tensor.
|
901
|
+
host_v = v.to(device="cpu", copy=True)
|
902
|
+
tmp = host_v
|
903
|
+
callbacks.append(get_call_back(v, host_v))
|
904
|
+
|
905
|
+
launch_ctx.set_arg_external_array_with_shape(
|
906
|
+
indices,
|
907
|
+
int(tmp.data_ptr()),
|
908
|
+
tmp.element_size() * tmp.nelement(),
|
909
|
+
array_shape,
|
910
|
+
int(v.grad.data_ptr()) if v.grad is not None else 0,
|
911
|
+
)
|
912
|
+
else:
|
913
|
+
raise GsTaichiRuntimeTypeError(
|
914
|
+
f"Argument of type {type(v)} cannot be converted into required type {needed}"
|
915
|
+
)
|
916
|
+
else:
|
917
|
+
raise GsTaichiRuntimeTypeError(f"Argument {needed} cannot be converted into required type {v}")
|
918
|
+
|
919
|
+
def set_arg_matrix(indices: tuple[int, ...], v, needed) -> None:
|
920
|
+
def cast_float(x: float | np.floating | np.integer | int) -> float:
|
921
|
+
if not isinstance(x, (int, float, np.integer, np.floating)):
|
922
|
+
raise GsTaichiRuntimeTypeError(
|
923
|
+
f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
|
924
|
+
)
|
925
|
+
return float(x)
|
926
|
+
|
927
|
+
def cast_int(x: int | np.integer) -> int:
|
928
|
+
if not isinstance(x, (int, np.integer)):
|
929
|
+
raise GsTaichiRuntimeTypeError(
|
930
|
+
f"Argument {needed.dtype} cannot be converted into required type {type(x)}"
|
931
|
+
)
|
932
|
+
return int(x)
|
933
|
+
|
934
|
+
cast_func = None
|
935
|
+
if needed.dtype in primitive_types.real_types:
|
936
|
+
cast_func = cast_float
|
937
|
+
elif needed.dtype in primitive_types.integer_types:
|
938
|
+
cast_func = cast_int
|
939
|
+
else:
|
940
|
+
raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.")
|
941
|
+
|
942
|
+
if needed.ndim == 2:
|
943
|
+
v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)]
|
944
|
+
else:
|
945
|
+
v = [cast_func(v[i]) for i in range(needed.n)]
|
946
|
+
v = needed(*v)
|
947
|
+
needed.set_kernel_struct_args(v, launch_ctx, indices)
|
948
|
+
|
949
|
+
def set_arg_sparse_matrix_builder(indices: tuple[int, ...], v) -> None:
|
950
|
+
# Pass only the base pointer of the ti.types.sparse_matrix_builder() argument
|
951
|
+
launch_ctx.set_arg_uint(indices, v._get_ndarray_addr())
|
952
|
+
|
953
|
+
set_later_list = []
|
954
|
+
|
955
|
+
def recursive_set_args(needed_arg_type: Type, provided_arg_type: Type, v: Any, indices: tuple[int, ...]) -> int:
|
956
|
+
"""
|
957
|
+
Returns the number of kernel args set
|
958
|
+
e.g. templates don't set kernel args, so returns 0
|
959
|
+
a single ndarray is 1 kernel arg, so returns 1
|
960
|
+
a struct of 3 ndarrays would set 3 kernel args, so return 3
|
961
|
+
note: len(indices) > 1 only happens with argpack (which we are removing support for)
|
962
|
+
"""
|
963
|
+
nonlocal actual_argument_slot, exceed_max_arg_num, set_later_list
|
964
|
+
if actual_argument_slot >= max_arg_num:
|
965
|
+
exceed_max_arg_num = True
|
966
|
+
return 0
|
967
|
+
actual_argument_slot += 1
|
968
|
+
# Note: do not use sth like "needed == f32". That would be slow.
|
969
|
+
if id(needed_arg_type) in primitive_types.real_type_ids:
|
970
|
+
if not isinstance(v, (float, int, np.floating, np.integer)):
|
971
|
+
raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
|
972
|
+
launch_ctx.set_arg_float(indices, float(v))
|
973
|
+
return 1
|
974
|
+
if id(needed_arg_type) in primitive_types.integer_type_ids:
|
975
|
+
if not isinstance(v, (int, np.integer)):
|
976
|
+
raise GsTaichiRuntimeTypeError.get(indices, needed_arg_type.to_string(), provided_arg_type)
|
977
|
+
if is_signed(cook_dtype(needed_arg_type)):
|
978
|
+
launch_ctx.set_arg_int(indices, int(v))
|
979
|
+
else:
|
980
|
+
launch_ctx.set_arg_uint(indices, int(v))
|
981
|
+
return 1
|
982
|
+
if isinstance(needed_arg_type, sparse_matrix_builder):
|
983
|
+
set_arg_sparse_matrix_builder(indices, v)
|
984
|
+
return 1
|
985
|
+
if dataclasses.is_dataclass(needed_arg_type):
|
986
|
+
assert provided_arg_type == needed_arg_type
|
987
|
+
idx = 0
|
988
|
+
for j, field in enumerate(dataclasses.fields(needed_arg_type)):
|
989
|
+
assert not isinstance(field.type, str)
|
990
|
+
field_value = getattr(v, field.name)
|
991
|
+
idx += recursive_set_args(field.type, field.type, field_value, (indices[0] + idx,))
|
992
|
+
return idx
|
993
|
+
if isinstance(needed_arg_type, ndarray_type.NdarrayType) and isinstance(v, gstaichi.lang._ndarray.Ndarray):
|
994
|
+
set_arg_ndarray(indices, v)
|
995
|
+
return 1
|
996
|
+
if isinstance(needed_arg_type, texture_type.TextureType) and isinstance(v, gstaichi.lang._texture.Texture):
|
997
|
+
set_arg_texture(indices, v)
|
998
|
+
return 1
|
999
|
+
if isinstance(needed_arg_type, texture_type.RWTextureType) and isinstance(
|
1000
|
+
v, gstaichi.lang._texture.Texture
|
1001
|
+
):
|
1002
|
+
set_arg_rw_texture(indices, v)
|
1003
|
+
return 1
|
1004
|
+
if isinstance(needed_arg_type, ndarray_type.NdarrayType):
|
1005
|
+
set_arg_ext_array(indices, v, needed_arg_type)
|
1006
|
+
return 1
|
1007
|
+
if isinstance(needed_arg_type, MatrixType):
|
1008
|
+
set_arg_matrix(indices, v, needed_arg_type)
|
1009
|
+
return 1
|
1010
|
+
if isinstance(needed_arg_type, StructType):
|
1011
|
+
# Unclear how to make the following pass typing checks
|
1012
|
+
# StructType implements __instancecheck__, which should be a classmethod, but
|
1013
|
+
# is currently an instance method
|
1014
|
+
# TODO: look into this more deeply at some point
|
1015
|
+
if not isinstance(v, needed_arg_type): # type: ignore
|
1016
|
+
raise GsTaichiRuntimeTypeError(
|
1017
|
+
f"Argument {provided_arg_type} cannot be converted into required type {needed_arg_type}"
|
1018
|
+
)
|
1019
|
+
needed_arg_type.set_kernel_struct_args(v, launch_ctx, indices)
|
1020
|
+
return 1
|
1021
|
+
if needed_arg_type == template or isinstance(needed_arg_type, template):
|
1022
|
+
return 0
|
1023
|
+
raise ValueError(f"Argument type mismatch. Expecting {needed_arg_type}, got {type(v)}.")
|
1024
|
+
|
1025
|
+
template_num = 0
|
1026
|
+
i_out = 0
|
1027
|
+
for i_in, val in enumerate(args):
|
1028
|
+
needed_ = self.arg_metas[i_in].annotation
|
1029
|
+
if needed_ == template or isinstance(needed_, template):
|
1030
|
+
template_num += 1
|
1031
|
+
i_out += 1
|
1032
|
+
continue
|
1033
|
+
i_out += recursive_set_args(needed_, type(val), val, (i_out - template_num,))
|
1034
|
+
|
1035
|
+
for i, (set_arg_func, params) in enumerate(set_later_list):
|
1036
|
+
set_arg_func((len(args) - template_num + i,), *params)
|
1037
|
+
|
1038
|
+
if exceed_max_arg_num:
|
1039
|
+
raise GsTaichiRuntimeError(
|
1040
|
+
f"The number of elements in kernel arguments is too big! Do not exceed {max_arg_num} on {_ti_core.arch_name(impl.current_cfg().arch)} backend."
|
1041
|
+
)
|
1042
|
+
|
1043
|
+
try:
|
1044
|
+
prog = impl.get_runtime().prog
|
1045
|
+
if not compiled_kernel_data:
|
1046
|
+
compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel)
|
1047
|
+
if self.fast_checksum:
|
1048
|
+
src_hasher.store(self.fast_checksum, self.visited_functions)
|
1049
|
+
prog.store_fast_cache(
|
1050
|
+
self.fast_checksum,
|
1051
|
+
self.kernel_cpp,
|
1052
|
+
prog.config(),
|
1053
|
+
prog.get_device_caps(),
|
1054
|
+
compiled_kernel_data,
|
1055
|
+
)
|
1056
|
+
self.src_ll_cache_observations.cache_stored = True
|
1057
|
+
prog.launch_kernel(compiled_kernel_data, launch_ctx)
|
1058
|
+
except Exception as e:
|
1059
|
+
e = handle_exception_from_cpp(e)
|
1060
|
+
if impl.get_runtime().print_full_traceback:
|
1061
|
+
raise e
|
1062
|
+
raise e from None
|
1063
|
+
|
1064
|
+
ret = None
|
1065
|
+
ret_dt = self.return_type
|
1066
|
+
has_ret = ret_dt is not None
|
1067
|
+
|
1068
|
+
if has_ret or self.has_print:
|
1069
|
+
runtime_ops.sync()
|
1070
|
+
|
1071
|
+
if has_ret:
|
1072
|
+
ret = []
|
1073
|
+
for i, ret_type in enumerate(ret_dt):
|
1074
|
+
ret.append(self.construct_kernel_ret(launch_ctx, ret_type, (i,)))
|
1075
|
+
if len(ret_dt) == 1:
|
1076
|
+
ret = ret[0]
|
1077
|
+
if callbacks:
|
1078
|
+
for c in callbacks:
|
1079
|
+
c()
|
1080
|
+
|
1081
|
+
return ret
|
1082
|
+
|
1083
|
+
def construct_kernel_ret(self, launch_ctx: KernelLaunchContext, ret_type: Any, index: tuple[int, ...] = ()):
|
1084
|
+
if isinstance(ret_type, CompoundType):
|
1085
|
+
return ret_type.from_kernel_struct_ret(launch_ctx, index)
|
1086
|
+
if ret_type in primitive_types.integer_types:
|
1087
|
+
if is_signed(cook_dtype(ret_type)):
|
1088
|
+
return launch_ctx.get_struct_ret_int(index)
|
1089
|
+
return launch_ctx.get_struct_ret_uint(index)
|
1090
|
+
if ret_type in primitive_types.real_types:
|
1091
|
+
return launch_ctx.get_struct_ret_float(index)
|
1092
|
+
raise GsTaichiRuntimeTypeError(f"Invalid return type on index={index}")
|
1093
|
+
|
1094
|
+
def ensure_compiled(self, *args: tuple[Any, ...]) -> tuple[Callable, int, AutodiffMode]:
|
1095
|
+
instance_id, arg_features = self.mapper.lookup(args)
|
1096
|
+
key = (self.func, instance_id, self.autodiff_mode)
|
1097
|
+
self.materialize(key=key, args=args, arg_features=arg_features)
|
1098
|
+
return key
|
1099
|
+
|
1100
|
+
# For small kernels (< 3us), the performance can be pretty sensitive to overhead in __call__
|
1101
|
+
# Thus this part needs to be fast. (i.e. < 3us on a 4 GHz x64 CPU)
|
1102
|
+
@_shell_pop_print
|
1103
|
+
def __call__(self, *args, **kwargs) -> Any:
|
1104
|
+
args = _process_args(self, is_func=False, args=args, kwargs=kwargs)
|
1105
|
+
|
1106
|
+
# Transform the primal kernel to forward mode grad kernel
|
1107
|
+
# then recover to primal when exiting the forward mode manager
|
1108
|
+
if self.runtime.fwd_mode_manager and not self.runtime.grad_replaced:
|
1109
|
+
# TODO: if we would like to compute 2nd-order derivatives by forward-on-reverse in a nested context manager fashion,
|
1110
|
+
# i.e., a `Tape` nested in the `FwdMode`, we can transform the kernels with `mode_original == AutodiffMode.REVERSE` only,
|
1111
|
+
# to avoid duplicate computation for 1st-order derivatives
|
1112
|
+
self.runtime.fwd_mode_manager.insert(self)
|
1113
|
+
|
1114
|
+
# Both the class kernels and the plain-function kernels are unified now.
|
1115
|
+
# In both cases, |self.grad| is another Kernel instance that computes the
|
1116
|
+
# gradient. For class kernels, args[0] is always the kernel owner.
|
1117
|
+
|
1118
|
+
# No need to capture grad kernels because they are already bound with their primal kernels
|
1119
|
+
if (
|
1120
|
+
self.autodiff_mode in (AutodiffMode.NONE, AutodiffMode.VALIDATION)
|
1121
|
+
and self.runtime.target_tape
|
1122
|
+
and not self.runtime.grad_replaced
|
1123
|
+
):
|
1124
|
+
self.runtime.target_tape.insert(self, args)
|
1125
|
+
|
1126
|
+
if self.autodiff_mode != AutodiffMode.NONE and impl.current_cfg().opt_level == 0:
|
1127
|
+
_logging.warn("""opt_level = 1 is enforced to enable gradient computation.""")
|
1128
|
+
impl.current_cfg().opt_level = 1
|
1129
|
+
key = self.ensure_compiled(*args)
|
1130
|
+
kernel_cpp = self.materialized_kernels[key]
|
1131
|
+
compiled_kernel_data = self.compiled_kernel_data_by_key.get(key, None)
|
1132
|
+
return self.launch_kernel(kernel_cpp, compiled_kernel_data, *args)
|
1133
|
+
|
1134
|
+
|
1135
|
+
# For a GsTaichi class definition like below:
|
1136
|
+
#
|
1137
|
+
# @ti.data_oriented
|
1138
|
+
# class X:
|
1139
|
+
# @ti.kernel
|
1140
|
+
# def foo(self):
|
1141
|
+
# ...
|
1142
|
+
#
|
1143
|
+
# When ti.kernel runs, the stackframe's |code_context| of Python 3.8(+) is
|
1144
|
+
# different from that of Python 3.7 and below. In 3.8+, it is 'class X:',
|
1145
|
+
# whereas in <=3.7, it is '@ti.data_oriented'. More interestingly, if the class
|
1146
|
+
# inherits, i.e. class X(object):, then in both versions, |code_context| is
|
1147
|
+
# 'class X(object):'...
|
1148
|
+
_KERNEL_CLASS_STACKFRAME_STMT_RES = [
|
1149
|
+
re.compile(r"@(\w+\.)?data_oriented"),
|
1150
|
+
re.compile(r"class "),
|
1151
|
+
]
|
1152
|
+
|
1153
|
+
|
1154
|
+
def _inside_class(level_of_class_stackframe: int) -> bool:
|
1155
|
+
try:
|
1156
|
+
maybe_class_frame = sys._getframe(level_of_class_stackframe)
|
1157
|
+
statement_list = inspect.getframeinfo(maybe_class_frame)[3]
|
1158
|
+
if statement_list is None:
|
1159
|
+
return False
|
1160
|
+
first_statment = statement_list[0].strip()
|
1161
|
+
for pat in _KERNEL_CLASS_STACKFRAME_STMT_RES:
|
1162
|
+
if pat.match(first_statment):
|
1163
|
+
return True
|
1164
|
+
except:
|
1165
|
+
pass
|
1166
|
+
return False
|
1167
|
+
|
1168
|
+
|
1169
|
+
def _kernel_impl(_func: Callable, level_of_class_stackframe: int, verbose: bool = False) -> GsTaichiCallable:
|
1170
|
+
# Can decorators determine if a function is being defined inside a class?
|
1171
|
+
# https://stackoverflow.com/a/8793684/12003165
|
1172
|
+
is_classkernel = _inside_class(level_of_class_stackframe + 1)
|
1173
|
+
|
1174
|
+
if verbose:
|
1175
|
+
print(f"kernel={_func.__name__} is_classkernel={is_classkernel}")
|
1176
|
+
primal = Kernel(_func, autodiff_mode=AutodiffMode.NONE, _classkernel=is_classkernel)
|
1177
|
+
adjoint = Kernel(_func, autodiff_mode=AutodiffMode.REVERSE, _classkernel=is_classkernel)
|
1178
|
+
# Having |primal| contains |grad| makes the tape work.
|
1179
|
+
primal.grad = adjoint
|
1180
|
+
|
1181
|
+
wrapped: GsTaichiCallable
|
1182
|
+
if is_classkernel:
|
1183
|
+
# For class kernels, their primal/adjoint callables are constructed
|
1184
|
+
# when the kernel is accessed via the instance inside
|
1185
|
+
# _BoundedDifferentiableMethod.
|
1186
|
+
# This is because we need to bind the kernel or |grad| to the instance
|
1187
|
+
# owning the kernel, which is not known until the kernel is accessed.
|
1188
|
+
#
|
1189
|
+
# See also: _BoundedDifferentiableMethod, data_oriented.
|
1190
|
+
@functools.wraps(_func)
|
1191
|
+
def wrapped_classkernel(*args, **kwargs):
|
1192
|
+
# If we reach here (we should never), it means the class is not decorated
|
1193
|
+
# with @ti.data_oriented, otherwise getattr would have intercepted the call.
|
1194
|
+
clsobj = type(args[0])
|
1195
|
+
assert not hasattr(clsobj, "_data_oriented")
|
1196
|
+
raise GsTaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
|
1197
|
+
|
1198
|
+
wrapped = GsTaichiCallable(_func, wrapped_classkernel)
|
1199
|
+
else:
|
1200
|
+
|
1201
|
+
@functools.wraps(_func)
|
1202
|
+
def wrapped_func(*args, **kwargs):
|
1203
|
+
try:
|
1204
|
+
return primal(*args, **kwargs)
|
1205
|
+
except (GsTaichiCompilationError, GsTaichiRuntimeError) as e:
|
1206
|
+
if impl.get_runtime().print_full_traceback:
|
1207
|
+
raise e
|
1208
|
+
raise type(e)("\n" + str(e)) from None
|
1209
|
+
|
1210
|
+
wrapped = GsTaichiCallable(_func, wrapped_func)
|
1211
|
+
wrapped.grad = adjoint
|
1212
|
+
|
1213
|
+
wrapped._is_wrapped_kernel = True
|
1214
|
+
wrapped._is_classkernel = is_classkernel
|
1215
|
+
wrapped._primal = primal
|
1216
|
+
wrapped._adjoint = adjoint
|
1217
|
+
primal.gstaichi_callable = wrapped
|
1218
|
+
return wrapped
|
1219
|
+
|
1220
|
+
|
1221
|
+
def kernel(fn: Callable):
|
1222
|
+
"""Marks a function as a GsTaichi kernel.
|
1223
|
+
|
1224
|
+
A GsTaichi kernel is a function written in Python, and gets JIT compiled by
|
1225
|
+
GsTaichi into native CPU/GPU instructions (e.g. a series of CUDA kernels).
|
1226
|
+
The top-level ``for`` loops are automatically parallelized, and distributed
|
1227
|
+
to either a CPU thread pool or massively parallel GPUs.
|
1228
|
+
|
1229
|
+
Kernel's gradient kernel would be generated automatically by the AutoDiff system.
|
1230
|
+
|
1231
|
+
See also https://docs.taichi-lang.org/docs/syntax#kernel.
|
1232
|
+
|
1233
|
+
Args:
|
1234
|
+
fn (Callable): the Python function to be decorated
|
1235
|
+
|
1236
|
+
Returns:
|
1237
|
+
Callable: The decorated function
|
1238
|
+
|
1239
|
+
Example::
|
1240
|
+
|
1241
|
+
>>> x = ti.field(ti.i32, shape=(4, 8))
|
1242
|
+
>>>
|
1243
|
+
>>> @ti.kernel
|
1244
|
+
>>> def run():
|
1245
|
+
>>> # Assigns all the elements of `x` in parallel.
|
1246
|
+
>>> for i in x:
|
1247
|
+
>>> x[i] = i
|
1248
|
+
"""
|
1249
|
+
return _kernel_impl(fn, level_of_class_stackframe=3)
|
1250
|
+
|
1251
|
+
|
1252
|
+
class _BoundedDifferentiableMethod:
|
1253
|
+
def __init__(self, kernel_owner: Any, wrapped_kernel_func: GsTaichiCallable | BoundGsTaichiCallable):
|
1254
|
+
clsobj = type(kernel_owner)
|
1255
|
+
if not getattr(clsobj, "_data_oriented", False):
|
1256
|
+
raise GsTaichiSyntaxError(f"Please decorate class {clsobj.__name__} with @ti.data_oriented")
|
1257
|
+
self._kernel_owner = kernel_owner
|
1258
|
+
self._primal = wrapped_kernel_func._primal
|
1259
|
+
self._adjoint = wrapped_kernel_func._adjoint
|
1260
|
+
self._is_staticmethod = wrapped_kernel_func._is_staticmethod
|
1261
|
+
self.__name__: str | None = None
|
1262
|
+
|
1263
|
+
def __call__(self, *args, **kwargs):
|
1264
|
+
try:
|
1265
|
+
assert self._primal is not None
|
1266
|
+
if self._is_staticmethod:
|
1267
|
+
return self._primal(*args, **kwargs)
|
1268
|
+
return self._primal(self._kernel_owner, *args, **kwargs)
|
1269
|
+
|
1270
|
+
except (GsTaichiCompilationError, GsTaichiRuntimeError) as e:
|
1271
|
+
if impl.get_runtime().print_full_traceback:
|
1272
|
+
raise e
|
1273
|
+
raise type(e)("\n" + str(e)) from None
|
1274
|
+
|
1275
|
+
def grad(self, *args, **kwargs) -> Kernel:
|
1276
|
+
assert self._adjoint is not None
|
1277
|
+
return self._adjoint(self._kernel_owner, *args, **kwargs)
|
1278
|
+
|
1279
|
+
|
1280
|
+
def data_oriented(cls):
|
1281
|
+
"""Marks a class as GsTaichi compatible.
|
1282
|
+
|
1283
|
+
To allow for modularized code, GsTaichi provides this decorator so that
|
1284
|
+
GsTaichi kernels can be defined inside a class.
|
1285
|
+
|
1286
|
+
See also https://docs.taichi-lang.org/docs/odop
|
1287
|
+
|
1288
|
+
Example::
|
1289
|
+
|
1290
|
+
>>> @ti.data_oriented
|
1291
|
+
>>> class TiArray:
|
1292
|
+
>>> def __init__(self, n):
|
1293
|
+
>>> self.x = ti.field(ti.f32, shape=n)
|
1294
|
+
>>>
|
1295
|
+
>>> @ti.kernel
|
1296
|
+
>>> def inc(self):
|
1297
|
+
>>> for i in self.x:
|
1298
|
+
>>> self.x[i] += 1.0
|
1299
|
+
>>>
|
1300
|
+
>>> a = TiArray(32)
|
1301
|
+
>>> a.inc()
|
1302
|
+
|
1303
|
+
Args:
|
1304
|
+
cls (Class): the class to be decorated
|
1305
|
+
|
1306
|
+
Returns:
|
1307
|
+
The decorated class.
|
1308
|
+
"""
|
1309
|
+
|
1310
|
+
def _getattr(self, item):
|
1311
|
+
method = cls.__dict__.get(item, None)
|
1312
|
+
is_property = method.__class__ == property
|
1313
|
+
is_staticmethod = method.__class__ == staticmethod
|
1314
|
+
if is_property:
|
1315
|
+
x = method.fget
|
1316
|
+
else:
|
1317
|
+
x = super(cls, self).__getattribute__(item)
|
1318
|
+
if hasattr(x, "_is_wrapped_kernel"):
|
1319
|
+
if inspect.ismethod(x):
|
1320
|
+
wrapped = x.__func__
|
1321
|
+
else:
|
1322
|
+
wrapped = x
|
1323
|
+
assert isinstance(wrapped, (BoundGsTaichiCallable, GsTaichiCallable))
|
1324
|
+
wrapped._is_staticmethod = is_staticmethod
|
1325
|
+
if wrapped._is_classkernel:
|
1326
|
+
ret = _BoundedDifferentiableMethod(self, wrapped)
|
1327
|
+
ret.__name__ = wrapped.__name__ # type: ignore
|
1328
|
+
if is_property:
|
1329
|
+
return ret()
|
1330
|
+
return ret
|
1331
|
+
if is_property:
|
1332
|
+
return x(self)
|
1333
|
+
return x
|
1334
|
+
|
1335
|
+
cls.__getattribute__ = _getattr
|
1336
|
+
cls._data_oriented = True
|
1337
|
+
|
1338
|
+
return cls
|
1339
|
+
|
1340
|
+
|
1341
|
+
__all__ = ["data_oriented", "func", "kernel", "pyfunc", "real_func"]
|