gstaichi 2.1.1rc3__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,304 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import ast
|
4
|
+
import dataclasses
|
5
|
+
from typing import Any, Callable
|
6
|
+
|
7
|
+
from gstaichi._lib.core.gstaichi_python import (
|
8
|
+
BoundaryMode,
|
9
|
+
DataTypeCxx,
|
10
|
+
)
|
11
|
+
from gstaichi.lang import (
|
12
|
+
_ndarray,
|
13
|
+
any_array,
|
14
|
+
expr,
|
15
|
+
impl,
|
16
|
+
kernel_arguments,
|
17
|
+
matrix,
|
18
|
+
)
|
19
|
+
from gstaichi.lang import ops as ti_ops
|
20
|
+
from gstaichi.lang._dataclass_util import create_flat_name
|
21
|
+
from gstaichi.lang.ast.ast_transformer_utils import (
|
22
|
+
ASTTransformerContext,
|
23
|
+
)
|
24
|
+
from gstaichi.lang.exception import (
|
25
|
+
GsTaichiSyntaxError,
|
26
|
+
)
|
27
|
+
from gstaichi.lang.matrix import MatrixType
|
28
|
+
from gstaichi.lang.struct import StructType
|
29
|
+
from gstaichi.lang.util import to_gstaichi_type
|
30
|
+
from gstaichi.types import annotations, ndarray_type, primitive_types, texture_type
|
31
|
+
|
32
|
+
|
33
|
+
class FunctionDefTransformer:
|
34
|
+
@staticmethod
|
35
|
+
def _decl_and_create_variable(
|
36
|
+
ctx: ASTTransformerContext,
|
37
|
+
annotation: Any,
|
38
|
+
name: str,
|
39
|
+
this_arg_features: tuple[tuple[Any, ...], ...] | None,
|
40
|
+
prefix_name: str,
|
41
|
+
) -> tuple[bool, Any]:
|
42
|
+
full_name = prefix_name + "_" + name
|
43
|
+
if not isinstance(annotation, primitive_types.RefType):
|
44
|
+
ctx.kernel_args.append(name)
|
45
|
+
if annotation == annotations.template or isinstance(annotation, annotations.template):
|
46
|
+
assert ctx.global_vars is not None
|
47
|
+
return True, ctx.global_vars[name]
|
48
|
+
if isinstance(annotation, annotations.sparse_matrix_builder):
|
49
|
+
return False, (
|
50
|
+
kernel_arguments.decl_sparse_matrix,
|
51
|
+
(
|
52
|
+
to_gstaichi_type(this_arg_features),
|
53
|
+
full_name,
|
54
|
+
),
|
55
|
+
)
|
56
|
+
if isinstance(annotation, ndarray_type.NdarrayType):
|
57
|
+
assert this_arg_features is not None
|
58
|
+
raw_element_type: DataTypeCxx
|
59
|
+
ndim: int
|
60
|
+
needs_grad: bool
|
61
|
+
boundary: BoundaryMode
|
62
|
+
raw_element_type, ndim, needs_grad, boundary = this_arg_features
|
63
|
+
return False, (
|
64
|
+
kernel_arguments.decl_ndarray_arg,
|
65
|
+
(
|
66
|
+
to_gstaichi_type(raw_element_type),
|
67
|
+
ndim,
|
68
|
+
full_name,
|
69
|
+
needs_grad,
|
70
|
+
boundary,
|
71
|
+
),
|
72
|
+
)
|
73
|
+
if isinstance(annotation, texture_type.TextureType):
|
74
|
+
assert this_arg_features is not None
|
75
|
+
return False, (kernel_arguments.decl_texture_arg, (this_arg_features[0], full_name))
|
76
|
+
if isinstance(annotation, texture_type.RWTextureType):
|
77
|
+
assert this_arg_features is not None
|
78
|
+
return False, (
|
79
|
+
kernel_arguments.decl_rw_texture_arg,
|
80
|
+
(this_arg_features[0], this_arg_features[1], this_arg_features[2], full_name),
|
81
|
+
)
|
82
|
+
if isinstance(annotation, MatrixType):
|
83
|
+
return True, kernel_arguments.decl_matrix_arg(annotation, name)
|
84
|
+
if isinstance(annotation, StructType):
|
85
|
+
return True, kernel_arguments.decl_struct_arg(annotation, name)
|
86
|
+
return True, kernel_arguments.decl_scalar_arg(annotation, name)
|
87
|
+
|
88
|
+
@staticmethod
|
89
|
+
def _transform_kernel_arg(
|
90
|
+
ctx: ASTTransformerContext,
|
91
|
+
argument_name: str,
|
92
|
+
argument_type: Any,
|
93
|
+
this_arg_features: tuple[Any, ...],
|
94
|
+
) -> None:
|
95
|
+
if dataclasses.is_dataclass(argument_type):
|
96
|
+
ctx.create_variable(argument_name, argument_type)
|
97
|
+
for field_idx, field in enumerate(dataclasses.fields(argument_type)):
|
98
|
+
flat_name = create_flat_name(argument_name, field.name)
|
99
|
+
# if a field is a dataclass, then feed back into process_kernel_arg recursively
|
100
|
+
if dataclasses.is_dataclass(field.type):
|
101
|
+
FunctionDefTransformer._transform_kernel_arg(
|
102
|
+
ctx,
|
103
|
+
flat_name,
|
104
|
+
field.type,
|
105
|
+
this_arg_features[field_idx],
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
result, obj = FunctionDefTransformer._decl_and_create_variable(
|
109
|
+
ctx,
|
110
|
+
field.type,
|
111
|
+
flat_name,
|
112
|
+
this_arg_features[field_idx],
|
113
|
+
"",
|
114
|
+
)
|
115
|
+
if result:
|
116
|
+
ctx.create_variable(flat_name, obj)
|
117
|
+
else:
|
118
|
+
decl_type_func, type_args = obj
|
119
|
+
obj = decl_type_func(*type_args)
|
120
|
+
ctx.create_variable(flat_name, obj)
|
121
|
+
else:
|
122
|
+
result, obj = FunctionDefTransformer._decl_and_create_variable(
|
123
|
+
ctx,
|
124
|
+
argument_type,
|
125
|
+
argument_name,
|
126
|
+
this_arg_features if ctx.arg_features is not None else None,
|
127
|
+
"",
|
128
|
+
)
|
129
|
+
if not result:
|
130
|
+
decl_type_func, type_args = obj
|
131
|
+
obj = decl_type_func(*type_args)
|
132
|
+
ctx.create_variable(argument_name, obj)
|
133
|
+
|
134
|
+
@staticmethod
|
135
|
+
def _transform_as_kernel(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
|
136
|
+
assert ctx.func is not None
|
137
|
+
assert ctx.arg_features is not None
|
138
|
+
if node.returns is not None:
|
139
|
+
if not isinstance(node.returns, ast.Constant):
|
140
|
+
assert ctx.func.return_type is not None
|
141
|
+
for return_type in ctx.func.return_type:
|
142
|
+
kernel_arguments.decl_ret(return_type)
|
143
|
+
compiling_callable = impl.get_runtime().compiling_callable
|
144
|
+
assert compiling_callable is not None
|
145
|
+
compiling_callable.finalize_rets()
|
146
|
+
|
147
|
+
for i in range(len(args.args)):
|
148
|
+
arg_meta = ctx.func.arg_metas[i]
|
149
|
+
FunctionDefTransformer._transform_kernel_arg(
|
150
|
+
ctx,
|
151
|
+
arg_meta.name,
|
152
|
+
arg_meta.annotation,
|
153
|
+
ctx.arg_features[i] if ctx.arg_features is not None else (),
|
154
|
+
)
|
155
|
+
|
156
|
+
compiling_callable.finalize_params()
|
157
|
+
# remove original args
|
158
|
+
node.args.args = []
|
159
|
+
|
160
|
+
@staticmethod
|
161
|
+
def _transform_func_arg(
|
162
|
+
ctx: ASTTransformerContext,
|
163
|
+
argument_name: str,
|
164
|
+
argument_type: Any,
|
165
|
+
data: Any,
|
166
|
+
) -> None:
|
167
|
+
# Template arguments are passed by reference.
|
168
|
+
if isinstance(argument_type, annotations.template):
|
169
|
+
ctx.create_variable(argument_name, data)
|
170
|
+
return None
|
171
|
+
|
172
|
+
if dataclasses.is_dataclass(argument_type):
|
173
|
+
for field in dataclasses.fields(argument_type):
|
174
|
+
flat_name = create_flat_name(argument_name, field.name)
|
175
|
+
data_child = getattr(data, field.name)
|
176
|
+
if isinstance(
|
177
|
+
data_child,
|
178
|
+
(
|
179
|
+
_ndarray.ScalarNdarray,
|
180
|
+
matrix.VectorNdarray,
|
181
|
+
matrix.MatrixNdarray,
|
182
|
+
any_array.AnyArray,
|
183
|
+
),
|
184
|
+
):
|
185
|
+
field.type.check_matched(data_child.get_type(), field.name)
|
186
|
+
ctx.create_variable(flat_name, data_child)
|
187
|
+
elif dataclasses.is_dataclass(data_child):
|
188
|
+
FunctionDefTransformer._transform_func_arg(
|
189
|
+
ctx,
|
190
|
+
flat_name,
|
191
|
+
field.type,
|
192
|
+
getattr(data, field.name),
|
193
|
+
)
|
194
|
+
else:
|
195
|
+
raise GsTaichiSyntaxError(
|
196
|
+
f"Argument {field.name} of type {argument_type} {field.type} is not recognized."
|
197
|
+
)
|
198
|
+
return None
|
199
|
+
|
200
|
+
# Ndarray arguments are passed by reference.
|
201
|
+
if isinstance(argument_type, (ndarray_type.NdarrayType)):
|
202
|
+
if not isinstance(
|
203
|
+
data, (_ndarray.ScalarNdarray, matrix.VectorNdarray, matrix.MatrixNdarray, any_array.AnyArray)
|
204
|
+
):
|
205
|
+
raise GsTaichiSyntaxError(f"Argument {argument_name} of type {argument_type} is not recognized.")
|
206
|
+
argument_type.check_matched(data.get_type(), argument_name)
|
207
|
+
ctx.create_variable(argument_name, data)
|
208
|
+
return None
|
209
|
+
|
210
|
+
# Matrix arguments are passed by value.
|
211
|
+
if isinstance(argument_type, (MatrixType)):
|
212
|
+
# "data" is expected to be an Expr here,
|
213
|
+
# so we simply call "impl.expr_init_func(data)" to perform:
|
214
|
+
#
|
215
|
+
# TensorType* t = alloca()
|
216
|
+
# assign(t, data)
|
217
|
+
#
|
218
|
+
# We created local variable "t" - a copy of the passed-in argument "data"
|
219
|
+
if not isinstance(data, expr.Expr) or not data.ptr.is_tensor():
|
220
|
+
raise GsTaichiSyntaxError(
|
221
|
+
f"Argument {argument_name} of type {argument_type} is expected to be a Matrix, but got {type(data)}."
|
222
|
+
)
|
223
|
+
|
224
|
+
element_shape = data.ptr.get_rvalue_type().shape()
|
225
|
+
if len(element_shape) != argument_type.ndim:
|
226
|
+
raise GsTaichiSyntaxError(
|
227
|
+
f"Argument {argument_name} of type {argument_type} is expected to be a Matrix with ndim {argument_type.ndim}, but got {len(element_shape)}."
|
228
|
+
)
|
229
|
+
|
230
|
+
assert argument_type.ndim > 0
|
231
|
+
if element_shape[0] != argument_type.n:
|
232
|
+
raise GsTaichiSyntaxError(
|
233
|
+
f"Argument {argument_name} of type {argument_type} is expected to be a Matrix with n {argument_type.n}, but got {element_shape[0]}."
|
234
|
+
)
|
235
|
+
|
236
|
+
if argument_type.ndim == 2 and element_shape[1] != argument_type.m:
|
237
|
+
raise GsTaichiSyntaxError(
|
238
|
+
f"Argument {argument_name} of type {argument_type} is expected to be a Matrix with m {argument_type.m}, but got {element_shape[0]}."
|
239
|
+
)
|
240
|
+
|
241
|
+
ctx.create_variable(argument_name, impl.expr_init_func(data))
|
242
|
+
return None
|
243
|
+
|
244
|
+
if id(argument_type) in primitive_types.type_ids:
|
245
|
+
ctx.create_variable(argument_name, impl.expr_init_func(ti_ops.cast(data, argument_type)))
|
246
|
+
return None
|
247
|
+
# Create a copy for non-template arguments,
|
248
|
+
# so that they are passed by value.
|
249
|
+
var_name = argument_name
|
250
|
+
ctx.create_variable(var_name, impl.expr_init_func(data))
|
251
|
+
return None
|
252
|
+
|
253
|
+
@staticmethod
|
254
|
+
def _transform_as_func(ctx: ASTTransformerContext, node: ast.FunctionDef, args: ast.arguments) -> None:
|
255
|
+
# pylint: disable=import-outside-toplevel
|
256
|
+
from gstaichi.lang.kernel_impl import Func
|
257
|
+
|
258
|
+
assert isinstance(ctx.func, Func)
|
259
|
+
assert ctx.argument_data is not None
|
260
|
+
for data_i, data in enumerate(ctx.argument_data):
|
261
|
+
argument = ctx.func.arg_metas[data_i]
|
262
|
+
FunctionDefTransformer._transform_func_arg(ctx, argument.name, argument.annotation, data)
|
263
|
+
|
264
|
+
# deal with dataclasses
|
265
|
+
for v in ctx.func.orig_arguments:
|
266
|
+
if dataclasses.is_dataclass(v.annotation):
|
267
|
+
ctx.create_variable(v.name, v.annotation)
|
268
|
+
|
269
|
+
@staticmethod
|
270
|
+
def build_FunctionDef(
|
271
|
+
ctx: ASTTransformerContext,
|
272
|
+
node: ast.FunctionDef,
|
273
|
+
build_stmts: Callable[[ASTTransformerContext, list[ast.stmt]], None],
|
274
|
+
) -> None:
|
275
|
+
if ctx.visited_funcdef:
|
276
|
+
raise GsTaichiSyntaxError(
|
277
|
+
f"Function definition is not allowed in 'ti.{'kernel' if ctx.is_kernel else 'func'}'."
|
278
|
+
)
|
279
|
+
ctx.visited_funcdef = True
|
280
|
+
|
281
|
+
args = node.args
|
282
|
+
assert args.vararg is None
|
283
|
+
assert args.kwonlyargs == []
|
284
|
+
assert args.kw_defaults == []
|
285
|
+
assert args.kwarg is None
|
286
|
+
|
287
|
+
if ctx.is_kernel: # ti.kernel
|
288
|
+
FunctionDefTransformer._transform_as_kernel(ctx, node, args)
|
289
|
+
|
290
|
+
if ctx.only_parse_function_def:
|
291
|
+
return None
|
292
|
+
|
293
|
+
if not ctx.is_kernel: # ti.func
|
294
|
+
assert ctx.argument_data is not None
|
295
|
+
assert ctx.func is not None
|
296
|
+
if ctx.is_real_function:
|
297
|
+
FunctionDefTransformer._transform_as_kernel(ctx, node, args)
|
298
|
+
else:
|
299
|
+
FunctionDefTransformer._transform_as_func(ctx, node, args)
|
300
|
+
|
301
|
+
with ctx.variable_scope_guard():
|
302
|
+
build_stmts(ctx, node.body)
|
303
|
+
|
304
|
+
return None
|
@@ -0,0 +1,106 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import ast
|
4
|
+
|
5
|
+
from gstaichi.lang._wrap_inspect import getsourcefile, getsourcelines
|
6
|
+
from gstaichi.lang.exception import GsTaichiSyntaxError
|
7
|
+
|
8
|
+
|
9
|
+
class KernelSimplicityASTChecker(ast.NodeVisitor):
|
10
|
+
class ScopeGuard:
|
11
|
+
def __init__(self, checker):
|
12
|
+
self.c = checker
|
13
|
+
self._allows_for_loop = True
|
14
|
+
self._allows_more_stmt = True
|
15
|
+
|
16
|
+
@property
|
17
|
+
def allows_for_loop(self):
|
18
|
+
return self._allows_for_loop
|
19
|
+
|
20
|
+
@property
|
21
|
+
def allows_more_stmt(self):
|
22
|
+
return self._allows_more_stmt
|
23
|
+
|
24
|
+
def mark_no_more_for_loop(self):
|
25
|
+
self._allows_for_loop = False
|
26
|
+
|
27
|
+
def mark_no_more_stmt(self):
|
28
|
+
self._allows_for_loop = False
|
29
|
+
self._allows_more_stmt = False
|
30
|
+
|
31
|
+
def __enter__(self):
|
32
|
+
self.c._scope_guards.append(self)
|
33
|
+
|
34
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
35
|
+
self.c._scope_guards.pop()
|
36
|
+
|
37
|
+
def __init__(self, func):
|
38
|
+
super().__init__()
|
39
|
+
self._func_file = getsourcefile(func)
|
40
|
+
self._func_lineno = getsourcelines(func)[1]
|
41
|
+
self._func_name = func.__name__
|
42
|
+
self._scope_guards = []
|
43
|
+
|
44
|
+
def new_scope(self):
|
45
|
+
return KernelSimplicityASTChecker.ScopeGuard(self)
|
46
|
+
|
47
|
+
@property
|
48
|
+
def current_scope(self):
|
49
|
+
return self._scope_guards[-1]
|
50
|
+
|
51
|
+
@property
|
52
|
+
def top_level(self):
|
53
|
+
return len(self._scope_guards) == 0
|
54
|
+
|
55
|
+
def get_error_location(self, node):
|
56
|
+
# -1 because ast's lineno is 1-based.
|
57
|
+
lineno = self._func_lineno + node.lineno - 1
|
58
|
+
return f"file={self._func_file} kernel={self._func_name} line={lineno}"
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def should_check(node):
|
62
|
+
if not isinstance(node, ast.stmt):
|
63
|
+
return False
|
64
|
+
# TODO(#536): Frontend pass should help make sure |func| is a valid AST for
|
65
|
+
# GsTaichi.
|
66
|
+
ignored = [ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef]
|
67
|
+
return not any(map(lambda t: isinstance(node, t), ignored))
|
68
|
+
|
69
|
+
def generic_visit(self, node):
|
70
|
+
if not self.should_check(node):
|
71
|
+
super().generic_visit(node)
|
72
|
+
return
|
73
|
+
|
74
|
+
if not (self.top_level or self.current_scope.allows_more_stmt):
|
75
|
+
raise GsTaichiSyntaxError(f"No more statements allowed, at {self.get_error_location(node)}")
|
76
|
+
old_top_level = self.top_level
|
77
|
+
if old_top_level:
|
78
|
+
self._scope_guards.append(self.new_scope())
|
79
|
+
# Marking here before the visit has the effect of disallow for-loops in
|
80
|
+
# nested blocks. E.g. if |node| is a IfStmt, then the checker would disallow
|
81
|
+
# for-loops inside it.
|
82
|
+
self.current_scope.mark_no_more_for_loop()
|
83
|
+
super().generic_visit(node)
|
84
|
+
if old_top_level:
|
85
|
+
self._scope_guards.pop()
|
86
|
+
|
87
|
+
@staticmethod
|
88
|
+
def visit_for(node):
|
89
|
+
# TODO: since autodiff is enhanced, AST checker rules should be relaxed. This part should be updated.
|
90
|
+
# original code is #def visit_For(self, node) without #@staticmethod before fix pylint R0201
|
91
|
+
return
|
92
|
+
# is_static = (isinstance(node.iter, ast.Call)
|
93
|
+
# and isinstance(node.iter.func, ast.Attribute)
|
94
|
+
# and isinstance(node.iter.func.value, ast.Name)
|
95
|
+
# and node.iter.func.value.id == 'ti'
|
96
|
+
# and node.iter.func.attr == 'static')
|
97
|
+
# if not (self.top_level or self.current_scope.allows_for_loop
|
98
|
+
# or is_static):
|
99
|
+
# raise GsTaichiSyntaxError(
|
100
|
+
# f'No more for loops allowed, at {self.get_error_location(node)}'
|
101
|
+
# )
|
102
|
+
# with self.new_scope():
|
103
|
+
# super().generic_visit(node)
|
104
|
+
#
|
105
|
+
# if not (self.top_level or is_static):
|
106
|
+
# self.current_scope.mark_no_more_stmt()
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
"""Provides helpers to resolve AST nodes."""
|
4
|
+
|
5
|
+
import ast
|
6
|
+
|
7
|
+
|
8
|
+
class ASTResolver:
|
9
|
+
"""Provides helper methods to resolve AST nodes."""
|
10
|
+
|
11
|
+
@staticmethod
|
12
|
+
def resolve_to(node, wanted, scope):
|
13
|
+
"""Check if symbol ``node`` resolves to ``wanted`` object.
|
14
|
+
|
15
|
+
This is only intended to check if a given AST node resolves to a symbol
|
16
|
+
under some namespaces, e.g. the ``a.b.c.foo`` pattern, but not meant for
|
17
|
+
more complicated expressions like ``(a + b).foo``.
|
18
|
+
|
19
|
+
Args:
|
20
|
+
node (Union[ast.Attribute, ast.Name]): an AST node to be resolved.
|
21
|
+
wanted (Any): The expected python object.
|
22
|
+
scope (Dict[str, Any]): Maps from symbol names to objects, for
|
23
|
+
example, globals()
|
24
|
+
|
25
|
+
Returns:
|
26
|
+
bool: The checked result.
|
27
|
+
"""
|
28
|
+
if isinstance(node, ast.Name):
|
29
|
+
return scope.get(node.id) is wanted
|
30
|
+
|
31
|
+
if not isinstance(node, ast.Attribute):
|
32
|
+
return False
|
33
|
+
|
34
|
+
v = node.value
|
35
|
+
chain = [node.attr]
|
36
|
+
while isinstance(v, ast.Attribute):
|
37
|
+
chain.append(v.attr)
|
38
|
+
v = v.value
|
39
|
+
if not isinstance(v, ast.Name):
|
40
|
+
# Example cases that fall under this branch:
|
41
|
+
#
|
42
|
+
# x[i].attr: ast.Subscript
|
43
|
+
# (a + b).attr: ast.BinOp
|
44
|
+
# ...
|
45
|
+
return False
|
46
|
+
chain.append(v.id)
|
47
|
+
|
48
|
+
for attr in reversed(chain):
|
49
|
+
try:
|
50
|
+
if isinstance(scope, dict):
|
51
|
+
scope = scope[attr]
|
52
|
+
else:
|
53
|
+
scope = getattr(scope, attr)
|
54
|
+
except (KeyError, AttributeError):
|
55
|
+
return False
|
56
|
+
# The name ``scope`` here could be a bit confusing
|
57
|
+
return scope is wanted
|
@@ -0,0 +1,9 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi.lang.ast.ast_transformer import ASTTransformer
|
4
|
+
from gstaichi.lang.ast.ast_transformer_utils import ASTTransformerContext
|
5
|
+
|
6
|
+
|
7
|
+
def transform_tree(tree, ctx: ASTTransformerContext):
|
8
|
+
ASTTransformer()(ctx, tree)
|
9
|
+
return ctx.return_data
|