gstaichi 0.0.0__cp312-cp312-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +51 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +5 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-312-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2917 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +122 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +83 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +366 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +7 -0
- gstaichi/lang/ast/ast_transformer.py +1351 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +327 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1259 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1386 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +784 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +10 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +21 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-0.0.0.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-0.0.0.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-0.0.0.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-0.0.0.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-0.0.0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-0.0.0.dist-info/METADATA +97 -0
- gstaichi-0.0.0.dist-info/RECORD +178 -0
- gstaichi-0.0.0.dist-info/WHEEL +5 -0
- gstaichi-0.0.0.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1351 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import ast
|
4
|
+
import collections.abc
|
5
|
+
import dataclasses
|
6
|
+
import itertools
|
7
|
+
import warnings
|
8
|
+
from ast import unparse
|
9
|
+
from typing import Any, Sequence, Type
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
|
13
|
+
from gstaichi._lib import core as _ti_core
|
14
|
+
from gstaichi.lang import expr, impl, matrix, mesh
|
15
|
+
from gstaichi.lang import ops as ti_ops
|
16
|
+
from gstaichi.lang._ndrange import _Ndrange
|
17
|
+
from gstaichi.lang.ast.ast_transformer_utils import (
|
18
|
+
ASTTransformerContext,
|
19
|
+
Builder,
|
20
|
+
LoopStatus,
|
21
|
+
ReturnStatus,
|
22
|
+
get_decorator,
|
23
|
+
)
|
24
|
+
from gstaichi.lang.ast.ast_transformers.call_transformer import CallTransformer
|
25
|
+
from gstaichi.lang.ast.ast_transformers.function_def_transformer import (
|
26
|
+
FunctionDefTransformer,
|
27
|
+
)
|
28
|
+
from gstaichi.lang.exception import (
|
29
|
+
GsTaichiIndexError,
|
30
|
+
GsTaichiRuntimeTypeError,
|
31
|
+
GsTaichiSyntaxError,
|
32
|
+
GsTaichiTypeError,
|
33
|
+
handle_exception_from_cpp,
|
34
|
+
)
|
35
|
+
from gstaichi.lang.expr import Expr, make_expr_group
|
36
|
+
from gstaichi.lang.field import Field
|
37
|
+
from gstaichi.lang.matrix import Matrix, MatrixType
|
38
|
+
from gstaichi.lang.snode import append, deactivate, length
|
39
|
+
from gstaichi.lang.struct import Struct, StructType
|
40
|
+
from gstaichi.types import primitive_types
|
41
|
+
from gstaichi.types.utils import is_integral
|
42
|
+
|
43
|
+
|
44
|
+
def reshape_list(flat_list: list[Any], target_shape: Sequence[int]) -> list[Any]:
|
45
|
+
if len(target_shape) < 2:
|
46
|
+
return flat_list
|
47
|
+
|
48
|
+
curr_list = []
|
49
|
+
dim = target_shape[-1]
|
50
|
+
for i, elem in enumerate(flat_list):
|
51
|
+
if i % dim == 0:
|
52
|
+
curr_list.append([])
|
53
|
+
curr_list[-1].append(elem)
|
54
|
+
|
55
|
+
return reshape_list(curr_list, target_shape[:-1])
|
56
|
+
|
57
|
+
|
58
|
+
def boundary_type_cast_warning(expression: Expr) -> None:
|
59
|
+
expr_dtype = expression.ptr.get_rvalue_type()
|
60
|
+
if not is_integral(expr_dtype) or expr_dtype in [
|
61
|
+
primitive_types.i64,
|
62
|
+
primitive_types.u64,
|
63
|
+
primitive_types.u32,
|
64
|
+
]:
|
65
|
+
warnings.warn(
|
66
|
+
f"Casting range_for boundary values from {expr_dtype} to i32, which may cause numerical issues",
|
67
|
+
Warning,
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
class ASTTransformer(Builder):
|
72
|
+
@staticmethod
|
73
|
+
def build_Name(ctx: ASTTransformerContext, node: ast.Name):
|
74
|
+
node.ptr = ctx.get_var_by_name(node.id)
|
75
|
+
if isinstance(node, (ast.stmt, ast.expr)) and isinstance(node.ptr, Expr):
|
76
|
+
node.ptr.dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
77
|
+
node.ptr.ptr.set_dbg_info(node.ptr.dbg_info)
|
78
|
+
return node.ptr
|
79
|
+
|
80
|
+
@staticmethod
|
81
|
+
def build_AnnAssign(ctx: ASTTransformerContext, node: ast.AnnAssign):
|
82
|
+
build_stmt(ctx, node.value)
|
83
|
+
build_stmt(ctx, node.annotation)
|
84
|
+
|
85
|
+
is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
|
86
|
+
|
87
|
+
node.ptr = ASTTransformer.build_assign_annotated(
|
88
|
+
ctx, node.target, node.value.ptr, is_static_assign, node.annotation.ptr
|
89
|
+
)
|
90
|
+
return node.ptr
|
91
|
+
|
92
|
+
@staticmethod
|
93
|
+
def build_assign_annotated(
|
94
|
+
ctx: ASTTransformerContext, target: ast.Name, value, is_static_assign: bool, annotation: Type
|
95
|
+
):
|
96
|
+
"""Build an annotated assignment like this: target: annotation = value.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
ctx (ast_builder_utils.BuilderContext): The builder context.
|
100
|
+
target (ast.Name): A variable name. `target.id` holds the name as
|
101
|
+
a string.
|
102
|
+
annotation: A type we hope to assign to the target
|
103
|
+
value: A node representing the value.
|
104
|
+
is_static_assign: A boolean value indicating whether this is a static assignment
|
105
|
+
"""
|
106
|
+
is_local = isinstance(target, ast.Name)
|
107
|
+
if is_local and target.id in ctx.kernel_args:
|
108
|
+
raise GsTaichiSyntaxError(
|
109
|
+
f'Kernel argument "{target.id}" is immutable in the kernel. '
|
110
|
+
f"If you want to change its value, please create a new variable."
|
111
|
+
)
|
112
|
+
anno = impl.expr_init(annotation)
|
113
|
+
if is_static_assign:
|
114
|
+
raise GsTaichiSyntaxError("Static assign cannot be used on annotated assignment")
|
115
|
+
if is_local and not ctx.is_var_declared(target.id):
|
116
|
+
var = ti_ops.cast(value, anno)
|
117
|
+
var = impl.expr_init(var)
|
118
|
+
ctx.create_variable(target.id, var)
|
119
|
+
else:
|
120
|
+
var = build_stmt(ctx, target)
|
121
|
+
if var.ptr.get_rvalue_type() != anno:
|
122
|
+
raise GsTaichiSyntaxError("Static assign cannot have type overloading")
|
123
|
+
var._assign(value)
|
124
|
+
return var
|
125
|
+
|
126
|
+
@staticmethod
|
127
|
+
def build_Assign(ctx: ASTTransformerContext, node: ast.Assign) -> None:
|
128
|
+
build_stmt(ctx, node.value)
|
129
|
+
is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
|
130
|
+
|
131
|
+
# Keep all generated assign statements and compose single one at last.
|
132
|
+
# The variable is introduced to support chained assignments.
|
133
|
+
# Ref https://github.com/taichi-dev/gstaichi/issues/2659.
|
134
|
+
values = node.value.ptr if is_static_assign else impl.expr_init(node.value.ptr)
|
135
|
+
|
136
|
+
for node_target in node.targets:
|
137
|
+
ASTTransformer.build_assign_unpack(ctx, node_target, values, is_static_assign)
|
138
|
+
return None
|
139
|
+
|
140
|
+
@staticmethod
|
141
|
+
def build_assign_unpack(ctx: ASTTransformerContext, node_target: list | ast.Tuple, values, is_static_assign: bool):
|
142
|
+
"""Build the unpack assignments like this: (target1, target2) = (value1, value2).
|
143
|
+
The function should be called only if the node target is a tuple.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
ctx (ast_builder_utils.BuilderContext): The builder context.
|
147
|
+
node_target (ast.Tuple): A list or tuple object. `node_target.elts` holds a
|
148
|
+
list of nodes representing the elements.
|
149
|
+
values: A node/list representing the values.
|
150
|
+
is_static_assign: A boolean value indicating whether this is a static assignment
|
151
|
+
"""
|
152
|
+
if not isinstance(node_target, ast.Tuple):
|
153
|
+
return ASTTransformer.build_assign_basic(ctx, node_target, values, is_static_assign)
|
154
|
+
targets = node_target.elts
|
155
|
+
|
156
|
+
if isinstance(values, matrix.Matrix):
|
157
|
+
if not values.m == 1:
|
158
|
+
raise ValueError("Matrices with more than one columns cannot be unpacked")
|
159
|
+
values = values.entries
|
160
|
+
|
161
|
+
# Unpack: a, b, c = ti.Vector([1., 2., 3.])
|
162
|
+
if isinstance(values, impl.Expr) and values.ptr.is_tensor():
|
163
|
+
if len(values.get_shape()) > 1:
|
164
|
+
raise ValueError("Matrices with more than one columns cannot be unpacked")
|
165
|
+
|
166
|
+
values = ctx.ast_builder.expand_exprs([values.ptr])
|
167
|
+
if len(values) == 1:
|
168
|
+
values = values[0]
|
169
|
+
|
170
|
+
if isinstance(values, impl.Expr) and values.ptr.is_struct():
|
171
|
+
values = ctx.ast_builder.expand_exprs([values.ptr])
|
172
|
+
if len(values) == 1:
|
173
|
+
values = values[0]
|
174
|
+
|
175
|
+
if not isinstance(values, collections.abc.Sequence):
|
176
|
+
raise GsTaichiSyntaxError(f"Cannot unpack type: {type(values)}")
|
177
|
+
|
178
|
+
if len(values) != len(targets):
|
179
|
+
raise GsTaichiSyntaxError("The number of targets is not equal to value length")
|
180
|
+
|
181
|
+
for i, target in enumerate(targets):
|
182
|
+
ASTTransformer.build_assign_basic(ctx, target, values[i], is_static_assign)
|
183
|
+
|
184
|
+
return None
|
185
|
+
|
186
|
+
@staticmethod
|
187
|
+
def build_assign_basic(ctx: ASTTransformerContext, target: ast.Name, value, is_static_assign: bool):
|
188
|
+
"""Build basic assignment like this: target = value.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
ctx (ast_builder_utils.BuilderContext): The builder context.
|
192
|
+
target (ast.Name): A variable name. `target.id` holds the name as
|
193
|
+
a string.
|
194
|
+
value: A node representing the value.
|
195
|
+
is_static_assign: A boolean value indicating whether this is a static assignment
|
196
|
+
"""
|
197
|
+
is_local = isinstance(target, ast.Name)
|
198
|
+
if is_local and target.id in ctx.kernel_args:
|
199
|
+
raise GsTaichiSyntaxError(
|
200
|
+
f'Kernel argument "{target.id}" is immutable in the kernel. '
|
201
|
+
f"If you want to change its value, please create a new variable."
|
202
|
+
)
|
203
|
+
if is_static_assign:
|
204
|
+
if not is_local:
|
205
|
+
raise GsTaichiSyntaxError("Static assign cannot be used on elements in arrays")
|
206
|
+
ctx.create_variable(target.id, value)
|
207
|
+
var = value
|
208
|
+
elif is_local and not ctx.is_var_declared(target.id):
|
209
|
+
var = impl.expr_init(value)
|
210
|
+
ctx.create_variable(target.id, var)
|
211
|
+
else:
|
212
|
+
var = build_stmt(ctx, target)
|
213
|
+
try:
|
214
|
+
var._assign(value)
|
215
|
+
except AttributeError:
|
216
|
+
raise GsTaichiSyntaxError(
|
217
|
+
f"Variable '{unparse(target).strip()}' cannot be assigned. Maybe it is not a GsTaichi object?"
|
218
|
+
)
|
219
|
+
return var
|
220
|
+
|
221
|
+
@staticmethod
|
222
|
+
def build_NamedExpr(ctx: ASTTransformerContext, node: ast.NamedExpr):
|
223
|
+
build_stmt(ctx, node.value)
|
224
|
+
is_static_assign = isinstance(node.value, ast.Call) and node.value.func.ptr is impl.static
|
225
|
+
node.ptr = ASTTransformer.build_assign_basic(ctx, node.target, node.value.ptr, is_static_assign)
|
226
|
+
return node.ptr
|
227
|
+
|
228
|
+
@staticmethod
|
229
|
+
def is_tuple(node):
|
230
|
+
if isinstance(node, ast.Tuple):
|
231
|
+
return True
|
232
|
+
if isinstance(node, ast.Index) and isinstance(node.value.ptr, tuple):
|
233
|
+
return True
|
234
|
+
if isinstance(node.ptr, tuple):
|
235
|
+
return True
|
236
|
+
return False
|
237
|
+
|
238
|
+
@staticmethod
|
239
|
+
def build_Subscript(ctx: ASTTransformerContext, node: ast.Subscript):
|
240
|
+
build_stmt(ctx, node.value)
|
241
|
+
build_stmt(ctx, node.slice)
|
242
|
+
if not ASTTransformer.is_tuple(node.slice):
|
243
|
+
node.slice.ptr = [node.slice.ptr]
|
244
|
+
node.ptr = impl.subscript(ctx.ast_builder, node.value.ptr, *node.slice.ptr)
|
245
|
+
return node.ptr
|
246
|
+
|
247
|
+
@staticmethod
|
248
|
+
def build_Slice(ctx: ASTTransformerContext, node: ast.Slice):
|
249
|
+
if node.lower is not None:
|
250
|
+
build_stmt(ctx, node.lower)
|
251
|
+
if node.upper is not None:
|
252
|
+
build_stmt(ctx, node.upper)
|
253
|
+
if node.step is not None:
|
254
|
+
build_stmt(ctx, node.step)
|
255
|
+
|
256
|
+
node.ptr = slice(
|
257
|
+
node.lower.ptr if node.lower else None,
|
258
|
+
node.upper.ptr if node.upper else None,
|
259
|
+
node.step.ptr if node.step else None,
|
260
|
+
)
|
261
|
+
return node.ptr
|
262
|
+
|
263
|
+
@staticmethod
|
264
|
+
def build_ExtSlice(ctx: ASTTransformerContext, node: ast.ExtSlice):
|
265
|
+
build_stmts(ctx, node.dims)
|
266
|
+
node.ptr = tuple(dim.ptr for dim in node.dims)
|
267
|
+
return node.ptr
|
268
|
+
|
269
|
+
@staticmethod
|
270
|
+
def build_Tuple(ctx: ASTTransformerContext, node: ast.Tuple):
|
271
|
+
build_stmts(ctx, node.elts)
|
272
|
+
node.ptr = tuple(elt.ptr for elt in node.elts)
|
273
|
+
return node.ptr
|
274
|
+
|
275
|
+
@staticmethod
|
276
|
+
def build_List(ctx: ASTTransformerContext, node: ast.List):
|
277
|
+
build_stmts(ctx, node.elts)
|
278
|
+
node.ptr = [elt.ptr for elt in node.elts]
|
279
|
+
return node.ptr
|
280
|
+
|
281
|
+
@staticmethod
|
282
|
+
def build_Dict(ctx: ASTTransformerContext, node: ast.Dict):
|
283
|
+
dic = {}
|
284
|
+
for key, value in zip(node.keys, node.values):
|
285
|
+
if key is None:
|
286
|
+
dic.update(build_stmt(ctx, value))
|
287
|
+
else:
|
288
|
+
dic[build_stmt(ctx, key)] = build_stmt(ctx, value)
|
289
|
+
node.ptr = dic
|
290
|
+
return node.ptr
|
291
|
+
|
292
|
+
@staticmethod
|
293
|
+
def process_listcomp(ctx: ASTTransformerContext, node, result) -> None:
|
294
|
+
result.append(build_stmt(ctx, node.elt))
|
295
|
+
|
296
|
+
@staticmethod
|
297
|
+
def process_dictcomp(ctx: ASTTransformerContext, node, result) -> None:
|
298
|
+
key = build_stmt(ctx, node.key)
|
299
|
+
value = build_stmt(ctx, node.value)
|
300
|
+
result[key] = value
|
301
|
+
|
302
|
+
@staticmethod
|
303
|
+
def process_generators(ctx: ASTTransformerContext, node: ast.GeneratorExp, now_comp, func, result):
|
304
|
+
if now_comp >= len(node.generators):
|
305
|
+
return func(ctx, node, result)
|
306
|
+
with ctx.static_scope_guard():
|
307
|
+
_iter = build_stmt(ctx, node.generators[now_comp].iter)
|
308
|
+
|
309
|
+
if isinstance(_iter, impl.Expr) and _iter.ptr.is_tensor():
|
310
|
+
shape = _iter.ptr.get_shape()
|
311
|
+
flattened = [Expr(x) for x in ctx.ast_builder.expand_exprs([_iter.ptr])]
|
312
|
+
_iter = reshape_list(flattened, shape)
|
313
|
+
|
314
|
+
for value in _iter:
|
315
|
+
with ctx.variable_scope_guard():
|
316
|
+
ASTTransformer.build_assign_unpack(ctx, node.generators[now_comp].target, value, True)
|
317
|
+
with ctx.static_scope_guard():
|
318
|
+
build_stmts(ctx, node.generators[now_comp].ifs)
|
319
|
+
ASTTransformer.process_ifs(ctx, node, now_comp, 0, func, result)
|
320
|
+
return None
|
321
|
+
|
322
|
+
@staticmethod
|
323
|
+
def process_ifs(ctx: ASTTransformerContext, node: ast.If, now_comp, now_if, func, result):
|
324
|
+
if now_if >= len(node.generators[now_comp].ifs):
|
325
|
+
return ASTTransformer.process_generators(ctx, node, now_comp + 1, func, result)
|
326
|
+
cond = node.generators[now_comp].ifs[now_if].ptr
|
327
|
+
if cond:
|
328
|
+
ASTTransformer.process_ifs(ctx, node, now_comp, now_if + 1, func, result)
|
329
|
+
|
330
|
+
return None
|
331
|
+
|
332
|
+
@staticmethod
|
333
|
+
def build_ListComp(ctx: ASTTransformerContext, node: ast.ListComp):
|
334
|
+
result = []
|
335
|
+
ASTTransformer.process_generators(ctx, node, 0, ASTTransformer.process_listcomp, result)
|
336
|
+
node.ptr = result
|
337
|
+
return node.ptr
|
338
|
+
|
339
|
+
@staticmethod
|
340
|
+
def build_DictComp(ctx: ASTTransformerContext, node: ast.DictComp):
|
341
|
+
result = {}
|
342
|
+
ASTTransformer.process_generators(ctx, node, 0, ASTTransformer.process_dictcomp, result)
|
343
|
+
node.ptr = result
|
344
|
+
return node.ptr
|
345
|
+
|
346
|
+
@staticmethod
|
347
|
+
def build_Index(ctx: ASTTransformerContext, node: ast.Index):
|
348
|
+
node.ptr = build_stmt(ctx, node.value)
|
349
|
+
return node.ptr
|
350
|
+
|
351
|
+
@staticmethod
|
352
|
+
def build_Constant(ctx: ASTTransformerContext, node: ast.Constant):
|
353
|
+
node.ptr = node.value
|
354
|
+
return node.ptr
|
355
|
+
|
356
|
+
@staticmethod
|
357
|
+
def build_Num(ctx: ASTTransformerContext, node: ast.Num):
|
358
|
+
node.ptr = node.n
|
359
|
+
return node.ptr
|
360
|
+
|
361
|
+
@staticmethod
|
362
|
+
def build_Str(ctx: ASTTransformerContext, node: ast.Str):
|
363
|
+
node.ptr = node.s
|
364
|
+
return node.ptr
|
365
|
+
|
366
|
+
@staticmethod
|
367
|
+
def build_Bytes(ctx: ASTTransformerContext, node: ast.Bytes):
|
368
|
+
node.ptr = node.s
|
369
|
+
return node.ptr
|
370
|
+
|
371
|
+
@staticmethod
|
372
|
+
def build_NameConstant(ctx: ASTTransformerContext, node: ast.NameConstant):
|
373
|
+
node.ptr = node.value
|
374
|
+
return node.ptr
|
375
|
+
|
376
|
+
@staticmethod
|
377
|
+
def build_keyword(ctx: ASTTransformerContext, node: ast.keyword):
|
378
|
+
build_stmt(ctx, node.value)
|
379
|
+
if node.arg is None:
|
380
|
+
node.ptr = node.value.ptr
|
381
|
+
else:
|
382
|
+
node.ptr = {node.arg: node.value.ptr}
|
383
|
+
return node.ptr
|
384
|
+
|
385
|
+
@staticmethod
|
386
|
+
def build_Starred(ctx: ASTTransformerContext, node: ast.Starred):
|
387
|
+
node.ptr = build_stmt(ctx, node.value)
|
388
|
+
return node.ptr
|
389
|
+
|
390
|
+
@staticmethod
|
391
|
+
def build_FormattedValue(ctx: ASTTransformerContext, node: ast.FormattedValue):
|
392
|
+
node.ptr = build_stmt(ctx, node.value)
|
393
|
+
if node.format_spec is None or len(node.format_spec.values) == 0:
|
394
|
+
return node.ptr
|
395
|
+
values = node.format_spec.values
|
396
|
+
assert len(values) == 1
|
397
|
+
format_str = values[0].s
|
398
|
+
assert format_str is not None
|
399
|
+
# distinguished from normal list
|
400
|
+
return ["__ti_fmt_value__", node.ptr, format_str]
|
401
|
+
|
402
|
+
@staticmethod
|
403
|
+
def build_JoinedStr(ctx: ASTTransformerContext, node: ast.JoinedStr):
|
404
|
+
str_spec = ""
|
405
|
+
args = []
|
406
|
+
for sub_node in node.values:
|
407
|
+
if isinstance(sub_node, ast.FormattedValue):
|
408
|
+
str_spec += "{}"
|
409
|
+
args.append(build_stmt(ctx, sub_node))
|
410
|
+
elif isinstance(sub_node, ast.Constant):
|
411
|
+
str_spec += sub_node.value
|
412
|
+
elif isinstance(sub_node, ast.Str):
|
413
|
+
str_spec += sub_node.s
|
414
|
+
else:
|
415
|
+
raise GsTaichiSyntaxError("Invalid value for fstring.")
|
416
|
+
|
417
|
+
args.insert(0, str_spec)
|
418
|
+
node.ptr = impl.ti_format(*args)
|
419
|
+
return node.ptr
|
420
|
+
|
421
|
+
@staticmethod
|
422
|
+
def build_Call(ctx: ASTTransformerContext, node: ast.Call) -> Any | None:
|
423
|
+
return CallTransformer.build_Call(ctx, node, build_stmt, build_stmts)
|
424
|
+
|
425
|
+
@staticmethod
|
426
|
+
def build_FunctionDef(ctx: ASTTransformerContext, node: ast.FunctionDef) -> None:
|
427
|
+
FunctionDefTransformer.build_FunctionDef(ctx, node, build_stmts)
|
428
|
+
|
429
|
+
@staticmethod
|
430
|
+
def build_Return(ctx: ASTTransformerContext, node: ast.Return) -> None:
|
431
|
+
if not ctx.is_real_function:
|
432
|
+
if ctx.is_in_non_static_control_flow():
|
433
|
+
raise GsTaichiSyntaxError("Return inside non-static if/for is not supported")
|
434
|
+
if node.value is not None:
|
435
|
+
build_stmt(ctx, node.value)
|
436
|
+
if node.value is None or node.value.ptr is None:
|
437
|
+
if not ctx.is_real_function:
|
438
|
+
ctx.returned = ReturnStatus.ReturnedVoid
|
439
|
+
return None
|
440
|
+
if ctx.is_kernel or ctx.is_real_function:
|
441
|
+
# TODO: check if it's at the end of a kernel, throw GsTaichiSyntaxError if not
|
442
|
+
if ctx.func.return_type is None:
|
443
|
+
raise GsTaichiSyntaxError(
|
444
|
+
f'A {"kernel" if ctx.is_kernel else "function"} '
|
445
|
+
"with a return value must be annotated "
|
446
|
+
"with a return type, e.g. def func() -> ti.f32"
|
447
|
+
)
|
448
|
+
return_exprs = []
|
449
|
+
if len(ctx.func.return_type) == 1:
|
450
|
+
node.value.ptr = [node.value.ptr]
|
451
|
+
assert len(ctx.func.return_type) == len(node.value.ptr)
|
452
|
+
for return_type, ptr in zip(ctx.func.return_type, node.value.ptr):
|
453
|
+
if id(return_type) in primitive_types.type_ids:
|
454
|
+
if isinstance(ptr, Expr):
|
455
|
+
if ptr.is_tensor() or ptr.is_struct() or ptr.element_type() not in primitive_types.all_types:
|
456
|
+
raise GsTaichiRuntimeTypeError.get_ret(str(return_type), ptr)
|
457
|
+
elif not isinstance(ptr, (float, int, np.floating, np.integer)):
|
458
|
+
raise GsTaichiRuntimeTypeError.get_ret(str(return_type), ptr)
|
459
|
+
return_exprs += [ti_ops.cast(expr.Expr(ptr), return_type).ptr]
|
460
|
+
elif isinstance(return_type, MatrixType):
|
461
|
+
values = ptr
|
462
|
+
if isinstance(values, Matrix):
|
463
|
+
if values.ndim != ctx.func.return_type.ndim:
|
464
|
+
raise GsTaichiRuntimeTypeError(
|
465
|
+
f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={values.ndim}."
|
466
|
+
)
|
467
|
+
elif return_type.get_shape() != values.get_shape():
|
468
|
+
raise GsTaichiRuntimeTypeError(
|
469
|
+
f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={values.get_shape()}."
|
470
|
+
)
|
471
|
+
values = (
|
472
|
+
itertools.chain.from_iterable(values.to_list())
|
473
|
+
if values.ndim == 1
|
474
|
+
else iter(values.to_list())
|
475
|
+
)
|
476
|
+
elif isinstance(values, Expr):
|
477
|
+
if not values.is_tensor():
|
478
|
+
raise GsTaichiRuntimeTypeError.get_ret(return_type.to_string(), ptr)
|
479
|
+
elif (
|
480
|
+
return_type.dtype in primitive_types.real_types
|
481
|
+
and not values.element_type() in primitive_types.all_types
|
482
|
+
):
|
483
|
+
raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), values.element_type())
|
484
|
+
elif (
|
485
|
+
return_type.dtype in primitive_types.integer_types
|
486
|
+
and not values.element_type() in primitive_types.integer_types
|
487
|
+
):
|
488
|
+
raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), values.element_type())
|
489
|
+
elif len(values.get_shape()) != return_type.ndim:
|
490
|
+
raise GsTaichiRuntimeTypeError(
|
491
|
+
f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={len(values.get_shape())}."
|
492
|
+
)
|
493
|
+
elif return_type.get_shape() != values.get_shape():
|
494
|
+
raise GsTaichiRuntimeTypeError(
|
495
|
+
f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={values.get_shape()}."
|
496
|
+
)
|
497
|
+
values = [values]
|
498
|
+
else:
|
499
|
+
np_array = np.array(values)
|
500
|
+
dt, shape, ndim = np_array.dtype, np_array.shape, np_array.ndim
|
501
|
+
if return_type.dtype in primitive_types.real_types and dt not in (
|
502
|
+
float,
|
503
|
+
int,
|
504
|
+
np.floating,
|
505
|
+
np.integer,
|
506
|
+
):
|
507
|
+
raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), dt)
|
508
|
+
elif return_type.dtype in primitive_types.integer_types and dt not in (int, np.integer):
|
509
|
+
raise GsTaichiRuntimeTypeError.get_ret(return_type.dtype.to_string(), dt)
|
510
|
+
elif ndim != return_type.ndim:
|
511
|
+
raise GsTaichiRuntimeTypeError(
|
512
|
+
f"Return matrix ndim mismatch, expecting={return_type.ndim}, got={ndim}."
|
513
|
+
)
|
514
|
+
elif return_type.get_shape() != shape:
|
515
|
+
raise GsTaichiRuntimeTypeError(
|
516
|
+
f"Return matrix shape mismatch, expecting={return_type.get_shape()}, got={shape}."
|
517
|
+
)
|
518
|
+
values = [values]
|
519
|
+
return_exprs += [ti_ops.cast(exp, return_type.dtype) for exp in values]
|
520
|
+
elif isinstance(return_type, StructType):
|
521
|
+
if not isinstance(ptr, Struct) or not isinstance(ptr, return_type):
|
522
|
+
raise GsTaichiRuntimeTypeError.get_ret(str(return_type), ptr)
|
523
|
+
values = ptr
|
524
|
+
assert isinstance(values, Struct)
|
525
|
+
return_exprs += expr._get_flattened_ptrs(values)
|
526
|
+
else:
|
527
|
+
raise GsTaichiSyntaxError("The return type is not supported now!")
|
528
|
+
ctx.ast_builder.create_kernel_exprgroup_return(
|
529
|
+
expr.make_expr_group(return_exprs), _ti_core.DebugInfo(ctx.get_pos_info(node))
|
530
|
+
)
|
531
|
+
else:
|
532
|
+
ctx.return_data = node.value.ptr
|
533
|
+
if ctx.func.return_type is not None:
|
534
|
+
if len(ctx.func.return_type) == 1:
|
535
|
+
ctx.return_data = [ctx.return_data]
|
536
|
+
for i, return_type in enumerate(ctx.func.return_type):
|
537
|
+
if id(return_type) in primitive_types.type_ids:
|
538
|
+
ctx.return_data[i] = ti_ops.cast(ctx.return_data[i], return_type)
|
539
|
+
if len(ctx.func.return_type) == 1:
|
540
|
+
ctx.return_data = ctx.return_data[0]
|
541
|
+
if not ctx.is_real_function:
|
542
|
+
ctx.returned = ReturnStatus.ReturnedValue
|
543
|
+
return None
|
544
|
+
|
545
|
+
@staticmethod
|
546
|
+
def build_Module(ctx: ASTTransformerContext, node: ast.Module) -> None:
|
547
|
+
with ctx.variable_scope_guard():
|
548
|
+
# Do NOT use |build_stmts| which inserts 'del' statements to the
|
549
|
+
# end and deletes parameters passed into the module
|
550
|
+
for stmt in node.body:
|
551
|
+
build_stmt(ctx, stmt)
|
552
|
+
return None
|
553
|
+
|
554
|
+
@staticmethod
|
555
|
+
def build_attribute_if_is_dynamic_snode_method(ctx: ASTTransformerContext, node) -> bool:
|
556
|
+
is_subscript = isinstance(node.value, ast.Subscript)
|
557
|
+
names = ("append", "deactivate", "length")
|
558
|
+
if node.attr not in names:
|
559
|
+
return False
|
560
|
+
if is_subscript:
|
561
|
+
x = node.value.value.ptr
|
562
|
+
indices = node.value.slice.ptr
|
563
|
+
else:
|
564
|
+
x = node.value.ptr
|
565
|
+
indices = []
|
566
|
+
if not isinstance(x, Field):
|
567
|
+
return False
|
568
|
+
if not x.parent().ptr.type == _ti_core.SNodeType.dynamic:
|
569
|
+
return False
|
570
|
+
field_dim = x.snode.ptr.num_active_indices()
|
571
|
+
indices_expr_group = make_expr_group(*indices)
|
572
|
+
index_dim = indices_expr_group.size()
|
573
|
+
if field_dim != index_dim + 1:
|
574
|
+
return False
|
575
|
+
if node.attr == "append":
|
576
|
+
node.ptr = lambda val: append(x.parent(), indices, val)
|
577
|
+
elif node.attr == "deactivate":
|
578
|
+
node.ptr = lambda: deactivate(x.parent(), indices)
|
579
|
+
else:
|
580
|
+
node.ptr = lambda: length(x.parent(), indices)
|
581
|
+
return True
|
582
|
+
|
583
|
+
@staticmethod
|
584
|
+
def build_Attribute(ctx: ASTTransformerContext, node: ast.Attribute):
|
585
|
+
# There are two valid cases for the methods of Dynamic SNode:
|
586
|
+
#
|
587
|
+
# 1. x[i, j].append (where the dimension of the field (3 in this case) is equal to one plus the number of the
|
588
|
+
# indices (2 in this case) )
|
589
|
+
#
|
590
|
+
# 2. x.append (where the dimension of the field is one, equal to x[()].append)
|
591
|
+
#
|
592
|
+
# For the first case, the AST (simplified) is like node = Attribute(value=Subscript(value=x, slice=[i, j]),
|
593
|
+
# attr="append"), when we build_stmt(node.value)(build the expression of the Subscript i.e. x[i, j]),
|
594
|
+
# it should build the expression of node.value.value (i.e. x) and node.value.slice (i.e. [i, j]), and raise a
|
595
|
+
# GsTaichiIndexError because the dimension of the field is not equal to the number of the indices. Therefore,
|
596
|
+
# when we meet the error, we can detect whether it is a method of Dynamic SNode and build the expression if
|
597
|
+
# it is by calling build_attribute_if_is_dynamic_snode_method. If we find that it is not a method of Dynamic
|
598
|
+
# SNode, we raise the error again.
|
599
|
+
#
|
600
|
+
# For the second case, the AST (simplified) is like node = Attribute(value=x, attr="append"), and it does not
|
601
|
+
# raise error when we build_stmt(node.value). Therefore, when we do not meet the error, we can also detect
|
602
|
+
# whether it is a method of Dynamic SNode and build the expression if it is by calling
|
603
|
+
# build_attribute_if_is_dynamic_snode_method. If we find that it is not a method of Dynamic SNode,
|
604
|
+
# we continue to process it as a normal attribute node.
|
605
|
+
try:
|
606
|
+
build_stmt(ctx, node.value)
|
607
|
+
except Exception as e:
|
608
|
+
e = handle_exception_from_cpp(e)
|
609
|
+
if isinstance(e, GsTaichiIndexError):
|
610
|
+
node.value.ptr = None
|
611
|
+
if ASTTransformer.build_attribute_if_is_dynamic_snode_method(ctx, node):
|
612
|
+
return node.ptr
|
613
|
+
raise e
|
614
|
+
|
615
|
+
if ASTTransformer.build_attribute_if_is_dynamic_snode_method(ctx, node):
|
616
|
+
return node.ptr
|
617
|
+
|
618
|
+
if isinstance(node.value.ptr, Expr) and not hasattr(node.value.ptr, node.attr):
|
619
|
+
if node.attr in Matrix._swizzle_to_keygroup:
|
620
|
+
keygroup = Matrix._swizzle_to_keygroup[node.attr]
|
621
|
+
Matrix._keygroup_to_checker[keygroup](node.value.ptr, node.attr)
|
622
|
+
attr_len = len(node.attr)
|
623
|
+
if attr_len == 1:
|
624
|
+
node.ptr = Expr(
|
625
|
+
impl.get_runtime()
|
626
|
+
.compiling_callable.ast_builder()
|
627
|
+
.expr_subscript(
|
628
|
+
node.value.ptr.ptr,
|
629
|
+
make_expr_group(keygroup.index(node.attr)),
|
630
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
631
|
+
)
|
632
|
+
)
|
633
|
+
else:
|
634
|
+
node.ptr = Expr(
|
635
|
+
_ti_core.subscript_with_multiple_indices(
|
636
|
+
node.value.ptr.ptr,
|
637
|
+
[make_expr_group(keygroup.index(ch)) for ch in node.attr],
|
638
|
+
(attr_len,),
|
639
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
640
|
+
)
|
641
|
+
)
|
642
|
+
else:
|
643
|
+
from gstaichi.lang import ( # pylint: disable=C0415
|
644
|
+
matrix_ops as tensor_ops,
|
645
|
+
)
|
646
|
+
|
647
|
+
node.ptr = getattr(tensor_ops, node.attr)
|
648
|
+
setattr(node, "caller", node.value.ptr)
|
649
|
+
elif dataclasses.is_dataclass(node.value.ptr):
|
650
|
+
node.ptr = next(field.type for field in dataclasses.fields(node.value.ptr))
|
651
|
+
else:
|
652
|
+
node.ptr = getattr(node.value.ptr, node.attr)
|
653
|
+
return node.ptr
|
654
|
+
|
655
|
+
@staticmethod
|
656
|
+
def build_BinOp(ctx: ASTTransformerContext, node: ast.BinOp):
|
657
|
+
build_stmt(ctx, node.left)
|
658
|
+
build_stmt(ctx, node.right)
|
659
|
+
# pylint: disable-msg=C0415
|
660
|
+
from gstaichi.lang.matrix_ops import matmul
|
661
|
+
|
662
|
+
op = {
|
663
|
+
ast.Add: lambda l, r: l + r,
|
664
|
+
ast.Sub: lambda l, r: l - r,
|
665
|
+
ast.Mult: lambda l, r: l * r,
|
666
|
+
ast.Div: lambda l, r: l / r,
|
667
|
+
ast.FloorDiv: lambda l, r: l // r,
|
668
|
+
ast.Mod: lambda l, r: l % r,
|
669
|
+
ast.Pow: lambda l, r: l**r,
|
670
|
+
ast.LShift: lambda l, r: l << r,
|
671
|
+
ast.RShift: lambda l, r: l >> r,
|
672
|
+
ast.BitOr: lambda l, r: l | r,
|
673
|
+
ast.BitXor: lambda l, r: l ^ r,
|
674
|
+
ast.BitAnd: lambda l, r: l & r,
|
675
|
+
ast.MatMult: matmul,
|
676
|
+
}.get(type(node.op))
|
677
|
+
try:
|
678
|
+
node.ptr = op(node.left.ptr, node.right.ptr)
|
679
|
+
except TypeError as e:
|
680
|
+
raise GsTaichiTypeError(str(e)) from None
|
681
|
+
return node.ptr
|
682
|
+
|
683
|
+
@staticmethod
|
684
|
+
def build_AugAssign(ctx: ASTTransformerContext, node: ast.AugAssign):
|
685
|
+
build_stmt(ctx, node.target)
|
686
|
+
build_stmt(ctx, node.value)
|
687
|
+
if isinstance(node.target, ast.Name) and node.target.id in ctx.kernel_args:
|
688
|
+
raise GsTaichiSyntaxError(
|
689
|
+
f'Kernel argument "{node.target.id}" is immutable in the kernel. '
|
690
|
+
f"If you want to change its value, please create a new variable."
|
691
|
+
)
|
692
|
+
node.ptr = node.target.ptr._augassign(node.value.ptr, type(node.op).__name__)
|
693
|
+
return node.ptr
|
694
|
+
|
695
|
+
@staticmethod
|
696
|
+
def build_UnaryOp(ctx: ASTTransformerContext, node: ast.UnaryOp):
|
697
|
+
build_stmt(ctx, node.operand)
|
698
|
+
op = {
|
699
|
+
ast.UAdd: lambda l: l,
|
700
|
+
ast.USub: lambda l: -l,
|
701
|
+
ast.Not: ti_ops.logical_not,
|
702
|
+
ast.Invert: lambda l: ~l,
|
703
|
+
}.get(type(node.op))
|
704
|
+
node.ptr = op(node.operand.ptr)
|
705
|
+
return node.ptr
|
706
|
+
|
707
|
+
@staticmethod
|
708
|
+
def build_bool_op(op):
|
709
|
+
def inner(operands):
|
710
|
+
if len(operands) == 1:
|
711
|
+
return operands[0].ptr
|
712
|
+
return op(operands[0].ptr, inner(operands[1:]))
|
713
|
+
|
714
|
+
return inner
|
715
|
+
|
716
|
+
@staticmethod
|
717
|
+
def build_static_and(operands):
|
718
|
+
for operand in operands:
|
719
|
+
if not operand.ptr:
|
720
|
+
return operand.ptr
|
721
|
+
return operands[-1].ptr
|
722
|
+
|
723
|
+
@staticmethod
|
724
|
+
def build_static_or(operands):
|
725
|
+
for operand in operands:
|
726
|
+
if operand.ptr:
|
727
|
+
return operand.ptr
|
728
|
+
return operands[-1].ptr
|
729
|
+
|
730
|
+
@staticmethod
|
731
|
+
def build_BoolOp(ctx: ASTTransformerContext, node: ast.BoolOp):
|
732
|
+
build_stmts(ctx, node.values)
|
733
|
+
if ctx.is_in_static_scope():
|
734
|
+
ops = {
|
735
|
+
ast.And: ASTTransformer.build_static_and,
|
736
|
+
ast.Or: ASTTransformer.build_static_or,
|
737
|
+
}
|
738
|
+
elif impl.get_runtime().short_circuit_operators:
|
739
|
+
ops = {
|
740
|
+
ast.And: ASTTransformer.build_bool_op(ti_ops.logical_and),
|
741
|
+
ast.Or: ASTTransformer.build_bool_op(ti_ops.logical_or),
|
742
|
+
}
|
743
|
+
else:
|
744
|
+
ops = {
|
745
|
+
ast.And: ASTTransformer.build_bool_op(ti_ops.bit_and),
|
746
|
+
ast.Or: ASTTransformer.build_bool_op(ti_ops.bit_or),
|
747
|
+
}
|
748
|
+
op = ops.get(type(node.op))
|
749
|
+
node.ptr = op(node.values)
|
750
|
+
return node.ptr
|
751
|
+
|
752
|
+
@staticmethod
|
753
|
+
def build_Compare(ctx: ASTTransformerContext, node: ast.Compare):
|
754
|
+
build_stmt(ctx, node.left)
|
755
|
+
build_stmts(ctx, node.comparators)
|
756
|
+
ops = {
|
757
|
+
ast.Eq: lambda l, r: l == r,
|
758
|
+
ast.NotEq: lambda l, r: l != r,
|
759
|
+
ast.Lt: lambda l, r: l < r,
|
760
|
+
ast.LtE: lambda l, r: l <= r,
|
761
|
+
ast.Gt: lambda l, r: l > r,
|
762
|
+
ast.GtE: lambda l, r: l >= r,
|
763
|
+
}
|
764
|
+
ops_static = {
|
765
|
+
ast.In: lambda l, r: l in r,
|
766
|
+
ast.NotIn: lambda l, r: l not in r,
|
767
|
+
}
|
768
|
+
if ctx.is_in_static_scope():
|
769
|
+
ops = {**ops, **ops_static}
|
770
|
+
operands = [node.left.ptr] + [comparator.ptr for comparator in node.comparators]
|
771
|
+
val = True
|
772
|
+
for i, node_op in enumerate(node.ops):
|
773
|
+
if isinstance(node_op, (ast.Is, ast.IsNot)):
|
774
|
+
name = "is" if isinstance(node_op, ast.Is) else "is not"
|
775
|
+
raise GsTaichiSyntaxError(f'Operator "{name}" in GsTaichi scope is not supported.')
|
776
|
+
l = operands[i]
|
777
|
+
r = operands[i + 1]
|
778
|
+
op = ops.get(type(node_op))
|
779
|
+
|
780
|
+
if op is None:
|
781
|
+
if type(node_op) in ops_static:
|
782
|
+
raise GsTaichiSyntaxError(f'"{type(node_op).__name__}" is only supported inside `ti.static`.')
|
783
|
+
else:
|
784
|
+
raise GsTaichiSyntaxError(f'"{type(node_op).__name__}" is not supported in GsTaichi kernels.')
|
785
|
+
val = ti_ops.logical_and(val, op(l, r))
|
786
|
+
if not isinstance(val, (bool, np.bool_)):
|
787
|
+
val = ti_ops.cast(val, primitive_types.u1)
|
788
|
+
node.ptr = val
|
789
|
+
return node.ptr
|
790
|
+
|
791
|
+
@staticmethod
|
792
|
+
def get_for_loop_targets(node: ast.Name | ast.Tuple | Any) -> list:
|
793
|
+
"""
|
794
|
+
Returns the list of indices of the for loop |node|.
|
795
|
+
See also: https://docs.python.org/3/library/ast.html#ast.For
|
796
|
+
"""
|
797
|
+
if isinstance(node.target, ast.Name):
|
798
|
+
return [node.target.id]
|
799
|
+
assert isinstance(node.target, ast.Tuple)
|
800
|
+
return [name.id for name in node.target.elts]
|
801
|
+
|
802
|
+
@staticmethod
|
803
|
+
def build_static_for(ctx: ASTTransformerContext, node: ast.For, is_grouped: bool) -> None:
|
804
|
+
ti_unroll_limit = impl.get_runtime().unrolling_limit
|
805
|
+
if is_grouped:
|
806
|
+
assert len(node.iter.args[0].args) == 1
|
807
|
+
ndrange_arg = build_stmt(ctx, node.iter.args[0].args[0])
|
808
|
+
if not isinstance(ndrange_arg, _Ndrange):
|
809
|
+
raise GsTaichiSyntaxError("Only 'ti.ndrange' is allowed in 'ti.static(ti.grouped(...))'.")
|
810
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
811
|
+
if len(targets) != 1:
|
812
|
+
raise GsTaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
|
813
|
+
target = targets[0]
|
814
|
+
iter_time = 0
|
815
|
+
alert_already = False
|
816
|
+
|
817
|
+
for value in impl.grouped(ndrange_arg):
|
818
|
+
iter_time += 1
|
819
|
+
if not alert_already and ti_unroll_limit and iter_time > ti_unroll_limit:
|
820
|
+
alert_already = True
|
821
|
+
warnings.warn_explicit(
|
822
|
+
f"""You are unrolling more than
|
823
|
+
{ti_unroll_limit} iterations, so the compile time may be extremely long.
|
824
|
+
You can use a non-static for loop if you want to decrease the compile time.
|
825
|
+
You can disable this warning by setting ti.init(unrolling_limit=0).""",
|
826
|
+
SyntaxWarning,
|
827
|
+
ctx.file,
|
828
|
+
node.lineno + ctx.lineno_offset,
|
829
|
+
module="gstaichi",
|
830
|
+
)
|
831
|
+
|
832
|
+
with ctx.variable_scope_guard():
|
833
|
+
ctx.create_variable(target, value)
|
834
|
+
build_stmts(ctx, node.body)
|
835
|
+
status = ctx.loop_status()
|
836
|
+
if status == LoopStatus.Break:
|
837
|
+
break
|
838
|
+
elif status == LoopStatus.Continue:
|
839
|
+
ctx.set_loop_status(LoopStatus.Normal)
|
840
|
+
else:
|
841
|
+
build_stmt(ctx, node.iter)
|
842
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
843
|
+
|
844
|
+
iter_time = 0
|
845
|
+
alert_already = False
|
846
|
+
for target_values in node.iter.ptr:
|
847
|
+
if not isinstance(target_values, collections.abc.Sequence) or len(targets) == 1:
|
848
|
+
target_values = [target_values]
|
849
|
+
|
850
|
+
iter_time += 1
|
851
|
+
if not alert_already and ti_unroll_limit and iter_time > ti_unroll_limit:
|
852
|
+
alert_already = True
|
853
|
+
warnings.warn_explicit(
|
854
|
+
f"""You are unrolling more than
|
855
|
+
{ti_unroll_limit} iterations, so the compile time may be extremely long.
|
856
|
+
You can use a non-static for loop if you want to decrease the compile time.
|
857
|
+
You can disable this warning by setting ti.init(unrolling_limit=0).""",
|
858
|
+
SyntaxWarning,
|
859
|
+
ctx.file,
|
860
|
+
node.lineno + ctx.lineno_offset,
|
861
|
+
module="gstaichi",
|
862
|
+
)
|
863
|
+
|
864
|
+
with ctx.variable_scope_guard():
|
865
|
+
for target, target_value in zip(targets, target_values):
|
866
|
+
ctx.create_variable(target, target_value)
|
867
|
+
build_stmts(ctx, node.body)
|
868
|
+
status = ctx.loop_status()
|
869
|
+
if status == LoopStatus.Break:
|
870
|
+
break
|
871
|
+
elif status == LoopStatus.Continue:
|
872
|
+
ctx.set_loop_status(LoopStatus.Normal)
|
873
|
+
return None
|
874
|
+
|
875
|
+
@staticmethod
|
876
|
+
def build_range_for(ctx: ASTTransformerContext, node: ast.For) -> None:
|
877
|
+
with ctx.variable_scope_guard():
|
878
|
+
loop_name = node.target.id
|
879
|
+
ctx.check_loop_var(loop_name)
|
880
|
+
loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
881
|
+
ctx.create_variable(loop_name, loop_var)
|
882
|
+
if len(node.iter.args) not in [1, 2]:
|
883
|
+
raise GsTaichiSyntaxError(f"Range should have 1 or 2 arguments, found {len(node.iter.args)}")
|
884
|
+
if len(node.iter.args) == 2:
|
885
|
+
begin_expr = expr.Expr(build_stmt(ctx, node.iter.args[0]))
|
886
|
+
end_expr = expr.Expr(build_stmt(ctx, node.iter.args[1]))
|
887
|
+
|
888
|
+
# Warning for implicit dtype conversion
|
889
|
+
boundary_type_cast_warning(begin_expr)
|
890
|
+
boundary_type_cast_warning(end_expr)
|
891
|
+
|
892
|
+
begin = ti_ops.cast(begin_expr, primitive_types.i32)
|
893
|
+
end = ti_ops.cast(end_expr, primitive_types.i32)
|
894
|
+
|
895
|
+
else:
|
896
|
+
end_expr = expr.Expr(build_stmt(ctx, node.iter.args[0]))
|
897
|
+
|
898
|
+
# Warning for implicit dtype conversion
|
899
|
+
boundary_type_cast_warning(end_expr)
|
900
|
+
|
901
|
+
begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
|
902
|
+
end = ti_ops.cast(end_expr, primitive_types.i32)
|
903
|
+
|
904
|
+
for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
905
|
+
ctx.ast_builder.begin_frontend_range_for(loop_var.ptr, begin.ptr, end.ptr, for_di)
|
906
|
+
build_stmts(ctx, node.body)
|
907
|
+
ctx.ast_builder.end_frontend_range_for()
|
908
|
+
return None
|
909
|
+
|
910
|
+
@staticmethod
|
911
|
+
def build_ndrange_for(ctx: ASTTransformerContext, node: ast.For) -> None:
|
912
|
+
with ctx.variable_scope_guard():
|
913
|
+
ndrange_var = impl.expr_init(build_stmt(ctx, node.iter))
|
914
|
+
ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
|
915
|
+
ndrange_end = ti_ops.cast(
|
916
|
+
expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)),
|
917
|
+
primitive_types.i32,
|
918
|
+
)
|
919
|
+
ndrange_loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
920
|
+
for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
921
|
+
ctx.ast_builder.begin_frontend_range_for(ndrange_loop_var.ptr, ndrange_begin.ptr, ndrange_end.ptr, for_di)
|
922
|
+
I = impl.expr_init(ndrange_loop_var)
|
923
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
924
|
+
if len(targets) != len(ndrange_var.dimensions):
|
925
|
+
raise GsTaichiSyntaxError(
|
926
|
+
"Ndrange for loop with number of the loop variables not equal to "
|
927
|
+
"the dimension of the ndrange is not supported. "
|
928
|
+
"Please check if the number of arguments of ti.ndrange() is equal to "
|
929
|
+
"the number of the loop variables."
|
930
|
+
)
|
931
|
+
for i, target in enumerate(targets):
|
932
|
+
if i + 1 < len(targets):
|
933
|
+
target_tmp = impl.expr_init(I // ndrange_var.acc_dimensions[i + 1])
|
934
|
+
else:
|
935
|
+
target_tmp = impl.expr_init(I)
|
936
|
+
ctx.create_variable(
|
937
|
+
target,
|
938
|
+
impl.expr_init(
|
939
|
+
target_tmp
|
940
|
+
+ impl.subscript(
|
941
|
+
ctx.ast_builder,
|
942
|
+
impl.subscript(ctx.ast_builder, ndrange_var.bounds, i),
|
943
|
+
0,
|
944
|
+
)
|
945
|
+
),
|
946
|
+
)
|
947
|
+
if i + 1 < len(targets):
|
948
|
+
I._assign(I - target_tmp * ndrange_var.acc_dimensions[i + 1])
|
949
|
+
build_stmts(ctx, node.body)
|
950
|
+
ctx.ast_builder.end_frontend_range_for()
|
951
|
+
return None
|
952
|
+
|
953
|
+
@staticmethod
|
954
|
+
def build_grouped_ndrange_for(ctx: ASTTransformerContext, node: ast.For) -> None:
|
955
|
+
with ctx.variable_scope_guard():
|
956
|
+
ndrange_var = impl.expr_init(build_stmt(ctx, node.iter.args[0]))
|
957
|
+
ndrange_begin = ti_ops.cast(expr.Expr(0), primitive_types.i32)
|
958
|
+
ndrange_end = ti_ops.cast(
|
959
|
+
expr.Expr(impl.subscript(ctx.ast_builder, ndrange_var.acc_dimensions, 0)),
|
960
|
+
primitive_types.i32,
|
961
|
+
)
|
962
|
+
ndrange_loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
963
|
+
for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
964
|
+
ctx.ast_builder.begin_frontend_range_for(ndrange_loop_var.ptr, ndrange_begin.ptr, ndrange_end.ptr, for_di)
|
965
|
+
|
966
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
967
|
+
if len(targets) != 1:
|
968
|
+
raise GsTaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
|
969
|
+
target = targets[0]
|
970
|
+
mat = matrix.make_matrix([0] * len(ndrange_var.dimensions), dt=primitive_types.i32)
|
971
|
+
target_var = impl.expr_init(mat)
|
972
|
+
|
973
|
+
ctx.create_variable(target, target_var)
|
974
|
+
I = impl.expr_init(ndrange_loop_var)
|
975
|
+
for i in range(len(ndrange_var.dimensions)):
|
976
|
+
if i + 1 < len(ndrange_var.dimensions):
|
977
|
+
target_tmp = I // ndrange_var.acc_dimensions[i + 1]
|
978
|
+
else:
|
979
|
+
target_tmp = I
|
980
|
+
impl.subscript(ctx.ast_builder, target_var, i)._assign(target_tmp + ndrange_var.bounds[i][0])
|
981
|
+
if i + 1 < len(ndrange_var.dimensions):
|
982
|
+
I._assign(I - target_tmp * ndrange_var.acc_dimensions[i + 1])
|
983
|
+
build_stmts(ctx, node.body)
|
984
|
+
ctx.ast_builder.end_frontend_range_for()
|
985
|
+
return None
|
986
|
+
|
987
|
+
@staticmethod
|
988
|
+
def build_struct_for(ctx: ASTTransformerContext, node: ast.For, is_grouped: bool) -> None:
|
989
|
+
# for i, j in x
|
990
|
+
# for I in ti.grouped(x)
|
991
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
992
|
+
|
993
|
+
for target in targets:
|
994
|
+
ctx.check_loop_var(target)
|
995
|
+
|
996
|
+
with ctx.variable_scope_guard():
|
997
|
+
if is_grouped:
|
998
|
+
if len(targets) != 1:
|
999
|
+
raise GsTaichiSyntaxError(f"Group for should have 1 loop target, found {len(targets)}")
|
1000
|
+
target = targets[0]
|
1001
|
+
loop_var = build_stmt(ctx, node.iter)
|
1002
|
+
loop_indices = expr.make_var_list(size=len(loop_var.shape), ast_builder=ctx.ast_builder)
|
1003
|
+
expr_group = expr.make_expr_group(loop_indices)
|
1004
|
+
impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, loop_var)
|
1005
|
+
ctx.create_variable(target, matrix.make_matrix(loop_indices, dt=primitive_types.i32))
|
1006
|
+
build_stmts(ctx, node.body)
|
1007
|
+
ctx.ast_builder.end_frontend_struct_for()
|
1008
|
+
else:
|
1009
|
+
_vars = []
|
1010
|
+
for name in targets:
|
1011
|
+
var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
1012
|
+
_vars.append(var)
|
1013
|
+
ctx.create_variable(name, var)
|
1014
|
+
loop_var = node.iter.ptr
|
1015
|
+
expr_group = expr.make_expr_group(*_vars)
|
1016
|
+
impl.begin_frontend_struct_for(ctx.ast_builder, expr_group, loop_var)
|
1017
|
+
build_stmts(ctx, node.body)
|
1018
|
+
ctx.ast_builder.end_frontend_struct_for()
|
1019
|
+
return None
|
1020
|
+
|
1021
|
+
@staticmethod
|
1022
|
+
def build_mesh_for(ctx: ASTTransformerContext, node: ast.For) -> None:
|
1023
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
1024
|
+
if len(targets) != 1:
|
1025
|
+
raise GsTaichiSyntaxError("Mesh for should have 1 loop target, found {len(targets)}")
|
1026
|
+
target = targets[0]
|
1027
|
+
|
1028
|
+
with ctx.variable_scope_guard():
|
1029
|
+
var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
1030
|
+
ctx.mesh = node.iter.ptr.mesh
|
1031
|
+
assert isinstance(ctx.mesh, impl.MeshInstance)
|
1032
|
+
mesh_idx = mesh.MeshElementFieldProxy(ctx.mesh, node.iter.ptr._type, var.ptr)
|
1033
|
+
ctx.create_variable(target, mesh_idx)
|
1034
|
+
ctx.ast_builder.begin_frontend_mesh_for(
|
1035
|
+
mesh_idx.ptr,
|
1036
|
+
ctx.mesh.mesh_ptr,
|
1037
|
+
node.iter.ptr._type,
|
1038
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
1039
|
+
)
|
1040
|
+
build_stmts(ctx, node.body)
|
1041
|
+
ctx.mesh = None
|
1042
|
+
ctx.ast_builder.end_frontend_mesh_for()
|
1043
|
+
return None
|
1044
|
+
|
1045
|
+
@staticmethod
|
1046
|
+
def build_nested_mesh_for(ctx: ASTTransformerContext, node: ast.For) -> None:
|
1047
|
+
targets = ASTTransformer.get_for_loop_targets(node)
|
1048
|
+
if len(targets) != 1:
|
1049
|
+
raise GsTaichiSyntaxError("Nested-mesh for should have 1 loop target, found {len(targets)}")
|
1050
|
+
target = targets[0]
|
1051
|
+
|
1052
|
+
with ctx.variable_scope_guard():
|
1053
|
+
ctx.mesh = node.iter.ptr.mesh
|
1054
|
+
assert isinstance(ctx.mesh, impl.MeshInstance)
|
1055
|
+
loop_name = node.target.id + "_index__"
|
1056
|
+
loop_var = expr.Expr(ctx.ast_builder.make_id_expr(""))
|
1057
|
+
ctx.create_variable(loop_name, loop_var)
|
1058
|
+
begin = expr.Expr(0)
|
1059
|
+
end = ti_ops.cast(node.iter.ptr.size, primitive_types.i32)
|
1060
|
+
for_di = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
1061
|
+
ctx.ast_builder.begin_frontend_range_for(loop_var.ptr, begin.ptr, end.ptr, for_di)
|
1062
|
+
entry_expr = _ti_core.get_relation_access(
|
1063
|
+
ctx.mesh.mesh_ptr,
|
1064
|
+
node.iter.ptr.from_index.ptr,
|
1065
|
+
node.iter.ptr.to_element_type,
|
1066
|
+
loop_var.ptr,
|
1067
|
+
)
|
1068
|
+
entry_expr.type_check(impl.get_runtime().prog.config())
|
1069
|
+
mesh_idx = mesh.MeshElementFieldProxy(ctx.mesh, node.iter.ptr.to_element_type, entry_expr)
|
1070
|
+
ctx.create_variable(target, mesh_idx)
|
1071
|
+
build_stmts(ctx, node.body)
|
1072
|
+
ctx.ast_builder.end_frontend_range_for()
|
1073
|
+
|
1074
|
+
return None
|
1075
|
+
|
1076
|
+
@staticmethod
|
1077
|
+
def build_For(ctx: ASTTransformerContext, node: ast.For) -> None:
|
1078
|
+
if node.orelse:
|
1079
|
+
raise GsTaichiSyntaxError("'else' clause for 'for' not supported in GsTaichi kernels")
|
1080
|
+
decorator = get_decorator(ctx, node.iter)
|
1081
|
+
double_decorator = ""
|
1082
|
+
if decorator != "" and len(node.iter.args) == 1:
|
1083
|
+
double_decorator = get_decorator(ctx, node.iter.args[0])
|
1084
|
+
|
1085
|
+
if decorator == "static":
|
1086
|
+
if double_decorator == "static":
|
1087
|
+
raise GsTaichiSyntaxError("'ti.static' cannot be nested")
|
1088
|
+
with ctx.loop_scope_guard(is_static=True):
|
1089
|
+
return ASTTransformer.build_static_for(ctx, node, double_decorator == "grouped")
|
1090
|
+
with ctx.loop_scope_guard():
|
1091
|
+
if decorator == "ndrange":
|
1092
|
+
if double_decorator != "":
|
1093
|
+
raise GsTaichiSyntaxError("No decorator is allowed inside 'ti.ndrange")
|
1094
|
+
return ASTTransformer.build_ndrange_for(ctx, node)
|
1095
|
+
if decorator == "grouped":
|
1096
|
+
if double_decorator == "static":
|
1097
|
+
raise GsTaichiSyntaxError("'ti.static' is not allowed inside 'ti.grouped'")
|
1098
|
+
elif double_decorator == "ndrange":
|
1099
|
+
return ASTTransformer.build_grouped_ndrange_for(ctx, node)
|
1100
|
+
elif double_decorator == "grouped":
|
1101
|
+
raise GsTaichiSyntaxError("'ti.grouped' cannot be nested")
|
1102
|
+
else:
|
1103
|
+
return ASTTransformer.build_struct_for(ctx, node, is_grouped=True)
|
1104
|
+
elif (
|
1105
|
+
isinstance(node.iter, ast.Call)
|
1106
|
+
and isinstance(node.iter.func, ast.Name)
|
1107
|
+
and node.iter.func.id == "range"
|
1108
|
+
):
|
1109
|
+
return ASTTransformer.build_range_for(ctx, node)
|
1110
|
+
elif isinstance(node.iter, ast.IfExp):
|
1111
|
+
# Handle inline if expression as the top level iterator expression, e.g.:
|
1112
|
+
#
|
1113
|
+
# for i in range(foo) if ti.static(some_flag) else ti.static(range(bar))
|
1114
|
+
#
|
1115
|
+
# Empirically, this appears to generalize to:
|
1116
|
+
# - being an inner loop
|
1117
|
+
# - either side can be static or not, as long as the if expression itself is static
|
1118
|
+
_iter = node.iter
|
1119
|
+
is_static_if = get_decorator(ctx, node.iter.test) == "static"
|
1120
|
+
if not is_static_if:
|
1121
|
+
raise GsTaichiSyntaxError(
|
1122
|
+
"Using non static inlined if statement as for-loop iterable is not currently supported."
|
1123
|
+
)
|
1124
|
+
build_stmt(ctx, _iter.test)
|
1125
|
+
next_iter = _iter.body if _iter.test.ptr else _iter.orelse
|
1126
|
+
new_for = ast.For(
|
1127
|
+
target=node.target,
|
1128
|
+
iter=next_iter,
|
1129
|
+
body=node.body,
|
1130
|
+
orelse=None,
|
1131
|
+
type_comment=getattr(node, "type_comment", None),
|
1132
|
+
lineno=node.lineno,
|
1133
|
+
end_lineno=node.end_lineno,
|
1134
|
+
col_offset=node.col_offset,
|
1135
|
+
end_col_offset=node.end_col_offset,
|
1136
|
+
)
|
1137
|
+
return ASTTransformer.build_For(ctx, new_for)
|
1138
|
+
else:
|
1139
|
+
build_stmt(ctx, node.iter)
|
1140
|
+
if isinstance(node.iter.ptr, mesh.MeshElementField):
|
1141
|
+
if not _ti_core.is_extension_supported(impl.default_cfg().arch, _ti_core.Extension.mesh):
|
1142
|
+
raise Exception(
|
1143
|
+
"Backend " + str(impl.default_cfg().arch) + " doesn't support MeshGsTaichi extension"
|
1144
|
+
)
|
1145
|
+
return ASTTransformer.build_mesh_for(ctx, node)
|
1146
|
+
if isinstance(node.iter.ptr, mesh.MeshRelationAccessProxy):
|
1147
|
+
return ASTTransformer.build_nested_mesh_for(ctx, node)
|
1148
|
+
# Struct for
|
1149
|
+
return ASTTransformer.build_struct_for(ctx, node, is_grouped=False)
|
1150
|
+
|
1151
|
+
@staticmethod
|
1152
|
+
def build_While(ctx: ASTTransformerContext, node: ast.While) -> None:
|
1153
|
+
if node.orelse:
|
1154
|
+
raise GsTaichiSyntaxError("'else' clause for 'while' not supported in GsTaichi kernels")
|
1155
|
+
|
1156
|
+
with ctx.loop_scope_guard():
|
1157
|
+
stmt_dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
1158
|
+
ctx.ast_builder.begin_frontend_while(expr.Expr(1, dtype=primitive_types.i32).ptr, stmt_dbg_info)
|
1159
|
+
while_cond = build_stmt(ctx, node.test)
|
1160
|
+
impl.begin_frontend_if(ctx.ast_builder, while_cond, stmt_dbg_info)
|
1161
|
+
ctx.ast_builder.begin_frontend_if_true()
|
1162
|
+
ctx.ast_builder.pop_scope()
|
1163
|
+
ctx.ast_builder.begin_frontend_if_false()
|
1164
|
+
ctx.ast_builder.insert_break_stmt(stmt_dbg_info)
|
1165
|
+
ctx.ast_builder.pop_scope()
|
1166
|
+
build_stmts(ctx, node.body)
|
1167
|
+
ctx.ast_builder.pop_scope()
|
1168
|
+
return None
|
1169
|
+
|
1170
|
+
@staticmethod
|
1171
|
+
def build_If(ctx: ASTTransformerContext, node: ast.If) -> ast.If | None:
|
1172
|
+
build_stmt(ctx, node.test)
|
1173
|
+
is_static_if = get_decorator(ctx, node.test) == "static"
|
1174
|
+
|
1175
|
+
if is_static_if:
|
1176
|
+
if node.test.ptr:
|
1177
|
+
build_stmts(ctx, node.body)
|
1178
|
+
else:
|
1179
|
+
build_stmts(ctx, node.orelse)
|
1180
|
+
return node
|
1181
|
+
|
1182
|
+
with ctx.non_static_if_guard(node):
|
1183
|
+
stmt_dbg_info = _ti_core.DebugInfo(ctx.get_pos_info(node))
|
1184
|
+
impl.begin_frontend_if(ctx.ast_builder, node.test.ptr, stmt_dbg_info)
|
1185
|
+
ctx.ast_builder.begin_frontend_if_true()
|
1186
|
+
build_stmts(ctx, node.body)
|
1187
|
+
ctx.ast_builder.pop_scope()
|
1188
|
+
ctx.ast_builder.begin_frontend_if_false()
|
1189
|
+
build_stmts(ctx, node.orelse)
|
1190
|
+
ctx.ast_builder.pop_scope()
|
1191
|
+
return None
|
1192
|
+
|
1193
|
+
@staticmethod
|
1194
|
+
def build_Expr(ctx: ASTTransformerContext, node: ast.Expr) -> None:
|
1195
|
+
build_stmt(ctx, node.value)
|
1196
|
+
return None
|
1197
|
+
|
1198
|
+
@staticmethod
|
1199
|
+
def build_IfExp(ctx: ASTTransformerContext, node: ast.IfExp):
|
1200
|
+
build_stmt(ctx, node.test)
|
1201
|
+
build_stmt(ctx, node.body)
|
1202
|
+
build_stmt(ctx, node.orelse)
|
1203
|
+
|
1204
|
+
has_tensor_type = False
|
1205
|
+
if isinstance(node.test.ptr, expr.Expr) and node.test.ptr.is_tensor():
|
1206
|
+
has_tensor_type = True
|
1207
|
+
if isinstance(node.body.ptr, expr.Expr) and node.body.ptr.is_tensor():
|
1208
|
+
has_tensor_type = True
|
1209
|
+
if isinstance(node.orelse.ptr, expr.Expr) and node.orelse.ptr.is_tensor():
|
1210
|
+
has_tensor_type = True
|
1211
|
+
|
1212
|
+
if has_tensor_type:
|
1213
|
+
if isinstance(node.test.ptr, expr.Expr) and node.test.ptr.is_tensor():
|
1214
|
+
raise GsTaichiSyntaxError(
|
1215
|
+
"Using conditional expression for element-wise select operation on "
|
1216
|
+
"GsTaichi vectors/matrices is deprecated and removed starting from GsTaichi v1.5.0 "
|
1217
|
+
'Please use "ti.select" instead.'
|
1218
|
+
)
|
1219
|
+
node.ptr = ti_ops.select(node.test.ptr, node.body.ptr, node.orelse.ptr)
|
1220
|
+
return node.ptr
|
1221
|
+
|
1222
|
+
is_static_if = get_decorator(ctx, node.test) == "static"
|
1223
|
+
|
1224
|
+
if is_static_if:
|
1225
|
+
if node.test.ptr:
|
1226
|
+
node.ptr = build_stmt(ctx, node.body)
|
1227
|
+
else:
|
1228
|
+
node.ptr = build_stmt(ctx, node.orelse)
|
1229
|
+
return node.ptr
|
1230
|
+
|
1231
|
+
node.ptr = ti_ops.ifte(node.test.ptr, node.body.ptr, node.orelse.ptr)
|
1232
|
+
return node.ptr
|
1233
|
+
|
1234
|
+
@staticmethod
|
1235
|
+
def _is_string_mod_args(msg) -> bool:
|
1236
|
+
# 1. str % (a, b, c, ...)
|
1237
|
+
# 2. str % single_item
|
1238
|
+
# Note that |msg.right| may not be a tuple.
|
1239
|
+
if not isinstance(msg, ast.BinOp):
|
1240
|
+
return False
|
1241
|
+
if not isinstance(msg.op, ast.Mod):
|
1242
|
+
return False
|
1243
|
+
if isinstance(msg.left, ast.Str):
|
1244
|
+
return True
|
1245
|
+
if isinstance(msg.left, ast.Constant) and isinstance(msg.left.value, str):
|
1246
|
+
return True
|
1247
|
+
return False
|
1248
|
+
|
1249
|
+
@staticmethod
|
1250
|
+
def _handle_string_mod_args(ctx: ASTTransformerContext, node):
|
1251
|
+
msg = build_stmt(ctx, node.left)
|
1252
|
+
args = build_stmt(ctx, node.right)
|
1253
|
+
if not isinstance(args, collections.abc.Sequence):
|
1254
|
+
args = (args,)
|
1255
|
+
args = [expr.Expr(x).ptr for x in args]
|
1256
|
+
return msg, args
|
1257
|
+
|
1258
|
+
@staticmethod
|
1259
|
+
def ti_format_list_to_assert_msg(raw) -> tuple[str, list]:
|
1260
|
+
# TODO: ignore formats here for now
|
1261
|
+
entries, _ = impl.ti_format_list_to_content_entries([raw])
|
1262
|
+
msg = ""
|
1263
|
+
args = []
|
1264
|
+
for entry in entries:
|
1265
|
+
if isinstance(entry, str):
|
1266
|
+
msg += entry
|
1267
|
+
elif isinstance(entry, _ti_core.ExprCxx):
|
1268
|
+
ty = entry.get_rvalue_type()
|
1269
|
+
if ty in primitive_types.real_types:
|
1270
|
+
msg += "%f"
|
1271
|
+
elif ty in primitive_types.integer_types:
|
1272
|
+
msg += "%d"
|
1273
|
+
else:
|
1274
|
+
raise GsTaichiSyntaxError(f"Unsupported data type: {type(ty)}")
|
1275
|
+
args.append(entry)
|
1276
|
+
else:
|
1277
|
+
raise GsTaichiSyntaxError(f"Unsupported type: {type(entry)}")
|
1278
|
+
return msg, args
|
1279
|
+
|
1280
|
+
@staticmethod
|
1281
|
+
def build_Assert(ctx: ASTTransformerContext, node: ast.Assert) -> None:
|
1282
|
+
extra_args = []
|
1283
|
+
if node.msg is not None:
|
1284
|
+
if ASTTransformer._is_string_mod_args(node.msg):
|
1285
|
+
msg, extra_args = ASTTransformer._handle_string_mod_args(ctx, node.msg)
|
1286
|
+
else:
|
1287
|
+
msg = build_stmt(ctx, node.msg)
|
1288
|
+
if isinstance(node.msg, ast.Constant):
|
1289
|
+
msg = str(msg)
|
1290
|
+
elif isinstance(node.msg, ast.Str):
|
1291
|
+
pass
|
1292
|
+
elif isinstance(msg, collections.abc.Sequence) and len(msg) > 0 and msg[0] == "__ti_format__":
|
1293
|
+
msg, extra_args = ASTTransformer.ti_format_list_to_assert_msg(msg)
|
1294
|
+
else:
|
1295
|
+
raise GsTaichiSyntaxError(f"assert info must be constant or formatted string, not {type(msg)}")
|
1296
|
+
else:
|
1297
|
+
msg = unparse(node.test)
|
1298
|
+
test = build_stmt(ctx, node.test)
|
1299
|
+
impl.ti_assert(test, msg.strip(), extra_args, _ti_core.DebugInfo(ctx.get_pos_info(node)))
|
1300
|
+
return None
|
1301
|
+
|
1302
|
+
@staticmethod
|
1303
|
+
def build_Break(ctx: ASTTransformerContext, node: ast.Break) -> None:
|
1304
|
+
if ctx.is_in_static_for():
|
1305
|
+
nearest_non_static_if = ctx.current_loop_scope().nearest_non_static_if
|
1306
|
+
if nearest_non_static_if:
|
1307
|
+
msg = ctx.get_pos_info(nearest_non_static_if.test)
|
1308
|
+
msg += (
|
1309
|
+
"You are trying to `break` a static `for` loop, "
|
1310
|
+
"but the `break` statement is inside a non-static `if`. "
|
1311
|
+
)
|
1312
|
+
raise GsTaichiSyntaxError(msg)
|
1313
|
+
ctx.set_loop_status(LoopStatus.Break)
|
1314
|
+
else:
|
1315
|
+
ctx.ast_builder.insert_break_stmt(_ti_core.DebugInfo(ctx.get_pos_info(node)))
|
1316
|
+
return None
|
1317
|
+
|
1318
|
+
@staticmethod
|
1319
|
+
def build_Continue(ctx: ASTTransformerContext, node: ast.Continue) -> None:
|
1320
|
+
if ctx.is_in_static_for():
|
1321
|
+
nearest_non_static_if = ctx.current_loop_scope().nearest_non_static_if
|
1322
|
+
if nearest_non_static_if:
|
1323
|
+
msg = ctx.get_pos_info(nearest_non_static_if.test)
|
1324
|
+
msg += (
|
1325
|
+
"You are trying to `continue` a static `for` loop, "
|
1326
|
+
"but the `continue` statement is inside a non-static `if`. "
|
1327
|
+
)
|
1328
|
+
raise GsTaichiSyntaxError(msg)
|
1329
|
+
ctx.set_loop_status(LoopStatus.Continue)
|
1330
|
+
else:
|
1331
|
+
ctx.ast_builder.insert_continue_stmt(_ti_core.DebugInfo(ctx.get_pos_info(node)))
|
1332
|
+
return None
|
1333
|
+
|
1334
|
+
@staticmethod
|
1335
|
+
def build_Pass(ctx: ASTTransformerContext, node: ast.Pass) -> None:
|
1336
|
+
return None
|
1337
|
+
|
1338
|
+
|
1339
|
+
build_stmt = ASTTransformer()
|
1340
|
+
|
1341
|
+
|
1342
|
+
def build_stmts(ctx: ASTTransformerContext, stmts: list[ast.stmt]):
|
1343
|
+
# TODO: Should we just make this part of ASTTransformer? Then, easier to pass around (just
|
1344
|
+
# pass the ASTTransformer object around)
|
1345
|
+
with ctx.variable_scope_guard():
|
1346
|
+
for stmt in stmts:
|
1347
|
+
if ctx.returned != ReturnStatus.NoReturn or ctx.loop_status() != LoopStatus.Normal:
|
1348
|
+
break
|
1349
|
+
else:
|
1350
|
+
build_stmt(ctx, stmt)
|
1351
|
+
return stmts
|